diff --git a/docs/transformers/src/transformers/models/myt5/tokenization_myt5.py b/docs/transformers/src/transformers/models/myt5/tokenization_myt5.py new file mode 100644 index 0000000000000000000000000000000000000000..f86e5e44c2a0e7381318acec2418c81625e7690c --- /dev/null +++ b/docs/transformers/src/transformers/models/myt5/tokenization_myt5.py @@ -0,0 +1,380 @@ +# coding=utf-8 +# Copyright 2024 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization class for model MyT5.""" + +import json +import os +import warnings +from collections import defaultdict +from typing import Dict, List, Optional, Tuple, Union + +from ...tokenization_utils import AddedToken, PreTrainedTokenizer +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +VOCAB_FILES_NAMES = {"vocab_file": "byte_maps.json"} + + +class ByteRewriter: + """ + Byte rewriter class for MyT5 tokenizer. + This class is used to rewrite bytes using a hash tree. The hash tree is constructed from a set of rewriting rules. + + Args: + rewriting_rules (`str` or `Dict[str, str]`): + A path to a json file containing the rewriting rules or a dictionary containing the rewriting rules. + + """ + + LEAF = "[LEAF]" + + def __init__(self, rewriting_rules: Union[str, Dict[str, str]]): + if isinstance(rewriting_rules, str): + with open(rewriting_rules, "r") as f: + rewriting_rules = json.load(f) + elif not isinstance(rewriting_rules, dict): + raise ValueError( + f"rewriting_rules should be either a path to json file or a dict, got {type(rewriting_rules)}" + ) + + self.hash_tree = self.construct_hash_tree(rewriting_rules) + reverse_rewriting_rules = {v: k for k, v in rewriting_rules.items()} + self.reverse_hash_tree = self.construct_hash_tree(reverse_rewriting_rules) + + def add_leaf(self, hash_tree: Dict[str, Union[dict, List[str]]], byte_in_sequence: str, byte_out_sequence: str): + """ + Add a leaf with the output byte sequence to the hash tree. + """ + byte_in_list = byte_in_sequence.split(" ") + byte_out_list = byte_out_sequence.split(" ") + + tree_pointer = hash_tree + for b in byte_in_list: + if b not in tree_pointer: + tree_pointer[b] = {} + tree_pointer = tree_pointer[b] + + tree_pointer[self.LEAF] = byte_out_list + + def construct_hash_tree(self, rewriting_rules: Dict[str, str]) -> Dict[str, Union[dict, List[str]]]: + """ + Construct a hash tree for rewritten byte sequences. + """ + hash_tree = defaultdict(dict) + for b in (f"{x:02x}" for x in range(256)): + hash_tree[b][self.LEAF] = [b] + + for in_sequence, out_sequence in rewriting_rules.items(): + self.add_leaf(hash_tree, in_sequence, out_sequence) + + return hash_tree + + def search_hash_tree(self, byte_sequence: List[str]) -> Union[None, List[str]]: + """ + Search the hash tree and return the rewritten byte sequence if found. + """ + tree_pointer = self.hash_tree + for b in byte_sequence: + if b in tree_pointer: + tree_pointer = tree_pointer[b] + else: + return None + + return tree_pointer[self.LEAF] + + def rewrite_bytes(self, in_bytes: List[str], reverse=False) -> List[str]: + """ + Rewrite a sequence of bytes using the hash tree. + + Args: + in_bytes (`List[str]`): A list of bytes to be rewritten. + reverse (`bool`): If True, decoding is performed with the reverse hash tree. + Returns: + `List[str]`: The rewritten byte sequence. + """ + out_bytes = [] + b_start = 0 + b_end = 0 + + while b_start < len(in_bytes): + tree_pointer = self.hash_tree if not reverse else self.reverse_hash_tree + for j in range(b_start, len(in_bytes)): + b = in_bytes[j] + if b in tree_pointer: + tree_pointer = tree_pointer[b] + elif j == b_start: + cur_leaf = [b] + b_end = j + break + else: + break + if self.LEAF in tree_pointer: + cur_leaf = tree_pointer[self.LEAF] + b_end = j + out_bytes.extend(cur_leaf) + b_start = b_end + 1 + + return out_bytes + + +class MyT5Tokenizer(PreTrainedTokenizer): + """ + Construct a MyT5 tokenizer. + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): The file containing the byte rewriting rules. + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + extra_ids (`int`, *optional*, defaults to 125): + Add a number of extra ids added to the end of the vocabulary for use as sentinels. These tokens are + accessible as "" where "{%d}" is a number between 0 and extra_ids-1. Extra tokens are + indexed from the end of the vocabulary up to beginning ("" is the last token in the vocabulary + like in ByT5 preprocessing see + [here](https://github.com/google-research/text-to-text-transfer-transformer/blob/9fd7b14a769417be33bc6c850f9598764913c833/t5/data/preprocessors.py#L2117)). + additional_special_tokens (`List[str]`, *optional*): + Additional special tokens used by the tokenizer. + """ + + model_input_names = ["input_ids", "attention_mask"] + vocab_files_names = VOCAB_FILES_NAMES + + def __init__( + self, + vocab_file, + eos_token="", + unk_token="", + pad_token="", + extra_ids=125, + additional_special_tokens=None, + **kwargs, + ) -> None: + # Add extra_ids to the special token list + if extra_ids > 0 and additional_special_tokens is None: + additional_special_tokens = [f"" for i in range(extra_ids)] + elif extra_ids > 0 and additional_special_tokens is not None and len(additional_special_tokens) > 0: + # Check that we have the right number of extra_id special tokens + extra_tokens = len(set(filter(lambda x: bool("extra_id" in str(x)), additional_special_tokens))) + if extra_tokens != extra_ids: + raise ValueError( + f"Both extra_ids ({extra_ids}) and additional_special_tokens ({additional_special_tokens}) are" + " provided to MyT5Tokenizer. In this case the additional_special_tokens must include the" + " extra_ids tokens" + ) + + pad_token = AddedToken(pad_token, lstrip=True, rstrip=True) if isinstance(pad_token, str) else pad_token + eos_token = AddedToken(eos_token, lstrip=True, rstrip=True) if isinstance(eos_token, str) else eos_token + unk_token = AddedToken(unk_token, lstrip=True, rstrip=True) if isinstance(unk_token, str) else unk_token + # unk token needs to be in the vocab with correct index + self._added_tokens_decoder = {0: pad_token, 1: eos_token, 2: unk_token} + self.offset = len(self._added_tokens_decoder) + self._utf_vocab_size = 2**8 # utf is 8 bits + + # Load byte maps + self.byte_maps = json.load(open(vocab_file, "r")) + + self.decompose_rewriter = ByteRewriter(self.byte_maps["decompose_map"]) + self.merge_rewriter = ByteRewriter(self.byte_maps["merge_map"]) + + super().__init__( + eos_token=eos_token, + unk_token=unk_token, + pad_token=pad_token, + extra_ids=0, + additional_special_tokens=additional_special_tokens, + **kwargs, + ) + + @property + def vocab_size(self): + return self._utf_vocab_size + + # Copied from transformers.models.byt5.tokenization_byt5.ByT5Tokenizer.get_vocab + def get_vocab(self): + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size + self.offset)} + vocab.update(self.added_tokens_encoder) + return vocab + + # Copied from transformers.models.byt5.tokenization_byt5.ByT5Tokenizer.get_special_tokens_mask + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + # normal case: some special tokens + if token_ids_1 is None: + return ([0] * len(token_ids_0)) + [1] + return ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] + + def _add_eos_if_not_present(self, token_ids: List[int]) -> List[int]: + """Do not add eos again if user already added it.""" + if len(token_ids) > 0 and token_ids[-1] == self.eos_token_id: + warnings.warn( + f"This sequence already has {self.eos_token}. In future versions this behavior may lead to duplicated" + " eos tokens being added." + ) + return token_ids + else: + return token_ids + [self.eos_token_id] + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. MyT5 does not + make use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of zeros. + """ + eos = [self.eos_token_id] + + if token_ids_1 is None: + return len(token_ids_0 + eos) * [0] + return len(token_ids_0 + eos + token_ids_1 + eos) * [0] + + # Copied from transformers.models.byt5.tokenization_byt5.ByT5Tokenizer.build_inputs_with_special_tokens + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A sequence has the following format: + + - single sequence: `X ` + - pair of sequences: `A B ` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + token_ids_0 = self._add_eos_if_not_present(token_ids_0) + if token_ids_1 is None: + return token_ids_0 + else: + token_ids_1 = self._add_eos_if_not_present(token_ids_1) + return token_ids_0 + token_ids_1 + + def _tokenize(self, text: str, **kwargs) -> List[str]: + """Take as input a string and return a list of strings (tokens) for words/sub-words. + Represents tokens in two character hex format""" + + tokens = [f"{i:02x}" for i in text.encode("utf-8")] + tokens = self.morphological_encode(tokens) + return tokens + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + + if len(token) != 2: + token_id = None + else: + token_id = int(token, 16) + self.offset + + return token_id + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + token = f"{index - self.offset:02x}" + return token + + def morphological_encode(self, indices: List[str]) -> List[str]: + # Decompose and merge morphological sequences + indices = self.decompose_rewriter.rewrite_bytes(indices, reverse=False) + indices = self.merge_rewriter.rewrite_bytes(indices, reverse=False) + return indices + + def morphological_decode(self, indices: List[str]) -> List[str]: + # Demerge and compose morphological sequences + indices = self.merge_rewriter.rewrite_bytes(indices, reverse=True) + indices = self.decompose_rewriter.rewrite_bytes(indices, reverse=True) + return indices + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + bstring = b"" + + out_tokens = [] + for token in tokens: + if token in self.added_tokens_decoder: + out_tokens.append(self.added_tokens_decoder[token]) + elif token in self.added_tokens_encoder: + out_tokens.append(token) + else: + out_tokens.append(token) + + out_tokens = self.morphological_decode(out_tokens) + _added_tokens = set(self.added_tokens_decoder.values()) | set(self.added_tokens_encoder) + for token in out_tokens: + if token in _added_tokens: + bstring += bytes(token, "utf-8") + else: + bstring += bytes.fromhex(token) + string = bstring.decode("utf-8", errors="ignore") + return string + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + if os.path.isdir(save_directory): + vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + else: + vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory + with open(vocab_file, "w", encoding="utf-8") as writer: + writer.write(json.dumps(self.byte_maps, indent=2, ensure_ascii=False)) + return (vocab_file,) + + +__all__ = ["MyT5Tokenizer"] diff --git a/docs/transformers/src/transformers/models/nemotron/configuration_nemotron.py b/docs/transformers/src/transformers/models/nemotron/configuration_nemotron.py new file mode 100644 index 0000000000000000000000000000000000000000..a6f1ad0c5d9c4c3ceeb87a4593c741e0e1d8cfbf --- /dev/null +++ b/docs/transformers/src/transformers/models/nemotron/configuration_nemotron.py @@ -0,0 +1,156 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. team. All rights reserved. +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Nemotron model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...modeling_rope_utils import rope_config_validation +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class NemotronConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`NemotronModel`]. It is used to instantiate an Nemotron + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the Nemotron-8B. + e.g. [nvidia/nemotron-3-8b-base-4k-hf](https://huggingface.co/nvidia/nemotron-3-8b-base-4k-hf). + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 256000): + Vocabulary size of the Nemotron model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`NemotronModel`] + hidden_size (`int`, *optional*, defaults to 6144): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 24576): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 48): + Number of attention heads for each attention layer in the Transformer decoder. + head_dim (`int`, *optional*): + Projection weights dimension in multi-head attention. Set to hidden_size // num_attention_heads if None + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"relu2"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 4096): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.0134): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 2): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 3): + End of stream token id. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + partial_rotary_factor (`float`, *optional*, defaults to 0.5): Percentage of the query and keys which will have rotary embedding. + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + mlp_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in up_proj and down_proj layers in the MLP layers. + + ```python + >>> from transformers import NemotronModel, NemotronConfig + + >>> # Initializing a Nemotron nemotron-15b style configuration + >>> configuration = NemotronConfig() + + >>> # Initializing a model from the nemotron-15b style configuration + >>> model = NemotronModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "nemotron" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=256000, + hidden_size=6144, + intermediate_size=24576, + num_hidden_layers=32, + num_attention_heads=48, + head_dim=None, + num_key_value_heads=None, + hidden_act="relu2", + max_position_embeddings=4096, + initializer_range=0.0134, + norm_eps=1e-5, + use_cache=True, + pad_token_id=None, + bos_token_id=2, + eos_token_id=3, + tie_word_embeddings=False, + rope_theta=10000.0, + partial_rotary_factor=0.5, + attention_bias=False, + attention_dropout=0.0, + mlp_bias=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.head_dim = head_dim if head_dim is not None else hidden_size // num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.norm_eps = norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.partial_rotary_factor = partial_rotary_factor + rope_config_validation(self) + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.mlp_bias = mlp_bias + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +__all__ = ["NemotronConfig"] diff --git a/docs/transformers/src/transformers/models/nemotron/convert_nemotron_nemo_to_hf.py b/docs/transformers/src/transformers/models/nemotron/convert_nemotron_nemo_to_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..07a5c51a3c4bcd1062e39c51223f2a46ff8544f0 --- /dev/null +++ b/docs/transformers/src/transformers/models/nemotron/convert_nemotron_nemo_to_hf.py @@ -0,0 +1,346 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import shutil +from argparse import ArgumentParser +from collections import OrderedDict + +import torch +from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer +from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel +from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy +from nemo.utils import logging +from pytorch_lightning import Trainer + +from transformers import LlamaTokenizer, PreTrainedTokenizerFast +from transformers.convert_slow_tokenizer import LlamaConverter + + +""" +Script to convert a nemotron checkpoint in nemo (mcore path) into a HuggingFace checkpoint. +This script can be used to 1) generate only the HF weights, or 2) generate an entire HF model folder. + +1) Generate only HF weights from a nemo file: + + python convert_nemotron_nemo_to_hf.py \ + --input_name_or_path /path/to/file.nemo or /path/to/extracted_folder \ + --output_path /path/to/pytorch_model.bin + +2) Generate the full HF model folder + + python convert_nemotron_nemo_to_hf.py \ + --input_name_or_path /path/to/file.nemo or /path/to/extracted_folder \ + --hf_input_path /path/to/input_hf_folder \ + --hf_output_path /path/to/output_hf_folder \ + + Use the --cpu-only flag if the model cannot fit in the GPU (e.g. Nemotron4 340b). + However this option makes the conversion script significantly slower. +""" + + +def get_args(): + parser = ArgumentParser() + parser.add_argument( + "--input_name_or_path", + type=str, + default=None, + required=True, + help="Path to .nemo file or extracted folder", + ) + parser.add_argument("--output_path", type=str, default=None, required=False, help="Path to HF .bin file") + parser.add_argument( + "--hf_input_path", + type=str, + default=None, + help="A HF model path, e.g. a folder containing https://huggingface.co/nvidia/Minitron-8B-Base", + ) + parser.add_argument( + "--hf_output_path", + type=str, + default=None, + help="Output HF model path, with the same format as above but user's own weights", + ) + parser.add_argument( + "--precision", + type=str, + default=None, + help="Precision of output weights." + "Defaults to precision of the input nemo weights (model.cfg.trainer.precision)", + ) + parser.add_argument( + "--cpu-only", + action="store_true", + help="Load model in cpu only. Useful if the model cannot fit in GPU memory, " + "but this option makes the conversion script significantly slower.", + ) + args = parser.parse_args() + return args + + +def convert_hf_config(nemo_config, tokenizer, vocab_size, dtype, hf_output_path, hf_url="nvidia/Minitron-8B-Base"): + """ + Convert NeMo config to HF config + """ + NEMO_ACT2HF = { + "squared-relu": "relu2", + "fast-swiglu": "silu", + } + DTYPE2HF = { + torch.bfloat16: "bfloat16", + torch.float16: "float16", + torch.float32: "float32", + } + hf_config = { + "_name_or_path": hf_url, + "architectures": ["NemotronForCausalLM"], + "bos_token_id": tokenizer.bos_id, + "eos_token_id": tokenizer.eos_id, + "hidden_act": NEMO_ACT2HF[nemo_config.activation], + "hidden_size": nemo_config.hidden_size, + "initializer_range": nemo_config.init_method_std, + "intermediate_size": nemo_config.ffn_hidden_size, + "max_position_embeddings": nemo_config.max_position_embeddings, + "model_type": "nemotron", + "num_attention_heads": nemo_config.num_attention_heads, + "num_hidden_layers": nemo_config.num_layers, + "num_key_value_heads": nemo_config.get("num_query_groups", nemo_config.num_attention_heads), + "norm_eps": nemo_config.layernorm_epsilon, + "rope_theta": nemo_config.get("rotary_base", 10000), + "partial_rotary_factor": nemo_config.get("rotary_percentage", 1.0), + "tie_word_embeddings": False, + "torch_dtype": DTYPE2HF[dtype], + "transformers_version": "4.32.0.dev0", # TODO + "use_cache": True, + "vocab_size": vocab_size, + } + if nemo_config.kv_channels is not None: + hf_config["kv_channels"] = nemo_config.kv_channels + json.dump(hf_config, open(f"{hf_output_path}/config.json", "w"), indent=2) + + +def convert(input_nemo_file, output_hf_file, precision=None, cpu_only=False) -> None: + """ + Convert NeMo weights to HF weights + """ + dummy_trainer = Trainer(devices=1, accelerator="cpu", strategy=NLPDDPStrategy()) + model_config = MegatronGPTModel.restore_from(input_nemo_file, trainer=dummy_trainer, return_config=True) + model_config.tensor_model_parallel_size = 1 + model_config.pipeline_model_parallel_size = 1 + model_config.sequence_parallel = False + model_config.transformer_engine = True + if cpu_only: + map_location = torch.device("cpu") + model_config.use_cpu_initialization = True + model_config.dist_ckpt_load_on_device = False + else: + map_location = None + + if cpu_only: + logging.info("******** Loading model on CPU. This will take a significant amount of time.") + + model = MegatronGPTModel.restore_from( + input_nemo_file, trainer=dummy_trainer, override_config_path=model_config, map_location=map_location + ) + + vocab_size = model.padded_vocab_size + + if precision is None: + precision = model.cfg.precision + if precision in [32, "32"]: + dtype = torch.float32 + elif precision in [16, "16", "16-mixed"]: + dtype = torch.float16 + elif precision in ["bf16", "bf16-mixed"]: + dtype = torch.bfloat16 + else: + logging.warning(f"Precision string {precision} is not recognized, falling back to fp32") + dtype = torch.float32 # fallback + logging.info(f"Using precision {dtype}") + + def param_to_weights(param): + return param.to(dtype) + + checkpoint = OrderedDict() + + hidden_size = model.cfg.hidden_size + head_num = model.cfg.num_attention_heads + num_layers = model.cfg.num_layers + ffn_hidden_size = model.cfg.ffn_hidden_size + num_query_groups = model.cfg.get("num_query_groups", head_num) # different num_query_groups for 70B + if num_query_groups is None: + num_query_groups = head_num + heads_per_group = head_num // num_query_groups + qkv_total_dim = head_num + 2 * num_query_groups + + # Embedding + embed_weight = model.state_dict()["model.embedding.word_embeddings.weight"] + embed_weights_base_name = "model.embed_tokens.weight" + checkpoint[embed_weights_base_name] = param_to_weights(embed_weight) + + for l in range(int(num_layers)): + print(f"converting layer {l}") + + qkv_weights = model.state_dict()[f"model.decoder.layers.{l}.self_attention.linear_qkv.weight"] + qkv_weights = qkv_weights.reshape([qkv_total_dim, -1, hidden_size]) + + q_slice = torch.cat( + [ + torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group) + for i in range(num_query_groups) + ] + ) + k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2)) + v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2)) + ## Example of slices + ## (without GQA): num_query_groups = head_num = 32, + ## q_slice = [0, 3, 6, 9 , ... 90, 93] + ## k_slice = [1, 4, 7, 10, ... 91, 94] + ## v_slice = [2, 5, 8, 11, ... 92, 95] + ## (with GQA): num_query_groups = 8, head_num = 64 + ## q_slice = [0, 1, .. 6, 7, 10, 11, .. 16, 17, 20, 21, .. 67, 70, ... 76, 77] + ## k_slice = [8, 18, 28, ... 68, 78] + ## v_slice = [9, 19, 29, ... 69, 79] + + q_weights_base_name = f"model.layers.{l}.self_attn.q_proj.weight" + k_weights_base_name = f"model.layers.{l}.self_attn.k_proj.weight" + v_weights_base_name = f"model.layers.{l}.self_attn.v_proj.weight" + + checkpoint[q_weights_base_name] = param_to_weights(qkv_weights[q_slice].reshape(-1, hidden_size)) + checkpoint[k_weights_base_name] = param_to_weights(qkv_weights[k_slice].reshape(-1, hidden_size)) + checkpoint[v_weights_base_name] = param_to_weights(qkv_weights[v_slice].reshape(-1, hidden_size)) + + # attention dense + o_weight = model.state_dict()[f"model.decoder.layers.{l}.self_attention.linear_proj.weight"] + o_weight_base_name = f"model.layers.{l}.self_attn.o_proj.weight" + checkpoint[o_weight_base_name] = param_to_weights(o_weight) + + # mlp + mlp_weights = model.state_dict()[f"model.decoder.layers.{l}.mlp.linear_fc1.weight"] + mlp_up_proj_weight = model.state_dict()[f"model.decoder.layers.{l}.mlp.linear_fc2.weight"] + + if mlp_weights.shape[0] != mlp_up_proj_weight.shape[1]: + # Has projection (used for swi-glu) + logging.warning( + "Gated projection layers detected in NeMo checkpoint. Currently Nemotron HF does not support gated MLP." + ) + assert mlp_weights.shape[0] == 2 * mlp_up_proj_weight.shape[1] + + mlp_down_proj_weight = mlp_weights[:ffn_hidden_size, :] + mlp_gate_proj_weight = mlp_weights[ffn_hidden_size:, :] + + mlp_down_proj_base_name = f"model.layers.{l}.mlp.gate_proj.weight" + mlp_gate_proj_base_name = f"model.layers.{l}.mlp.up_proj.weight" + + checkpoint[mlp_down_proj_base_name] = param_to_weights(mlp_down_proj_weight) + checkpoint[mlp_gate_proj_base_name] = param_to_weights(mlp_gate_proj_weight) + else: + mlp_down_proj_weight = mlp_weights + mlp_down_proj_base_name = f"model.layers.{l}.mlp.up_proj.weight" + checkpoint[mlp_down_proj_base_name] = param_to_weights(mlp_down_proj_weight) + + mlp_up_proj_base_name = f"model.layers.{l}.mlp.down_proj.weight" + checkpoint[mlp_up_proj_base_name] = param_to_weights(mlp_up_proj_weight) + + # layernorm + input_ln_weight = model.state_dict()[f"model.decoder.layers.{l}.self_attention.linear_qkv.layer_norm_weight"] + input_ln_base_name = f"model.layers.{l}.input_layernorm.weight" + checkpoint[input_ln_base_name] = param_to_weights(input_ln_weight) + if ( + model.state_dict().get(f"model.decoder.layers.{l}.self_attention.linear_qkv.layer_norm_bias", None) + is not None + ): + input_ln_bias = model.state_dict()[f"model.decoder.layers.{l}.self_attention.linear_qkv.layer_norm_bias"] + input_ln_bias_name = f"model.layers.{l}.input_layernorm.bias" + checkpoint[input_ln_bias_name] = param_to_weights(input_ln_bias) + + post_attn_ln_weight = model.state_dict()[f"model.decoder.layers.{l}.mlp.linear_fc1.layer_norm_weight"] + post_attn_ln_base_name = f"model.layers.{l}.post_attention_layernorm.weight" + checkpoint[post_attn_ln_base_name] = param_to_weights(post_attn_ln_weight) + if model.state_dict().get(f"model.decoder.layers.{l}.mlp.linear_fc1.layer_norm_bias", None) is not None: + post_attn_ln_bias = model.state_dict()[f"model.decoder.layers.{l}.mlp.linear_fc1.layer_norm_bias"] + post_attn_ln_bias_name = f"model.layers.{l}.post_attention_layernorm.bias" + checkpoint[post_attn_ln_bias_name] = param_to_weights(post_attn_ln_bias) + + print(f"done layer {l}") + + final_ln_weight = model.state_dict()["model.decoder.final_layernorm.weight"] + final_ln_base_name = "model.norm.weight" + checkpoint[final_ln_base_name] = param_to_weights(final_ln_weight) + if model.state_dict().get("model.decoder.final_layernorm.bias", None) is not None: + final_ln_bias = model.state_dict()["model.decoder.final_layernorm.bias"] + final_ln_bias_name = "model.norm.bias" + checkpoint[final_ln_bias_name] = param_to_weights(final_ln_bias) + + output_layer_weight = model.state_dict()["model.output_layer.weight"] + output_layer_base_name = "lm_head.weight" + checkpoint[output_layer_base_name] = param_to_weights(output_layer_weight) + + os.makedirs(os.path.dirname(output_hf_file), exist_ok=True) + torch.save(checkpoint, output_hf_file) + logging.info(f"Weights saved to {output_hf_file}") + + return model_config, model.tokenizer, dtype, vocab_size + + +def extract_nemotron_tokenizer(nemo_file, model_config, output_hf_path, nemo_tokenizer): + tokenizer_cfg = model_config.tokenizer + if tokenizer_cfg.library == "sentencepiece": + # For sentencepiece tokenizer, we are wrapping with HF's LlamaTokenizer + # and convert it to a PreTrainedTokenizerFast + tokenizer_fn = tokenizer_cfg.model[5:] + output_tokenizer = f"{output_hf_path}/tokenizer.model" + if nemo_file.endswith(".nemo"): + import tarfile + + archive = tarfile.open(nemo_file, "r") + tokenizer_filename = "./" + tokenizer_fn # exclude 'nemo:' prefix + archive.extract(tokenizer_filename, output_hf_path) + archive.close() + os.rename(f"{output_hf_path}/{tokenizer_fn}", output_tokenizer) + elif os.path.isdir(nemo_file): + shutil.copy(f"{nemo_file}/{tokenizer_fn}", output_tokenizer) + # We use LlamaTokenizer for sentencepiece based tokenizer + tokenizer = LlamaTokenizer.from_pretrained(output_hf_path, legacy=False) + # Convert the LlamaTokenizer to a PreTrainedTokenizerFast instance + tokenizer = PreTrainedTokenizerFast( + tokenizer_object=LlamaConverter(tokenizer).converted(), model_input_names=["input_ids", "token_type_ids"] + ) + tokenizer.save_pretrained(output_hf_path) + logging.info(f"Setencepiece tokenizer has been saved to {output_tokenizer}") + elif isinstance(nemo_tokenizer, AutoTokenizer): + nemo_tokenizer.tokenizer.save_pretrained(output_hf_path) + logging.info(f"HF AutoTokenizer has been saved to {output_hf_path}") + else: + raise ValueError(f"Unsupported tokenizer type: library: {tokenizer_cfg.library}, type: {tokenizer_cfg.type}") + + +if __name__ == "__main__": + args = get_args() + if not args.hf_output_path: + assert args.output_path is not None, "Need to provide either output_path or hf_output_path" + else: + args.output_path = f"{args.hf_output_path}/pytorch_model.bin" + logging.info(f"weight will be saved to {args.output_path}") + + nemo_config, nemo_tokenizer, dtype, vocab_size = convert( + args.input_name_or_path, args.output_path, precision=args.precision, cpu_only=args.cpu_only + ) + if args.hf_input_path and args.hf_output_path: + convert_hf_config(nemo_config, nemo_tokenizer, vocab_size, dtype, args.hf_output_path, args.hf_input_path) + extract_nemotron_tokenizer(args.input_name_or_path, nemo_config, args.hf_output_path, nemo_tokenizer) + else: + logging.info("`hf_input_path` and/or `hf_output_path` not provided, not generating full HF model.") + logging.info(f".bin file is saved to {args.output_path}") diff --git a/docs/transformers/src/transformers/models/nllb/tokenization_nllb.py b/docs/transformers/src/transformers/models/nllb/tokenization_nllb.py new file mode 100644 index 0000000000000000000000000000000000000000..1ac0059ddd21760811a71b12fd4d3fa0a0a25388 --- /dev/null +++ b/docs/transformers/src/transformers/models/nllb/tokenization_nllb.py @@ -0,0 +1,394 @@ +# coding=utf-8 +# Copyright 2022 The Facebook AI Research Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple + +import sentencepiece as spm + +from ...tokenization_utils import AddedToken, BatchEncoding, PreTrainedTokenizer +from ...utils import logging +from ...utils.import_utils import requires + + +logger = logging.get_logger(__name__) + +SPIECE_UNDERLINE = "▁" + +VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model"} + + +FAIRSEQ_LANGUAGE_CODES = ['ace_Arab', 'ace_Latn', 'acm_Arab', 'acq_Arab', 'aeb_Arab', 'afr_Latn', 'ajp_Arab', 'aka_Latn', 'amh_Ethi', 'apc_Arab', 'arb_Arab', 'ars_Arab', 'ary_Arab', 'arz_Arab', 'asm_Beng', 'ast_Latn', 'awa_Deva', 'ayr_Latn', 'azb_Arab', 'azj_Latn', 'bak_Cyrl', 'bam_Latn', 'ban_Latn', 'bel_Cyrl', 'bem_Latn', 'ben_Beng', 'bho_Deva', 'bjn_Arab', 'bjn_Latn', 'bod_Tibt', 'bos_Latn', 'bug_Latn', 'bul_Cyrl', 'cat_Latn', 'ceb_Latn', 'ces_Latn', 'cjk_Latn', 'ckb_Arab', 'crh_Latn', 'cym_Latn', 'dan_Latn', 'deu_Latn', 'dik_Latn', 'dyu_Latn', 'dzo_Tibt', 'ell_Grek', 'eng_Latn', 'epo_Latn', 'est_Latn', 'eus_Latn', 'ewe_Latn', 'fao_Latn', 'pes_Arab', 'fij_Latn', 'fin_Latn', 'fon_Latn', 'fra_Latn', 'fur_Latn', 'fuv_Latn', 'gla_Latn', 'gle_Latn', 'glg_Latn', 'grn_Latn', 'guj_Gujr', 'hat_Latn', 'hau_Latn', 'heb_Hebr', 'hin_Deva', 'hne_Deva', 'hrv_Latn', 'hun_Latn', 'hye_Armn', 'ibo_Latn', 'ilo_Latn', 'ind_Latn', 'isl_Latn', 'ita_Latn', 'jav_Latn', 'jpn_Jpan', 'kab_Latn', 'kac_Latn', 'kam_Latn', 'kan_Knda', 'kas_Arab', 'kas_Deva', 'kat_Geor', 'knc_Arab', 'knc_Latn', 'kaz_Cyrl', 'kbp_Latn', 'kea_Latn', 'khm_Khmr', 'kik_Latn', 'kin_Latn', 'kir_Cyrl', 'kmb_Latn', 'kon_Latn', 'kor_Hang', 'kmr_Latn', 'lao_Laoo', 'lvs_Latn', 'lij_Latn', 'lim_Latn', 'lin_Latn', 'lit_Latn', 'lmo_Latn', 'ltg_Latn', 'ltz_Latn', 'lua_Latn', 'lug_Latn', 'luo_Latn', 'lus_Latn', 'mag_Deva', 'mai_Deva', 'mal_Mlym', 'mar_Deva', 'min_Latn', 'mkd_Cyrl', 'plt_Latn', 'mlt_Latn', 'mni_Beng', 'khk_Cyrl', 'mos_Latn', 'mri_Latn', 'zsm_Latn', 'mya_Mymr', 'nld_Latn', 'nno_Latn', 'nob_Latn', 'npi_Deva', 'nso_Latn', 'nus_Latn', 'nya_Latn', 'oci_Latn', 'gaz_Latn', 'ory_Orya', 'pag_Latn', 'pan_Guru', 'pap_Latn', 'pol_Latn', 'por_Latn', 'prs_Arab', 'pbt_Arab', 'quy_Latn', 'ron_Latn', 'run_Latn', 'rus_Cyrl', 'sag_Latn', 'san_Deva', 'sat_Beng', 'scn_Latn', 'shn_Mymr', 'sin_Sinh', 'slk_Latn', 'slv_Latn', 'smo_Latn', 'sna_Latn', 'snd_Arab', 'som_Latn', 'sot_Latn', 'spa_Latn', 'als_Latn', 'srd_Latn', 'srp_Cyrl', 'ssw_Latn', 'sun_Latn', 'swe_Latn', 'swh_Latn', 'szl_Latn', 'tam_Taml', 'tat_Cyrl', 'tel_Telu', 'tgk_Cyrl', 'tgl_Latn', 'tha_Thai', 'tir_Ethi', 'taq_Latn', 'taq_Tfng', 'tpi_Latn', 'tsn_Latn', 'tso_Latn', 'tuk_Latn', 'tum_Latn', 'tur_Latn', 'twi_Latn', 'tzm_Tfng', 'uig_Arab', 'ukr_Cyrl', 'umb_Latn', 'urd_Arab', 'uzn_Latn', 'vec_Latn', 'vie_Latn', 'war_Latn', 'wol_Latn', 'xho_Latn', 'ydd_Hebr', 'yor_Latn', 'yue_Hant', 'zho_Hans', 'zho_Hant', 'zul_Latn'] # fmt: skip + + +@requires(backends=("sentencepiece",)) +class NllbTokenizer(PreTrainedTokenizer): + """ + Construct an NLLB tokenizer. + + Adapted from [`RobertaTokenizer`] and [`XLNetTokenizer`]. Based on + [SentencePiece](https://github.com/google/sentencepiece). + + The tokenization method is ` ` for source language documents, and ` + ` for target language documents. + + Examples: + + ```python + >>> from transformers import NllbTokenizer + + >>> tokenizer = NllbTokenizer.from_pretrained( + ... "facebook/nllb-200-distilled-600M", src_lang="eng_Latn", tgt_lang="fra_Latn" + ... ) + >>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria" + >>> expected_translation_french = "Le chef de l'ONU affirme qu'il n'y a pas de solution militaire en Syrie." + >>> inputs = tokenizer(example_english_phrase, text_target=expected_translation_french, return_tensors="pt") + ``` + + Args: + vocab_file (`str`): + Path to the vocabulary file. + bos_token (`str`, *optional*, defaults to `""`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + + + + When building a sequence using special tokens, this is not the token that is used for the beginning of + sequence. The token used is the `cls_token`. + + + + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + sep_token (`str`, *optional*, defaults to `""`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + cls_token (`str`, *optional*, defaults to `""`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + mask_token (`str`, *optional*, defaults to `""`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + tokenizer_file (`str`, *optional*): + The path to a tokenizer file to use instead of the vocab file. + src_lang (`str`, *optional*): + The language to use as source language for translation. + tgt_lang (`str`, *optional*): + The language to use as target language for translation. + sp_model_kwargs (`Dict[str, str]`): + Additional keyword arguments to pass to the model initialization. + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + + prefix_tokens: List[int] = [] + suffix_tokens: List[int] = [] + + def __init__( + self, + vocab_file, + bos_token="", + eos_token="", + sep_token="", + cls_token="", + unk_token="", + pad_token="", + mask_token="", + tokenizer_file=None, + src_lang=None, + tgt_lang=None, + sp_model_kwargs: Optional[Dict[str, Any]] = None, + additional_special_tokens=None, + legacy_behaviour=False, + **kwargs, + ): + if additional_special_tokens is None: + additional_special_tokens = FAIRSEQ_LANGUAGE_CODES + bos_token = AddedToken(bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token + pad_token = AddedToken(pad_token, normalized=False, special=True) if isinstance(pad_token, str) else pad_token + eos_token = AddedToken(eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token + unk_token = AddedToken(unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token + # Mask token behave like a normal word, i.e. include the space before it + mask_token = ( + AddedToken(mask_token, normalized=True, lstrip=True, special=True) + if isinstance(mask_token, str) + else mask_token + ) + + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + self.legacy_behaviour = legacy_behaviour + + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(str(vocab_file)) + self.vocab_file = vocab_file + # Original fairseq vocab and spm vocab must be "aligned": + # Vocab | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 + # -------- | ------- | ------- | ------ | ------- | ---- | ---- | ---- | ---- | ---- | ---- + # fairseq | '' | '' | '' | '' | 'an' | '▁n' | '▁m' | '▁t' | '▁k' | '▁a' + # spm | '' | '' | '' | 'an' | '▁n' | '▁m' | '▁t' | '▁k' | '▁a' | '▁s' + + # unk token needs to be in the vocab with correct index + self._added_tokens_decoder = {0: bos_token, 1: pad_token, 2: eos_token, 3: unk_token} + # The first "real" token "," has position 4 in the original fairseq vocab and position 3 in the spm vocab + self.fairseq_offset = 1 + self.sp_model_size = len(self.sp_model) + + super().__init__( + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + sep_token=sep_token, + cls_token=cls_token, + pad_token=pad_token, + mask_token=mask_token, + tokenizer_file=tokenizer_file, + src_lang=src_lang, + tgt_lang=tgt_lang, + additional_special_tokens=additional_special_tokens, + sp_model_kwargs=self.sp_model_kwargs, + legacy_behaviour=legacy_behaviour, + **kwargs, + ) + + self._src_lang = src_lang if src_lang is not None else "eng_Latn" + self.cur_lang_code_id = self.convert_tokens_to_ids(self._src_lang) + self.tgt_lang = tgt_lang + self.set_src_lang_special_tokens(self._src_lang) + + def __getstate__(self): + state = self.__dict__.copy() + state["sp_model"] = None + state["sp_model_proto"] = self.sp_model.serialized_model_proto() + return state + + def __setstate__(self, d): + self.__dict__ = d + + # for backward compatibility + if not hasattr(self, "sp_model_kwargs"): + self.sp_model_kwargs = {} + + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.LoadFromSerializedProto(self.sp_model_proto) + + @property + def vocab_size(self): + return len(self.sp_model) + self.fairseq_offset + + @property + def src_lang(self) -> str: + return self._src_lang + + @src_lang.setter + def src_lang(self, new_src_lang: str) -> None: + self._src_lang = new_src_lang + self.set_src_lang_special_tokens(self._src_lang) + + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + prefix_ones = [1] * len(self.prefix_tokens) + suffix_ones = [1] * len(self.suffix_tokens) + if token_ids_1 is None: + return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones + return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones + + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. An NLLB sequence has the following format, where `X` represents the sequence: + + - `input_ids` (for encoder) `X [eos, src_lang_code]` + - `decoder_input_ids`: (for decoder) `X [eos, tgt_lang_code]` + + BOS is never used. Pairs of sequences are not the expected use case, but they will be handled without a + separator. + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + if token_ids_1 is None: + return self.prefix_tokens + token_ids_0 + self.suffix_tokens + # We don't expect to process pairs, but leave the pair logic for API consistency + return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. nllb does not + make use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of zeros. + + """ + + sep = [self.sep_token_id] + cls = [self.cls_token_id] + + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0] + + def _build_translation_inputs( + self, raw_inputs, return_tensors: str, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs + ): + """Used by translation pipeline, to prepare inputs for the generate function""" + if src_lang is None or tgt_lang is None: + raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model") + self.src_lang = src_lang + inputs = self(raw_inputs, add_special_tokens=True, return_tensors=return_tensors, **extra_kwargs) + tgt_lang_id = self.convert_tokens_to_ids(tgt_lang) + inputs["forced_bos_token_id"] = tgt_lang_id + return inputs + + def get_vocab(self): + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + def _tokenize(self, text: str) -> List[str]: + return self.sp_model.encode(text, out_type=str) + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + spm_id = self.sp_model.PieceToId(token) + # Need to return unknown token if the SP model returned 0 + return spm_id + self.fairseq_offset if spm_id else self.unk_token_id + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.sp_model.IdToPiece(index - self.fairseq_offset) + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (strings for sub-words) in a single string.""" + out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip() + return out_string + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): + copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + + return (out_vocab_file,) + + def prepare_seq2seq_batch( + self, + src_texts: List[str], + src_lang: str = "eng_Latn", + tgt_texts: Optional[List[str]] = None, + tgt_lang: str = "fra_Latn", + **kwargs, + ) -> BatchEncoding: + self.src_lang = src_lang + self.tgt_lang = tgt_lang + return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs) + + def _switch_to_input_mode(self): + return self.set_src_lang_special_tokens(self.src_lang) + + def _switch_to_target_mode(self): + return self.set_tgt_lang_special_tokens(self.tgt_lang) + + def set_src_lang_special_tokens(self, src_lang) -> None: + """Reset the special tokens to the source lang setting. + - In legacy mode: No prefix and suffix=[eos, src_lang_code]. + - In default mode: Prefix=[src_lang_code], suffix = [eos] + """ + self.cur_lang_code = self.convert_tokens_to_ids(src_lang) + if self.legacy_behaviour: + self.prefix_tokens = [] + self.suffix_tokens = [self.eos_token_id, self.cur_lang_code] + else: + self.prefix_tokens = [self.cur_lang_code] + self.suffix_tokens = [self.eos_token_id] + + def set_tgt_lang_special_tokens(self, lang: str) -> None: + """Reset the special tokens to the target lang setting. + - In legacy mode: No prefix and suffix=[eos, tgt_lang_code]. + - In default mode: Prefix=[tgt_lang_code], suffix = [eos] + """ + self.cur_lang_code = self.convert_tokens_to_ids(lang) + if self.legacy_behaviour: + self.prefix_tokens = [] + self.suffix_tokens = [self.eos_token_id, self.cur_lang_code] + else: + self.prefix_tokens = [self.cur_lang_code] + self.suffix_tokens = [self.eos_token_id] + + +__all__ = ["NllbTokenizer"] diff --git a/docs/transformers/src/transformers/models/nllb_moe/__init__.py b/docs/transformers/src/transformers/models/nllb_moe/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e21bde18010d3c81455b0a66eed8263c47a0e5de --- /dev/null +++ b/docs/transformers/src/transformers/models/nllb_moe/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_nllb_moe import * + from .modeling_nllb_moe import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/nllb_moe/configuration_nllb_moe.py b/docs/transformers/src/transformers/models/nllb_moe/configuration_nllb_moe.py new file mode 100644 index 0000000000000000000000000000000000000000..f161930d90d0435ce4e8603eb12c9903d4e7832f --- /dev/null +++ b/docs/transformers/src/transformers/models/nllb_moe/configuration_nllb_moe.py @@ -0,0 +1,219 @@ +# coding=utf-8 +# Copyright 2023, HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""NLLB-MoE model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class NllbMoeConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`NllbMoeModel`]. It is used to instantiate an + NLLB-MoE model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the NLLB-MoE + [facebook/nllb-moe-54b](https://huggingface.co/facebook/nllb-moe-54b) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50265): + Vocabulary size of the NllbMoe model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`NllbMoeModel`] or + d_model (`int`, *optional*, defaults to 1024): + Dimensionality of the layers and the pooler layer. + encoder_layers (`int`, *optional*, defaults to 12): + Number of encoder layers. + decoder_layers (`int`, *optional*, defaults to 12): + Number of decoder layers. + encoder_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + decoder_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer decoder. + decoder_ffn_dim (`int`, *optional*, defaults to 4096): + Dimensionality of the "intermediate" (often named feed-forward) layer in decoder. + encoder_ffn_dim (`int`, *optional*, defaults to 4096): + Dimensionality of the "intermediate" (often named feed-forward) layer in encoder. + activation_function (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + activation_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for activations inside the fully connected layer. + classifier_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for classifier. + max_position_embeddings (`int`, *optional*, defaults to 1024): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + init_std (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + encoder_layerdrop (`float`, *optional*, defaults to 0.0): + The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) + for more details. + decoder_layerdrop (`float`, *optional*, defaults to 0.0): + The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) + for more details. + second_expert_policy ( `str`, *optional*, default to `"all"`): + The policy used for the sampling the probability of being sampled to a second expert for each token. + normalize_router_prob_before_dropping (`bool`, *optional*, defaults to `True`): + Whether or not to normalize the router probabilities before applying a mask based on the experts capacity + (capacity dropping). + batch_prioritized_routing (`bool`, *optional*, defaults to `True`): + Whether or not to orders the tokens by their router probabilities before capacity dropping. This means that + the tokens that have the highest probabilities will be routed before other tokens that might be further in + the sequence. + moe_eval_capacity_token_fraction (`float`, *optional*, defaults to 1.0): + Fraction of tokens as capacity during validation, if set to negative, uses the same as training. Should be + in range: (0.0, 1.0]. + num_experts (`int`, *optional*, defaults to 128): + Number of experts for each NllbMoeSparseMlp layer. + expert_capacity (`int`, *optional*, defaults to 64): + Number of tokens that can be stored in each expert. + encoder_sparse_step (`int`, *optional*, defaults to 4): + Frequency of the sparse layers in the encoder. 4 means that one out of 4 layers will be sparse. + decoder_sparse_step (`int`, *optional*, defaults to 4): + Frequency of the sparse layers in the decoder. 4 means that one out of 4 layers will be sparse. + router_dtype (`str`, *optional*, default to `"float32"`): + The `dtype` used for the routers. It is preferable to keep the `dtype` to `"float32"` as specified in the + *selective precision* discussion in [the paper](https://arxiv.org/abs/2101.03961). + router_ignore_padding_tokens (`bool`, *optional*, defaults to `False`): + Whether to ignore padding tokens when routing. if `False`, the padding tokens are not routed to any + experts. + router_bias (`bool`, *optional*, defaults to `False`): + Whether or not the classifier of the router should have a bias. + moe_token_dropout (`float`, *optional*, default to 0.2): + Masking rate for MoE expert output masking (EOM), which is implemented via a Dropout2d on the expert + outputs. + output_router_logits (`bool`, *optional*, defaults to `False`): + Whether or not to return the router logits. Only set to `True` to get the auxiliary loss when training. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + + Example: + + ```python + >>> from transformers import NllbMoeModel, NllbMoeConfig + + >>> # Initializing a NllbMoe facebook/nllb-moe-54b style configuration + >>> configuration = NllbMoeConfig() + + >>> # Initializing a model from the facebook/nllb-moe-54b style configuration + >>> model = NllbMoeModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "nllb-moe" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"} + + def __init__( + self, + vocab_size=128112, + max_position_embeddings=1024, + encoder_layers=12, + encoder_ffn_dim=4096, + encoder_attention_heads=16, + decoder_layers=12, + decoder_ffn_dim=4096, + decoder_attention_heads=16, + encoder_layerdrop=0.05, + decoder_layerdrop=0.05, + use_cache=True, + is_encoder_decoder=True, + activation_function="relu", + d_model=1024, + dropout=0.1, + attention_dropout=0.1, + activation_dropout=0.0, + init_std=0.02, + decoder_start_token_id=2, + scale_embedding=True, + router_bias=False, + router_dtype="float32", + router_ignore_padding_tokens=False, + num_experts=128, + expert_capacity=64, + encoder_sparse_step=4, + decoder_sparse_step=4, + router_z_loss_coef=0.001, + router_aux_loss_coef=0.001, + second_expert_policy="all", + normalize_router_prob_before_dropping=False, + batch_prioritized_routing=False, + moe_eval_capacity_token_fraction=1.0, + moe_token_dropout=0.2, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + output_router_logits=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.d_model = d_model + self.encoder_ffn_dim = encoder_ffn_dim + self.encoder_layers = encoder_layers + self.encoder_attention_heads = encoder_attention_heads + self.decoder_ffn_dim = decoder_ffn_dim + self.decoder_layers = decoder_layers + self.decoder_attention_heads = decoder_attention_heads + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.activation_function = activation_function + self.init_std = init_std + self.encoder_layerdrop = encoder_layerdrop + self.decoder_layerdrop = decoder_layerdrop + self.use_cache = use_cache + self.num_hidden_layers = encoder_layers + self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True + self.router_z_loss_coef = router_z_loss_coef + self.router_aux_loss_coef = router_aux_loss_coef + self.decoder_sparse_step = decoder_sparse_step + self.encoder_sparse_step = encoder_sparse_step + self.num_experts = num_experts + self.expert_capacity = expert_capacity + self.router_bias = router_bias + if router_dtype not in ["float32", "float16", "bfloat16"]: + raise ValueError(f"`router_dtype` must be one of 'float32', 'float16' or 'bfloat16', got {router_dtype}") + self.router_dtype = router_dtype + + self.router_ignore_padding_tokens = router_ignore_padding_tokens + self.batch_prioritized_routing = batch_prioritized_routing + self.second_expert_policy = second_expert_policy + self.normalize_router_prob_before_dropping = normalize_router_prob_before_dropping + self.moe_eval_capacity_token_fraction = moe_eval_capacity_token_fraction + self.moe_token_dropout = moe_token_dropout + self.output_router_logits = output_router_logits + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + is_encoder_decoder=is_encoder_decoder, + decoder_start_token_id=decoder_start_token_id, + **kwargs, + ) + + +__all__ = ["NllbMoeConfig"] diff --git a/docs/transformers/src/transformers/models/nllb_moe/convert_nllb_moe_sharded_original_checkpoint_to_pytorch.py b/docs/transformers/src/transformers/models/nllb_moe/convert_nllb_moe_sharded_original_checkpoint_to_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..317c5c713c7f50c4641bf2b1efb233c2d521a160 --- /dev/null +++ b/docs/transformers/src/transformers/models/nllb_moe/convert_nllb_moe_sharded_original_checkpoint_to_pytorch.py @@ -0,0 +1,161 @@ +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import json +import os + +import torch +from torch import nn + +from transformers import NllbMoeConfig, NllbMoeModel +from transformers.utils import WEIGHTS_INDEX_NAME, WEIGHTS_NAME + + +def remove_ignore_keys_(state_dict): + ignore_keys = [ + "encoder.version", + "decoder.version", + "model.encoder.version", + "model.decoder.version", + "decoder.output_projection.weight", + "_float_tensor", + "encoder.embed_positions._float_tensor", + "decoder.embed_positions._float_tensor", + ] + for k in ignore_keys: + state_dict.pop(k, None) + + +def make_linear_from_emb(emb): + vocab_size, emb_size = emb.weight.shape + lin_layer = nn.Linear(vocab_size, emb_size, bias=False) + lin_layer.weight.data = emb.weight.data + return lin_layer + + +def rename_fairseq_keys(state_dict, expert_idx=None): + new_dict = {} + for old_key in state_dict.keys(): + key = old_key + if "moe_layer.experts." in key: + if expert_idx is not None: + key = key.replace("moe_layer.experts.0", f"ffn.experts.expert_{expert_idx}") + else: + key = key.replace("moe_layer.experts.", "ffn.experts.expert_") + if "gate" in key: + key = key.replace(".moe_layer.gate.wg", ".ffn.router.classifier") + if "fc2" and "experts" not in key: + key = key.replace(".fc2.", ".ffn.fc2.") + if "fc1" and "experts" not in key: + key = key.replace(".fc1.", ".ffn.fc1.") + if ".encoder_attn." in key: + key = key.replace(".encoder_attn.", ".cross_attention.") + if "encoder_attn_layer_norm" in key: + key = key.replace("encoder_attn_layer_norm", "cross_attention_layer_norm") + if "final_layer_norm" in key: + key = key.replace("final_layer_norm", "ff_layer_norm") + new_dict[key] = state_dict[old_key] + return new_dict + + +def shard_on_the_fly(switch_checkpoint_path, dump_path, num_experts, dtype, weights_name: str = WEIGHTS_NAME): + sharded_state_dicts = [] + total_size = 0 + os.makedirs(dump_path, exist_ok=True) + + for expert in range(num_experts): + expert_path = switch_checkpoint_path + f"-rank-{expert}.pt" + if os.path.isfile(expert_path): + expert_state = torch.load(expert_path, weights_only=True)["model"] + remove_ignore_keys_(expert_state) + expert_state = rename_fairseq_keys(expert_state, expert) + save_path = os.path.join( + dump_path, weights_name.replace(".bin", f"-{len(sharded_state_dicts) + 1:05d}-of-???.bin") + ) + torch.save(expert_state, save_path) + sharded_state_dicts.append(expert_state.keys()) + total_size += sum([value.numel() for key, value in expert_state.items()]) * ( + expert_state[list(expert_state)[0]].element_size() + ) + + # Add the last block + save_path = os.path.join( + dump_path, weights_name.replace(".bin", f"-{len(sharded_state_dicts) + 1:05d}-of-???.bin") + ) + shared_weights = torch.load(switch_checkpoint_path + "-shared.pt", weights_only=True)["model"] + remove_ignore_keys_(shared_weights) + shared_weights = rename_fairseq_keys(shared_weights, None) + shared_weights["shared.weight"] = shared_weights["decoder.embed_tokens.weight"] + sharded_state_dicts.append(shared_weights.keys()) + + # If we only have the shared weights (dummy model/experts saved on the same file) + if len(sharded_state_dicts) == 1: + save_path = os.path.join(dump_path, weights_name) + torch.save(shared_weights, save_path) + return {weights_name: sharded_state_dicts[0]}, None + else: + torch.save(shared_weights, save_path) + # Otherwise, let's build the index + weight_map = {} + for idx, shard in enumerate(sharded_state_dicts): + shard_file = weights_name.replace(".bin", f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.bin") + temp_filename = os.path.join(dump_path, weights_name.replace(".bin", f"-{idx + 1:05d}-of-???.bin")) + os.rename(temp_filename, os.path.join(dump_path, shard_file)) + for key in shard: + weight_map[key] = shard_file + + # Add the metadata + metadata = {"total_size": total_size} + index = {"metadata": metadata, "weight_map": weight_map} + + with open(os.path.join(dump_path, WEIGHTS_INDEX_NAME), "w", encoding="utf-8") as f: + content = json.dumps(index, indent=2, sort_keys=True) + "\n" + f.write(content) + + return metadata, index + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--nllb_moe_checkpoint_path", + default="/home/arthur_huggingface_co/fairseq/weights/checkpoints/model_moe_54b/checkpoint_2_300000", + type=str, + required=False, + help="Path to a directory containing a folder per layer. Follows the original Google format.", + ) + parser.add_argument("--dtype", default="float32", type=str, required=False, help="dtype of the saved model") + parser.add_argument( + "--pytorch_dump_folder_path", + default="/home/arthur_huggingface_co/fairseq/weights/checkpoints/hf-converted-moe-54b", + type=str, + required=False, + help="Path to the output pytorch model.", + ) + args = parser.parse_args() + metadata, index = shard_on_the_fly( + args.nllb_moe_checkpoint_path, + args.pytorch_dump_folder_path, + 128, + args.dtype, + ) + + config = NllbMoeConfig.from_pretrained( + "facebook/nllb-200-3.3B", encoder_sparse_step=4, decoder_sparse_step=4, num_experts=128 + ) + config.save_pretrained(args.pytorch_dump_folder_path) + model = NllbMoeModel.from_pretrained(args.pytorch_dump_folder_path) + print("Done") + model.save_pretrained(args.pytorch_dump_folder_path) diff --git a/docs/transformers/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/docs/transformers/src/transformers/models/nllb_moe/modeling_nllb_moe.py new file mode 100644 index 0000000000000000000000000000000000000000..8a2e735f9fe26bd9531dfe264ffd1a0108a933aa --- /dev/null +++ b/docs/transformers/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -0,0 +1,1784 @@ +# coding=utf-8 +# Copyright 2023 NllbMoe Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch NLLB-MoE model.""" + +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from torch.nn import CrossEntropyLoss + +from ...activations import ACT2FN +from ...generation import GenerationMixin +from ...integrations.deepspeed import is_deepspeed_zero3_enabled +from ...integrations.fsdp import is_fsdp_managed_module +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask +from ...modeling_outputs import ( + MoEModelOutput, + MoEModelOutputWithPastAndCrossAttentions, + Seq2SeqMoEModelOutput, + Seq2SeqMoEOutput, +) +from ...modeling_utils import PreTrainedModel +from ...utils import ( + add_end_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from .configuration_nllb_moe import NllbMoeConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "NllbMoeConfig" +_CHECKPOINT_FOR_DOC = "hf-internal-testing/dummy-nllb-moe-2-experts" +_REAL_CHECKPOINT_FOR_DOC = "facebook/nllb-moe-54b" + + +#################################################### +# This dict contains ids and associated url +# for the pretrained weights provided with the models +#################################################### + + +# Copied from transformers.models.bart.modeling_bart.shift_tokens_right +def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): + """ + Shift input ids one token to the right. + """ + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() + shifted_input_ids[:, 0] = decoder_start_token_id + + if pad_token_id is None: + raise ValueError("self.model.config.pad_token_id has to be defined.") + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +# Copied from transformers.models.roberta.modeling_roberta.create_position_ids_from_input_ids +def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0): + """ + Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols + are ignored. This is modified from fairseq's `utils.make_positions`. + + Args: + x: torch.Tensor x: + + Returns: torch.Tensor + """ + # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. + mask = input_ids.ne(padding_idx).int() + incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask + return incremental_indices.long() + padding_idx + + +def load_balancing_loss_func(router_probs: torch.Tensor, expert_indices: torch.Tensor) -> float: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + router_probs (`torch.Tensor`): + Probability assigned to each expert per token. Shape: [batch_size, seqeunce_length, num_experts]. + expert_indices (`torch.Tensor`): + Indices tensor of shape [batch_size, seqeunce_length] identifying the selected expert for a given token. + + Returns: + The auxiliary loss. + """ + if router_probs is None: + return 0 + + num_experts = router_probs.shape[-1] + + # cast the expert indices to int64, otherwise one-hot encoding will fail + if expert_indices.dtype != torch.int64: + expert_indices = expert_indices.to(torch.int64) + + if len(expert_indices.shape) == 2: + expert_indices = expert_indices.unsqueeze(2) + + expert_mask = torch.nn.functional.one_hot(expert_indices, num_experts) + + # For a given token, determine if it was routed to a given expert. + expert_mask = torch.max(expert_mask, axis=-2).values + + # cast to float32 otherwise mean will fail + expert_mask = expert_mask.to(torch.float32) + tokens_per_group_and_expert = torch.mean(expert_mask, axis=-2) + + router_prob_per_group_and_expert = torch.mean(router_probs, axis=-2) + return torch.mean(tokens_per_group_and_expert * router_prob_per_group_and_expert) * (num_experts**2) + + +# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100ScaledWordEmbedding with M2M100->NllbMoe +class NllbMoeScaledWordEmbedding(nn.Embedding): + """ + This module overrides nn.Embeddings' forward by multiplying with embeddings scale. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0): + super().__init__(num_embeddings, embedding_dim, padding_idx) + self.embed_scale = embed_scale + + def forward(self, input_ids: torch.Tensor): + return super().forward(input_ids) * self.embed_scale + + +# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding +class NllbMoeSinusoidalPositionalEmbedding(nn.Module): + """This module produces sinusoidal positional embeddings of any length.""" + + def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None): + super().__init__() + self.offset = 2 + self.embedding_dim = embedding_dim + self.padding_idx = padding_idx + self.make_weights(num_positions + self.offset, embedding_dim, padding_idx) + + def make_weights(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None): + emb_weights = self.get_embedding(num_embeddings, embedding_dim, padding_idx) + if hasattr(self, "weights"): + # in forward put the weights on the correct dtype and device of the param + emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device) + + self.register_buffer("weights", emb_weights, persistent=False) + + @staticmethod + def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None): + """ + Build sinusoidal embeddings. + + This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of + "Attention Is All You Need". + """ + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb) + emb = torch.arange(num_embeddings, dtype=torch.int64).float().unsqueeze(1) * emb.unsqueeze(0) + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1) + if embedding_dim % 2 == 1: + # zero pad + emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) + if padding_idx is not None: + emb[padding_idx, :] = 0 + + return emb.to(torch.get_default_dtype()) + + @torch.no_grad() + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + past_key_values_length: int = 0, + ): + if input_ids is not None: + bsz, seq_len = input_ids.size() + # Create the position ids from the input token ids. Any padded tokens remain padded. + position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length).to( + input_ids.device + ) + else: + bsz, seq_len = inputs_embeds.size()[:-1] + position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds, past_key_values_length) + + # expand embeddings if needed + max_pos = self.padding_idx + 1 + seq_len + past_key_values_length + if max_pos > self.weights.size(0): + self.make_weights(max_pos + self.offset, self.embedding_dim, self.padding_idx) + + return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, self.weights.shape[-1]).detach() + + def create_position_ids_from_inputs_embeds(self, inputs_embeds, past_key_values_length): + """ + We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids. + + Args: + inputs_embeds: torch.Tensor + + Returns: torch.Tensor + """ + input_shape = inputs_embeds.size()[:-1] + sequence_length = input_shape[1] + + position_ids = torch.arange( + self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device + ) + return position_ids.unsqueeze(0).expand(input_shape).contiguous() + past_key_values_length + + +class NllbMoeTop2Router(nn.Module): + """ + Router using tokens choose top-2 experts assignment. + + This router uses the same mechanism as in NLLB-MoE from the fairseq repository. Items are sorted by router_probs + and then routed to their choice of expert until the expert's expert_capacity is reached. **There is no guarantee + that each token is processed by an expert**, or that each expert receives at least one token. + + The router combining weights are also returned to make sure that the states that are not updated will be masked. + + """ + + def __init__(self, config: NllbMoeConfig): + super().__init__() + self.num_experts = config.num_experts + self.expert_capacity = config.expert_capacity + self.classifier = nn.Linear(config.hidden_size, self.num_experts, bias=config.router_bias) + self.router_ignore_padding_tokens = config.router_ignore_padding_tokens + self.dtype = getattr(torch, config.router_dtype) + + self.second_expert_policy = config.second_expert_policy + self.normalize_router_prob_before_dropping = config.normalize_router_prob_before_dropping + self.batch_prioritized_routing = config.batch_prioritized_routing + self.moe_eval_capacity_token_fraction = config.moe_eval_capacity_token_fraction + + def _cast_classifier(self): + r""" + `bitsandbytes` `Linear8bitLt` layers does not support manual casting Therefore we need to check if they are an + instance of the `Linear8bitLt` class by checking special attributes. + """ + if not (hasattr(self.classifier, "SCB") or hasattr(self.classifier, "CB")): + self.classifier = self.classifier.to(self.dtype) + + def normalize_router_probabilities(self, router_probs, top_1_mask, top_2_mask): + top_1_max_probs = (router_probs * top_1_mask).sum(dim=1) + top_2_max_probs = (router_probs * top_2_mask).sum(dim=1) + denom_s = torch.clamp(top_1_max_probs + top_2_max_probs, min=torch.finfo(router_probs.dtype).eps) + top_1_max_probs = top_1_max_probs / denom_s + top_2_max_probs = top_2_max_probs / denom_s + return top_1_max_probs, top_2_max_probs + + def route_tokens( + self, + router_logits: torch.Tensor, + input_dtype: torch.dtype = torch.float32, + padding_mask: Optional[torch.LongTensor] = None, + ) -> Tuple: + """ + Computes the `dispatch_mask` and the `dispatch_weights` for each experts. The masks are adapted to the expert + capacity. + """ + nb_tokens = router_logits.shape[0] + # Apply Softmax and cast back to the original `dtype` + router_probs = nn.functional.softmax(router_logits, dim=-1, dtype=self.dtype).to(input_dtype) + top_1_expert_index = torch.argmax(router_probs, dim=-1) + top_1_mask = torch.nn.functional.one_hot(top_1_expert_index, num_classes=self.num_experts) + + if self.second_expert_policy == "sampling": + gumbel = torch.distributions.gumbel.Gumbel(0, 1).rsample + router_logits += gumbel(router_logits.shape).to(router_logits.device) + + # replace top_1_expert_index with min values + logits_except_top_1 = router_logits.masked_fill(top_1_mask.bool(), float("-inf")) + top_2_expert_index = torch.argmax(logits_except_top_1, dim=-1) + top_2_mask = torch.nn.functional.one_hot(top_2_expert_index, num_classes=self.num_experts) + + if self.normalize_router_prob_before_dropping: + top_1_max_probs, top_2_max_probs = self.normalize_router_probabilities( + router_probs, top_1_mask, top_2_mask + ) + + if self.second_expert_policy == "random": + top_2_max_probs = (router_probs * top_2_mask).sum(dim=1) + sampled = (2 * top_2_max_probs) > torch.rand_like(top_2_max_probs.float()) + top_2_mask = top_2_mask * sampled.repeat(self.num_experts, 1).transpose(1, 0) + + if padding_mask is not None and not self.router_ignore_padding_tokens: + if len(padding_mask.shape) == 4: + # only get the last causal mask + padding_mask = padding_mask[:, :, -1, :].reshape(-1)[-nb_tokens:] + non_padding = ~padding_mask.bool() + top_1_mask = top_1_mask * non_padding.unsqueeze(-1).to(top_1_mask.dtype) + top_2_mask = top_2_mask * non_padding.unsqueeze(-1).to(top_1_mask.dtype) + + if self.batch_prioritized_routing: + # sort tokens based on their routing probability + # to make sure important tokens are routed, first + importance_scores = -1 * router_probs.max(dim=1)[0] + sorted_top_1_mask = top_1_mask[importance_scores.argsort(dim=0)] + sorted_cumsum1 = (torch.cumsum(sorted_top_1_mask, dim=0) - 1) * sorted_top_1_mask + locations1 = sorted_cumsum1[importance_scores.argsort(dim=0).argsort(dim=0)] + + sorted_top_2_mask = top_2_mask[importance_scores.argsort(dim=0)] + sorted_cumsum2 = (torch.cumsum(sorted_top_2_mask, dim=0) - 1) * sorted_top_2_mask + locations2 = sorted_cumsum2[importance_scores.argsort(dim=0).argsort(dim=0)] + # Update 2nd's location by accounting for locations of 1st + locations2 += torch.sum(top_1_mask, dim=0, keepdim=True) + + else: + locations1 = torch.cumsum(top_1_mask, dim=0) - 1 + locations2 = torch.cumsum(top_2_mask, dim=0) - 1 + # Update 2nd's location by accounting for locations of 1st + locations2 += torch.sum(top_1_mask, dim=0, keepdim=True) + + if not self.training and self.moe_eval_capacity_token_fraction > 0: + self.expert_capacity = math.ceil(self.moe_eval_capacity_token_fraction * nb_tokens) + else: + capacity = 2 * math.ceil(nb_tokens / self.num_experts) + self.expert_capacity = capacity if self.expert_capacity is None else self.expert_capacity + + # Remove locations outside capacity from ( cumsum < capacity = False will not be routed) + top_1_mask = top_1_mask * torch.lt(locations1, self.expert_capacity) + top_2_mask = top_2_mask * torch.lt(locations2, self.expert_capacity) + + if not self.normalize_router_prob_before_dropping: + top_1_max_probs, top_2_max_probs = self.normalize_router_probabilities( + router_probs, top_1_mask, top_2_mask + ) + + # Calculate combine_weights and dispatch_mask + gates1 = top_1_max_probs[:, None] * top_1_mask + gates2 = top_2_max_probs[:, None] * top_2_mask + router_probs = gates1 + gates2 + + return top_1_mask, router_probs + + def forward(self, hidden_states: torch.Tensor, padding_mask: Optional[torch.LongTensor] = None) -> Tuple: + r""" + The hidden states are reshaped to simplify the computation of the router probabilities (combining weights for + each experts.) + + Args: + hidden_states (`torch.Tensor`): + (batch_size, sequence_length, hidden_dim) from which router probabilities are computed. + Returns: + top_1_mask (`torch.Tensor` of shape (batch_size, sequence_length)): + Index tensor of shape [batch_size, sequence_length] corresponding to the expert selected for each token + using the top1 probabilities of the router. + router_probabilities (`torch.Tensor` of shape (batch_size, sequence_length, nump_experts)): + Tensor of shape (batch_size, sequence_length, num_experts) corresponding to the probabilities for each + token and expert. Used for routing tokens to experts. + router_logits (`torch.Tensor` of shape (batch_size, sequence_length))): + Logits tensor of shape (batch_size, sequence_length, num_experts) corresponding to raw router logits. + This is used later for computing router z-loss. + """ + self.input_dtype = hidden_states.dtype + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.reshape((batch_size * sequence_length), hidden_dim) + hidden_states = hidden_states.to(self.dtype) + self._cast_classifier() + router_logits = self.classifier(hidden_states) + top_1_mask, router_probs = self.route_tokens(router_logits, self.input_dtype, padding_mask) + return top_1_mask, router_probs + + +class NllbMoeDenseActDense(nn.Module): + def __init__(self, config: NllbMoeConfig, ffn_dim: int): + super().__init__() + self.fc1 = nn.Linear(config.d_model, ffn_dim) + self.fc2 = nn.Linear(ffn_dim, config.d_model) + self.dropout = nn.Dropout(config.activation_dropout) + self.act = ACT2FN[config.activation_function] + + def forward(self, hidden_states): + hidden_states = self.fc1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.dropout(hidden_states) + if ( + isinstance(self.fc2.weight, torch.Tensor) + and hidden_states.dtype != self.fc2.weight.dtype + and (self.fc2.weight.dtype != torch.int8 and self.fc2.weight.dtype != torch.uint8) + ): + hidden_states = hidden_states.to(self.fc2.weight.dtype) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class NllbMoeSparseMLP(nn.Module): + r""" + Implementation of the NLLB-MoE sparse MLP module. + """ + + def __init__(self, config: NllbMoeConfig, ffn_dim: int, expert_class: nn.Module = NllbMoeDenseActDense): + super().__init__() + self.router = NllbMoeTop2Router(config) + self.moe_token_dropout = config.moe_token_dropout + self.token_dropout = nn.Dropout(self.moe_token_dropout) + self.num_experts = config.num_experts + + self.experts = nn.ModuleDict() + for idx in range(self.num_experts): + self.experts[f"expert_{idx}"] = expert_class(config, ffn_dim) + + def forward(self, hidden_states: torch.Tensor, padding_mask: Optional[torch.Tensor] = False): + r""" + The goal of this forward pass is to have the same number of operation as the equivalent `NllbMoeDenseActDense` + (mlp) layer. This means that all of the hidden states should be processed at most twice ( since we are using a + top_2 gating mecanism). This means that we keep the complexity to O(batch_size x sequence_length x hidden_dim) + instead of O(num_experts x batch_size x sequence_length x hidden_dim). + + 1- Get the `router_probs` from the `router`. The shape of the `router_mask` is `(batch_size X sequence_length, + num_expert)` and corresponds to the boolean version of the `router_probs`. The inputs are masked using the + `router_mask`. + + 2- Dispatch the hidden_states to its associated experts. The router probabilities are used to weight the + contribution of each experts when updating the masked hidden states. + + Args: + hidden_states (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_dim)`): + The hidden states + padding_mask (`torch.Tensor`, *optional*, defaults to `False`): + Attention mask. Can be in the causal form or not. + + Returns: + hidden_states (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_dim)`): + Updated hidden states + router_logits (`torch.Tensor` of shape `(batch_size, sequence_length, num_experts)`): + Needed for computing the loss + + """ + batch_size, sequence_length, hidden_dim = hidden_states.shape + + top_1_mask, router_probs = self.router(hidden_states, padding_mask) + router_mask = router_probs.bool() + hidden_states = hidden_states.reshape((batch_size * sequence_length), hidden_dim) + masked_hidden_states = torch.einsum("bm,be->ebm", hidden_states, router_mask) + for idx, expert in enumerate(self.experts.values()): + token_indices = router_mask[:, idx] + combining_weights = router_probs[token_indices, idx] + expert_output = expert(masked_hidden_states[idx, token_indices]) + if self.moe_token_dropout > 0: + if self.training: + expert_output = self.token_dropout(expert_output) + else: + expert_output *= 1 - self.moe_token_dropout + masked_hidden_states[idx, token_indices] = torch.einsum("b,be->be", combining_weights, expert_output) + hidden_states = masked_hidden_states.sum(dim=0).reshape(batch_size, sequence_length, hidden_dim) + + top_1_expert_index = torch.argmax(top_1_mask, dim=-1) + return hidden_states, (router_probs, top_1_expert_index) + + +# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->NllbMoe,key_value_states->encoder_hidden_states +class NllbMoeAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + is_causal: bool = False, + config: Optional[NllbMoeConfig] = None, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + self.config = config + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + self.is_causal = is_causal + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if encoder_hidden_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = encoder_hidden_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + # `past_key_value[0].shape[2] == encoder_hidden_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `encoder_hidden_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == encoder_hidden_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(encoder_hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(encoder_hidden_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.reshape(*proj_shape) + value_states = value_states.reshape(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value + + +class NllbMoeEncoderLayer(nn.Module): + def __init__(self, config: NllbMoeConfig, is_sparse: bool = False): + super().__init__() + self.embed_dim = config.d_model + self.is_sparse = is_sparse + self.self_attn = NllbMoeAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + ) + self.attn_dropout = nn.Dropout(config.dropout) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + if not self.is_sparse: + self.ffn = NllbMoeDenseActDense(config, ffn_dim=config.encoder_ffn_dim) + else: + self.ffn = NllbMoeSparseMLP(config, ffn_dim=config.encoder_ffn_dim) + self.ff_layer_norm = nn.LayerNorm(config.d_model) + self.ff_dropout = nn.Dropout(config.activation_dropout) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + layer_head_mask: torch.Tensor, + output_attentions: bool = False, + output_router_logits: bool = False, + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): + input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): + attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very + large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states, attn_weights, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = self.attn_dropout(hidden_states) + hidden_states = residual + hidden_states + + residual = hidden_states + + hidden_states = self.ff_layer_norm(hidden_states) + if self.is_sparse: + hidden_states, router_states = self.ffn(hidden_states, attention_mask) + else: + # router_states set to None to track which layers have None gradients. + hidden_states, router_states = self.ffn(hidden_states), None + + hidden_states = self.ff_dropout(hidden_states) + + hidden_states = residual + hidden_states + + if hidden_states.dtype == torch.float16 and ( + torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() + ): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + if output_router_logits: + outputs += (router_states,) + + return outputs + + +class NllbMoeDecoderLayer(nn.Module): + def __init__(self, config: NllbMoeConfig, is_sparse: bool = False): + super().__init__() + self.embed_dim = config.d_model + self.is_sparse = is_sparse + self.self_attn = NllbMoeAttention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + ) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.attn_dropout = nn.Dropout(config.dropout) + + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.cross_attention = NllbMoeAttention( + self.embed_dim, config.decoder_attention_heads, config.attention_dropout, is_decoder=True + ) + self.cross_attention_layer_norm = nn.LayerNorm(self.embed_dim) + if not self.is_sparse: + self.ffn = NllbMoeDenseActDense(config, ffn_dim=config.decoder_ffn_dim) + else: + self.ffn = NllbMoeSparseMLP(config, ffn_dim=config.decoder_ffn_dim) + self.ff_layer_norm = nn.LayerNorm(config.d_model) + self.ff_dropout = nn.Dropout(config.activation_dropout) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + cross_attn_layer_head_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + output_router_logits: Optional[bool] = False, + use_cache: Optional[bool] = True, + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): + input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): + attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very + large negative values. + encoder_hidden_states (`torch.FloatTensor`): + cross attention input to the layer of shape `(batch, seq_len, embed_dim)` + encoder_attention_mask (`torch.FloatTensor`): + encoder attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by + very large negative values. + layer_head_mask (`torch.FloatTensor`): + mask for attention heads in a given layer of size `(encoder_attention_heads,)`. + cross_attn_layer_head_mask (`torch.FloatTensor`): + mask for cross-attention heads in a given layer of size `(decoder_attention_heads,)`. + past_key_value (`Tuple(torch.FloatTensor)`): + cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + # add present self-attn cache to positions 1,2 of present_key_value tuple + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=self_attn_past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = self.attn_dropout(hidden_states) + hidden_states = residual + hidden_states + + # Cross-Attention Block + cross_attn_present_key_value = None + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + hidden_states = self.cross_attention_layer_norm(hidden_states) + + # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + hidden_states, cross_attn_weights, cross_attn_present_key_value = self.cross_attention( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + past_key_value=cross_attn_past_key_value, + attention_mask=encoder_attention_mask, + layer_head_mask=cross_attn_layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = self.attn_dropout(hidden_states) + hidden_states = residual + hidden_states + + # add cross-attn to positions 3,4 of present_key_value tuple + present_key_value += cross_attn_present_key_value + + # Fully Connected + residual = hidden_states + + hidden_states = self.ff_layer_norm(hidden_states) + if self.is_sparse: + hidden_states, router_states = self.ffn(hidden_states, attention_mask) + else: + hidden_states, router_states = self.ffn(hidden_states), None + + hidden_states = self.ff_dropout(hidden_states) + + hidden_states = residual + hidden_states + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states, present_key_value) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + if output_router_logits: + outputs += (router_states,) + + return outputs + + +class NllbMoePreTrainedModel(PreTrainedModel): + config_class = NllbMoeConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["NllbMoeEncoderLayer", "NllbMoeDecoderLayer"] + + def _init_weights(self, module): + """Initialize the weights""" + std = self.config.init_std + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +NLLB_MOE_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`NllbMoeConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +NLLB_MOE_GENERATION_EXAMPLE = r""" + Translation example: + + ```python + >>> from transformers import AutoTokenizer, NllbMoeForConditionalGeneration + + >>> model = NllbMoeForConditionalGeneration.from_pretrained("facebook/nllb-moe-54b") + >>> tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-moe-54b") + + >>> text_to_translate = "Life is like a box of chocolates" + >>> model_inputs = tokenizer(text_to_translate, return_tensors="pt") + + >>> # translate to French + >>> gen_tokens = model.generate(**model_inputs, forced_bos_token_id=tokenizer.get_lang_id("eng_Latn")) + >>> print(tokenizer.batch_decode(gen_tokens, skip_special_tokens=True)) + ``` +""" + +NLLB_MOE_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + NllbMoe uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If + `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + decoder_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0, + 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): + Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of + hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded + representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be + input (see `past_key_values`). This is useful if you want more control over how to convert + `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. + + If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value + of `inputs_embeds`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, and + should not be returned during inference. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class NllbMoeEncoder(NllbMoePreTrainedModel): + """ + Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a + [`NllbMoeEncoderLayer`]. + + Args: + config: + NllbMoeConfig + embed_tokens (nn.Embedding): + output embedding + """ + + def __init__(self, config: NllbMoeConfig, embed_tokens: Optional[nn.Embedding] = None): + super().__init__(config) + + self.dropout = config.dropout + self.layerdrop = config.encoder_layerdrop + + embed_dim = config.d_model + self.padding_idx = config.pad_token_id + self.max_source_positions = config.max_position_embeddings + embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + + self.embed_tokens = NllbMoeScaledWordEmbedding( + config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale + ) + + if embed_tokens is not None: + self.embed_tokens.weight = embed_tokens.weight + + self.embed_positions = NllbMoeSinusoidalPositionalEmbedding( + config.max_position_embeddings, + embed_dim, + self.padding_idx, + ) + sparse_step = config.encoder_sparse_step + self.layers = nn.ModuleList() + for i in range(config.encoder_layers): + is_sparse = (i + 1) % sparse_step == 0 if sparse_step > 0 else False + self.layers.append(NllbMoeEncoderLayer(config, is_sparse)) + + self.layer_norm = nn.LayerNorm(config.d_model) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, + and should not be returned during inference. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + embed_pos = self.embed_positions(input_ids, inputs_embeds) + embed_pos = embed_pos.to(inputs_embeds.device) + + hidden_states = inputs_embeds + embed_pos + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + # expand attention_mask + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + encoder_states = () if output_hidden_states else None + all_router_probs = () if output_router_logits else None + all_attentions = () if output_attentions else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + if head_mask.size()[0] != len(self.layers): + raise ValueError( + f"The head_mask should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = torch.rand([]) + if self.training and (dropout_probability < self.layerdrop): # skip the layer + layer_outputs = (None, None, None) + else: + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + attention_mask, + (head_mask[idx] if head_mask is not None else None), + output_attentions, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + output_router_logits=output_router_logits, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions += (layer_outputs[1],) + + if output_router_logits: + all_router_probs += (layer_outputs[-1],) + + last_hidden_state = self.layer_norm(hidden_states) + + if output_hidden_states: + encoder_states += (last_hidden_state,) + + if not return_dict: + return tuple( + v for v in [last_hidden_state, encoder_states, all_attentions, all_router_probs] if v is not None + ) + + return MoEModelOutput( + last_hidden_state=last_hidden_state, + hidden_states=encoder_states, + attentions=all_attentions, + router_probs=all_router_probs, + ) + + +class NllbMoeDecoder(NllbMoePreTrainedModel): + """ + Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`NllbMoeDecoderLayer`] + + Args: + config: + NllbMoeConfig + embed_tokens (nn.Embedding): + output embedding + """ + + def __init__(self, config: NllbMoeConfig, embed_tokens: Optional[nn.Embedding] = None): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.decoder_layerdrop + self.padding_idx = config.pad_token_id + self.max_target_positions = config.max_position_embeddings + embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + + self.embed_tokens = NllbMoeScaledWordEmbedding( + config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale + ) + + if embed_tokens is not None: + self.embed_tokens.weight = embed_tokens.weight + + self.embed_positions = NllbMoeSinusoidalPositionalEmbedding( + config.max_position_embeddings, + config.d_model, + self.padding_idx, + ) + + sparse_step = config.decoder_sparse_step + self.layers = nn.ModuleList() + for i in range(config.decoder_layers): + is_sparse = (i + 1) % sparse_step == 0 if sparse_step > 0 else False + self.layers.append(NllbMoeDecoderLayer(config, is_sparse)) + + self.layer_norm = nn.LayerNorm(config.d_model) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): + Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values + selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing + cross-attention on hidden heads. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, + and should not be returned during inference. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + # embed positions + positions = self.embed_positions(input_ids, inputs_embeds, past_key_values_length) + positions = positions.to(inputs_embeds.device) + + hidden_states = inputs_embeds + positions + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_router_probs = () if output_router_logits else None + all_cross_attentions = () if output_attentions else None + present_key_value_states = () if use_cache else None + + # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): + if attn_mask is not None: + if attn_mask.size()[0] != len(self.layers): + raise ValueError( + f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = torch.rand([]) + + skip_the_layer = True if self.training and (dropout_probability < self.layerdrop) else False + if not skip_the_layer or synced_gpus: + layer_head_mask = head_mask[idx] if head_mask is not None else None + cross_attn_layer_head_mask = cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + # under fsdp or deepspeed zero3 all gpus must run in sync + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.forward, + hidden_states, + combined_attention_mask, + encoder_hidden_states, + encoder_attention_mask, + layer_head_mask, + cross_attn_layer_head_mask, + None, # past_key_value is always None with gradient checkpointing + use_cache, + output_attentions, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=combined_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=layer_head_mask, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + ) + + hidden_states = layer_outputs[0] + + if skip_the_layer: + continue + + if use_cache: + present_key_value_states += (layer_outputs[1],) + + if output_attentions: + all_self_attns += (layer_outputs[2],) + all_cross_attentions += (layer_outputs[3],) + + if output_router_logits: + all_router_probs += (layer_outputs[-1],) + + hidden_states = self.layer_norm(hidden_states) + + # Add last layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + present_key_value_states, + all_hidden_states, + all_self_attns, + all_cross_attentions, + all_router_probs, + ] + if v is not None + ) + return MoEModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=present_key_value_states, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + router_probs=all_router_probs, + ) + + +@add_start_docstrings( + "The bare NllbMoe Model outputting raw hidden-states without any specific head on top.", + NLLB_MOE_START_DOCSTRING, +) +class NllbMoeModel(NllbMoePreTrainedModel): + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + + def __init__(self, config: NllbMoeConfig): + super().__init__(config) + + padding_idx, vocab_size = config.pad_token_id, config.vocab_size + embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + self.shared = NllbMoeScaledWordEmbedding(vocab_size, config.d_model, padding_idx, embed_scale=embed_scale) + + self.encoder = NllbMoeEncoder(config, self.shared) + self.decoder = NllbMoeDecoder(config, self.shared) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, value): + self.shared = value + self.encoder.embed_tokens = self.shared + self.decoder.embed_tokens = self.shared + + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + @add_start_docstrings_to_model_forward(NLLB_MOE_INPUTS_DOCSTRING) + @add_start_docstrings_to_model_forward(NLLB_MOE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqMoEModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], Seq2SeqMoEModelOutput]: + r""" + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, NllbMoeModel + + >>> tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/random-nllb-moe-2-experts") + >>> model = SwitchTransformersModel.from_pretrained("hf-internal-testing/random-nllb-moe-2-experts") + + >>> input_ids = tokenizer( + ... "Studies have been shown that owning a dog is good for you", return_tensors="pt" + ... ).input_ids # Batch size 1 + >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1 + + >>> # preprocess: Prepend decoder_input_ids with start token which is pad token for NllbMoeModel + >>> decoder_input_ids = model._shift_right(decoder_input_ids) + + >>> # forward pass + >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) + >>> last_hidden_states = outputs.last_hidden_state + ```""" + return_dict = return_dict if return_dict is not None else self.config.return_dict + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, MoEModelOutput): + encoder_outputs = MoEModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + router_probs=encoder_outputs[3] if len(encoder_outputs) > 3 else None, + ) + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + return_dict=return_dict, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqMoEModelOutput( + past_key_values=decoder_outputs.past_key_values, + cross_attentions=decoder_outputs.cross_attentions, + last_hidden_state=decoder_outputs.last_hidden_state, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + decoder_hidden_states=decoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + decoder_attentions=decoder_outputs.attentions, + encoder_router_logits=encoder_outputs.router_probs, + decoder_router_logits=decoder_outputs.router_probs, + ) + + +@add_start_docstrings( + "The NllbMoe Model with a language modeling head. Can be used for summarization.", NLLB_MOE_START_DOCSTRING +) +class NllbMoeForConditionalGeneration(NllbMoePreTrainedModel, GenerationMixin): + base_model_prefix = "model" + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] + + def __init__(self, config: NllbMoeConfig): + super().__init__(config) + self.model = NllbMoeModel(config) + self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) + + self.router_z_loss_coef = config.router_z_loss_coef + self.router_aux_loss_coef = config.router_aux_loss_coef + # Initialize weights and apply final processing + self.post_init() + + def get_encoder(self): + return self.model.get_encoder() + + def get_decoder(self): + return self.model.get_decoder() + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + @add_start_docstrings_to_model_forward(NLLB_MOE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqMoEOutput, config_class=_CONFIG_FOR_DOC) + @add_end_docstrings(NLLB_MOE_GENERATION_EXAMPLE) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], Seq2SeqMoEOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + """ + return_dict = return_dict if return_dict is not None else self.config.return_dict + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + if labels is not None: + if decoder_input_ids is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + encoder_outputs=encoder_outputs, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + return_dict=return_dict, + ) + lm_logits = self.lm_head(outputs[0]) + + loss = None + encoder_aux_loss = None + decoder_aux_loss = None + + if labels is not None: + loss_fct = CrossEntropyLoss(ignore_index=-100) + # todo check in the config if router loss enables + + if output_router_logits: + encoder_router_logits = outputs[-1] + decoder_router_logits = outputs[3 if output_attentions else 4] + + # Compute the router loss (z_loss + auxiliary loss) for each router in the encoder and decoder + encoder_router_logits, encoder_expert_indexes = self._unpack_router_logits(encoder_router_logits) + encoder_aux_loss = load_balancing_loss_func(encoder_router_logits, encoder_expert_indexes) + + decoder_router_logits, decoder_expert_indexes = self._unpack_router_logits(decoder_router_logits) + decoder_aux_loss = load_balancing_loss_func(decoder_router_logits, decoder_expert_indexes) + + loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) + + if output_router_logits and labels is not None: + aux_loss = self.router_aux_loss_coef * (encoder_aux_loss + decoder_aux_loss) + loss = loss + aux_loss + + output = (loss,) if loss is not None else () + if not return_dict: + output += (lm_logits,) + if output_router_logits: # only return the loss if they are not None + output += ( + encoder_aux_loss, + decoder_aux_loss, + *outputs[1:], + ) + else: + output += outputs[1:] + + return output + + return Seq2SeqMoEOutput( + loss=loss, + logits=lm_logits, + past_key_values=outputs.past_key_values, + cross_attentions=outputs.cross_attentions, + encoder_aux_loss=encoder_aux_loss, + decoder_aux_loss=decoder_aux_loss, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + decoder_hidden_states=outputs.decoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + decoder_attentions=outputs.decoder_attentions, + encoder_router_logits=outputs.encoder_router_logits, + decoder_router_logits=outputs.decoder_router_logits, + ) + + def _unpack_router_logits(self, router_outputs): + total_router_logits = [] + total_expert_indexes = [] + for router_output in router_outputs: + if router_output is not None: + router_logits, expert_indexes = router_output + total_router_logits.append(router_logits) + total_expert_indexes.append(expert_indexes) + + total_router_logits = torch.cat(total_router_logits, dim=1) if len(total_router_logits) > 0 else None + total_expert_indexes = torch.stack(total_expert_indexes, dim=1) if len(total_expert_indexes) > 0 else None + return total_router_logits, total_expert_indexes + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + + +__all__ = [ + "NllbMoeForConditionalGeneration", + "NllbMoeModel", + "NllbMoePreTrainedModel", + "NllbMoeTop2Router", + "NllbMoeSparseMLP", +] diff --git a/docs/transformers/src/transformers/models/nougat/__init__.py b/docs/transformers/src/transformers/models/nougat/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4c87d75e58e8fbb2a39f692029c766a912638d5c --- /dev/null +++ b/docs/transformers/src/transformers/models/nougat/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .image_processing_nougat import * + from .processing_nougat import * + from .tokenization_nougat_fast import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/nougat/convert_nougat_to_hf.py b/docs/transformers/src/transformers/models/nougat/convert_nougat_to_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..e42f8553ac4f5abf11ca942c108a3e613825bfd0 --- /dev/null +++ b/docs/transformers/src/transformers/models/nougat/convert_nougat_to_hf.py @@ -0,0 +1,282 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert Nougat checkpoints using the original `nougat` library. URL: +https://github.com/facebookresearch/nougat/tree/main""" + +import argparse + +import torch +from huggingface_hub import hf_hub_download +from nougat import NougatModel +from nougat.dataset.rasterize import rasterize_paper +from nougat.utils.checkpoint import get_checkpoint +from PIL import Image + +from transformers import ( + DonutSwinConfig, + DonutSwinModel, + MBartConfig, + MBartForCausalLM, + NougatImageProcessor, + NougatProcessor, + NougatTokenizerFast, + VisionEncoderDecoderModel, +) + + +def get_configs(model): + original_config = model.config + + encoder_config = DonutSwinConfig( + image_size=original_config.input_size, + patch_size=4, + depths=original_config.encoder_layer, + num_heads=[4, 8, 16, 32], + window_size=original_config.window_size, + embed_dim=128, + ) + decoder_config = MBartConfig( + is_decoder=True, + is_encoder_decoder=False, + add_cross_attention=True, + decoder_layers=original_config.decoder_layer, + max_position_embeddings=original_config.max_position_embeddings, + vocab_size=len( + model.decoder.tokenizer + ), # several special tokens are added to the vocab of XLMRobertaTokenizer, see repo on the hub (added_tokens.json) + scale_embedding=True, + add_final_layer_norm=True, + tie_word_embeddings=False, + ) + + return encoder_config, decoder_config + + +# Copied from transformers.models.donut.convert_donut_to_pytorch.rename_key +def rename_key(name): + if "encoder.model" in name: + name = name.replace("encoder.model", "encoder") + if "decoder.model" in name: + name = name.replace("decoder.model", "decoder") + if "patch_embed.proj" in name: + name = name.replace("patch_embed.proj", "embeddings.patch_embeddings.projection") + if "patch_embed.norm" in name: + name = name.replace("patch_embed.norm", "embeddings.norm") + if name.startswith("encoder"): + if "layers" in name: + name = "encoder." + name + if "attn.proj" in name: + name = name.replace("attn.proj", "attention.output.dense") + if "attn" in name and "mask" not in name: + name = name.replace("attn", "attention.self") + if "norm1" in name: + name = name.replace("norm1", "layernorm_before") + if "norm2" in name: + name = name.replace("norm2", "layernorm_after") + if "mlp.fc1" in name: + name = name.replace("mlp.fc1", "intermediate.dense") + if "mlp.fc2" in name: + name = name.replace("mlp.fc2", "output.dense") + + if name == "encoder.norm.weight": + name = "encoder.layernorm.weight" + if name == "encoder.norm.bias": + name = "encoder.layernorm.bias" + + return name + + +# Copied from transformers.models.donut.convert_donut_to_pytorch.convert_state_dict +def convert_state_dict(orig_state_dict, model): + for key in orig_state_dict.copy().keys(): + val = orig_state_dict.pop(key) + + if "qkv" in key: + key_split = key.split(".") + layer_num = int(key_split[3]) + block_num = int(key_split[5]) + dim = model.encoder.encoder.layers[layer_num].blocks[block_num].attention.self.all_head_size + + if "weight" in key: + orig_state_dict[ + f"encoder.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.query.weight" + ] = val[:dim, :] + orig_state_dict[f"encoder.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.key.weight"] = ( + val[dim : dim * 2, :] + ) + orig_state_dict[ + f"encoder.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.value.weight" + ] = val[-dim:, :] + else: + orig_state_dict[f"encoder.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.query.bias"] = ( + val[:dim] + ) + orig_state_dict[f"encoder.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.key.bias"] = ( + val[dim : dim * 2] + ) + orig_state_dict[f"encoder.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.value.bias"] = ( + val[-dim:] + ) + elif "attn_mask" in key or key in ["encoder.model.norm.weight", "encoder.model.norm.bias"]: + # HuggingFace implementation doesn't use attn_mask buffer + # and model doesn't use final LayerNorms for the encoder + pass + else: + orig_state_dict[rename_key(key)] = val + + return orig_state_dict + + +def convert_nougat_checkpoint(model_tag, pytorch_dump_folder_path=None, push_to_hub=False): + # load original model + checkpoint_path = get_checkpoint(None, model_tag) + original_model = NougatModel.from_pretrained(checkpoint_path) + original_model.eval() + + # load HuggingFace model + encoder_config, decoder_config = get_configs(original_model) + encoder = DonutSwinModel(encoder_config) + decoder = MBartForCausalLM(decoder_config) + model = VisionEncoderDecoderModel(encoder=encoder, decoder=decoder) + model.eval() + + state_dict = original_model.state_dict() + new_state_dict = convert_state_dict(state_dict, model) + model.load_state_dict(new_state_dict) + + # verify results on PDF + filepath = hf_hub_download(repo_id="ysharma/nougat", filename="input/nougat.pdf", repo_type="space") + images = rasterize_paper(pdf=filepath, return_pil=True) + image = Image.open(images[0]) + + tokenizer_file = checkpoint_path / "tokenizer.json" + tokenizer = NougatTokenizerFast(tokenizer_file=str(tokenizer_file)) + tokenizer.pad_token = "" + tokenizer.bos_token = "" + tokenizer.eos_token = "" + tokenizer.unk_token = "" + tokenizer.model_max_length = original_model.config.max_length + + size = {"height": original_model.config.input_size[0], "width": original_model.config.input_size[1]} + image_processor = NougatImageProcessor( + do_align_long_axis=original_model.config.align_long_axis, + size=size, + ) + processor = NougatProcessor(image_processor=image_processor, tokenizer=tokenizer) + + # verify pixel_values + pixel_values = processor(image, return_tensors="pt").pixel_values + original_pixel_values = original_model.encoder.prepare_input(image).unsqueeze(0) + + assert torch.allclose(original_pixel_values, pixel_values) + + # verify patch embeddings + original_patch_embed = original_model.encoder.model.patch_embed(pixel_values) + patch_embeddings, _ = model.encoder.embeddings(pixel_values) + assert torch.allclose(original_patch_embed, patch_embeddings) + + # verify encoder hidden states + original_last_hidden_state = original_model.encoder(pixel_values) + last_hidden_state = model.encoder(pixel_values).last_hidden_state + assert torch.allclose(original_last_hidden_state, last_hidden_state, atol=1e-2) + + # NOTE original model does not use tied weights for embeddings of decoder + original_embeddings = original_model.decoder.model.model.decoder.embed_tokens + embeddings = model.decoder.model.decoder.embed_tokens + assert torch.allclose(original_embeddings.weight, embeddings.weight, atol=1e-3) + + # verify decoder hidden states + prompt = "hello world" + decoder_input_ids = original_model.decoder.tokenizer( + prompt, add_special_tokens=False, return_tensors="pt" + ).input_ids + decoder_attention_mask = torch.ones_like(decoder_input_ids) + original_logits = original_model( + image_tensors=pixel_values, decoder_input_ids=decoder_input_ids, attention_mask=decoder_attention_mask + ).logits + logits = model( + pixel_values, + decoder_input_ids=decoder_input_ids[:, :-1], + decoder_attention_mask=decoder_attention_mask[:, :-1], + ).logits + assert torch.allclose(original_logits, logits, atol=1e-3) + + # verify generation + outputs = model.generate( + pixel_values, + min_length=1, + max_length=30, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + use_cache=True, + bad_words_ids=[ + [tokenizer.unk_token_id], + ], + return_dict_in_generate=True, + do_sample=False, + ) + generated = tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True)[0] + + if model_tag == "0.1.0-base": + expected_generation = "# Nougat: Neural Optical Understanding for Academic Documents\n\nLukas Blecher\n\nCorrespondence to: lblec" + elif model_tag == "0.1.0-small": + expected_generation = ( + "# Nougat: Neural Optical Understanding for Academic Documents\n\nLukas Blecher\n\nCorrespondence to: lble" + ) + else: + raise ValueError(f"Unexpected model tag: {model_tag}") + + assert generated == expected_generation + print("Looks ok!") + + if pytorch_dump_folder_path is not None: + print(f"Saving model and processor to {pytorch_dump_folder_path}") + model.save_pretrained(pytorch_dump_folder_path) + processor.save_pretrained(pytorch_dump_folder_path) + + if push_to_hub: + tag_to_name = {"0.1.0-base": "nougat-base", "0.1.0-small": "nougat-small"} + model_name = tag_to_name[model_tag] + + model.push_to_hub(f"facebook/{model_name}") + processor.push_to_hub(f"facebook/{model_name}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--model_tag", + default="0.1.0-base", + required=False, + type=str, + choices=["0.1.0-base", "0.1.0-small"], + help="Tag of the original model you'd like to convert.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", + default=None, + required=False, + type=str, + help="Path to the output PyTorch model directory.", + ) + parser.add_argument( + "--push_to_hub", + action="store_true", + help="Whether or not to push the converted model and processor to the 🤗 hub.", + ) + + args = parser.parse_args() + convert_nougat_checkpoint(args.model_tag, args.pytorch_dump_folder_path, args.push_to_hub) diff --git a/docs/transformers/src/transformers/models/nougat/image_processing_nougat.py b/docs/transformers/src/transformers/models/nougat/image_processing_nougat.py new file mode 100644 index 0000000000000000000000000000000000000000..25b5c5e7bc858ad2437418c9925a398bf6b17dea --- /dev/null +++ b/docs/transformers/src/transformers/models/nougat/image_processing_nougat.py @@ -0,0 +1,525 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for Nougat.""" + +from typing import Dict, List, Optional, Union + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import ( + get_resize_output_image_size, + pad, + resize, + to_channel_dimension_format, + to_pil_image, +) +from ...image_utils import ( + IMAGENET_DEFAULT_MEAN, + IMAGENET_DEFAULT_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + get_image_size, + infer_channel_dimension_format, + is_scaled_image, + make_list_of_images, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...utils import TensorType, filter_out_non_signature_kwargs, logging +from ...utils.import_utils import is_cv2_available, is_vision_available + + +logger = logging.get_logger(__name__) + + +if is_cv2_available(): + pass + + +if is_vision_available(): + import PIL + + +class NougatImageProcessor(BaseImageProcessor): + r""" + Constructs a Nougat image processor. + + Args: + do_crop_margin (`bool`, *optional*, defaults to `True`): + Whether to crop the image margins. + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by + `do_resize` in the `preprocess` method. + size (`Dict[str, int]` *optional*, defaults to `{"height": 896, "width": 672}`): + Size of the image after resizing. Can be overridden by `size` in the `preprocess` method. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`): + Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method. + do_thumbnail (`bool`, *optional*, defaults to `True`): + Whether to resize the image using thumbnail method. + do_align_long_axis (`bool`, *optional*, defaults to `False`): + Whether to align the long axis of the image with the long axis of `size` by rotating by 90 degrees. + do_pad (`bool`, *optional*, defaults to `True`): + Whether to pad the images to the largest image size in the batch. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale` + parameter in the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the + `preprocess` method. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method. + image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_DEFAULT_MEAN`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_DEFAULT_STD`): + Image standard deviation. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_crop_margin: bool = True, + do_resize: bool = True, + size: Dict[str, int] = None, + resample: PILImageResampling = PILImageResampling.BILINEAR, + do_thumbnail: bool = True, + do_align_long_axis: bool = False, + do_pad: bool = True, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + size = size if size is not None else {"height": 896, "width": 672} + size = get_size_dict(size) + + self.do_crop_margin = do_crop_margin + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_thumbnail = do_thumbnail + self.do_align_long_axis = do_align_long_axis + self.do_pad = do_pad + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN + self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD + + def python_find_non_zero(self, image: np.array): + """This is a reimplementation of a findNonZero function equivalent to cv2.""" + non_zero_indices = np.column_stack(np.nonzero(image)) + idxvec = non_zero_indices[:, [1, 0]] + idxvec = idxvec.reshape(-1, 1, 2) + return idxvec + + def python_bounding_rect(self, coordinates): + """This is a reimplementation of a BoundingRect function equivalent to cv2.""" + min_values = np.min(coordinates, axis=(0, 1)).astype(int) + max_values = np.max(coordinates, axis=(0, 1)).astype(int) + x_min, y_min = min_values[0], min_values[1] + width = max_values[0] - x_min + 1 + height = max_values[1] - y_min + 1 + return x_min, y_min, width, height + + def crop_margin( + self, + image: np.array, + gray_threshold: int = 200, + data_format: Optional[ChannelDimension] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> np.array: + """ + Crops the margin of the image. Gray pixels are considered margin (i.e., pixels with a value below the + threshold). + + Args: + image (`np.array`): + The image to be cropped. + gray_threshold (`int`, *optional*, defaults to `200`) + Value below which pixels are considered to be gray. + data_format (`ChannelDimension`, *optional*): + The channel dimension format of the output image. If unset, will use the inferred format from the + input. + input_data_format (`ChannelDimension`, *optional*): + The channel dimension format of the input image. If unset, will use the inferred format from the input. + """ + if input_data_format is None: + input_data_format = infer_channel_dimension_format(image) + + image = to_pil_image(image, input_data_format=input_data_format) + data = np.array(image.convert("L")).astype(np.uint8) + max_val = data.max() + min_val = data.min() + if max_val == min_val: + image = np.array(image) + image = ( + to_channel_dimension_format(image, data_format, input_data_format) + if data_format is not None + else image + ) + return image + data = (data - min_val) / (max_val - min_val) * 255 + gray = data < gray_threshold + coords = self.python_find_non_zero(gray) + x_min, y_min, width, height = self.python_bounding_rect(coords) + image = image.crop((x_min, y_min, x_min + width, y_min + height)) + image = np.array(image).astype(np.uint8) + image = to_channel_dimension_format(image, input_data_format, ChannelDimension.LAST) + + image = ( + to_channel_dimension_format(image, data_format, input_data_format) if data_format is not None else image + ) + + return image + + # Copied from transformers.models.donut.image_processing_donut.DonutImageProcessor.align_long_axis + def align_long_axis( + self, + image: np.ndarray, + size: Dict[str, int], + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> np.ndarray: + """ + Align the long axis of the image to the longest axis of the specified size. + + Args: + image (`np.ndarray`): + The image to be aligned. + size (`Dict[str, int]`): + The size `{"height": h, "width": w}` to align the long axis to. + data_format (`str` or `ChannelDimension`, *optional*): + The data format of the output image. If unset, the same format as the input image is used. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + + Returns: + `np.ndarray`: The aligned image. + """ + input_height, input_width = get_image_size(image, channel_dim=input_data_format) + output_height, output_width = size["height"], size["width"] + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(image) + + if input_data_format == ChannelDimension.LAST: + rot_axes = (0, 1) + elif input_data_format == ChannelDimension.FIRST: + rot_axes = (1, 2) + else: + raise ValueError(f"Unsupported data format: {input_data_format}") + + if (output_width < output_height and input_width > input_height) or ( + output_width > output_height and input_width < input_height + ): + image = np.rot90(image, 3, axes=rot_axes) + + if data_format is not None: + image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + + return image + + def pad_image( + self, + image: np.ndarray, + size: Dict[str, int], + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> np.ndarray: + """ + Pad the image to the specified size at the top, bottom, left and right. + + Args: + image (`np.ndarray`): + The image to be padded. + size (`Dict[str, int]`): + The size `{"height": h, "width": w}` to pad the image to. + data_format (`str` or `ChannelDimension`, *optional*): + The data format of the output image. If unset, the same format as the input image is used. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + output_height, output_width = size["height"], size["width"] + input_height, input_width = get_image_size(image, channel_dim=input_data_format) + + delta_width = output_width - input_width + delta_height = output_height - input_height + + pad_top = delta_height // 2 + pad_left = delta_width // 2 + + pad_bottom = delta_height - pad_top + pad_right = delta_width - pad_left + + padding = ((pad_top, pad_bottom), (pad_left, pad_right)) + return pad(image, padding, data_format=data_format, input_data_format=input_data_format) + + # Copied from transformers.models.donut.image_processing_donut.DonutImageProcessor.thumbnail + def thumbnail( + self, + image: np.ndarray, + size: Dict[str, int], + resample: PILImageResampling = PILImageResampling.BICUBIC, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize the image to make a thumbnail. The image is resized so that no dimension is larger than any + corresponding dimension of the specified size. + + Args: + image (`np.ndarray`): + The image to be resized. + size (`Dict[str, int]`): + The size `{"height": h, "width": w}` to resize the image to. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + The resampling filter to use. + data_format (`Optional[Union[str, ChannelDimension]]`, *optional*): + The data format of the output image. If unset, the same format as the input image is used. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + input_height, input_width = get_image_size(image, channel_dim=input_data_format) + output_height, output_width = size["height"], size["width"] + + # We always resize to the smallest of either the input or output size. + height = min(input_height, output_height) + width = min(input_width, output_width) + + if height == input_height and width == input_width: + return image + + if input_height > input_width: + width = int(input_width * height / input_height) + elif input_width > input_height: + height = int(input_height * width / input_width) + + return resize( + image, + size=(height, width), + resample=resample, + reducing_gap=2.0, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + # Copied from transformers.models.donut.image_processing_donut.DonutImageProcessor.resize + def resize( + self, + image: np.ndarray, + size: Dict[str, int], + resample: PILImageResampling = PILImageResampling.BICUBIC, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resizes `image` to `(height, width)` specified by `size` using the PIL library. + + Args: + image (`np.ndarray`): + Image to resize. + size (`Dict[str, int]`): + Size of the output image. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + Resampling filter to use when resiizing the image. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be the same as the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + size = get_size_dict(size) + shortest_edge = min(size["height"], size["width"]) + output_size = get_resize_output_image_size( + image, size=shortest_edge, default_to_square=False, input_data_format=input_data_format + ) + resized_image = resize( + image, + size=output_size, + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + return resized_image + + @filter_out_non_signature_kwargs() + def preprocess( + self, + images: ImageInput, + do_crop_margin: Optional[bool] = None, + do_resize: Optional[bool] = None, + size: Dict[str, int] = None, + resample: PILImageResampling = None, + do_thumbnail: Optional[bool] = None, + do_align_long_axis: Optional[bool] = None, + do_pad: Optional[bool] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Union[int, float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> PIL.Image.Image: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. + do_crop_margin (`bool`, *optional*, defaults to `self.do_crop_margin`): + Whether to crop the image margins. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Size of the image after resizing. Shortest edge of the image is resized to min(size["height"], + size["width"]) with the longest edge resized to keep the input aspect ratio. + resample (`int`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only + has an effect if `do_resize` is set to `True`. + do_thumbnail (`bool`, *optional*, defaults to `self.do_thumbnail`): + Whether to resize the image using thumbnail method. + do_align_long_axis (`bool`, *optional*, defaults to `self.do_align_long_axis`): + Whether to align the long axis of the image with the long axis of `size` by rotating by 90 degrees. + do_pad (`bool`, *optional*, defaults to `self.do_pad`): + Whether to pad the images to the largest image size in the batch. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image by the specified scale `rescale_factor`. + rescale_factor (`int` or `float`, *optional*, defaults to `self.rescale_factor`): + Scale factor to use if rescaling the image. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use for normalization. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use for normalization. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: defaults to the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + do_crop_margin = do_crop_margin if do_crop_margin is not None else self.do_crop_margin + do_resize = do_resize if do_resize is not None else self.do_resize + size = size if size is not None else self.size + resample = resample if resample is not None else self.resample + do_thumbnail = do_thumbnail if do_thumbnail is not None else self.do_thumbnail + do_align_long_axis = do_align_long_axis if do_align_long_axis is not None else self.do_align_long_axis + do_pad = do_pad if do_pad is not None else self.do_pad + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + + images = make_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_pad=do_pad, + size_divisibility=size, # There is no pad divisibility in this processor, but pad requires the size arg. + do_resize=do_resize, + size=size, + resample=resample, + ) + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if do_rescale and is_scaled_image(images[0]): + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + if do_crop_margin: + images = [self.crop_margin(image, input_data_format=input_data_format) for image in images] + + if do_align_long_axis: + images = [self.align_long_axis(image, size=size, input_data_format=input_data_format) for image in images] + + if do_resize: + images = [ + self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format) + for image in images + ] + + if do_thumbnail: + images = [self.thumbnail(image=image, size=size, input_data_format=input_data_format) for image in images] + + if do_pad: + images = [self.pad_image(image=image, size=size, input_data_format=input_data_format) for image in images] + + if do_rescale: + images = [ + self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + for image in images + ] + + if do_normalize: + images = [ + self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) + for image in images + ] + + images = [ + to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images + ] + + data = {"pixel_values": images} + return BatchFeature(data=data, tensor_type=return_tensors) + + +__all__ = ["NougatImageProcessor"] diff --git a/docs/transformers/src/transformers/models/nougat/processing_nougat.py b/docs/transformers/src/transformers/models/nougat/processing_nougat.py new file mode 100644 index 0000000000000000000000000000000000000000..ca395e261affe7ba5dbfc7ac22320e299cb78170 --- /dev/null +++ b/docs/transformers/src/transformers/models/nougat/processing_nougat.py @@ -0,0 +1,163 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Processor class for Nougat. +""" + +from typing import Dict, List, Optional, Union + +from transformers.tokenization_utils_base import PreTokenizedInput, TextInput, TruncationStrategy + +from ...processing_utils import ProcessorMixin +from ...utils import PaddingStrategy, TensorType + + +class NougatProcessor(ProcessorMixin): + r""" + Constructs a Nougat processor which wraps a Nougat image processor and a Nougat tokenizer into a single processor. + + [`NougatProcessor`] offers all the functionalities of [`NougatImageProcessor`] and [`NougatTokenizerFast`]. See the + [`~NougatProcessor.__call__`] and [`~NougatProcessor.decode`] for more information. + + Args: + image_processor ([`NougatImageProcessor`]): + An instance of [`NougatImageProcessor`]. The image processor is a required input. + tokenizer ([`NougatTokenizerFast`]): + An instance of [`NougatTokenizerFast`]. The tokenizer is a required input. + """ + + attributes = ["image_processor", "tokenizer"] + image_processor_class = "AutoImageProcessor" + tokenizer_class = "AutoTokenizer" + + def __init__(self, image_processor, tokenizer): + super().__init__(image_processor, tokenizer) + self.current_processor = self.image_processor + + def __call__( + self, + images=None, + text=None, + do_crop_margin: Optional[bool] = None, + do_resize: Optional[bool] = None, + size: Dict[str, int] = None, + resample: "PILImageResampling" = None, # noqa: F821 + do_thumbnail: Optional[bool] = None, + do_align_long_axis: Optional[bool] = None, + do_pad: Optional[bool] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Union[int, float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + data_format: Optional["ChannelDimension"] = "channels_first", # noqa: F821 + input_data_format: Optional[Union[str, "ChannelDimension"]] = None, # noqa: F821 + text_pair: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, + text_target: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + text_pair_target: Optional[ + Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] + ] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length: Optional[int] = None, + stride: int = 0, + is_split_into_words: bool = False, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + ): + if images is None and text is None: + raise ValueError("You need to specify either an `images` or `text` input to process.") + + if images is not None: + inputs = self.image_processor( + images, + do_crop_margin=do_crop_margin, + do_resize=do_resize, + size=size, + resample=resample, + do_thumbnail=do_thumbnail, + do_align_long_axis=do_align_long_axis, + do_pad=do_pad, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + return_tensors=return_tensors, + data_format=data_format, + input_data_format=input_data_format, + ) + if text is not None: + encodings = self.tokenizer( + text, + text_pair=text_pair, + text_target=text_target, + text_pair_target=text_pair_target, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + stride=stride, + is_split_into_words=is_split_into_words, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + ) + + if text is None: + return inputs + elif images is None: + return encodings + else: + inputs["labels"] = encodings["input_ids"] + return inputs + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to NougatTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please refer + to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to NougatTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + def post_process_generation(self, *args, **kwargs): + """ + This method forwards all its arguments to NougatTokenizer's [`~PreTrainedTokenizer.post_process_generation`]. + Please refer to the docstring of this method for more information. + """ + return self.tokenizer.post_process_generation(*args, **kwargs) + + +__all__ = ["NougatProcessor"] diff --git a/docs/transformers/src/transformers/models/nougat/tokenization_nougat_fast.py b/docs/transformers/src/transformers/models/nougat/tokenization_nougat_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..e5dc6ed1645369dd15c9c2324a82fca576e82501 --- /dev/null +++ b/docs/transformers/src/transformers/models/nougat/tokenization_nougat_fast.py @@ -0,0 +1,620 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Fast tokenizer class for Nougat. +""" + +import re +from functools import partial +from multiprocessing import Pool +from typing import List, Optional, Union + +import numpy as np + +from transformers.tokenization_utils_base import INIT_TOKENIZER_DOCSTRING +from transformers.tokenization_utils_fast import PreTrainedTokenizerFast +from transformers.utils import add_end_docstrings + +from ...utils import is_levenshtein_available, is_nltk_available, logging, requires_backends + + +if is_levenshtein_available(): + from Levenshtein import ratio + +if is_nltk_available(): + import nltk + + +logger = logging.get_logger(__name__) + + +INIT_TOKENIZER_DOCSTRING += """ + tokenizer_object ([`tokenizers.Tokenizer`]): + A [`tokenizers.Tokenizer`] object from 🤗 tokenizers to instantiate from. See [Using tokenizers from 🤗 + tokenizers](../fast_tokenizers) for more information. + tokenizer_file ([`str`]): + A path to a local JSON file representing a previously serialized [`tokenizers.Tokenizer`] object from 🤗 + tokenizers. +""" + + +VOCAB_FILES_NAMES = {"tokenizer_file": "tokenizer.json"} + + +def markdown_compatible(text: str) -> str: + """ + Make text compatible with Markdown formatting. + + This function makes various text formatting adjustments to make it compatible with Markdown. + + Args: + text (`str`): + The input text to be made Markdown-compatible. + + Returns: + `str`: The Markdown-compatible text. + """ + # equation tag + # Replace lines that start with a pattern like (decimal) \[some text\] with \[[some text] \tag{decimal}\]. + text = re.sub(r"^\(([\d.]+[a-zA-Z]?)\) \\\[(.+?)\\\]$", r"\[\2 \\tag{\1}\]", text, flags=re.M) + # Replace lines that start with a pattern like \[some text\] (decimal) with \[[some text] \tag{decimal}\]. + text = re.sub(r"^\\\[(.+?)\\\] \(([\d.]+[a-zA-Z]?)\)$", r"\[\1 \\tag{\2}\]", text, flags=re.M) + # Replace lines that start with a pattern like \[some text\] (digits) \[another text\] with \[[some text] \tag{digits}\] [another text]. + text = re.sub( + r"^\\\[(.+?)\\\] \(([\d.]+[a-zA-Z]?)\) (\\\[.+?\\\])$", + r"\[\1 \\tag{\2}\] \3", + text, + flags=re.M, + ) + # multi line + text = text.replace(r"\. ", ". ") + # bold formatting + text = text.replace(r"\bm{", r"\mathbf{").replace(r"{\\bm ", r"\mathbf{") + text = re.sub(r"\\mbox{ ?\\boldmath\$(.*?)\$}", r"\\mathbf{\1}", text) + # Reformat urls (http, ftp and https only) to markdown [url](url) clickable format + text = re.sub( + r"((?:http|ftp|https):\/\/(?:[\w_-]+(?:(?:\.[\w_-]+)+))(?:[\w.,@?^=%&:\/~+#-]*[\w@?^=%&\/~+#-]))", + r"[\1](\1)", + text, + ) + # algorithms + text = re.sub(r"```\s*(.+?)\s*```", r"```\n\1\n```", text, flags=re.S) + + return text + + +def normalize_list_like_lines(generation): + """ + Normalize lines in the given text that resemble list items. The function looks for lines that start optionally with + '-' or '*', possibly followed by Roman numerals or digits indicating nesting levels. The function reformats such + lines to make them more structured. + + Args: + generation (str): The input text containing lines that need to be normalized. + + Returns: + str: The input text with the list-like lines normalized. + + Note: + The function uses regular expressions to identify and reformat the list-like lines. The patterns capture + optional bullet points, nesting levels indicated by numerals, and the actual list item content. The + normalization adjusts the bullet point style and nesting levels based on the captured patterns. + """ + + lines = generation.split("\n") + output_lines = [] + for line_no, line in enumerate(lines): + match = re.search(r". ([-*]) ", line) + if not match or line[0] not in ("-", "*"): + output_lines.append(line) + continue # Doesn't fit the pattern we want, no changes + delim = match.group(1) + " " + splits = line.split(delim)[1:] + replacement = "" + delim1 = line[0] + " " + + for i, item in enumerate(splits): + level = 0 + potential_numeral, _, rest = item.strip().partition(" ") + if not rest: + continue + # Infer current nesting level based on detected numbering + if re.match(r"^[\dixv]+((?:\.[\dixv])?)+$", potential_numeral, flags=re.I | re.M): + level = potential_numeral.count(".") + + replacement += ( + ("\n" if i > 0 else "") + ("\t" * level) + (delim if i > 0 or line_no == 0 else delim1) + item.strip() + ) + + if line_no == len(lines) - 1: # If this is the last line in the generation + replacement += "\n" # Add an empty line to the end of the generation + + output_lines.append(replacement) + + return "\n".join(output_lines) + + +def find_next_punctuation(text: str, start_idx=0): + """ + Find the index of the next punctuation mark. + + Args: + text (`str`): + String to examine + start_idx (`int`, *optional*) + Index where to start + """ + + for i in range(start_idx, len(text)): + if text[i] in [".", "?", "!", "\n"]: + return i + + return None + + +def truncate_repetitions(text: str, min_len: int = 30) -> str: + """ + Attempt to truncate repeating segments in the input string. + + This function looks for the longest repeating substring at the end of the input string and truncates it to appear + only once. To be considered for removal, repetitions need to be continuous. + + Args: + text (`str`): + The input raw prediction to be truncated. + min_len (int): + The minimum length of the repeating segment. + + Returns: + `str`: The input string with repeated segments truncated. + """ + text_lower = text.lower() + text_length = len(text_lower) + + if text_length < 2 * min_len: + return text + + # try to find a length at which the tail is repeating + max_repetition_length = None + for repetition_length in range(min_len, int(text_length / 2)): + # check if there is a repetition at the end + same = True + for i in range(0, repetition_length): + if text_lower[text_length - repetition_length - i - 1] != text_lower[text_length - i - 1]: + same = False + break + + if same: + max_repetition_length = repetition_length + + if max_repetition_length is None: + return text + + lcs = text_lower[-max_repetition_length:] + + # remove all but the last repetition + substituted_text = text + substituted_text_lower = text_lower + while substituted_text_lower.endswith(lcs): + substituted_text = substituted_text[:-max_repetition_length] + substituted_text_lower = substituted_text_lower[:-max_repetition_length] + + # this is the tail with the repetitions + repeating_tail = text_lower[len(substituted_text_lower) :] + + # add until next punctuation and make sure last sentence is not repeating + substituted_text_lower_out = substituted_text_lower + while True: + sentence_end = find_next_punctuation(text_lower, len(substituted_text_lower_out)) + sentence_start = find_next_punctuation(text_lower[::-1], len(substituted_text_lower_out)) + if sentence_end and sentence_start: + sentence = text_lower[sentence_start:sentence_end] + substituted_text_lower_out = text_lower[: sentence_end + 1] + if sentence in repeating_tail: + break + else: + break + + text_out = text[: len(substituted_text_lower_out)] + + return text_out + + +def remove_numbers(lines): + def _clean(s): + return re.sub(r"(?:[\d_]|\*\*)", "", s).strip() + + if isinstance(lines, str): + return _clean(lines) + out = [] + for l in lines: + out.append(_clean(l)) + return out + + +def get_slices(lines, clean_lines): + """ + Get slices of text based on specific criteria within the lines. + + This function identifies and returns slices of text from the input lines based on certain conditions. + + These conditions were chosen by the Nougat authors: + - The slice is less than 200 characters long. + - The slice is more than 3 characters long. + - The slice does not start with "[MISSING_PAGE". + - The slice is either the same as the next slice or the ratio of the two in terms of Levensthein distance is + greater than 0.9. + + Args: + lines (`List[str]`): + The list of lines containing the text. + clean_lines (`List[str]`): + A cleaned version of the text (without numbers). + + Returns: + `List[tuple]`: A list of tuples representing the start and end indices of text slices. + """ + indices = np.zeros(len(lines)) + for i in range(len(lines) - 1): + j = i + 1 + while not clean_lines[j] and j < len(lines) - 1: + j += 1 + if ( + len(clean_lines[i]) < 200 + and len(clean_lines[i]) > 3 + and len(clean_lines[j]) < 200 + and len(clean_lines[j]) > 3 + and not clean_lines[i].startswith("[MISSING_PAGE") + and (clean_lines[i] == clean_lines[j] or ratio(clean_lines[i], clean_lines[j]) > 0.9) + ): + indices[i:j] = 1 + ids = np.where(indices)[0] + slices = [] + if len(ids) == 0: + return slices + j0 = 0 + for j, x in enumerate(np.diff(ids) > 3): + if x: + slices.append((ids[j0], ids[j] + 2)) + j0 = j + 1 + slices.append((ids[j0], ids[-1] + 2)) + return [sli for sli in slices if sli[1] - sli[0] > 15] + + +def remove_slice_from_lines(lines, clean_text, slice) -> str: + """ + Remove a slice of text from the lines based on specific criteria. + + This function identifies a slice of text within the lines and removes it based on certain conditions. + + Args: + lines (list of str): The list of lines containing the text. + clean_text (list of str): A cleaned version of the text (without numbers). + slice (tuple): A tuple representing the start and end indices of the slice to be removed. + + Returns: + str: The removed slice of text as a single string. + """ + base = clean_text[slice[0]] + section = list(slice) + check_start_flag = False + # backwards pass, at most 5 lines + for line_idx in range(max(0, slice[0] - 1), max(0, slice[0] - 5), -1): + if not lines[line_idx]: + continue + if lines[line_idx] == "## References": + section[0] = line_idx + break + elif ratio(base, remove_numbers(lines[line_idx])) < 0.9: + section[0] = line_idx + 1 + potential_ref = remove_numbers(lines[max(0, line_idx - 1)].partition("* [")[-1]) + if len(potential_ref) >= 0.75 * len(base) and ratio(base, potential_ref) < 0.9: + section[0] = line_idx + check_start_flag = True + break + # forward pass, at most 5 lines + for line_idx in range(min(len(lines), slice[1]), min(len(lines), slice[1] + 5)): + if ratio(base, remove_numbers(lines[line_idx])) < 0.9: + section[1] = line_idx + break + if len(lines) <= section[1]: + section[1] = len(lines) - 1 + to_delete = "\n".join(lines[section[0] : section[1] + 1]) + # cut off next page content + itera, iterb = enumerate(lines[section[1] - 1]), enumerate(lines[section[1]]) + while True: + try: + (ia, a) = next(itera) + while a.isnumeric(): + (ia, a) = next(itera) + (ib, b) = next(iterb) + while b.isnumeric(): + (ib, b) = next(iterb) + if a != b: + break + except StopIteration: + break + if check_start_flag and "* [" in to_delete: + to_delete = "* [" + to_delete.partition("* [")[-1] + try: + delta = len(lines[section[1]]) - ib - 1 + if delta > 0: + to_delete = to_delete[:-delta] + except UnboundLocalError: + pass + + return to_delete.strip() + + +@add_end_docstrings(INIT_TOKENIZER_DOCSTRING) +class NougatTokenizerFast(PreTrainedTokenizerFast): + """ + Fast tokenizer for Nougat (backed by HuggingFace tokenizers library). + + This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should + refer to this superclass for more information regarding those methods. This class mainly adds Nougat-specific + methods for postprocessing the generated text. + + Args: + vocab_file (`str`, *optional*): + [SentencePiece](https://github.com/google/sentencepiece) file (generally has a .model extension) that + contains the vocabulary necessary to instantiate a tokenizer. + tokenizer_file (`str`, *optional*): + [tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that + contains everything needed to load the tokenizer. + + clean_up_tokenization_spaces (`str`, *optional*, defaults to `False`): + Wether to cleanup spaces after decoding, cleanup consists in removing potential artifacts like extra + spaces. + + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + + bos_token (`str`, *optional*, defaults to `""`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + slow_tokenizer_class = None + + def __init__( + self, + vocab_file=None, + tokenizer_file=None, + clean_up_tokenization_spaces=False, + unk_token="", + bos_token="", + eos_token="", + pad_token="", + **kwargs, + ): + super().__init__( + vocab_file=vocab_file, + tokenizer_file=tokenizer_file, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + unk_token=unk_token, + bos_token=bos_token, + eos_token=eos_token, + pad_token=pad_token, + **kwargs, + ) + self.vocab_file = vocab_file + + def remove_hallucinated_references(self, text: str) -> str: + """ + Remove hallucinated or missing references from the text. + + This function identifies and removes references that are marked as missing or hallucinated from the input text. + + Args: + text (`str`): + The input text containing references. + + Returns: + `str`: The text with hallucinated references removed. + """ + lines = text.split("\n") + if len(lines) == 0: + return "" + clean_lines = remove_numbers(lines) + slices = get_slices(lines, clean_lines) + to_delete = [] + for slice in slices: + to_delete.append(remove_slice_from_lines(lines, clean_lines, slice)) + for to_delete in reversed(to_delete): + text = text.replace(to_delete, "\n\n[MISSING_PAGE_POST]\n\n") + text = re.sub( + r"## References\n+\[MISSING_PAGE_POST(:\d+)?\]", + "\n\n[MISSING_PAGE_POST\\1]", + text, + ) + return text + + def correct_tables(self, generation: str) -> str: + """ + Takes a generated string and fixes tables/tabulars to make them match the markdown format needed. + + Args: + generation (str): The generated text to be postprocessed. + + Returns: + str: The postprocessed text. + + Example: + + ```python + correct_tables("\\begin{table} \\begin{tabular}{l l} & \\ \\end{tabular} \\end{table}") + "\\begin{table}\n\\begin{tabular}{l l} & \\ \\end{tabular}\n\\end{table}" + ``` + """ + # remove obvious wrong tables + for l in generation.split("\n"): + if l.count("\\begin{tabular}") > 15 or l.count("\\multicolumn") > 60 or l.count("&") > 400: + generation = generation.replace(l, "") + # whitespace corrections + + generation = generation.replace("\\begin{table} \\begin{tabular}", "\\begin{table}\n\\begin{tabular}") + generation = generation.replace("\\end{tabular} \\end{table}", "\\end{tabular}\n\\end{table}") + generation = generation.replace("\\end{table} Tab", "\\end{table}\nTab") + + generation = re.sub(r"(^.+)\\begin{tab", r"\1\n\\begin{tab", generation, flags=re.M) + + # Remove left-aligned empty LaTeX tabular blocks. + generation = generation.replace(r"\begin{tabular}{l l} & \\ \end{tabular}", "") + # Remove tabulars with just 2 newline characters. + generation = generation.replace("\\begin{tabular}{}\n\n\\end{tabular}", "") + return generation + + def post_process_single(self, generation: str, fix_markdown: bool = True) -> str: + """ + Postprocess a single generated text. Regular expressions used here are taken directly from the Nougat article + authors. These expressions are commented for clarity and tested end-to-end in most cases. + + Args: + generation (str): The generated text to be postprocessed. + fix_markdown (bool, optional): Whether to perform Markdown formatting fixes. Default is True. + + Returns: + str: The postprocessed text. + """ + generation = re.sub( + r"(?:\n|^)#+ \d*\W? ?(.{100,})", r"\n\1", generation + ) # too long section titles probably are none + generation = generation.strip() + # Remove LaTeX left margin tag + generation = generation.replace("\n* [leftmargin=*]\n", "\n") + # Remove lines with markdown headings starting with #, with numerals, + # and possibly roman numerals with trailing spaces and newlines + generation = re.sub(r"^#+ (?:[\d+\.]+|[ixv\.]+)?\s*(?:$|\n\s*)", "", generation, flags=re.M) + # most likely hallucinated titles + lines = generation.split("\n") + if lines[-1].startswith("#") and lines[-1].lstrip("#").startswith(" ") and len(lines) > 1: + logger.info("Likely hallucinated title at the end of the page: " + lines[-1]) + generation = "\n".join(lines[:-1]) + # obvious repetition detection + generation = truncate_repetitions(generation) + # Reference corrections + generation = self.remove_hallucinated_references(generation) + # Remove lines starting with asterisks and numbers like "*[1]" and followed by capital letters and periods (ie too long references) + generation = re.sub(r"^\* \[\d+\](\s?[A-W]\.+\s?){10,}.*$", "", generation, flags=re.M) + # Remove empty brackets after a reference number in brackets. *[12][]ABC will become *[12]ABC + generation = re.sub(r"^(\* \[\d+\])\[\](.*)$", r"\1\2", generation, flags=re.M) + # Remove single characters before or after 2 new lines + generation = re.sub(r"(^\w\n\n|\n\n\w$)", "", generation) + # pmc math artifact correction + generation = re.sub( + r"([\s.,()])_([a-zA-Z0-9])__([a-zA-Z0-9]){1,3}_([\s.,:()])", + r"\1\(\2_{\3}\)\4", + generation, + ) + generation = re.sub(r"([\s.,\d])_([a-zA-Z0-9])_([\s.,\d;])", r"\1\(\2\)\3", generation) + # footnote mistakes + generation = re.sub( + r"(\nFootnote .*?:) (?:footnotetext|thanks):\W*(.*(?:\n\n|$))", + r"\1 \2", + generation, + ) + # TODO Come up with footnote formatting inside a table + generation = re.sub(r"\[FOOTNOTE:.+?\](.*?)\[ENDFOOTNOTE\]", "", generation) + # itemize post processing + generation = normalize_list_like_lines(generation) + + if generation.endswith((".", "}")): + generation += "\n\n" + if re.match(r"[A-Z0-9,;:]$", generation): + # add space in case it there is a comma or word ending + generation += " " + elif generation.startswith(("#", "**", "\\begin")): + generation = "\n\n" + generation + elif generation.split("\n")[-1].startswith(("#", "Figure", "Table")): + generation = generation + "\n\n" + else: + try: + last_word = generation.split(" ")[-1] + if last_word in nltk.corpus.words.words(): + generation += " " + except LookupError: + # add space just in case. Will split words but better than concatenating them + generation += " " + + # table corrections + generation = self.correct_tables(generation) + # Remove optional, empty square brackets after begin{array} + generation = generation.replace("\\begin{array}[]{", "\\begin{array}{") + # Remove empty or malformed LaTeX tabular blocks with 2 or more columns specified, with spaces and ampersands. + generation = re.sub( + r"\\begin{tabular}{([clr ]){2,}}\s*[& ]*\s*(\\\\)? \\end{tabular}", + "", + generation, + ) + # Remove lines containing "S.A.B." one or more times. Was included in Nougat's code. + generation = re.sub(r"(\*\*S\. A\. B\.\*\*\n+){2,}", "", generation) + # Remove markdown-style headers that are incomplete or empty on multiple lines. + generation = re.sub(r"^#+( [\[\d\w])?$", "", generation, flags=re.M) + # Remove lines with just one period. + generation = re.sub(r"^\.\s*$", "", generation, flags=re.M) + # Replace instances of three or more newlines with just two newlines. + generation = re.sub(r"\n{3,}", "\n\n", generation) + if fix_markdown: + return markdown_compatible(generation) + else: + return generation + + def post_process_generation( + self, + generation: Union[str, List[str]], + fix_markdown: bool = True, + num_workers: Optional[int] = None, + ) -> Union[str, List[str]]: + """ + Postprocess a generated text or a list of generated texts. + + This function can be used to perform postprocessing on generated text, such as fixing Markdown formatting. + + Postprocessing is quite slow so it is recommended to use multiprocessing to speed up the process. + + Args: + generation (Union[str, List[str]]): + The generated text or a list of generated texts. + fix_markdown (`bool`, *optional*, defaults to `True`): + Whether to perform Markdown formatting fixes. + num_workers (`int`, *optional*): + Optional number of workers to pass to leverage multiprocessing (postprocessing several texts in + parallel). + + Returns: + Union[str, List[str]]: The postprocessed text or list of postprocessed texts. + """ + requires_backends(self, ["nltk", "levenshtein"]) + + if isinstance(generation, list): + if num_workers is not None and isinstance(num_workers, int): + with Pool(num_workers) as p: + return p.map(partial(self.post_process_single, fix_markdown=fix_markdown), generation) + else: + return [self.post_process_single(s, fix_markdown=fix_markdown) for s in generation] + else: + return self.post_process_single(generation, fix_markdown=fix_markdown) + + +__all__ = ["NougatTokenizerFast"] diff --git a/docs/transformers/src/transformers/models/nystromformer/__init__.py b/docs/transformers/src/transformers/models/nystromformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c32df6d8db909207e4867e119d2a7fca4a0bd907 --- /dev/null +++ b/docs/transformers/src/transformers/models/nystromformer/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_nystromformer import * + from .modeling_nystromformer import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/nystromformer/configuration_nystromformer.py b/docs/transformers/src/transformers/models/nystromformer/configuration_nystromformer.py new file mode 100644 index 0000000000000000000000000000000000000000..96a48b99fda4cd7ff056ceeb441889653682bec1 --- /dev/null +++ b/docs/transformers/src/transformers/models/nystromformer/configuration_nystromformer.py @@ -0,0 +1,132 @@ +# coding=utf-8 +# Copyright 2022 UW-Madison and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Nystromformer model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class NystromformerConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`NystromformerModel`]. It is used to instantiate + an Nystromformer model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the Nystromformer + [uw-madison/nystromformer-512](https://huggingface.co/uw-madison/nystromformer-512) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 30000): + Vocabulary size of the Nystromformer model. Defines the number of different tokens that can be represented + by the `inputs_ids` passed when calling [`NystromformerModel`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimension of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + type_vocab_size (`int`, *optional*, defaults to 2): + The vocabulary size of the `token_type_ids` passed when calling [`NystromformerModel`]. + segment_means_seq_len (`int`, *optional*, defaults to 64): + Sequence length used in segment-means. + num_landmarks (`int`, *optional*, defaults to 64): + The number of landmark (or Nystrom) points to use in Nystrom approximation of the softmax self-attention + matrix. + conv_kernel_size (`int`, *optional*, defaults to 65): + The kernel size of depthwise convolution used in Nystrom approximation. + inv_coeff_init_option (`bool`, *optional*, defaults to `False`): + Whether or not to use exact coefficient computation for the initial values for the iterative method of + calculating the Moore-Penrose inverse of a matrix. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + + Example: + + ```python + >>> from transformers import NystromformerModel, NystromformerConfig + + >>> # Initializing a Nystromformer uw-madison/nystromformer-512 style configuration + >>> configuration = NystromformerConfig() + + >>> # Initializing a model from the uw-madison/nystromformer-512 style configuration + >>> model = NystromformerModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "nystromformer" + + def __init__( + self, + vocab_size=30000, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu_new", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=510, + type_vocab_size=2, + segment_means_seq_len=64, + num_landmarks=64, + conv_kernel_size=65, + inv_coeff_init_option=False, + initializer_range=0.02, + layer_norm_eps=1e-5, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.initializer_range = initializer_range + self.type_vocab_size = type_vocab_size + self.segment_means_seq_len = segment_means_seq_len + self.num_landmarks = num_landmarks + self.conv_kernel_size = conv_kernel_size + self.inv_coeff_init_option = inv_coeff_init_option + self.layer_norm_eps = layer_norm_eps + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + +__all__ = ["NystromformerConfig"] diff --git a/docs/transformers/src/transformers/models/nystromformer/convert_nystromformer_original_pytorch_checkpoint_to_pytorch.py b/docs/transformers/src/transformers/models/nystromformer/convert_nystromformer_original_pytorch_checkpoint_to_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..6664a7d8ad019a2779b164ed4b8783ea3a6f7e12 --- /dev/null +++ b/docs/transformers/src/transformers/models/nystromformer/convert_nystromformer_original_pytorch_checkpoint_to_pytorch.py @@ -0,0 +1,111 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Convert Nystromformer checkpoints from the original repository.""" + +import argparse + +import torch + +from transformers import NystromformerConfig, NystromformerForMaskedLM + + +def rename_key(orig_key): + if "model" in orig_key: + orig_key = orig_key.replace("model.", "") + if "norm1" in orig_key: + orig_key = orig_key.replace("norm1", "attention.output.LayerNorm") + if "norm2" in orig_key: + orig_key = orig_key.replace("norm2", "output.LayerNorm") + if "norm" in orig_key: + orig_key = orig_key.replace("norm", "LayerNorm") + if "transformer" in orig_key: + layer_num = orig_key.split(".")[0].split("_")[-1] + orig_key = orig_key.replace(f"transformer_{layer_num}", f"encoder.layer.{layer_num}") + if "mha.attn" in orig_key: + orig_key = orig_key.replace("mha.attn", "attention.self") + if "mha" in orig_key: + orig_key = orig_key.replace("mha", "attention") + if "W_q" in orig_key: + orig_key = orig_key.replace("W_q", "self.query") + if "W_k" in orig_key: + orig_key = orig_key.replace("W_k", "self.key") + if "W_v" in orig_key: + orig_key = orig_key.replace("W_v", "self.value") + if "ff1" in orig_key: + orig_key = orig_key.replace("ff1", "intermediate.dense") + if "ff2" in orig_key: + orig_key = orig_key.replace("ff2", "output.dense") + if "ff" in orig_key: + orig_key = orig_key.replace("ff", "output.dense") + if "mlm_class" in orig_key: + orig_key = orig_key.replace("mlm.mlm_class", "cls.predictions.decoder") + if "mlm" in orig_key: + orig_key = orig_key.replace("mlm", "cls.predictions.transform") + if "cls" not in orig_key: + orig_key = "nystromformer." + orig_key + + return orig_key + + +def convert_checkpoint_helper(config, orig_state_dict): + for key in orig_state_dict.copy().keys(): + val = orig_state_dict.pop(key) + + if ("pooler" in key) or ("sen_class" in key) or ("conv.bias" in key): + continue + else: + orig_state_dict[rename_key(key)] = val + + orig_state_dict["cls.predictions.bias"] = orig_state_dict["cls.predictions.decoder.bias"] + orig_state_dict["nystromformer.embeddings.position_ids"] = ( + torch.arange(config.max_position_embeddings).expand((1, -1)) + 2 + ) + + return orig_state_dict + + +def convert_nystromformer_checkpoint(checkpoint_path, nystromformer_config_file, pytorch_dump_path): + orig_state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True)["model_state_dict"] + config = NystromformerConfig.from_json_file(nystromformer_config_file) + model = NystromformerForMaskedLM(config) + + new_state_dict = convert_checkpoint_helper(config, orig_state_dict) + + model.load_state_dict(new_state_dict) + model.eval() + model.save_pretrained(pytorch_dump_path) + + print(f"Checkpoint successfuly converted. Model saved at {pytorch_dump_path}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--pytorch_model_path", default=None, type=str, required=True, help="Path to Nystromformer pytorch checkpoint." + ) + parser.add_argument( + "--config_file", + default=None, + type=str, + required=True, + help="The json file for Nystromformer model config.", + ) + parser.add_argument( + "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." + ) + args = parser.parse_args() + convert_nystromformer_checkpoint(args.pytorch_model_path, args.config_file, args.pytorch_dump_path) diff --git a/docs/transformers/src/transformers/models/nystromformer/modeling_nystromformer.py b/docs/transformers/src/transformers/models/nystromformer/modeling_nystromformer.py new file mode 100644 index 0000000000000000000000000000000000000000..afa17d94e5db6be536d1260d007aa5964c19284d --- /dev/null +++ b/docs/transformers/src/transformers/models/nystromformer/modeling_nystromformer.py @@ -0,0 +1,1124 @@ +# coding=utf-8 +# Copyright 2022 UW-Madison The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Nystromformer model.""" + +import math +from typing import Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + MaskedLMOutput, + MultipleChoiceModelOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging +from .configuration_nystromformer import NystromformerConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "uw-madison/nystromformer-512" +_CONFIG_FOR_DOC = "NystromformerConfig" + + +class NystromformerEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings + 2, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)) + 2, persistent=False + ) + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + self.register_buffer( + "token_type_ids", + torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device), + persistent=False, + ) + + def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs + # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves + # issue #5664 + if token_type_ids is None: + if hasattr(self, "token_type_ids"): + buffered_token_type_ids = self.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class NystromformerSelfAttention(nn.Module): + def __init__(self, config, position_embedding_type=None): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.num_landmarks = config.num_landmarks + self.seq_len = config.segment_means_seq_len + self.conv_kernel_size = config.conv_kernel_size + + if config.inv_coeff_init_option: + self.init_option = config["inv_init_coeff_option"] + else: + self.init_option = "original" + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = position_embedding_type or getattr( + config, "position_embedding_type", "absolute" + ) + + if self.conv_kernel_size is not None: + self.conv = nn.Conv2d( + in_channels=self.num_attention_heads, + out_channels=self.num_attention_heads, + kernel_size=(self.conv_kernel_size, 1), + padding=(self.conv_kernel_size // 2, 0), + bias=False, + groups=self.num_attention_heads, + ) + + # Function to approximate Moore-Penrose inverse via the iterative method + def iterative_inv(self, mat, n_iter=6): + identity = torch.eye(mat.size(-1), device=mat.device) + key = mat + + # The entries of key are positive and ||key||_{\infty} = 1 due to softmax + if self.init_option == "original": + # This original implementation is more conservative to compute coefficient of Z_0. + value = 1 / torch.max(torch.sum(key, dim=-2)) * key.transpose(-1, -2) + else: + # This is the exact coefficient computation, 1 / ||key||_1, of initialization of Z_0, leading to faster convergence. + value = 1 / torch.max(torch.sum(key, dim=-2), dim=-1).values[:, :, None, None] * key.transpose(-1, -2) + + for _ in range(n_iter): + key_value = torch.matmul(key, value) + value = torch.matmul( + 0.25 * value, + 13 * identity + - torch.matmul(key_value, 15 * identity - torch.matmul(key_value, 7 * identity - key_value)), + ) + return value + + def transpose_for_scores(self, layer): + new_layer_shape = layer.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + layer = layer.view(*new_layer_shape) + return layer.permute(0, 2, 1, 3) + + def forward(self, hidden_states, attention_mask=None, output_attentions=False): + mixed_query_layer = self.query(hidden_states) + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + query_layer = query_layer / math.sqrt(math.sqrt(self.attention_head_size)) + key_layer = key_layer / math.sqrt(math.sqrt(self.attention_head_size)) + + if self.num_landmarks == self.seq_len: + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in NystromformerModel forward() function) + attention_scores = attention_scores + attention_mask + + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + context_layer = torch.matmul(attention_probs, value_layer) + + else: + q_landmarks = query_layer.reshape( + -1, + self.num_attention_heads, + self.num_landmarks, + self.seq_len // self.num_landmarks, + self.attention_head_size, + ).mean(dim=-2) + k_landmarks = key_layer.reshape( + -1, + self.num_attention_heads, + self.num_landmarks, + self.seq_len // self.num_landmarks, + self.attention_head_size, + ).mean(dim=-2) + + kernel_1 = torch.nn.functional.softmax(torch.matmul(query_layer, k_landmarks.transpose(-1, -2)), dim=-1) + kernel_2 = torch.nn.functional.softmax(torch.matmul(q_landmarks, k_landmarks.transpose(-1, -2)), dim=-1) + + attention_scores = torch.matmul(q_landmarks, key_layer.transpose(-1, -2)) + + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in NystromformerModel forward() function) + attention_scores = attention_scores + attention_mask + + kernel_3 = nn.functional.softmax(attention_scores, dim=-1) + attention_probs = torch.matmul(kernel_1, self.iterative_inv(kernel_2)) + new_value_layer = torch.matmul(kernel_3, value_layer) + context_layer = torch.matmul(attention_probs, new_value_layer) + + if self.conv_kernel_size is not None: + context_layer += self.conv(value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +# Copied from transformers.models.bert.modeling_bert.BertSelfOutput +class NystromformerSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class NystromformerAttention(nn.Module): + def __init__(self, config, position_embedding_type=None): + super().__init__() + self.self = NystromformerSelfAttention(config, position_embedding_type=position_embedding_type) + self.output = NystromformerSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward(self, hidden_states, attention_mask=None, output_attentions=False): + self_outputs = self.self(hidden_states, attention_mask, output_attentions) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->Nystromformer +class NystromformerIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->Nystromformer +class NystromformerOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class NystromformerLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = NystromformerAttention(config) + self.add_cross_attention = config.add_cross_attention + self.intermediate = NystromformerIntermediate(config) + self.output = NystromformerOutput(config) + + def forward(self, hidden_states, attention_mask=None, output_attentions=False): + self_attention_outputs = self.attention(hidden_states, attention_mask, output_attentions=output_attentions) + attention_output = self_attention_outputs[0] + + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output + ) + outputs = (layer_output,) + outputs + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class NystromformerEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList([NystromformerLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.__call__, + hidden_states, + attention_mask, + output_attentions, + ) + else: + layer_outputs = layer_module(hidden_states, attention_mask, output_attentions) + + hidden_states = layer_outputs[0] + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +# Copied from transformers.models.bert.modeling_bert.BertPredictionHeadTransform with Bert->Nystromformer +class NystromformerPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead with Bert->Nystromformer +class NystromformerLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = NystromformerPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def _tie_weights(self): + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertOnlyMLMHead with Bert->Nystromformer +class NystromformerOnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = NystromformerLMPredictionHead(config) + + def forward(self, sequence_output: torch.Tensor) -> torch.Tensor: + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class NystromformerPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = NystromformerConfig + base_model_prefix = "nystromformer" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +NYSTROMFORMER_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use + it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`NystromformerConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +NYSTROMFORMER_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert *input_ids* indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare Nyströmformer Model transformer outputting raw hidden-states without any specific head on top.", + NYSTROMFORMER_START_DOCSTRING, +) +class NystromformerModel(NystromformerPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.config = config + + self.embeddings = NystromformerEmbeddings(config) + self.encoder = NystromformerEncoder(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(NYSTROMFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPastAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length)), device=device) + + if token_type_ids is None: + if hasattr(self.embeddings, "token_type_ids"): + buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + ) + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + + if not return_dict: + return (sequence_output,) + encoder_outputs[1:] + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=sequence_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +@add_start_docstrings("""Nyströmformer Model with a `language modeling` head on top.""", NYSTROMFORMER_START_DOCSTRING) +class NystromformerForMaskedLM(NystromformerPreTrainedModel): + _tied_weights_keys = ["cls.predictions.decoder"] + + def __init__(self, config): + super().__init__(config) + + self.nystromformer = NystromformerModel(config) + self.cls = NystromformerOnlyMLMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias + + @add_start_docstrings_to_model_forward(NYSTROMFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MaskedLMOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.nystromformer( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class NystromformerClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.out_proj = nn.Linear(config.hidden_size, config.num_labels) + + self.config = config + + def forward(self, features, **kwargs): + x = features[:, 0, :] # take token (equiv. to [CLS]) + x = self.dropout(x) + x = self.dense(x) + x = ACT2FN[self.config.hidden_act](x) + x = self.dropout(x) + x = self.out_proj(x) + return x + + +@add_start_docstrings( + """ + Nyströmformer Model transformer with a sequence classification/regression head on top (a linear layer on top of the + pooled output) e.g. for GLUE tasks. + """, + NYSTROMFORMER_START_DOCSTRING, +) +class NystromformerForSequenceClassification(NystromformerPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.nystromformer = NystromformerModel(config) + self.classifier = NystromformerClassificationHead(config) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(NYSTROMFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=SequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.nystromformer( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Nyströmformer Model with a multiple choice classification head on top (a linear layer on top of the pooled output + and a softmax) e.g. for RocStories/SWAG tasks. + """, + NYSTROMFORMER_START_DOCSTRING, +) +class NystromformerForMultipleChoice(NystromformerPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.nystromformer = NystromformerModel(config) + self.pre_classifier = nn.Linear(config.hidden_size, config.hidden_size) + self.classifier = nn.Linear(config.hidden_size, 1) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward( + NYSTROMFORMER_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length") + ) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MultipleChoiceModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., + num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See + `input_ids` above) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] + + input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None + attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None + position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) + + outputs = self.nystromformer( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_state = outputs[0] # (bs * num_choices, seq_len, dim) + pooled_output = hidden_state[:, 0] # (bs * num_choices, dim) + pooled_output = self.pre_classifier(pooled_output) # (bs * num_choices, dim) + pooled_output = nn.ReLU()(pooled_output) # (bs * num_choices, dim) + logits = self.classifier(pooled_output) + + reshaped_logits = logits.view(-1, num_choices) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + + if not return_dict: + output = (reshaped_logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return MultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Nyströmformer Model with a token classification head on top (a linear layer on top of the hidden-states output) + e.g. for Named-Entity-Recognition (NER) tasks. + """, + NYSTROMFORMER_START_DOCSTRING, +) +class NystromformerForTokenClassification(NystromformerPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.nystromformer = NystromformerModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(NYSTROMFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.nystromformer( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Nyströmformer Model with a span classification head on top for extractive question-answering tasks like SQuAD (a + linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + NYSTROMFORMER_START_DOCSTRING, +) +class NystromformerForQuestionAnswering(NystromformerPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + config.num_labels = 2 + self.num_labels = config.num_labels + + self.nystromformer = NystromformerModel(config) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(NYSTROMFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=QuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.nystromformer( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1) + end_logits = end_logits.squeeze(-1) + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[1:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "NystromformerForMaskedLM", + "NystromformerForMultipleChoice", + "NystromformerForQuestionAnswering", + "NystromformerForSequenceClassification", + "NystromformerForTokenClassification", + "NystromformerLayer", + "NystromformerModel", + "NystromformerPreTrainedModel", +] diff --git a/docs/transformers/src/transformers/models/olmo/__init__.py b/docs/transformers/src/transformers/models/olmo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..139af5473e41fabbe978ee4b9c956dfd9b6e4453 --- /dev/null +++ b/docs/transformers/src/transformers/models/olmo/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 EleutherAI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_olmo import * + from .modeling_olmo import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/olmo/configuration_olmo.py b/docs/transformers/src/transformers/models/olmo/configuration_olmo.py new file mode 100644 index 0000000000000000000000000000000000000000..4ad5de6152055a42cf0178a6ca5ea4f32a7802d0 --- /dev/null +++ b/docs/transformers/src/transformers/models/olmo/configuration_olmo.py @@ -0,0 +1,198 @@ +# coding=utf-8 +# Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""OLMo model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class OlmoConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`OlmoModel`]. It is used to instantiate an OLMo + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the [allenai/OLMo-7B-hf](https://huggingface.co/allenai/OLMo-7B-hf). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50304): + Vocabulary size of the OLMo model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`OlmoModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*, defaults to 1): + Padding token id. + bos_token_id (`int`, *optional*): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 50279): + End of stream token id. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling + strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is + `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. See the following thread for more information on how + these scaling strategies behave: + https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an + experimental feature, subject to breaking API changes in future versions. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + clip_qkv (`float`, *optional*): + If not `None`, elements of query, key and value attention states are clipped so that their + absolute value does not exceed this value. + + ```python + >>> from transformers import OlmoModel, OlmoConfig + + >>> # Initializing a OLMo 7B style configuration + >>> configuration = OlmoConfig() + + >>> # Initializing a model from the OLMo 7B style configuration + >>> model = OlmoModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "olmo" + keys_to_ignore_at_inference = ["past_key_values"] + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=50304, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + use_cache=True, + pad_token_id=1, + bos_token_id=None, + eos_token_id=50279, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + clip_qkv=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self._rope_scaling_validation() + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.clip_qkv = clip_qkv + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + def _rope_scaling_validation(self): + """ + Validate the `rope_scaling` configuration. + """ + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: + raise ValueError( + f"`rope_scaling` must be a dictionary with two fields, `type` and `factor`, got {self.rope_scaling}" + ) + rope_scaling_type = self.rope_scaling.get("type", None) + rope_scaling_factor = self.rope_scaling.get("factor", None) + if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: + raise ValueError( + f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" + ) + if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: + raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") + + +__all__ = ["OlmoConfig"] diff --git a/docs/transformers/src/transformers/models/olmo/convert_olmo_weights_to_hf.py b/docs/transformers/src/transformers/models/olmo/convert_olmo_weights_to_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..b3a2ad80b01f18145072ebb64d978e6c5d534601 --- /dev/null +++ b/docs/transformers/src/transformers/models/olmo/convert_olmo_weights_to_hf.py @@ -0,0 +1,248 @@ +# Copyright 2024 EleutherAI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import gc +import json +import os +import shutil +from pathlib import Path + +import torch +import yaml +from tokenizers import Tokenizer + +from transformers import OlmoConfig, OlmoForCausalLM +from transformers.models.gpt_neox.tokenization_gpt_neox_fast import GPTNeoXTokenizerFast + + +""" +Sample usage: + +``` +python src/transformers/models/olmo/convert_olmo_weights_to_hf.py \ + --input_dir /path/to/downloaded/olmo/weights --model_size 7B --output_dir /output/path +``` + +Thereafter, models can be loaded via: + +```py +from transformers import OlmoForCausalLM, AutoTokenizer + +model = OlmoForCausalLM.from_pretrained("/output/path") +tokenizer = AutoTokenizer.from_pretrained("/output/path") +``` + +Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions +come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM). +""" + + +def compute_intermediate_size(n, ffn_dim_multiplier=1, multiple_of=256): + return multiple_of * ((int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of) + + +def read_json(path): + with open(path, "r") as f: + return json.load(f) + + +def write_json(text, path): + with open(path, "w") as f: + json.dump(text, f) + + +def write_model(model_path, input_base_path, tokenizer_path=None, safe_serialization=True, fix_eos_token_id=True): + os.makedirs(model_path, exist_ok=True) + tmp_model_path = os.path.join(model_path, "tmp") + os.makedirs(tmp_model_path, exist_ok=True) + + config_path = Path(input_base_path) / "config.yaml" + olmo_config = yaml.safe_load(config_path.read_text())["model"] + + n_layers = olmo_config["n_layers"] + n_heads = olmo_config["n_heads"] + dim = olmo_config["d_model"] + dims_per_head = dim // n_heads + base = 10000.0 + inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head)) + max_position_embeddings = olmo_config["max_sequence_length"] + + vocab_size = olmo_config.get("embedding_size", olmo_config["vocab_size"]) + + if olmo_config.get("n_kv_heads", None) is not None: + num_key_value_heads = olmo_config["n_kv_heads"] # for GQA / MQA + elif olmo_config["multi_query_attention"]: # compatibility with other checkpoints + num_key_value_heads = 1 + else: + num_key_value_heads = n_heads + + print(f"Fetching all parameters from the checkpoint at {input_base_path}.") + + # Not sharded + # (The sharded implementation would also work, but this is simpler.) + loaded = torch.load(os.path.join(input_base_path, "model.pt"), map_location="cpu", weights_only=True) + + param_count = 0 + index_dict = {"weight_map": {}} + for layer_i in range(n_layers): + filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin" + # Unsharded + # TODO: Layernorm stuff + # TODO: multi query attention + fused_dims = [dim, dims_per_head * num_key_value_heads, dims_per_head * num_key_value_heads] + q_proj_weight, k_proj_weight, v_proj_weight = torch.split( + loaded[f"transformer.blocks.{layer_i}.att_proj.weight"], fused_dims, dim=0 + ) + up_proj_weight, gate_proj_weight = torch.chunk( + loaded[f"transformer.blocks.{layer_i}.ff_proj.weight"], 2, dim=0 + ) + state_dict = { + f"model.layers.{layer_i}.self_attn.q_proj.weight": q_proj_weight, + f"model.layers.{layer_i}.self_attn.k_proj.weight": k_proj_weight, + f"model.layers.{layer_i}.self_attn.v_proj.weight": v_proj_weight, + f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"transformer.blocks.{layer_i}.attn_out.weight"], + f"model.layers.{layer_i}.mlp.gate_proj.weight": gate_proj_weight, + f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"transformer.blocks.{layer_i}.ff_out.weight"], + f"model.layers.{layer_i}.mlp.up_proj.weight": up_proj_weight, + } + + state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq + + for k, v in state_dict.items(): + index_dict["weight_map"][k] = filename + param_count += v.numel() + torch.save(state_dict, os.path.join(tmp_model_path, filename)) + + filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin" + + # Unsharded + # TODO: Deal with weight-tying + state_dict = { + "model.embed_tokens.weight": loaded["transformer.wte.weight"], + "lm_head.weight": loaded["transformer.ff_out.weight"] + if "transformer.ff_out.weight" in loaded + else loaded["transformer.wte.weight"], + } + + for k, v in state_dict.items(): + index_dict["weight_map"][k] = filename + param_count += v.numel() + torch.save(state_dict, os.path.join(tmp_model_path, filename)) + + # Write configs + index_dict["metadata"] = {"total_size": param_count * 2} + write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json")) + + if olmo_config.get("mlp_hidden_size", None) is not None: + intermediate_size = olmo_config["mlp_hidden_size"] // 2 + else: + intermediate_size = (dim * olmo_config["mlp_ratio"]) // 2 + + config = OlmoConfig( + vocab_size=vocab_size, + hidden_size=dim, + intermediate_size=intermediate_size, + num_hidden_layers=n_layers, + num_attention_heads=n_heads, + num_key_value_heads=num_key_value_heads, + max_position_embeddings=max_position_embeddings, + pad_token_id=olmo_config["pad_token_id"], + bos_token_id=None, + eos_token_id=olmo_config["eos_token_id"], + tie_word_embeddings=olmo_config["weight_tying"], + rope_theta=base, + clip_qkv=olmo_config.get("clip_qkv"), + ) + config.save_pretrained(tmp_model_path) + + # Make space so we can load the model properly now. + del state_dict + del loaded + gc.collect() + + if tokenizer_path is not None: + _write_tokenizer(model_path, config, tokenizer_path, fix_eos_token_id) + + print("Loading the checkpoint in a OLMo model.") + model = OlmoForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch.float32, low_cpu_mem_usage=True) + # Avoid saving this as part of the config. + del model.config._name_or_path + print("Saving in the Transformers format.") + model.save_pretrained(model_path, safe_serialization=safe_serialization) + shutil.rmtree(tmp_model_path) + + +def _write_tokenizer( + output_path: Path, config: OlmoConfig, input_tokenizer_path: Path, fix_eos_token_id: bool = True +) -> None: + print(f"Saving a {GPTNeoXTokenizerFast.__name__} to {output_path}.") + + base_tokenizer = Tokenizer.from_file(str(input_tokenizer_path)) + + eos_token_id = config.eos_token_id if config.eos_token_id is not None else base_tokenizer.get_vocab_size() - 1 + pad_token_id = config.pad_token_id if config.pad_token_id is not None else eos_token_id + + if fix_eos_token_id and eos_token_id == 0: + # Fixing a bug in OLMo where eos token id was incorrectly set + print("Changing eos_token_id from 0 to 50279.") + eos_token_id = 50279 + + tokenizer = GPTNeoXTokenizerFast( + tokenizer_object=base_tokenizer, + eos_token=base_tokenizer.decode([eos_token_id], skip_special_tokens=False), + pad_token=base_tokenizer.decode([pad_token_id], skip_special_tokens=False), + unk_token=None, + bos_token=None, + ) + + tokenizer.save_pretrained(output_path) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--input_dir", + required=True, + help="Location of OLMo weights, which contains config.yaml and model.pt.", + ) + parser.add_argument( + "--tokenizer_json_path", + default=None, + help="Location of OLMo tokenizer json file.", + ) + parser.add_argument( + "--output_dir", + required=True, + help="Location to write HF model and tokenizer", + ) + parser.add_argument( + "--no_fix_eos_token_id", + action="store_false", + dest="fix_eos_token_id", + help="If set, does not change eos token id from 0 to 50279 if it is 0. Changing 0 to 50279 is a bug fix, so use this option with care.", + ) + parser.add_argument("--safe_serialization", type=bool, help="Whether or not to save using `safetensors`.") + # Different OLMo versions used different default values for max_position_embeddings, hence the need to be able to specify which version is being used. + args = parser.parse_args() + write_model( + model_path=args.output_dir, + input_base_path=args.input_dir, + safe_serialization=args.safe_serialization, + tokenizer_path=args.tokenizer_json_path, + fix_eos_token_id=args.fix_eos_token_id, + ) + + +if __name__ == "__main__": + main() diff --git a/docs/transformers/src/transformers/models/olmo/modeling_olmo.py b/docs/transformers/src/transformers/models/olmo/modeling_olmo.py new file mode 100644 index 0000000000000000000000000000000000000000..6de200e66285d328fb3c571ca7b6a903c80bb609 --- /dev/null +++ b/docs/transformers/src/transformers/models/olmo/modeling_olmo.py @@ -0,0 +1,814 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/olmo/modular_olmo.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_olmo.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +from typing import Callable, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ( + LossKwargs, + add_start_docstrings, + add_start_docstrings_to_model_forward, + can_return_tuple, + is_torch_flex_attn_available, + logging, + replace_return_docstrings, +) +from .configuration_olmo import OlmoConfig + + +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + +logger = logging.get_logger(__name__) +_CONFIG_FOR_DOC = "OlmoConfig" + + +class OlmoLayerNorm(nn.Module): + """LayerNorm but with no learnable weight or bias.""" + + def __init__(self, hidden_size: int) -> None: + super().__init__() + self.normalized_shape = (hidden_size,) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + orig_dtype = hidden_states.dtype + return F.layer_norm(hidden_states.to(dtype=torch.float32), self.normalized_shape, None, None, eps=1e-5).to( + orig_dtype + ) + + +class OlmoMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class OlmoAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: OlmoConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + if self.config.clip_qkv is not None: + query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + + query_states = query_states.view(hidden_shape).transpose(1, 2) + key_states = key_states.view(hidden_shape).transpose(1, 2) + value_states = value_states.view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class OlmoDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: OlmoConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = OlmoAttention(config=config, layer_idx=layer_idx) + + self.mlp = OlmoMLP(config) + self.input_layernorm = OlmoLayerNorm(config.hidden_size) + self.post_attention_layernorm = OlmoLayerNorm(config.hidden_size) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +class OlmoRotaryEmbedding(nn.Module): + def __init__(self, config: OlmoConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +OLMO_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`OlmoConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Olmo Model outputting raw hidden-states without any specific head on top.", + OLMO_START_DOCSTRING, +) +class OlmoPreTrainedModel(PreTrainedModel): + config_class = OlmoConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["OlmoDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + _supports_attention_backend = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +OLMO_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length) or `BlockMask`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + If the model is configured to use flex_attention, it will attempt to convert the mask Tensor into a BlockMask, + but you can also pass a `BlockMask` object directly here. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare Olmo Model outputting raw hidden-states without any specific head on top.", + OLMO_START_DOCSTRING, +) +class OlmoModel(OlmoPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`OlmoDecoderLayer`] + + Args: + config: OlmoConfig + """ + + def __init__(self, config: OlmoConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [OlmoDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = OlmoLayerNorm(config.hidden_size) + self.rotary_emb = OlmoRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @can_return_tuple + @add_start_docstrings_to_model_forward(OLMO_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> BaseModelOutputWithPast: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache + if not isinstance(past_key_values, (type(None), Cache)): + raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, "BlockMask"], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to place the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + +class OlmoForCausalLM(OlmoPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = OlmoModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @can_return_tuple + @add_start_docstrings_to_model_forward(OLMO_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[KwargsForCausalLM], + ) -> CausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, OlmoForCausalLM + + >>> model = OlmoForCausalLM.from_pretrained("meta-olmo/Olmo-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-olmo/Olmo-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = ["OlmoForCausalLM", "OlmoModel", "OlmoPreTrainedModel"] diff --git a/docs/transformers/src/transformers/models/olmo/modular_olmo.py b/docs/transformers/src/transformers/models/olmo/modular_olmo.py new file mode 100644 index 0000000000000000000000000000000000000000..a2af171557a659d9d5aace40927a724174a5fc3d --- /dev/null +++ b/docs/transformers/src/transformers/models/olmo/modular_olmo.py @@ -0,0 +1,148 @@ +from typing import Callable, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint + +from ...cache_utils import Cache +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...utils import logging +from ..llama.modeling_llama import ( + LlamaAttention, + LlamaDecoderLayer, + LlamaForCausalLM, + LlamaMLP, + LlamaModel, + LlamaPreTrainedModel, + LlamaRotaryEmbedding, + apply_rotary_pos_emb, + eager_attention_forward, +) +from .configuration_olmo import OlmoConfig + + +logger = logging.get_logger(__name__) + + +class OlmoLayerNorm(nn.Module): + """LayerNorm but with no learnable weight or bias.""" + + def __init__(self, hidden_size: int) -> None: + super().__init__() + self.normalized_shape = (hidden_size,) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + orig_dtype = hidden_states.dtype + return F.layer_norm(hidden_states.to(dtype=torch.float32), self.normalized_shape, None, None, eps=1e-5).to( + orig_dtype + ) + + +class OlmoMLP(LlamaMLP): + def __init__(self, config): + super().__init__(config) + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + + +class OlmoAttention(LlamaAttention): + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + if self.config.clip_qkv is not None: + query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + + query_states = query_states.view(hidden_shape).transpose(1, 2) + key_states = key_states.view(hidden_shape).transpose(1, 2) + value_states = value_states.view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class OlmoDecoderLayer(LlamaDecoderLayer): + def __init__(self, config: OlmoConfig, layer_idx: int): + super().__init__(config, layer_idx) + self.input_layernorm = OlmoLayerNorm(config.hidden_size) + self.post_attention_layernorm = OlmoLayerNorm(config.hidden_size) + self.self_attn = OlmoAttention(config=config, layer_idx=layer_idx) + + +class OlmoRotaryEmbedding(LlamaRotaryEmbedding): + pass + + +class OlmoPreTrainedModel(LlamaPreTrainedModel): + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +class OlmoModel(LlamaModel): + def __init__(self, config: OlmoConfig): + super().__init__(config) + self.layers = nn.ModuleList( + [OlmoDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = OlmoLayerNorm(config.hidden_size) + + +class OlmoForCausalLM(LlamaForCausalLM): + pass + + +__all__ = ["OlmoForCausalLM", "OlmoModel", "OlmoPreTrainedModel"] diff --git a/docs/transformers/src/transformers/models/olmo2/__init__.py b/docs/transformers/src/transformers/models/olmo2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e2161a4948b5e32f600af135c33330c2e2c353c7 --- /dev/null +++ b/docs/transformers/src/transformers/models/olmo2/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 EleutherAI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_olmo2 import * + from .modeling_olmo2 import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/olmo2/configuration_olmo2.py b/docs/transformers/src/transformers/models/olmo2/configuration_olmo2.py new file mode 100644 index 0000000000000000000000000000000000000000..3c1f396e0f8414c991411b8c508896166d282228 --- /dev/null +++ b/docs/transformers/src/transformers/models/olmo2/configuration_olmo2.py @@ -0,0 +1,180 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/olmo2/modular_olmo2.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_olmo2.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 + +from ...configuration_utils import PretrainedConfig + + +class Olmo2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Olmo2Model`]. It is used to instantiate an OLMo2 + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the [allenai/Olmo2-7B-1124-hf](https://huggingface.co/allenai/Olmo2-7B-1124-hf). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50304): + Vocabulary size of the Olmo2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Olmo2Model`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*, defaults to 1): + Padding token id. + bos_token_id (`int`, *optional*): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 50279): + End of stream token id. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling + strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is + `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. See the following thread for more information on how + these scaling strategies behave: + https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an + experimental feature, subject to breaking API changes in future versions. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + + ```python + >>> from transformers import Olmo2Model, Olmo2Config + + >>> # Initializing a Olmo2 7B style configuration + >>> configuration = Olmo2Config() + + >>> # Initializing a model from the Olmo2 7B style configuration + >>> model = Olmo2Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "olmo2" + keys_to_ignore_at_inference = ["past_key_values"] + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k + "layers.*.self_attn.k_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k + "layers.*.self_attn.v_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k + "layers.*.self_attn.o_proj": "rowwise_rep", # we need to replicate here due to the added norm on q and k + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=50304, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + use_cache=True, + pad_token_id=1, + bos_token_id=None, + eos_token_id=50279, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + rms_norm_eps=1e-5, + **kwargs, + ): + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self._rope_scaling_validation() + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + self.rms_norm_eps = rms_norm_eps + + def _rope_scaling_validation(self): + """ + Validate the `rope_scaling` configuration. + """ + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: + raise ValueError( + f"`rope_scaling` must be a dictionary with two fields, `type` and `factor`, got {self.rope_scaling}" + ) + rope_scaling_type = self.rope_scaling.get("type", None) + rope_scaling_factor = self.rope_scaling.get("factor", None) + if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: + raise ValueError( + f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" + ) + if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: + raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") + + +__all__ = ["Olmo2Config"] diff --git a/docs/transformers/src/transformers/models/olmo2/convert_olmo2_weights_to_hf.py b/docs/transformers/src/transformers/models/olmo2/convert_olmo2_weights_to_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..1e8fb54ddb654899c63733f74a4e18aa66813983 --- /dev/null +++ b/docs/transformers/src/transformers/models/olmo2/convert_olmo2_weights_to_hf.py @@ -0,0 +1,306 @@ +# Copyright 2024 EleutherAI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import argparse +import gc +import json +import os +import shutil +from pathlib import Path +from typing import Any, Dict + +import torch +import yaml +from tokenizers import Tokenizer + +from transformers import Olmo2Config, Olmo2ForCausalLM +from transformers.models.gpt2.tokenization_gpt2_fast import GPT2TokenizerFast + + +""" +Sample usage: + +``` +python src/transformers/models/olmo2/convert_olmo2_weights_to_hf.py \ + --input_dir /path/to/downloaded/olmo2/weights --model_size 7B --output_dir /output/path +``` + +Thereafter, models can be loaded via: + +```py +from transformers import Olmo2ForCausalLM, AutoTokenizer + +model = Olmo2ForCausalLM.from_pretrained("/output/path") +tokenizer = AutoTokenizer.from_pretrained("/output/path") +``` + +Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions +come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM). +""" + + +def compute_intermediate_size(n, ffn_dim_multiplier=1, multiple_of=256): + return multiple_of * ((int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of) + + +def read_json(path): + with open(path, "r") as f: + return json.load(f) + + +def write_json(text, path): + with open(path, "w") as f: + json.dump(text, f) + + +def write_model( + model_path, + input_base_path, + include_tokenizer=True, + tokenizer_path=None, + safe_serialization=True, + fix_eos_token_id=True, + tmp_cleanup=True, +): + os.makedirs(model_path, exist_ok=True) + tmp_model_path = os.path.join(model_path, "tmp") + os.makedirs(tmp_model_path, exist_ok=True) + + config_path = Path(input_base_path) / "config.yaml" + olmo2_config = yaml.safe_load(config_path.read_text())["model"] + + if not olmo2_config.get("attention_layer_norm", False): + raise RuntimeError("OLMo2 checkpoints must have attention layer norm") + if not olmo2_config.get("norm_after", False): + raise RuntimeError("OLMo2 checkpoints must set norm_after to True") + + n_layers = olmo2_config["n_layers"] + n_heads = olmo2_config["n_heads"] + dim = olmo2_config["d_model"] + dims_per_head = dim // n_heads + base = olmo2_config["rope_theta"] + inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head)) + max_position_embeddings = olmo2_config["max_sequence_length"] + + vocab_size = olmo2_config.get("embedding_size", olmo2_config["vocab_size"]) + + if olmo2_config.get("n_kv_heads", None) is not None: + num_key_value_heads = olmo2_config["n_kv_heads"] # for GQA / MQA + elif olmo2_config["multi_query_attention"]: # compatibility with other checkpoints + num_key_value_heads = 1 + else: + num_key_value_heads = n_heads + + print(f"Fetching all parameters from the checkpoint at {input_base_path}.") + + # Not sharded + # (The sharded implementation would also work, but this is simpler.) + loaded = torch.load(os.path.join(input_base_path, "model.pt"), map_location="cpu", weights_only=True) + + param_count = 0 + index_dict: Dict[str, Any] = {"weight_map": {}} + for layer_i in range(n_layers): + filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin" + # Unsharded + # TODO: Layernorm stuff + # TODO: multi query attention + fused_dims = [dim, dims_per_head * num_key_value_heads, dims_per_head * num_key_value_heads] + q_proj_weight, k_proj_weight, v_proj_weight = torch.split( + loaded[f"transformer.blocks.{layer_i}.att_proj.weight"], fused_dims, dim=0 + ) + up_proj_weight, gate_proj_weight = torch.chunk( + loaded[f"transformer.blocks.{layer_i}.ff_proj.weight"], 2, dim=0 + ) + state_dict = { + f"model.layers.{layer_i}.self_attn.q_proj.weight": q_proj_weight, + f"model.layers.{layer_i}.self_attn.k_proj.weight": k_proj_weight, + f"model.layers.{layer_i}.self_attn.v_proj.weight": v_proj_weight, + f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"transformer.blocks.{layer_i}.attn_out.weight"], + f"model.layers.{layer_i}.self_attn.q_norm.weight": loaded[f"transformer.blocks.{layer_i}.q_norm.weight"], + f"model.layers.{layer_i}.self_attn.k_norm.weight": loaded[f"transformer.blocks.{layer_i}.k_norm.weight"], + f"model.layers.{layer_i}.mlp.gate_proj.weight": gate_proj_weight, + f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"transformer.blocks.{layer_i}.ff_out.weight"], + f"model.layers.{layer_i}.mlp.up_proj.weight": up_proj_weight, + f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[ + f"transformer.blocks.{layer_i}.attn_norm.weight" + ], + f"model.layers.{layer_i}.post_feedforward_layernorm.weight": loaded[ + f"transformer.blocks.{layer_i}.ff_norm.weight" + ], + } + + state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq + + for k, v in state_dict.items(): + index_dict["weight_map"][k] = filename + param_count += v.numel() + torch.save(state_dict, os.path.join(tmp_model_path, filename)) + + filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin" + + # Unsharded + # TODO: Deal with weight-tying + state_dict = { + "model.embed_tokens.weight": loaded["transformer.wte.weight"], + "model.norm.weight": loaded["transformer.ln_f.weight"], + "lm_head.weight": loaded["transformer.ff_out.weight"] + if "transformer.ff_out.weight" in loaded + else loaded["transformer.wte.weight"], + } + + for k, v in state_dict.items(): + index_dict["weight_map"][k] = filename + param_count += v.numel() + torch.save(state_dict, os.path.join(tmp_model_path, filename)) + + # Write configs + index_dict["metadata"] = {"total_size": param_count * 2} + write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json")) + + if olmo2_config.get("mlp_hidden_size", None) is not None: + intermediate_size = olmo2_config["mlp_hidden_size"] // 2 + else: + intermediate_size = (dim * olmo2_config["mlp_ratio"]) // 2 + + if fix_eos_token_id and olmo2_config["eos_token_id"] == 0: + # Fixing a bug in OLMo where eos token id was incorrectly set + print("Changing eos_token_id from 0 to 50279.") + olmo2_config["eos_token_id"] = 50279 + + config = Olmo2Config( + vocab_size=vocab_size, + hidden_size=dim, + intermediate_size=intermediate_size, + num_hidden_layers=n_layers, + num_attention_heads=n_heads, + num_key_value_heads=num_key_value_heads, + max_position_embeddings=max_position_embeddings, + pad_token_id=olmo2_config["pad_token_id"], + bos_token_id=None, + eos_token_id=olmo2_config["eos_token_id"], + tie_word_embeddings=olmo2_config["weight_tying"], + rms_norm_eps=olmo2_config["layer_norm_eps"], + rope_theta=base, + ) + config.save_pretrained(tmp_model_path) + + # Make space so we can load the model properly now. + del state_dict + del loaded + gc.collect() + + if include_tokenizer: + _write_tokenizer(model_path, config, input_base_path, tokenizer_path) + + print("Loading the checkpoint in a OLMo2 model.") + model = Olmo2ForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch.float32, low_cpu_mem_usage=True) + # Avoid saving this as part of the config. + del model.config._name_or_path + print("Saving in the Transformers format.") + model.save_pretrained(model_path, safe_serialization=safe_serialization) + if tmp_cleanup: + # Make cleanup optional; attempting to `rmtree` the `tmp_model_path` causes + # errors if using NFS. + shutil.rmtree(tmp_model_path) + + +def _write_tokenizer( + output_path: Path, + config: Olmo2Config, + checkpoint_dir: str, + input_tokenizer_path: Path | None, +) -> None: + print(f"Saving a {GPT2TokenizerFast.__name__} to {output_path}.") + + if input_tokenizer_path is not None: + base_tokenizer = Tokenizer.from_file(str(input_tokenizer_path)) + else: + config_path = Path(checkpoint_dir) / "config.yaml" + tokenizer_config = yaml.safe_load(config_path.read_text())["tokenizer"] + + # Initialize tokenizer and validate vocab size. + if Path(tokenizer_config["identifier"]).is_file(): + base_tokenizer = Tokenizer.from_file(tokenizer_config["identifier"]) + else: + base_tokenizer = Tokenizer.from_pretrained(tokenizer_config["identifier"]) + + eos_token_id = config.eos_token_id if config.eos_token_id is not None else base_tokenizer.get_vocab_size() - 1 + pad_token_id = config.pad_token_id if config.pad_token_id is not None else eos_token_id + + tokenizer = GPT2TokenizerFast( + tokenizer_object=base_tokenizer, + eos_token=base_tokenizer.decode([eos_token_id], skip_special_tokens=False), + pad_token=base_tokenizer.decode([pad_token_id], skip_special_tokens=False), + ) + + tokenizer.save_pretrained(output_path) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--input_dir", + required=True, + help="Location of OLMo2 weights, which contains config.yaml and model.pt.", + ) + parser.add_argument( + "--no_tokenizer", + action="store_false", + dest="include_tokenizer", + help="If set, do not convert OLMo tokenizer to HF tokenizer.", + ) + parser.add_argument( + "--tokenizer_json_path", + type=Path, + default=None, + help="Location of OLMo2 tokenizer json file. Defaults to what is set in the config file.", + ) + parser.add_argument( + "--output_dir", + required=True, + help="Location to write HF model and tokenizer", + ) + parser.add_argument( + "--no_fix_eos_token_id", + action="store_false", + dest="fix_eos_token_id", + help="If set, does not change eos token id from 0 to 50279 if it is 0. Changing 0 to 50279 is a bug fix, so use this option with care.", + ) + parser.add_argument( + "--no_tmp_cleanup", + action="store_false", + dest="tmp_cleanup", + help="If passed, don't remove temp dir at end of HF conversion.", + ) + parser.add_argument( + "--no_safe_serialization", + action="store_false", + dest="safe_serialization", + help="Whether or not to save using `safetensors`.", + ) + args = parser.parse_args() + write_model( + model_path=args.output_dir, + input_base_path=args.input_dir, + safe_serialization=args.safe_serialization, + include_tokenizer=args.include_tokenizer, + tokenizer_path=args.tokenizer_json_path, + fix_eos_token_id=args.fix_eos_token_id, + tmp_cleanup=args.tmp_cleanup, + ) + + +if __name__ == "__main__": + main() diff --git a/docs/transformers/src/transformers/models/olmo2/modeling_olmo2.py b/docs/transformers/src/transformers/models/olmo2/modeling_olmo2.py new file mode 100644 index 0000000000000000000000000000000000000000..31e6805cfde8cbb5adb90c79adec9694d9931e44 --- /dev/null +++ b/docs/transformers/src/transformers/models/olmo2/modeling_olmo2.py @@ -0,0 +1,820 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/olmo2/modular_olmo2.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_olmo2.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +from typing import Callable, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ( + LossKwargs, + add_start_docstrings, + add_start_docstrings_to_model_forward, + can_return_tuple, + is_torch_flex_attn_available, + logging, + replace_return_docstrings, +) +from .configuration_olmo2 import Olmo2Config + + +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + +logger = logging.get_logger(__name__) +_CONFIG_FOR_DOC = "Olmo2Config" + + +@use_kernel_forward_from_hub("RMSNorm") +class Olmo2RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Olmo2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class Olmo2Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Olmo2Config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.q_norm = Olmo2RMSNorm(config.num_attention_heads * self.head_dim, config.rms_norm_eps) + self.k_norm = Olmo2RMSNorm(config.num_key_value_heads * self.head_dim, config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_norm(self.q_proj(hidden_states)) + key_states = self.k_norm(self.k_proj(hidden_states)) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(hidden_shape).transpose(1, 2) + key_states = key_states.view(hidden_shape).transpose(1, 2) + value_states = value_states.view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Olmo2MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class Olmo2DecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Olmo2Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = Olmo2Attention(config=config, layer_idx=layer_idx) + + self.mlp = Olmo2MLP(config) + self.post_attention_layernorm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_feedforward_layernorm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +class Olmo2RotaryEmbedding(nn.Module): + def __init__(self, config: Olmo2Config, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +OLMO2_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Olmo2Config`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Olmo2 Model outputting raw hidden-states without any specific head on top.", + OLMO2_START_DOCSTRING, +) +class Olmo2PreTrainedModel(PreTrainedModel): + config_class = Olmo2Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Olmo2DecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + _supports_attention_backend = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, Olmo2RMSNorm): + module.weight.data.fill_(1.0) + + +OLMO2_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length) or `BlockMask`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + If the model is configured to use flex_attention, it will attempt to convert the mask Tensor into a BlockMask, + but you can also pass a `BlockMask` object directly here. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare Olmo2 Model outputting raw hidden-states without any specific head on top.", + OLMO2_START_DOCSTRING, +) +class Olmo2Model(Olmo2PreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Olmo2DecoderLayer`] + + Args: + config: Olmo2Config + """ + + def __init__(self, config: Olmo2Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Olmo2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Olmo2RotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @can_return_tuple + @add_start_docstrings_to_model_forward(OLMO2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> BaseModelOutputWithPast: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache + if not isinstance(past_key_values, (type(None), Cache)): + raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, "BlockMask"], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to place the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + +class Olmo2ForCausalLM(Olmo2PreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = Olmo2Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @can_return_tuple + @add_start_docstrings_to_model_forward(OLMO2_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[KwargsForCausalLM], + ) -> CausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, Olmo2ForCausalLM + + >>> model = Olmo2ForCausalLM.from_pretrained("meta-olmo2/Olmo2-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-olmo2/Olmo2-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = ["Olmo2ForCausalLM", "Olmo2Model", "Olmo2PreTrainedModel"] diff --git a/docs/transformers/src/transformers/models/olmo2/modular_olmo2.py b/docs/transformers/src/transformers/models/olmo2/modular_olmo2.py new file mode 100644 index 0000000000000000000000000000000000000000..c43263e9545e2e1bc9cd78372770be011d1771cd --- /dev/null +++ b/docs/transformers/src/transformers/models/olmo2/modular_olmo2.py @@ -0,0 +1,320 @@ +from typing import Callable, Optional, Tuple + +import torch +import torch.nn as nn + +from ...cache_utils import Cache +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...pytorch_utils import ALL_LAYERNORM_LAYERS +from ...utils import logging +from ..llama.modeling_llama import LlamaPreTrainedModel, LlamaRMSNorm, eager_attention_forward +from ..olmo.configuration_olmo import OlmoConfig +from ..olmo.modeling_olmo import ( + OlmoAttention, + OlmoDecoderLayer, + OlmoForCausalLM, + OlmoModel, + OlmoRotaryEmbedding, + apply_rotary_pos_emb, +) + + +logger = logging.get_logger(__name__) + + +class Olmo2Config(OlmoConfig): + r""" + This is the configuration class to store the configuration of a [`Olmo2Model`]. It is used to instantiate an OLMo2 + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the [allenai/Olmo2-7B-1124-hf](https://huggingface.co/allenai/Olmo2-7B-1124-hf). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50304): + Vocabulary size of the Olmo2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Olmo2Model`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*, defaults to 1): + Padding token id. + bos_token_id (`int`, *optional*): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 50279): + End of stream token id. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling + strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is + `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. See the following thread for more information on how + these scaling strategies behave: + https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an + experimental feature, subject to breaking API changes in future versions. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + + ```python + >>> from transformers import Olmo2Model, Olmo2Config + + >>> # Initializing a Olmo2 7B style configuration + >>> configuration = Olmo2Config() + + >>> # Initializing a model from the Olmo2 7B style configuration + >>> model = Olmo2Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "olmo2" + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k + "layers.*.self_attn.k_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k + "layers.*.self_attn.v_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k + "layers.*.self_attn.o_proj": "rowwise_rep", # we need to replicate here due to the added norm on q and k + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=50304, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + use_cache=True, + pad_token_id=1, + bos_token_id=None, + eos_token_id=50279, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + rms_norm_eps=1e-5, + **kwargs, + ): + super().__init__( + vocab_size=vocab_size, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_hidden_layers=num_hidden_layers, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + hidden_act=hidden_act, + max_position_embeddings=max_position_embeddings, + initializer_range=initializer_range, + use_cache=use_cache, + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + attention_bias=attention_bias, + attention_dropout=attention_dropout, + **kwargs, + ) + + self.rms_norm_eps = rms_norm_eps + del self.clip_qkv + + +class Olmo2RMSNorm(LlamaRMSNorm): + pass + + +ALL_LAYERNORM_LAYERS.append(Olmo2RMSNorm) + + +# Olmo2 attention is identical to OLMo attention except: +# - Norm is applied to attention queries and keys. +# - No qkv clipping. +class Olmo2Attention(OlmoAttention): + def __init__(self, config: Olmo2Config, layer_idx: Optional[int] = None): + super().__init__(config, layer_idx=layer_idx) + self.q_norm = Olmo2RMSNorm(config.num_attention_heads * self.head_dim, config.rms_norm_eps) + self.k_norm = Olmo2RMSNorm(config.num_key_value_heads * self.head_dim, config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_norm(self.q_proj(hidden_states)) + key_states = self.k_norm(self.k_proj(hidden_states)) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(hidden_shape).transpose(1, 2) + key_states = key_states.view(hidden_shape).transpose(1, 2) + value_states = value_states.view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +# The OLMo2 layers are identical to those of the OLMo model except: +# - RMSNorm is used instead of standard layer norm. +# - Norm is applied after attention/feedforward rather than before. +class Olmo2DecoderLayer(OlmoDecoderLayer): + def __init__(self, config: Olmo2Config, layer_idx: int): + super().__init__(config, layer_idx=layer_idx) + self.post_attention_layernorm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_feedforward_layernorm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.self_attn = Olmo2Attention(config=config, layer_idx=layer_idx) + del self.input_layernorm + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +class Olmo2RotaryEmbedding(OlmoRotaryEmbedding): + pass + + +class Olmo2PreTrainedModel(LlamaPreTrainedModel): + pass + + +# The OLMo2 model is identical to the OLMo model, except RMSNorm is used instead of +# standard layer norm for the output norm. +class Olmo2Model(OlmoModel): + def __init__(self, config: Olmo2Config): + super().__init__(config) + self.norm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.layers = nn.ModuleList( + [Olmo2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + + +# The heads now only need to redefine the model inside to the correct `RobertaModel` +class Olmo2ForCausalLM(OlmoForCausalLM): + pass + + +__all__ = [ + "Olmo2Config", + "Olmo2ForCausalLM", + "Olmo2Model", + "Olmo2PreTrainedModel", # noqa: F822 +] diff --git a/docs/transformers/src/transformers/models/olmoe/__init__.py b/docs/transformers/src/transformers/models/olmoe/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e0e18d77cbad11bf218b6cd304f703b4d7935eef --- /dev/null +++ b/docs/transformers/src/transformers/models/olmoe/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_olmoe import * + from .modeling_olmoe import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/olmoe/configuration_olmoe.py b/docs/transformers/src/transformers/models/olmoe/configuration_olmoe.py new file mode 100644 index 0000000000000000000000000000000000000000..0f24d5523bafd0398e6985959ffdd223261cc5ca --- /dev/null +++ b/docs/transformers/src/transformers/models/olmoe/configuration_olmoe.py @@ -0,0 +1,182 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""OLMoE model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...modeling_rope_utils import rope_config_validation + + +class OlmoeConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`OlmoeModel`]. It is used to instantiate an OLMoE + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the [allenai/OLMoE-1B-7B-0924](https://huggingface.co/allenai/OLMoE-1B-7B-0924). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50304): + Vocabulary size of the OLMoE model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`OlmoeModel`] + hidden_size (`int`, *optional*, defaults to 2048): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 2048): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 16): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 4096): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*, defaults to 1): + Padding token id. + bos_token_id (`int`, *optional*): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 50279): + End of stream token id. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling + strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is + `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. See the following thread for more information on how + these scaling strategies behave: + https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an + experimental feature, subject to breaking API changes in future versions. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + clip_qkv (`float`, *optional*): + If not `None`, elements of query, key and value attention states are clipped so that their + absolute value does not exceed this value. + num_experts_per_tok (`int`, *optional*, defaults to 8): + Number of selected experts. + num_experts (`int`, *optional*, defaults to 64): + Number of routed experts. + output_router_logits (`bool`, *optional*, defaults to `False`): + Whether or not the router logits should be returned by the model. Enabeling this will also + allow the model to output the auxiliary loss, including load balancing loss and router z-loss. + router_aux_loss_coef (`float`, *optional*, defaults to 0.01): + The aux loss factor for the total loss. + norm_topk_prob (`bool`, *optional*, defaults to `False`): + Whether to normalize the topk probabilities. + + ```python + >>> from transformers import OlmoeModel, OlmoeConfig + + >>> # Initializing a OLMoE 7B A1B style configuration + >>> configuration = OlmoeConfig() + + >>> # Initializing a model from the OLMoE 7B A1B style configuration + >>> model = OlmoeModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "olmoe" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=50304, + hidden_size=2048, + intermediate_size=2048, + num_hidden_layers=16, + num_attention_heads=16, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=4096, + initializer_range=0.02, + rms_norm_eps=1e-05, + use_cache=True, + pad_token_id=1, + bos_token_id=None, + eos_token_id=50279, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + clip_qkv=None, + num_experts_per_tok=8, + num_experts=64, + output_router_logits=False, + router_aux_loss_coef=0.01, + norm_topk_prob=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.clip_qkv = clip_qkv + self.num_experts_per_tok = num_experts_per_tok + self.num_experts = num_experts + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + self.norm_topk_prob = norm_topk_prob + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, move it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +__all__ = ["OlmoeConfig"] diff --git a/docs/transformers/src/transformers/models/olmoe/convert_olmoe_weights_to_hf.py b/docs/transformers/src/transformers/models/olmoe/convert_olmoe_weights_to_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..3fc5a49c7e5e47c0a4f14b2fe1fc316a6426679d --- /dev/null +++ b/docs/transformers/src/transformers/models/olmoe/convert_olmoe_weights_to_hf.py @@ -0,0 +1,281 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Example for running: +0. Cp ckpts to local +aws s3 cp --recursive s3://ai2-llm/checkpoints/OLMoE/olmoe-8x1b-newhp-newds-final-annealFrom1200000/step23842 /data/niklas/llm/checkpoints/olmoe-8x1b-newhp-newds-final-annealFrom1200000_step23842 +1. Unshard your OLMoE checkpoint using https://github.com/allenai/OLMo/blob/7d63fe09d23cf23714da5aa633a44a90180195da/scripts/unshard.py +python OLMo/scripts/unshard.py /data/niklas/llm/checkpoints/23485/step954000 /data/niklas/llm/checkpoints/1b-954000-unsharded --model-only +python OLMo/scripts/unshard.py /data/niklas/llm/checkpoints/23485/step954000 /data/niklas/llm/checkpoints/1b-954000-unsharded --model-only +python OLMo/scripts/unshard.py /data/niklas/llm/checkpoints/olmoe-8x1b-newhp-newds-final-annealFrom1200000_step23842 /data/niklas/llm/checkpoints/olmoe-8x1b-newhp-newds-final-annealFrom1200000_step23842-unsharded --model-only +2. Convert to transformers +rm -rf olmoe; mkdir olmoe; python /data/niklas/transformers/src/transformers/models/olmoe/convert_olmoe_weights_to_hf.py --input_dir /data/niklas/llm/checkpoints/olmoe-8x1b-newhp-newds-final-annealFrom1200000_step23842-unsharded --tokenizer_json_path /data/niklas/llm/checkpoints/olmoe-step1200000-unsharded/tokenizer.json --output_dir olmoe +3. Load model via: +``` +from transformers import OlmoeForCausalLM, AutoTokenizer +import torch +model = OlmoeForCausalLM.from_pretrained("../transformers/olmoe", torch_dtype=torch.bfloat16).cuda() +model = OlmoeForCausalLM.from_pretrained("../transformers/olmoe").cuda() +tokenizer = AutoTokenizer.from_pretrained("../transformers/olmoe") +inputs = tokenizer("Bitcoin is", return_tensors="pt") +inputs = {k: v.cuda() for k, v in inputs.items()} +out = model.generate(**inputs, max_length=64) +print(tokenizer.decode(out[0])) +# > # Bitcoin is a digital currency that is created and held electronically. No one controls it. Bitcoins aren’t printed, like dollars or euros – they’re produced by people and businesses running computers all around the world, using software that solves mathematical +# Or quick sanity check: +o = model(torch.tensor([[0, 1]]).cuda()) +# If the checkpoint is not converted to BF16 but kept in FP32: +# > # Bitcoin is a digital currency that is not controlled by any central authority. It is a peer-to-peer payment system that allows users to send and receive payments from anywhere in the world. Bitcoin is also known as a cryptocurrency because it uses cryptography to secure transactions and prevent fraud. +``` + +Note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions +come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM). + +Compare with OLMo codebase: +``` +from olmo.model import OLMo +import torch +model = OLMo.from_checkpoint("/data/niklas/llm/checkpoints/olmoe-step1200000-unsharded-pt") +model = model.cuda() +model = model.to(torch.bfloat16) +from transformers import AutoTokenizer +tokenizer = AutoTokenizer.from_pretrained("../transformers/olmoe") +inputs = tokenizer("Bitcoin is", return_tensors="pt") +inputs = {k: v.cuda() for k, v in inputs.items()} +out = model.generate(**inputs) +print(tokenizer.decode(out[0][0][0])) +# Bitcoin is a digital currency that is created and held electronically. No one controls it. Bitcoins aren’t printed, like dollars or euros – they’re produced by people and businesses running computers all around the world, using software that solves mathematical problems. It’s the first example of a growing category of money +# Or quick sanity check: +o = model(torch.tensor([[0, 1]]).cuda()) +``` +""" + +import argparse +import gc +import json +import os +import shutil +from pathlib import Path + +import torch +import yaml +from tokenizers import Tokenizer + +from transformers import OlmoeConfig, OlmoeForCausalLM +from transformers.models.gpt_neox.tokenization_gpt_neox_fast import GPTNeoXTokenizerFast + + +def compute_intermediate_size(n, ffn_dim_multiplier=1, multiple_of=256): + return multiple_of * ((int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of) + + +def read_json(path): + with open(path, "r") as f: + return json.load(f) + + +def write_json(text, path): + with open(path, "w") as f: + json.dump(text, f) + + +def write_model(model_path, input_base_path, tokenizer_path=None, safe_serialization=True, fix_eos_token_id=True): + os.makedirs(model_path, exist_ok=True) + tmp_model_path = os.path.join(model_path, "tmp") + os.makedirs(tmp_model_path, exist_ok=True) + + config_path = Path(input_base_path) / "config.yaml" + olmoe_config = yaml.safe_load(config_path.read_text())["model"] + + if fix_eos_token_id: + olmoe_config["eos_token_id"] = 50279 + + n_layers = olmoe_config["n_layers"] + n_heads = olmoe_config["n_heads"] + dim = olmoe_config["d_model"] + dims_per_head = dim // n_heads + base = 10000.0 + inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head)) + max_position_embeddings = olmoe_config["max_sequence_length"] + + vocab_size = olmoe_config.get("embedding_size", olmoe_config["vocab_size"]) + + if olmoe_config.get("n_kv_heads", None) is not None: + num_key_value_heads = olmoe_config["n_kv_heads"] # for GQA / MQA + elif olmoe_config["multi_query_attention"]: # compatibility with other checkpoints + num_key_value_heads = 1 + else: + num_key_value_heads = n_heads + + print(f"Fetching all parameters from the checkpoint at {input_base_path}.") + + # Not sharded + loaded = torch.load(os.path.join(input_base_path, "model.pt"), map_location="cpu", weights_only=True) + + param_count = 0 + index_dict = {"weight_map": {}} + for layer_i in range(n_layers): + filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin" + fused_dims = [dim, dims_per_head * num_key_value_heads, dims_per_head * num_key_value_heads] + q_proj_weight, k_proj_weight, v_proj_weight = torch.split( + loaded[f"transformer.blocks.{layer_i}.att_proj.weight"], fused_dims, dim=0 + ) + state_dict = { + f"model.layers.{layer_i}.self_attn.q_proj.weight": q_proj_weight, + f"model.layers.{layer_i}.self_attn.k_proj.weight": k_proj_weight, + f"model.layers.{layer_i}.self_attn.v_proj.weight": v_proj_weight, + f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"transformer.blocks.{layer_i}.attn_out.weight"], + f"model.layers.{layer_i}.self_attn.q_norm.weight": loaded[f"transformer.blocks.{layer_i}.q_norm.weight"], + f"model.layers.{layer_i}.self_attn.k_norm.weight": loaded[f"transformer.blocks.{layer_i}.k_norm.weight"], + f"model.layers.{layer_i}.mlp.gate.weight": loaded[f"transformer.blocks.{layer_i}.ffn.router.layer.weight"], + f"model.layers.{layer_i}.input_layernorm.weight": loaded[f"transformer.blocks.{layer_i}.attn_norm.weight"], + f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[ + f"transformer.blocks.{layer_i}.ff_norm.weight" + ], + } + + num_experts = loaded[f"transformer.blocks.{layer_i}.ffn.router.layer.weight"].shape[0] + dim_per_expert = loaded[f"transformer.blocks.{layer_i}.ffn.experts.mlp.w1"].shape[0] // num_experts + for expert_i in range(num_experts): + state_dict[f"model.layers.{layer_i}.mlp.experts.{expert_i}.gate_proj.weight"] = loaded[ + f"transformer.blocks.{layer_i}.ffn.experts.mlp.w1" + ][dim_per_expert * expert_i : dim_per_expert * (expert_i + 1), :] + state_dict[f"model.layers.{layer_i}.mlp.experts.{expert_i}.up_proj.weight"] = loaded[ + f"transformer.blocks.{layer_i}.ffn.experts.mlp.v1" + ][dim_per_expert * expert_i : dim_per_expert * (expert_i + 1), :] + state_dict[f"model.layers.{layer_i}.mlp.experts.{expert_i}.down_proj.weight"] = loaded[ + f"transformer.blocks.{layer_i}.ffn.experts.mlp.w2" + ][dim_per_expert * expert_i : dim_per_expert * (expert_i + 1), :].T.contiguous() + + state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq + + for k, v in state_dict.items(): + index_dict["weight_map"][k] = filename + param_count += v.numel() + torch.save(state_dict, os.path.join(tmp_model_path, filename)) + + filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin" + + # Unsharded + state_dict = { + "model.embed_tokens.weight": loaded["transformer.wte.weight"], + "lm_head.weight": loaded["transformer.ff_out.weight"], + "model.norm.weight": loaded["transformer.ln_f.weight"], + } + + for k, v in state_dict.items(): + index_dict["weight_map"][k] = filename + param_count += v.numel() + torch.save(state_dict, os.path.join(tmp_model_path, filename)) + + # Write configs + index_dict["metadata"] = {"total_size": param_count * 2} + write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json")) + + config = OlmoeConfig( + vocab_size=vocab_size, + hidden_size=dim, + intermediate_size=dim_per_expert, + num_hidden_layers=n_layers, + num_attention_heads=n_heads, + num_key_value_heads=num_key_value_heads, + max_position_embeddings=max_position_embeddings, + pad_token_id=olmoe_config["pad_token_id"], + bos_token_id=None, + eos_token_id=olmoe_config["eos_token_id"], + tie_word_embeddings=olmoe_config["weight_tying"], + rope_theta=base, + clip_qkv=olmoe_config.get("clip_qkv"), + ) + config.save_pretrained(tmp_model_path) + + # Make space so we can load the model properly now. + del state_dict + del loaded + gc.collect() + + if tokenizer_path is not None: + _write_tokenizer(model_path, config, tokenizer_path, fix_eos_token_id) + + print("Loading the checkpoint in a OLMoE model.") + model = OlmoeForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch.bfloat16) + # Avoid saving this as part of the config. + del model.config._name_or_path + print("Saving in the Transformers format.") + model.save_pretrained(model_path, safe_serialization=safe_serialization) + shutil.rmtree(tmp_model_path) + + +def _write_tokenizer( + output_path: Path, config: OlmoeConfig, input_tokenizer_path: Path, fix_eos_token_id: bool = True +) -> None: + print(f"Saving a {GPTNeoXTokenizerFast.__name__} to {output_path}.") + + base_tokenizer = Tokenizer.from_file(str(input_tokenizer_path)) + + eos_token_id = config.eos_token_id if config.eos_token_id is not None else base_tokenizer.get_vocab_size() - 1 + pad_token_id = config.pad_token_id if config.pad_token_id is not None else eos_token_id + + if fix_eos_token_id and eos_token_id == 0: + # Fixing a bug in OLMo where eos token id was incorrectly set + print("Changing eos_token_id from 0 to 50279.") + eos_token_id = 50279 + + tokenizer = GPTNeoXTokenizerFast( + tokenizer_object=base_tokenizer, + eos_token=base_tokenizer.decode([eos_token_id], skip_special_tokens=False), + pad_token=base_tokenizer.decode([pad_token_id], skip_special_tokens=False), + unk_token=None, + bos_token=None, + ) + + tokenizer.save_pretrained(output_path) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--input_dir", + required=True, + help="Location of OLMoE weights, which contains config.yaml and model.pt.", + ) + parser.add_argument( + "--tokenizer_json_path", + default=None, + help="Location of OLMoE tokenizer json file.", + ) + parser.add_argument( + "--output_dir", + required=True, + help="Location to write HF model and tokenizer", + ) + parser.add_argument( + "--no_fix_eos_token_id", + action="store_false", + dest="fix_eos_token_id", + help="If set, does not change eos token id from 0 to 50279 if it is 0. Changing 0 to 50279 is a bug fix, so use this option with care.", + ) + parser.add_argument( + "--safe_serialization", type=bool, default=True, help="Whether or not to save using `safetensors`." + ) + args = parser.parse_args() + write_model( + model_path=args.output_dir, + input_base_path=args.input_dir, + safe_serialization=args.safe_serialization, + tokenizer_path=args.tokenizer_json_path, + fix_eos_token_id=args.fix_eos_token_id, + ) + + +if __name__ == "__main__": + main() diff --git a/docs/transformers/src/transformers/models/olmoe/modeling_olmoe.py b/docs/transformers/src/transformers/models/olmoe/modeling_olmoe.py new file mode 100644 index 0000000000000000000000000000000000000000..429ea28413a4cb18c47d4badb29558ee0b6e5efb --- /dev/null +++ b/docs/transformers/src/transformers/models/olmoe/modeling_olmoe.py @@ -0,0 +1,1273 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch OLMoE model.""" + +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_outputs import ( + MoeCausalLMOutputWithPast, + MoeModelOutputWithPast, +) +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import ALL_LAYERNORM_LAYERS +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from .configuration_olmoe import OlmoeConfig + + +if is_flash_attn_available(): + from ...modeling_flash_attention_utils import _flash_attention_forward + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "OlmoeConfig" + + +# Copied from transformers.models.mixtral.modeling_mixtral.load_balancing_loss_func +def load_balancing_loss_func( + gate_logits: Union[torch.Tensor, Tuple[torch.Tensor], None], + num_experts: Optional[int] = None, + top_k=2, + attention_mask: Optional[torch.Tensor] = None, +) -> Union[torch.Tensor, int]: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + gate_logits: + Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of + shape [batch_size X sequence_length, num_experts]. + num_experts: + Number of experts + top_k: + The number of experts to route per-token, can be also interpreted as the `top-k` routing + parameter. + attention_mask (`torch.Tensor`, *optional*): + The attention_mask used in forward function + shape [batch_size X sequence_length] if not None. + + Returns: + The auxiliary loss. + """ + if gate_logits is None or not isinstance(gate_logits, tuple): + return 0 + + if isinstance(gate_logits, tuple): + compute_device = gate_logits[0].device + concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) + + routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) + + _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) + + if attention_mask is None: + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.mean(routing_weights, dim=0) + else: + batch_size, sequence_length = attention_mask.shape + num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) + + # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask + expert_attention_mask = ( + attention_mask[None, :, :, None, None] + .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)) + .reshape(-1, top_k, num_experts) + .to(compute_device) + ) + + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( + expert_attention_mask, dim=0 + ) + + # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert + router_per_expert_attention_mask = ( + attention_mask[None, :, :, None] + .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) + .reshape(-1, num_experts) + .to(compute_device) + ) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( + router_per_expert_attention_mask, dim=0 + ) + + overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + return overall_loss * num_experts + + +class OlmoeRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-5): + """ + OlmoeRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +ALL_LAYERNORM_LAYERS.append(OlmoeRMSNorm) + + +# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Olmoe +class OlmoeRotaryEmbedding(nn.Module): + def __init__(self, config: OlmoeConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +# Copied from transformers.models.olmo.modeling_olmo.OlmoMLP with Olmo->Olmoe +class OlmoeMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class OlmoeAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: OlmoeConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) + self.q_norm = OlmoeRMSNorm(self.hidden_size, eps=config.rms_norm_eps) + self.k_norm = OlmoeRMSNorm( + (self.hidden_size // self.num_heads) * self.num_key_value_heads, eps=config.rms_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_norm(self.q_proj(hidden_states)) + key_states = self.k_norm(self.k_proj(hidden_states)) + value_states = self.v_proj(hidden_states) + + if self.config.clip_qkv is not None: + query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class OlmoeFlashAttention2(OlmoeAttention): + """ + OLMoE flash attention module. This module inherits from `OlmoeAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_norm(self.q_proj(hidden_states)) + key_states = self.k_norm(self.k_proj(hidden_states)) + value_states = self.v_proj(hidden_states) + if self.config.clip_qkv is not None: + query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (OlmoeRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class OlmoeSdpaAttention(OlmoeAttention): + """ + OLMoE attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `OlmoeAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from OlmoeAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "OlmoeModel is using OlmoeSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_norm(self.q_proj(hidden_states)) + key_states = self.k_norm(self.k_proj(hidden_states)) + value_states = self.v_proj(hidden_states) + + if self.config.clip_qkv is not None: + query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + # if attention_mask is not None and cache_position is not None: + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +OLMOE_ATTENTION_CLASSES = { + "eager": OlmoeAttention, + "flash_attention_2": OlmoeFlashAttention2, + "sdpa": OlmoeSdpaAttention, +} + + +class OlmoeSparseMoeBlock(nn.Module): + def __init__(self, config): + super().__init__() + self.num_experts = config.num_experts + self.top_k = config.num_experts_per_tok + self.norm_topk_prob = config.norm_topk_prob + self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False) + self.experts = nn.ModuleList([OlmoeMLP(config) for _ in range(self.num_experts)]) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + if self.norm_topk_prob: + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be selected + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + # Loop over all available experts in the model and perform the computation on each expert + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) + + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] + + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states, router_logits + + +class OlmoeDecoderLayer(nn.Module): + def __init__(self, config: OlmoeConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = OLMOE_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + + self.mlp = OlmoeSparseMoeBlock(config) + self.input_layernorm = OlmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = OlmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + output_router_logits: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, + and should not be returned during inference. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states, router_logits = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + if output_router_logits: + outputs += (router_logits,) + + return outputs + + +OLMOE_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`OlmoeConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Olmoe Model outputting raw hidden-states without any specific head on top.", + OLMOE_START_DOCSTRING, +) +class OlmoePreTrainedModel(PreTrainedModel): + config_class = OlmoeConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["OlmoeDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, OlmoeRMSNorm): + module.weight.data.fill_(1.0) + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +OLMOE_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, and + should not be returned during inference. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare Olmoe Model outputting raw hidden-states without any specific head on top.", + OLMOE_START_DOCSTRING, +) +# TODO: re-enable check: Copied from transformers.models.llama.modeling_llama.LlamaModel with Llama->Olmoe +class OlmoeModel(OlmoePreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`OlmoeDecoderLayer`] + + Args: + config: OlmoeConfig + """ + + def __init__(self, config: OlmoeConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [OlmoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = OlmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = OlmoeRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(OLMOE_INPUTS_DOCSTRING) + # Ignore copy + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, MoeModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + if past_key_values is None: + past_key_values = DynamicCache() + else: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " + "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " + "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" + ) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + # embed positions + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_router_logits = () if output_router_logits else None + next_decoder_cache = None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + output_router_logits, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if output_router_logits and layer_outputs[-1] is not None: + all_router_logits += (layer_outputs[-1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + router_logits=all_router_logits, + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to place the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +class OlmoeForCausalLM(OlmoePreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = OlmoeModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.router_aux_loss_coef = config.router_aux_loss_coef + self.num_experts = config.num_experts + self.num_experts_per_tok = config.num_experts_per_tok + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(OLMOE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **loss_kwargs, + ) -> Union[Tuple, MoeCausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, OlmoeForCausalLM + + >>> model = OlmoeForCausalLM.from_pretrained("allenai/OLMoE-1B-7B-0924") + >>> tokenizer = AutoTokenizer.from_pretrained("allenai/OLMoE-1B-7B-0924") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + 'Hey, are you conscious? Can you talk to me?\nI’m not sure if you’re conscious of this, but I’m' + ``` + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits if return_dict else outputs[-1], + self.num_experts, + self.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device + + if not return_dict: + output = (logits,) + outputs[1:] + if output_router_logits: + output = (aux_loss,) + output + return (loss,) + output if loss is not None else output + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) + + +__all__ = ["OlmoeForCausalLM", "OlmoeModel", "OlmoePreTrainedModel"] diff --git a/docs/transformers/src/transformers/models/omdet_turbo/__init__.py b/docs/transformers/src/transformers/models/omdet_turbo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a8b49479b2599b9cc689f9ef281b61fd32ae01c9 --- /dev/null +++ b/docs/transformers/src/transformers/models/omdet_turbo/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_omdet_turbo import * + from .modeling_omdet_turbo import * + from .processing_omdet_turbo import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/omdet_turbo/configuration_omdet_turbo.py b/docs/transformers/src/transformers/models/omdet_turbo/configuration_omdet_turbo.py new file mode 100644 index 0000000000000000000000000000000000000000..f26f01dc243a54a8f06e8cbae99e5ed6bcb7717e --- /dev/null +++ b/docs/transformers/src/transformers/models/omdet_turbo/configuration_omdet_turbo.py @@ -0,0 +1,293 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""OmDet-Turbo model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ...utils.backbone_utils import verify_backbone_config_arguments +from ..auto import CONFIG_MAPPING + + +logger = logging.get_logger(__name__) + + +class OmDetTurboConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`OmDetTurboForObjectDetection`]. + It is used to instantiate a OmDet-Turbo model according to the specified arguments, defining the model architecture + Instantiating a configuration with the defaults will yield a similar configuration to that of the OmDet-Turbo + [omlab/omdet-turbo-swin-tiny-hf](https://huggingface.co/omlab/omdet-turbo-swin-tiny-hf) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + text_config (`PretrainedConfig`, *optional*): + The configuration of the text backbone. + backbone_config (`PretrainedConfig`, *optional*): + The configuration of the vision backbone. + use_timm_backbone (`bool`, *optional*, defaults to `True`): + Whether to use the timm for the vision backbone. + backbone (`str`, *optional*, defaults to `"swin_tiny_patch4_window7_224"`): + The name of the pretrained vision backbone to use. If `use_pretrained_backbone=False` a randomly initialized + backbone with the same architecture `backbone` is used. + backbone_kwargs (`dict`, *optional*): + Additional kwargs for the vision backbone. + use_pretrained_backbone (`bool`, *optional*, defaults to `False`): + Whether to use a pretrained vision backbone. + apply_layernorm_after_vision_backbone (`bool`, *optional*, defaults to `True`): + Whether to apply layer normalization on the feature maps of the vision backbone output. + image_size (`int`, *optional*, defaults to 640): + The size (resolution) of each image. + disable_custom_kernels (`bool`, *optional*, defaults to `False`): + Whether to disable custom kernels. + layer_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon value for layer normalization. + batch_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon value for batch normalization. + init_std (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + text_projection_in_dim (`int`, *optional*, defaults to 512): + The input dimension for the text projection. + text_projection_out_dim (`int`, *optional*, defaults to 512): + The output dimension for the text projection. + task_encoder_hidden_dim (`int`, *optional*, defaults to 1024): + The feedforward dimension for the task encoder. + class_embed_dim (`int`, *optional*, defaults to 512): + The dimension of the classes embeddings. + class_distance_type (`str`, *optional*, defaults to `"cosine"`): + The type of of distance to compare predicted classes to projected classes embeddings. + Can be `"cosine"` or `"dot"`. + num_queries (`int`, *optional*, defaults to 900): + The number of queries. + csp_activation (`str`, *optional*, defaults to `"silu"`): + The activation function of the Cross Stage Partial (CSP) networks of the encoder. + conv_norm_activation (`str`, *optional*, defaults to `"gelu"`): + The activation function of the ConvNormLayer layers of the encoder. + encoder_feedforward_activation (`str`, *optional*, defaults to `"relu"`): + The activation function for the feedforward network of the encoder. + encoder_feedforward_dropout (`float`, *optional*, defaults to 0.0): + The dropout rate following the activation of the encoder feedforward network. + encoder_dropout (`float`, *optional*, defaults to 0.0): + The dropout rate of the encoder multi-head attention module. + hidden_expansion (`int`, *optional*, defaults to 1): + The hidden expansion of the CSP networks in the encoder. + vision_features_channels (`tuple(int)`, *optional*, defaults to `[256, 256, 256]`): + The projected vision features channels used as inputs for the decoder. + encoder_hidden_dim (`int`, *optional*, defaults to 256): + The hidden dimension of the encoder. + encoder_in_channels (`List(int)`, *optional*, defaults to `[192, 384, 768]`): + The input channels for the encoder. + encoder_projection_indices (`List(int)`, *optional*, defaults to `[2]`): + The indices of the input features projected by each layers. + encoder_attention_heads (`int`, *optional*, defaults to 8): + The number of attention heads for the encoder. + encoder_dim_feedforward (`int`, *optional*, defaults to 2048): + The feedforward dimension for the encoder. + encoder_layers (`int`, *optional*, defaults to 1): + The number of layers in the encoder. + positional_encoding_temperature (`int`, *optional*, defaults to 10000): + The positional encoding temperature in the encoder. + num_feature_levels (`int`, *optional*, defaults to 3): + The number of feature levels for the multi-scale deformable attention module of the decoder. + decoder_hidden_dim (`int`, *optional*, defaults to 256): + The hidden dimension of the decoder. + decoder_num_heads (`int`, *optional*, defaults to 8): + The number of heads for the decoder. + decoder_num_layers (`int`, *optional*, defaults to 6): + The number of layers for the decoder. + decoder_activation (`str`, *optional*, defaults to `"relu"`): + The activation function for the decoder. + decoder_dim_feedforward (`int`, *optional*, defaults to 2048): + The feedforward dimension for the decoder. + decoder_num_points (`int`, *optional*, defaults to 4): + The number of points sampled in the decoder multi-scale deformable attention module. + decoder_dropout (`float`, *optional*, defaults to 0.0): + The dropout rate for the decoder. + eval_size (`Tuple[int, int]`, *optional*): + Height and width used to computes the effective height and width of the position embeddings after taking + into account the stride (see RTDetr). + learn_initial_query (`bool`, *optional*, defaults to `False`): + Whether to learn the initial query. + cache_size (`int`, *optional*, defaults to 100): + The cache size for the classes and prompts caches. + is_encoder_decoder (`bool`, *optional*, defaults to `True`): + Whether the model is used as an encoder-decoder model or not. + kwargs (`Dict[str, Any]`, *optional*): + Additional parameters from the architecture. The values in kwargs will be saved as part of the configuration + and can be used to control the model outputs. + + Examples: + + ```python + >>> from transformers import OmDetTurboConfig, OmDetTurboForObjectDetection + + >>> # Initializing a OmDet-Turbo omlab/omdet-turbo-swin-tiny-hf style configuration + >>> configuration = OmDetTurboConfig() + + >>> # Initializing a model (with random weights) from the omlab/omdet-turbo-swin-tiny-hf style configuration + >>> model = OmDetTurboForObjectDetection(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "omdet-turbo" + attribute_map = { + "encoder_hidden_dim": "d_model", + "num_attention_heads": "encoder_attention_heads", + } + + def __init__( + self, + text_config=None, + backbone_config=None, + use_timm_backbone=True, + backbone="swin_tiny_patch4_window7_224", + backbone_kwargs=None, + use_pretrained_backbone=False, + apply_layernorm_after_vision_backbone=True, + image_size=640, + disable_custom_kernels=False, + layer_norm_eps=1e-5, + batch_norm_eps=1e-5, + init_std=0.02, + text_projection_in_dim=512, + text_projection_out_dim=512, + task_encoder_hidden_dim=1024, + class_embed_dim=512, + class_distance_type="cosine", + num_queries=900, + csp_activation="silu", + conv_norm_activation="gelu", + encoder_feedforward_activation="relu", + encoder_feedforward_dropout=0.0, + encoder_dropout=0.0, + hidden_expansion=1, + vision_features_channels=[256, 256, 256], + encoder_hidden_dim=256, + encoder_in_channels=[192, 384, 768], + encoder_projection_indices=[2], + encoder_attention_heads=8, + encoder_dim_feedforward=2048, + encoder_layers=1, + positional_encoding_temperature=10000, + num_feature_levels=3, + decoder_hidden_dim=256, + decoder_num_heads=8, + decoder_num_layers=6, + decoder_activation="relu", + decoder_dim_feedforward=2048, + decoder_num_points=4, + decoder_dropout=0.0, + eval_size=None, + learn_initial_query=False, + cache_size=100, + is_encoder_decoder=True, + **kwargs, + ): + if use_timm_backbone: + if backbone_config is None: + backbone_kwargs = { + "out_indices": [1, 2, 3], + "img_size": image_size, + "always_partition": True, + } + elif backbone_config is None: + logger.info("`backbone_config` is `None`. Initializing the config with the default `swin` vision config.") + backbone_config = CONFIG_MAPPING["swin"]( + window_size=7, + image_size=image_size, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + out_indices=[2, 3, 4], + ) + elif isinstance(backbone_config, dict): + backbone_model_type = backbone_config.get("model_type") + config_class = CONFIG_MAPPING[backbone_model_type] + backbone_config = config_class.from_dict(backbone_config) + + verify_backbone_config_arguments( + use_timm_backbone=use_timm_backbone, + use_pretrained_backbone=use_pretrained_backbone, + backbone=backbone, + backbone_config=backbone_config, + backbone_kwargs=backbone_kwargs, + ) + + if text_config is None: + logger.info( + "`text_config` is `None`. Initializing the config with the default `clip_text_model` text config." + ) + text_config = CONFIG_MAPPING["clip_text_model"]() + elif isinstance(text_config, dict): + text_model_type = text_config.get("model_type") + text_config = CONFIG_MAPPING[text_model_type](**text_config) + + if class_distance_type not in ["cosine", "dot"]: + raise ValueError( + f"Invalid `class_distance_type`. It should be either `cosine` or `dot`, but got {class_distance_type}." + ) + + self.text_config = text_config + self.backbone_config = backbone_config + self.use_timm_backbone = use_timm_backbone + self.backbone = backbone + self.backbone_kwargs = backbone_kwargs + self.use_pretrained_backbone = use_pretrained_backbone + self.apply_layernorm_after_vision_backbone = apply_layernorm_after_vision_backbone + self.image_size = image_size + self.disable_custom_kernels = disable_custom_kernels + self.layer_norm_eps = layer_norm_eps + self.batch_norm_eps = batch_norm_eps + self.init_std = init_std + self.text_projection_in_dim = text_projection_in_dim + self.text_projection_out_dim = text_projection_out_dim + self.task_encoder_hidden_dim = task_encoder_hidden_dim + self.class_embed_dim = class_embed_dim + self.class_distance_type = class_distance_type + self.num_queries = num_queries + self.csp_activation = csp_activation + self.conv_norm_activation = conv_norm_activation + self.encoder_feedforward_activation = encoder_feedforward_activation + self.encoder_feedforward_dropout = encoder_feedforward_dropout + self.encoder_dropout = encoder_dropout + self.hidden_expansion = hidden_expansion + self.vision_features_channels = vision_features_channels + self.encoder_hidden_dim = encoder_hidden_dim + self.encoder_in_channels = encoder_in_channels + self.encoder_projection_indices = encoder_projection_indices + self.encoder_attention_heads = encoder_attention_heads + self.encoder_dim_feedforward = encoder_dim_feedforward + self.encoder_layers = encoder_layers + self.positional_encoding_temperature = positional_encoding_temperature + self.num_feature_levels = num_feature_levels + self.decoder_hidden_dim = decoder_hidden_dim + self.decoder_num_heads = decoder_num_heads + self.decoder_num_layers = decoder_num_layers + self.decoder_activation = decoder_activation + self.decoder_dim_feedforward = decoder_dim_feedforward + self.decoder_num_points = decoder_num_points + self.decoder_dropout = decoder_dropout + self.eval_size = eval_size + self.learn_initial_query = learn_initial_query + self.cache_size = cache_size + self.is_encoder_decoder = is_encoder_decoder + + super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs) + + +__all__ = ["OmDetTurboConfig"] diff --git a/docs/transformers/src/transformers/models/omdet_turbo/convert_omdet_turbo_to_hf.py b/docs/transformers/src/transformers/models/omdet_turbo/convert_omdet_turbo_to_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..2e515e983408b887a4827c6849a8a7f601361614 --- /dev/null +++ b/docs/transformers/src/transformers/models/omdet_turbo/convert_omdet_turbo_to_hf.py @@ -0,0 +1,349 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert OmDet-Turbo checkpoints from the original repository. + +URL: https://github.com/om-ai-lab/OmDet""" + +import argparse + +import requests +import torch +from PIL import Image + +from transformers import ( + CLIPTokenizer, + DetrImageProcessor, + OmDetTurboConfig, + OmDetTurboForObjectDetection, + OmDetTurboProcessor, +) + + +IMAGE_MEAN = [123.675, 116.28, 103.53] +IMAGE_STD = [58.395, 57.12, 57.375] + + +def get_omdet_turbo_config(model_name, use_timm_backbone): + if "tiny" in model_name: + window_size = 7 + embed_dim = 96 + depths = (2, 2, 6, 2) + num_heads = (3, 6, 12, 24) + image_size = 640 + else: + raise ValueError("Model not supported, only supports tiny variant.") + + config = OmDetTurboConfig( + backbone_window_size=window_size, + backbone_image_size=image_size, + backbone_embed_dim=embed_dim, + backbone_depths=depths, + backbone_num_heads=num_heads, + backbone_out_indices=(1, 2, 3), + text_config={"model_type": "clip_text_model"}, + use_timm_backbone=use_timm_backbone, + backbone="swin_tiny_patch4_window7_224" if use_timm_backbone else None, + apply_layernorm_after_vision_backbone=True if use_timm_backbone else False, + use_pretrained_backbone=False, + ) + + return config + + +def create_rename_keys_vision(state_dict, config): + rename_keys = [] + # fmt: off + ########################################## VISION BACKBONE - START + for layer_name in state_dict.keys(): + if layer_name.startswith("backbone") and not layer_name.startswith("backbone.norm"): + if config.use_timm_backbone: + layer_name_replace = layer_name.replace("backbone", "vision_backbone.vision_backbone._backbone") + layer_name_replace = layer_name_replace.replace(".layers.", ".layers_") + if "downsample" in layer_name: + # get layer number + layer_num = int(layer_name.split(".")[2]) + layer_name_replace = layer_name_replace.replace(f"{layer_num}.downsample", f"{layer_num+1}.downsample") + else: + layer_name_replace = layer_name.replace("backbone", "vision_backbone.vision_backbone") + layer_name_replace = layer_name_replace.replace("patch_embed.proj", "embeddings.patch_embeddings.projection") + layer_name_replace = layer_name_replace.replace("patch_embed.norm", "embeddings.norm") + if layer_name.startswith("backbone.layers"): + layer_name_replace = layer_name_replace.replace("norm1", "layernorm_before") + layer_name_replace = layer_name_replace.replace("norm2", "layernorm_after") + layer_name_replace = layer_name_replace.replace("attn.proj", "attention.output.dense") + layer_name_replace = layer_name_replace.replace("mlp.fc1", "intermediate.dense") + layer_name_replace = layer_name_replace.replace("mlp.fc2", "output.dense") + layer_name_replace = layer_name_replace.replace(".layers.", ".encoder.layers.") + layer_name_replace = layer_name_replace.replace(".attn.", ".attention.self.") + elif layer_name.startswith("backbone.norm"): + layer_num = int(layer_name.split("norm")[1].split(".")[0]) + if config.use_timm_backbone: + layer_name_replace = layer_name.replace("backbone", "vision_backbone") + layer_name_replace = layer_name_replace.replace(f"norm{layer_num}", f"layer_norms.{layer_num-1}") + else: + layer_name_replace = layer_name.replace(f"backbone.norm{layer_num}", f"vision_backbone.vision_backbone.hidden_states_norms.stage{layer_num+1}") + else: + continue + rename_keys.append((layer_name, layer_name_replace)) + ########################################## VISION BACKBONE - END + + ########################################## ENCODER - START + for layer_name, params in state_dict.items(): + if "neck" in layer_name: + layer_name_replace = layer_name.replace("neck", "encoder") + layer_name_replace = layer_name_replace.replace("input_proj", "channel_projection_layers") + if "fpn_blocks" in layer_name or "pan_blocks" in layer_name or "lateral_convs" in layer_name or "downsample_convs" in layer_name: + layer_name_replace = layer_name_replace.replace(".m.", ".bottlenecks.") + layer_name_replace = layer_name_replace.replace(".cv", ".conv") + layer_name_replace = layer_name_replace.replace(".bn", ".norm") + if "encoder_layer" in layer_name: + layer_name_replace = layer_name_replace.replace("encoder_layer", "encoder.0.layers.0") + layer_name_replace = layer_name_replace.replace(".linear", ".fc") + layer_name_replace = layer_name_replace.replace("norm1", "self_attn_layer_norm") + layer_name_replace = layer_name_replace.replace("norm2", "final_layer_norm") + rename_keys.append((layer_name, layer_name_replace)) + ########################################## ENCODER - END + + ########################################## DECODER - START + for layer_name, params in state_dict.items(): + if layer_name.startswith("decoder"): + layer_name_replace = layer_name.replace("decoder.decoder.layers", "decoder.layers") + layer_name_replace = layer_name_replace.replace("input_proj", "channel_projection_layers") + layer_name_replace = layer_name_replace.replace("query_pos_head", "query_position_head") + layer_name_replace = layer_name_replace.replace("enc_bbox_head", "encoder_bbox_head") + layer_name_replace = layer_name_replace.replace("enc_output", "encoder_vision_features") + layer_name_replace = layer_name_replace.replace("dec_score_head", "decoder_class_head") + layer_name_replace = layer_name_replace.replace("dec_bbox_head", "decoder_bbox_head") + layer_name_replace = layer_name_replace.replace("enc_score_head", "encoder_class_head") + rename_keys.append((layer_name, layer_name_replace)) + ########################################## DECODER - END + # fmt: on + return rename_keys + + +def create_rename_keys_language(state_dict): + rename_keys = [] + # fmt: off + for layer_name in state_dict.keys(): + if layer_name.startswith("language_backbone") and not layer_name.startswith("language_backbone.text_projection"): + layer_name_replace = layer_name.replace("language_backbone", "language_backbone.model.text_model") + layer_name_replace = layer_name_replace.replace("transformer.resblocks", "encoder.layers") + layer_name_replace = layer_name_replace.replace("token_embedding", "embeddings.token_embedding") + layer_name_replace = layer_name_replace.replace("positional_embedding", "embeddings.position_embedding.weight") + layer_name_replace = layer_name_replace.replace(".attn", ".self_attn") + layer_name_replace = layer_name_replace.replace(".mlp.c_fc", ".mlp.fc1") + layer_name_replace = layer_name_replace.replace(".mlp.c_proj", ".mlp.fc2") + layer_name_replace = layer_name_replace.replace("ln_final", "final_layer_norm") + layer_name_replace = layer_name_replace.replace(".ln_", ".layer_norm") + rename_keys.append((layer_name, layer_name_replace)) + # fmt: on + return rename_keys + + +def rename_key(dct, old, new): + val = dct.pop(old) + dct[new] = val + + +# we split up the matrix of each encoder layer into queries, keys and values +def read_in_q_k_v_vision(state_dict, config): + state_dict_keys = list(state_dict.keys()) + for layer_name_vision in state_dict_keys: + if layer_name_vision.startswith("vision_backbone") and "qkv" in layer_name_vision: + layer_num = int(layer_name_vision.split(".")[4]) + hidden_size = config.backbone_config.embed_dim * 2**layer_num + if "weight" in layer_name_vision: + in_proj_weight = state_dict.pop(layer_name_vision) + state_dict[layer_name_vision.replace("qkv.weight", "key.weight")] = in_proj_weight[:hidden_size, :] + state_dict[layer_name_vision.replace("qkv.weight", "query.weight")] = in_proj_weight[ + hidden_size : hidden_size * 2, : + ] + state_dict[layer_name_vision.replace("qkv.weight", "value.weight")] = in_proj_weight[-hidden_size:, :] + elif "bias" in layer_name_vision: + in_proj_bias = state_dict.pop(layer_name_vision) + state_dict[layer_name_vision.replace("qkv.bias", "key.bias")] = in_proj_bias[:hidden_size] + state_dict[layer_name_vision.replace("qkv.bias", "query.bias")] = in_proj_bias[ + hidden_size : hidden_size * 2 + ] + state_dict[layer_name_vision.replace("qkv.bias", "value.bias")] = in_proj_bias[-hidden_size:] + + +def read_in_q_k_v_text(state_dict, config): + state_dict_keys = list(state_dict.keys()) + hidden_size = config.text_config.projection_dim + for layer_name_text in state_dict_keys: + if layer_name_text.startswith("language_backbone") and "in_proj" in layer_name_text: + if "weight" in layer_name_text: + in_proj_weight = state_dict.pop(layer_name_text) + state_dict[layer_name_text.replace("in_proj_weight", "q_proj.weight")] = in_proj_weight[ + :hidden_size, : + ] + state_dict[layer_name_text.replace("in_proj_weight", "k_proj.weight")] = in_proj_weight[ + hidden_size : hidden_size * 2, : + ] + state_dict[layer_name_text.replace("in_proj_weight", "v_proj.weight")] = in_proj_weight[ + -hidden_size:, : + ] + elif "bias" in layer_name_text: + in_proj_bias = state_dict.pop(layer_name_text) + state_dict[layer_name_text.replace("in_proj_bias", "q_proj.bias")] = in_proj_bias[:hidden_size] + state_dict[layer_name_text.replace("in_proj_bias", "k_proj.bias")] = in_proj_bias[ + hidden_size : hidden_size * 2 + ] + state_dict[layer_name_text.replace("in_proj_bias", "v_proj.bias")] = in_proj_bias[-hidden_size:] + + +def read_in_q_k_v_encoder(state_dict, config): + embed_dim = config.encoder_hidden_dim + # read in weights + bias of input projection layer (in original implementation, this is a single matrix + bias) + in_proj_weight = state_dict.pop("encoder.encoder.0.layers.0.self_attn.in_proj_weight") + in_proj_bias = state_dict.pop("encoder.encoder.0.layers.0.self_attn.in_proj_bias") + # next, add query, keys and values (in that order) to the state dict + state_dict["encoder.encoder.0.layers.0.self_attn.query.weight"] = in_proj_weight[:embed_dim, :] + state_dict["encoder.encoder.0.layers.0.self_attn.query.bias"] = in_proj_bias[:embed_dim] + state_dict["encoder.encoder.0.layers.0.self_attn.key.weight"] = in_proj_weight[embed_dim : embed_dim * 2, :] + state_dict["encoder.encoder.0.layers.0.self_attn.key.bias"] = in_proj_bias[embed_dim : embed_dim * 2] + state_dict["encoder.encoder.0.layers.0.self_attn.value.weight"] = in_proj_weight[-embed_dim:, :] + state_dict["encoder.encoder.0.layers.0.self_attn.value.bias"] = in_proj_bias[-embed_dim:] + + +def read_in_q_k_v_decoder(state_dict, config): + for layer_num in range(config.decoder_num_layers): + embed_dim = config.decoder_hidden_dim + # read in weights + bias of input projection layer (in original implementation, this is a single matrix + bias) + in_proj_weight = state_dict.pop(f"decoder.layers.{layer_num}.self_attn.in_proj_weight") + in_proj_bias = state_dict.pop(f"decoder.layers.{layer_num}.self_attn.in_proj_bias") + # next, add query, keys and values (in that order) to the state dict + state_dict[f"decoder.layers.{layer_num}.self_attn.query.weight"] = in_proj_weight[:embed_dim, :] + state_dict[f"decoder.layers.{layer_num}.self_attn.query.bias"] = in_proj_bias[:embed_dim] + state_dict[f"decoder.layers.{layer_num}.self_attn.key.weight"] = in_proj_weight[embed_dim : embed_dim * 2, :] + state_dict[f"decoder.layers.{layer_num}.self_attn.key.bias"] = in_proj_bias[embed_dim : embed_dim * 2] + state_dict[f"decoder.layers.{layer_num}.self_attn.value.weight"] = in_proj_weight[-embed_dim:, :] + state_dict[f"decoder.layers.{layer_num}.self_attn.value.bias"] = in_proj_bias[-embed_dim:] + + +def run_test(model, processor): + # We will verify our results on an image of cute cats + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + image = Image.open(requests.get(url, stream=True).raw).convert("RGB") + + classes = ["cat", "remote"] + task = "Detect {}.".format(", ".join(classes)) + inputs = processor(image, text=classes, task=task, return_tensors="pt") + + # Running forward + with torch.no_grad(): + outputs = model(**inputs) + + predicted_slice = outputs[1][0, :3, :3] + print(predicted_slice) + expected_slice = torch.tensor([[0.9427, -2.5958], [0.2105, -3.4569], [-2.6364, -4.1610]]) + + assert torch.allclose(predicted_slice, expected_slice, atol=1e-4) + print("Looks ok!") + + +@torch.no_grad() +def convert_omdet_turbo_checkpoint(args): + model_name = args.model_name + pytorch_dump_folder_path = args.pytorch_dump_folder_path + push_to_hub = args.push_to_hub + use_timm_backbone = args.use_timm_backbone + + checkpoint_mapping = { + "omdet-turbo-tiny": [ + "https://huggingface.co/omlab/OmDet-Turbo_tiny_SWIN_T/resolve/main/OmDet-Turbo_tiny_SWIN_T.pth", + "https://huggingface.co/omlab/OmDet-Turbo_tiny_SWIN_T/resolve/main/ViT-B-16.pt", + ], + } + # Define default OmDetTurbo configuation + config = get_omdet_turbo_config(model_name, use_timm_backbone) + + # Load original checkpoint + checkpoint_url = checkpoint_mapping[model_name] + original_state_dict_vision = torch.hub.load_state_dict_from_url(checkpoint_url[0], map_location="cpu")["model"] + original_state_dict_vision = {k.replace("module.", ""): v for k, v in original_state_dict_vision.items()} + + # Rename keys + new_state_dict = original_state_dict_vision.copy() + rename_keys_vision = create_rename_keys_vision(new_state_dict, config) + + rename_keys_language = create_rename_keys_language(new_state_dict) + + for src, dest in rename_keys_vision: + rename_key(new_state_dict, src, dest) + + for src, dest in rename_keys_language: + rename_key(new_state_dict, src, dest) + + if not use_timm_backbone: + read_in_q_k_v_vision(new_state_dict, config) + read_in_q_k_v_text(new_state_dict, config) + read_in_q_k_v_encoder(new_state_dict, config) + read_in_q_k_v_decoder(new_state_dict, config) + # add "model" prefix to all keys + new_state_dict = {f"model.{k}": v for k, v in new_state_dict.items()} + + # Load HF model + model = OmDetTurboForObjectDetection(config) + model.eval() + missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False) + print("Missing keys:", missing_keys) + print("Unexpected keys:", unexpected_keys) + + image_processor = DetrImageProcessor( + size={"height": config.backbone_image_size, "width": config.backbone_image_size}, + do_rescale=False, + image_mean=IMAGE_MEAN, + image_std=IMAGE_STD, + do_pad=False, + ) + tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") + processor = OmDetTurboProcessor(image_processor=image_processor, tokenizer=tokenizer) + + # end-to-end consistency test + run_test(model, processor) + + if pytorch_dump_folder_path is not None: + model.save_pretrained(pytorch_dump_folder_path) + processor.save_pretrained(pytorch_dump_folder_path) + + if push_to_hub: + model.push_to_hub(f"omlab/{model_name}") + processor.push_to_hub(f"omlab/{model_name}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--model_name", + default="omdet-turbo-tiny", + type=str, + choices=["omdet-turbo-tiny"], + help="Name of the OmDetTurbo model you'd like to convert.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory." + ) + parser.add_argument( + "--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub." + ) + parser.add_argument( + "--use_timm_backbone", action="store_true", help="Whether or not to use timm backbone for vision backbone." + ) + + args = parser.parse_args() + convert_omdet_turbo_checkpoint(args) diff --git a/docs/transformers/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py b/docs/transformers/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py new file mode 100644 index 0000000000000000000000000000000000000000..67143cff68c7daacd18bd0299231cef215ec23f1 --- /dev/null +++ b/docs/transformers/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py @@ -0,0 +1,1711 @@ +# coding=utf-8 +# Copyright 2024 Om Research Lab and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch OmDet-Turbo model.""" + +import math +import warnings +from collections import OrderedDict +from dataclasses import dataclass +from functools import lru_cache +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from torch import Tensor, nn + +from ...activations import ACT2CLS, ACT2FN +from ...file_utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + replace_return_docstrings, +) +from ...integrations import use_kernel_forward_from_hub +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask +from ...modeling_utils import PreTrainedModel +from ...utils import logging +from ...utils.backbone_utils import load_backbone +from ..auto import AutoModel +from .configuration_omdet_turbo import OmDetTurboConfig + + +logger = logging.get_logger(__name__) +_CONFIG_FOR_DOC = "OmDetTurboConfig" + + +@dataclass +class OmDetTurboEncoderOutput(ModelOutput): + """ + Base class for outputs of the OmDetTurboHybridEncoder. + + Args: + last_hidden_state (`torch.FloatTensor`): + Last hidden states of the encoder. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + extracted_states (`Tuple[torch.FloatTensor]`): + The extracted states from the Feature Pyramid Network (FPN) and Path Aggregation Network (PAN) of the encoder. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + extracted_states: Tuple[torch.FloatTensor] = None + + +@dataclass +class OmDetTurboDecoderOutput(ModelOutput): + """ + Base class for outputs of the OmDetTurboDecoder. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the decoder. + decoder_coords (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): + The predicted coordinates of the objects. + decoder_classes (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes)`): + The predicted classes of the objects. + encoder_coord_logits (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): + The predicted coordinates of the objects from the encoder. + encoder_class_logits (`Tuple[torch.FloatTensor]`) of shape `(batch_size, num_queries, num_classes)`: + The predicted class of the objects from the encoder. + init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): + The initial reference points. + intermediate_reference_points (`Tuple[Tuple[torch.FloatTensor]]`): + The intermediate reference points. + hidden_states (`Optional[Tuple[torch.FloatTensor]]`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer + plus the initial embedding outputs. + attentions (`Optional[Tuple[Tuple[torch.FloatTensor]]]`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of tuples of `torch.FloatTensor` (one for attention for each layer) of shape `(batch_size, num_heads, + sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the + weighted average in the self-attention, cross-attention and multi-scale deformable attention heads. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + decoder_coords: Optional[torch.FloatTensor] = None + decoder_classes: Optional[torch.FloatTensor] = None + encoder_coord_logits: Optional[torch.FloatTensor] = None + encoder_class_logits: Tuple[torch.FloatTensor] = None + init_reference_points: Optional[torch.FloatTensor] = None + intermediate_reference_points: Tuple[Tuple[torch.FloatTensor]] = None + + +@dataclass +class OmDetTurboObjectDetectionOutput(ModelOutput): + """ + Output type of [`OmDetTurboObjectDetectionOutput`]. + + Args: + loss (`torch.FloatTensor`): + The loss value. + decoder_coord_logits (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): + The predicted coordinates logits of the objects. + decoder_class_logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes)`): + The predicted class of the objects. + init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): + The initial reference points. + intermediate_reference_points (`Tuple[Tuple[torch.FloatTensor]]`): + The intermediate reference points. + encoder_coord_logits (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): + The predicted coordinates of the objects from the encoder. + encoder_class_logits (`Tuple[torch.FloatTensor]`): + The predicted class of the objects from the encoder. + encoder_extracted_states (`torch.FloatTensor`): + The extracted states from the Feature Pyramid Network (FPN) and Path Aggregation Network (PAN) of the encoder. + decoder_hidden_states (`Tuple[torch.FloatTensor]`, *optional*): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer + plus the initial embedding outputs. + decoder_attentions (`Tuple[Tuple[torch.FloatTensor]]`, *optional*): + Tuple of tuples of `torch.FloatTensor` (one for attention for each layer) of shape `(batch_size, num_heads, + sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the + weighted average in the self-attention, cross-attention and multi-scale deformable attention heads. + encoder_hidden_states (`Tuple[torch.FloatTensor]`, *optional*): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer + plus the initial embedding outputs. + encoder_attentions (`Tuple[Tuple[torch.FloatTensor]]`, *optional*): + Tuple of tuples of `torch.FloatTensor` (one for attention for each layer) of shape `(batch_size, num_heads, + sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the + weighted average in the self-attention, cross-attention and multi-scale deformable attention heads. + classes_structure (`torch.LongTensor`, *optional*): + The number of queried classes for each image. + """ + + loss: Optional[torch.FloatTensor] = None + decoder_coord_logits: Optional[torch.FloatTensor] = None + decoder_class_logits: Optional[torch.FloatTensor] = None + init_reference_points: Optional[torch.FloatTensor] = None + intermediate_reference_points: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + encoder_coord_logits: Optional[torch.FloatTensor] = None + encoder_class_logits: Tuple[torch.FloatTensor] = None + encoder_extracted_states: Optional[torch.FloatTensor] = None + decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + encoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + classes_structure: Optional[torch.LongTensor] = None + + +@use_kernel_forward_from_hub("MultiScaleDeformableAttention") +# Copied from transformers.models.deformable_detr.modeling_deformable_detr.MultiScaleDeformableAttention +class MultiScaleDeformableAttention(nn.Module): + def forward( + self, + value: Tensor, + value_spatial_shapes: Tensor, + value_spatial_shapes_list: List[Tuple], + level_start_index: Tensor, + sampling_locations: Tensor, + attention_weights: Tensor, + im2col_step: int, + ): + batch_size, _, num_heads, hidden_dim = value.shape + _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape + value_list = value.split([height * width for height, width in value_spatial_shapes_list], dim=1) + sampling_grids = 2 * sampling_locations - 1 + sampling_value_list = [] + for level_id, (height, width) in enumerate(value_spatial_shapes_list): + # batch_size, height*width, num_heads, hidden_dim + # -> batch_size, height*width, num_heads*hidden_dim + # -> batch_size, num_heads*hidden_dim, height*width + # -> batch_size*num_heads, hidden_dim, height, width + value_l_ = ( + value_list[level_id] + .flatten(2) + .transpose(1, 2) + .reshape(batch_size * num_heads, hidden_dim, height, width) + ) + # batch_size, num_queries, num_heads, num_points, 2 + # -> batch_size, num_heads, num_queries, num_points, 2 + # -> batch_size*num_heads, num_queries, num_points, 2 + sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1) + # batch_size*num_heads, hidden_dim, num_queries, num_points + sampling_value_l_ = nn.functional.grid_sample( + value_l_, + sampling_grid_l_, + mode="bilinear", + padding_mode="zeros", + align_corners=False, + ) + sampling_value_list.append(sampling_value_l_) + # (batch_size, num_queries, num_heads, num_levels, num_points) + # -> (batch_size, num_heads, num_queries, num_levels, num_points) + # -> (batch_size, num_heads, 1, num_queries, num_levels*num_points) + attention_weights = attention_weights.transpose(1, 2).reshape( + batch_size * num_heads, 1, num_queries, num_levels * num_points + ) + output = ( + (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights) + .sum(-1) + .view(batch_size, num_heads * hidden_dim, num_queries) + ) + return output.transpose(1, 2).contiguous() + + +class OmDetTurboLRUCache: + def __init__(self, capacity: int): + self.cache = OrderedDict() + self.capacity = capacity + self.current_load = 0 + + def has(self, key) -> bool: + return key in self.cache + + def get(self, key): + """ + Get the value of the key if the key exists in the cache, otherwise return None. + Move the key to the end of the cache to show that it was recently used. + """ + if key not in self.cache: + return None + self.cache.move_to_end(key) + return self.cache[key] + + def put(self, key, value) -> None: + """ + Add the key-value pair to the cache. + Move the key to the end of the cache to show that it was recently used. + If the cache is full, remove the first key (least recently used). + """ + if key not in self.cache: + self.current_load += 1 + if self.current_load > self.capacity: + self.cache.popitem(last=False) + self.current_load -= 1 + + self.cache[key] = value + self.cache.move_to_end(key) + + +class OmDetTurboLanguageBackbone(nn.Module): + def __init__(self, config: OmDetTurboConfig): + super().__init__() + self.model = AutoModel.from_config(config.text_config) + self.text_projection = nn.Parameter(torch.zeros(config.text_projection_in_dim, config.text_projection_out_dim)) + + def forward(self, hidden_states, mask=None, encode_type="task"): + text_outputs = self.model(hidden_states) + pooled_output = text_outputs[0] + if encode_type == "task": + if mask is None: + raise ValueError("mask is required for task encoding") + max_len = (mask != 0).sum(1).max().item() + truncated_mask = mask[:, :max_len] + truncated_output = pooled_output[:, :max_len, :] + return truncated_output.transpose(0, 1), truncated_mask + elif encode_type == "class": + max_pooled_output = pooled_output[torch.arange(pooled_output.shape[0]), hidden_states.argmax(dim=-1)] + projected_output = max_pooled_output @ self.text_projection + return projected_output + else: + raise ValueError(f"encode_type {encode_type} is not supported") + + +class OmDetTurboVisionBackbone(nn.Module): + def __init__(self, config: OmDetTurboConfig): + super().__init__() + self.apply_layernorm_after_vision_backbone = config.apply_layernorm_after_vision_backbone + self.vision_backbone = load_backbone(config) + self.layer_norms = nn.ModuleList( + [nn.LayerNorm(in_channel_dim, eps=config.layer_norm_eps) for in_channel_dim in config.encoder_in_channels] + ) + + def forward(self, pixel_values): + outputs = self.vision_backbone(pixel_values).feature_maps + if self.apply_layernorm_after_vision_backbone: + outputs = [ + layer_norm(output).permute(0, 3, 1, 2).contiguous() + for layer_norm, output in zip(self.layer_norms, outputs) + ] + + return outputs + + +# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrMultiscaleDeformableAttention with DeformableDetr->OmDetTurbo, Deformable DETR->OmDet-Turbo +class OmDetTurboMultiscaleDeformableAttention(nn.Module): + """ + Multiscale deformable attention as proposed in Deformable DETR. + """ + + def __init__(self, config: OmDetTurboConfig, num_heads: int, n_points: int): + super().__init__() + + self.attn = MultiScaleDeformableAttention() + + if config.d_model % num_heads != 0: + raise ValueError( + f"embed_dim (d_model) must be divisible by num_heads, but got {config.d_model} and {num_heads}" + ) + dim_per_head = config.d_model // num_heads + # check if dim_per_head is power of 2 + if not ((dim_per_head & (dim_per_head - 1) == 0) and dim_per_head != 0): + warnings.warn( + "You'd better set embed_dim (d_model) in OmDetTurboMultiscaleDeformableAttention to make the" + " dimension of each attention head a power of 2 which is more efficient in the authors' CUDA" + " implementation." + ) + + self.im2col_step = 64 + + self.d_model = config.d_model + self.n_levels = config.num_feature_levels + self.n_heads = num_heads + self.n_points = n_points + + self.sampling_offsets = nn.Linear(config.d_model, num_heads * self.n_levels * n_points * 2) + self.attention_weights = nn.Linear(config.d_model, num_heads * self.n_levels * n_points) + self.value_proj = nn.Linear(config.d_model, config.d_model) + self.output_proj = nn.Linear(config.d_model, config.d_model) + + self.disable_custom_kernels = config.disable_custom_kernels + + def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]): + return tensor if position_embeddings is None else tensor + position_embeddings + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states=None, + encoder_attention_mask=None, + position_embeddings: Optional[torch.Tensor] = None, + reference_points=None, + spatial_shapes=None, + spatial_shapes_list=None, + level_start_index=None, + output_attentions: bool = False, + ): + # add position embeddings to the hidden states before projecting to queries and keys + if position_embeddings is not None: + hidden_states = self.with_pos_embed(hidden_states, position_embeddings) + + batch_size, num_queries, _ = hidden_states.shape + batch_size, sequence_length, _ = encoder_hidden_states.shape + # Ignore copy + total_elements = sum([shape[0] * shape[1] for shape in spatial_shapes_list]) + if total_elements != sequence_length: + raise ValueError( + "Make sure to align the spatial shapes with the sequence length of the encoder hidden states" + ) + + value = self.value_proj(encoder_hidden_states) + if attention_mask is not None: + # we invert the attention_mask + value = value.masked_fill(~attention_mask[..., None], float(0)) + value = value.view(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads) + sampling_offsets = self.sampling_offsets(hidden_states).view( + batch_size, num_queries, self.n_heads, self.n_levels, self.n_points, 2 + ) + attention_weights = self.attention_weights(hidden_states).view( + batch_size, num_queries, self.n_heads, self.n_levels * self.n_points + ) + attention_weights = F.softmax(attention_weights, -1).view( + batch_size, num_queries, self.n_heads, self.n_levels, self.n_points + ) + # batch_size, num_queries, n_heads, n_levels, n_points, 2 + num_coordinates = reference_points.shape[-1] + if num_coordinates == 2: + offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1) + sampling_locations = ( + reference_points[:, :, None, :, None, :] + + sampling_offsets / offset_normalizer[None, None, None, :, None, :] + ) + elif num_coordinates == 4: + sampling_locations = ( + reference_points[:, :, None, :, None, :2] + + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 + ) + else: + raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}") + + output = self.attn( + value, + spatial_shapes, + spatial_shapes_list, + level_start_index, + sampling_locations, + attention_weights, + self.im2col_step, + ) + + output = self.output_proj(output) + + return output, attention_weights + + +# Copied from transformers.models.rt_detr.modeling_rt_detr.RTDetrConvNormLayer with RTDetr->OmDetTurbo +class OmDetTurboConvNormLayer(nn.Module): + def __init__(self, config, in_channels, out_channels, kernel_size, stride, padding=None, activation=None): + super().__init__() + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride, + padding=(kernel_size - 1) // 2 if padding is None else padding, + bias=False, + ) + self.norm = nn.BatchNorm2d(out_channels, config.batch_norm_eps) + self.activation = nn.Identity() if activation is None else ACT2CLS[activation]() + + def forward(self, hidden_state): + hidden_state = self.conv(hidden_state) + hidden_state = self.norm(hidden_state) + hidden_state = self.activation(hidden_state) + return hidden_state + + +# Copied from transformers.models.rt_detr.modeling_rt_detr.RTDetrRepVggBlock with RTDetr->OmDetTurbo, activation_function->csp_activation +class OmDetTurboRepVggBlock(nn.Module): + """ + RepVGG architecture block introduced by the work "RepVGG: Making VGG-style ConvNets Great Again". + """ + + def __init__(self, config: OmDetTurboConfig): + super().__init__() + + activation = config.csp_activation + hidden_channels = int(config.encoder_hidden_dim * config.hidden_expansion) + self.conv1 = OmDetTurboConvNormLayer(config, hidden_channels, hidden_channels, 3, 1, padding=1) + self.conv2 = OmDetTurboConvNormLayer(config, hidden_channels, hidden_channels, 1, 1, padding=0) + self.activation = nn.Identity() if activation is None else ACT2CLS[activation]() + + def forward(self, x): + y = self.conv1(x) + self.conv2(x) + return self.activation(y) + + +# Copied from transformers.models.rt_detr.modeling_rt_detr.RTDetrCSPRepLayer with RTDetr->OmDetTurbo, activation_function->csp_activation +class OmDetTurboCSPRepLayer(nn.Module): + """ + Cross Stage Partial (CSP) network layer with RepVGG blocks. + """ + + def __init__(self, config: OmDetTurboConfig): + super().__init__() + + in_channels = config.encoder_hidden_dim * 2 + out_channels = config.encoder_hidden_dim + num_blocks = 3 + activation = config.csp_activation + + hidden_channels = int(out_channels * config.hidden_expansion) + self.conv1 = OmDetTurboConvNormLayer(config, in_channels, hidden_channels, 1, 1, activation=activation) + self.conv2 = OmDetTurboConvNormLayer(config, in_channels, hidden_channels, 1, 1, activation=activation) + self.bottlenecks = nn.Sequential(*[OmDetTurboRepVggBlock(config) for _ in range(num_blocks)]) + if hidden_channels != out_channels: + self.conv3 = OmDetTurboConvNormLayer(config, hidden_channels, out_channels, 1, 1, activation=activation) + else: + self.conv3 = nn.Identity() + + def forward(self, hidden_state): + hidden_state_1 = self.conv1(hidden_state) + hidden_state_1 = self.bottlenecks(hidden_state_1) + hidden_state_2 = self.conv2(hidden_state) + return self.conv3(hidden_state_1 + hidden_state_2) + + +class OmDetTurboMultiheadAttention(nn.Module): + """Equivalent implementation of nn.MultiheadAttention with `batch_first=True`.""" + + def __init__(self, config, hidden_size, num_attention_heads, dropout): + super().__init__() + if hidden_size % num_attention_heads != 0: + raise ValueError( + f"The hidden size ({hidden_size}) is not a multiple of the number of attention " + f"heads ({num_attention_heads})" + ) + self.num_attention_heads = num_attention_heads + self.attention_head_size = int(hidden_size / num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.query = nn.Linear(hidden_size, self.all_head_size) + self.key = nn.Linear(hidden_size, self.all_head_size) + self.value = nn.Linear(hidden_size, self.all_head_size) + self.out_proj = nn.Linear(hidden_size, hidden_size) + self.dropout = nn.Dropout(dropout) + + def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + queries: torch.Tensor, + keys: torch.Tensor, + values: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + query_layer = self.transpose_for_scores(self.query(queries)) + key_layer = self.transpose_for_scores(self.key(keys)) + value_layer = self.transpose_for_scores(self.value(values)) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + + if attention_mask is not None: + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + context_layer = self.out_proj(context_layer) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +class OmDetTurboEncoderLayer(nn.Module): + def __init__(self, config: OmDetTurboConfig): + super().__init__() + self.self_attn = OmDetTurboMultiheadAttention( + config, + hidden_size=config.encoder_hidden_dim, + num_attention_heads=config.num_attention_heads, + dropout=config.encoder_dropout, + ) + self.self_attn_layer_norm = nn.LayerNorm(config.encoder_hidden_dim, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.encoder_dropout) + self.activation_fn = ACT2FN[config.encoder_feedforward_activation] + self.encoder_feedforward_dropout = nn.Dropout(config.encoder_feedforward_dropout) + self.fc1 = nn.Linear(config.encoder_hidden_dim, config.encoder_dim_feedforward) + self.fc2 = nn.Linear(config.encoder_dim_feedforward, config.encoder_hidden_dim) + self.final_layer_norm = nn.LayerNorm(config.encoder_hidden_dim, eps=config.layer_norm_eps) + + @staticmethod + def with_pos_embed(tensor, pos_embed): + return tensor if pos_embed is None else tensor + pos_embed + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + position_embeddings: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ): + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative + values. + position_embeddings (`torch.FloatTensor`, *optional*): + Object queries (also called content embeddings), to be added to the hidden states. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + query = key = self.with_pos_embed(hidden_states, position_embeddings) + + hidden_states = self.self_attn( + queries=query, + keys=key, + values=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_states, attentions = hidden_states if output_attentions else (hidden_states[0], None) + hidden_states = self.dropout(hidden_states) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = self.encoder_feedforward_dropout(hidden_states) + hidden_states = self.fc2(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + if self.training: + if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + if output_attentions: + return hidden_states, attentions + + return (hidden_states,) + + +class OmDetTurboEncoder(nn.Module): + def __init__(self, config: OmDetTurboConfig): + super().__init__() + + self.layers = nn.ModuleList([OmDetTurboEncoderLayer(config) for _ in range(config.encoder_layers)]) + + def forward( + self, src, src_mask=None, pos_embed=None, output_attentions: bool = False + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + hidden_states = src + attention = () if output_attentions else None + for layer in self.layers: + hidden_states = layer( + hidden_states, + attention_mask=src_mask, + position_embeddings=pos_embed, + output_attentions=output_attentions, + ) + if output_attentions: + attention = attention + (hidden_states[1],) + hidden_states = hidden_states[0] + + return hidden_states, attention + + +class OmDetTurboHybridEncoder(nn.Module): + """ + Encoder consisting of channel projection layers, a set of `OmDetTurboEncoder`, a top-down Feature Pyramid Network + (FPN) and a bottom-up Path Aggregation Network (PAN). More details on the paper: https://arxiv.org/abs/2304.08069 + + Args: + config: OmDetTurboConfig + """ + + def __init__(self, config: OmDetTurboConfig): + super().__init__() + self.config = config + self.in_channels = config.encoder_in_channels + self.encoder_hidden_dim = config.encoder_hidden_dim + self.encoder_projection_indices = config.encoder_projection_indices + self.positional_encoding_temperature = config.positional_encoding_temperature + self.eval_size = config.eval_size + self.out_channels = [self.encoder_hidden_dim for _ in self.in_channels] + + self.channel_projection_layers = nn.ModuleList() + for in_channel in self.in_channels: + self.channel_projection_layers.append( + nn.Sequential( + nn.Conv2d(in_channel, self.encoder_hidden_dim, kernel_size=(1, 1), bias=False), + nn.BatchNorm2d(self.encoder_hidden_dim), + ) + ) + + # encoder transformer + self.encoder = nn.ModuleList([OmDetTurboEncoder(config) for _ in range(len(self.encoder_projection_indices))]) + # top-down fpn + self.lateral_convs = nn.ModuleList() + self.fpn_blocks = nn.ModuleList() + for _ in range(len(self.in_channels) - 1, 0, -1): + self.lateral_convs.append( + OmDetTurboConvNormLayer( + config, + in_channels=self.encoder_hidden_dim, + out_channels=self.encoder_hidden_dim, + kernel_size=1, + stride=1, + activation=config.conv_norm_activation, + ) + ) + self.fpn_blocks.append(OmDetTurboCSPRepLayer(config)) + + # bottom-up pan + self.downsample_convs = nn.ModuleList() + self.pan_blocks = nn.ModuleList() + for _ in range(len(self.in_channels) - 1): + self.downsample_convs.append( + OmDetTurboConvNormLayer( + config, + in_channels=self.encoder_hidden_dim, + out_channels=self.encoder_hidden_dim, + kernel_size=3, + stride=2, + activation=config.conv_norm_activation, + ) + ) + self.pan_blocks.append(OmDetTurboCSPRepLayer(config)) + + @staticmethod + def build_2d_sincos_position_embedding( + width, height, embed_dim=256, temperature=10000.0, device="cpu", dtype=torch.float32 + ): + grid_w = torch.arange(int(width), dtype=dtype, device=device) + grid_h = torch.arange(int(height), dtype=dtype, device=device) + grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="ij") + if embed_dim % 4 != 0: + raise ValueError("Embed dimension must be divisible by 4 for 2D sin-cos position embedding") + pos_dim = embed_dim // 4 + omega = torch.arange(pos_dim, dtype=dtype, device=device) / pos_dim + omega = 1.0 / (temperature**omega) + + out_w = grid_w.flatten()[..., None] @ omega[None] + out_h = grid_h.flatten()[..., None] @ omega[None] + + return torch.concat([out_w.sin(), out_w.cos(), out_h.sin(), out_h.cos()], dim=1)[None, :, :] + + def forward( + self, + inputs_embeddings=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + inputs_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Flattened feature map (output of the backbone + projection layers) that is passed to the encoder. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + hidden_states = inputs_embeddings + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + # get projection features + projected_features = [self.channel_projection_layers[i](feature) for i, feature in enumerate(hidden_states)] + # encoder + for encoder_layer_index, feature_to_project_index in enumerate(self.encoder_projection_indices): + if output_hidden_states: + encoder_states = encoder_states + (projected_features[feature_to_project_index],) + height, width = projected_features[feature_to_project_index].shape[2:] + # flatten [batch, channel, height, width] to [batch, height*width, channel] + src_flatten = projected_features[feature_to_project_index].flatten(2).permute(0, 2, 1) + if self.training or self.eval_size is None: + pos_embed = self.build_2d_sincos_position_embedding( + width, + height, + self.encoder_hidden_dim, + self.positional_encoding_temperature, + device=src_flatten.device, + dtype=src_flatten.dtype, + ).to(src_flatten.device, src_flatten.dtype) + else: + pos_embed = None + layer_outputs = self.encoder[encoder_layer_index]( + src_flatten, + pos_embed=pos_embed, + output_attentions=output_attentions, + ) + projected_features[feature_to_project_index] = ( + layer_outputs[0].permute(0, 2, 1).reshape(-1, self.encoder_hidden_dim, height, width).contiguous() + ) + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (projected_features[feature_to_project_index],) + + # Feature Pyramid Network (FPN) + fpn_feature_maps = [projected_features[-1]] + for idx in range(len(self.in_channels) - 1, 0, -1): + feat_high = fpn_feature_maps[0] + feat_low = projected_features[idx - 1] + feat_high = self.lateral_convs[len(self.in_channels) - 1 - idx](feat_high) + fpn_feature_maps[0] = feat_high + upsample_feat = F.interpolate(feat_high, scale_factor=2.0, mode="nearest") + fps_map = self.fpn_blocks[len(self.in_channels) - 1 - idx](torch.concat([upsample_feat, feat_low], dim=1)) + fpn_feature_maps.insert(0, fps_map) + + # Path Aggregation Network (PAN) + fpn_states = [fpn_feature_maps[0]] + for idx in range(len(self.in_channels) - 1): + feat_low = fpn_states[-1] + feat_high = fpn_feature_maps[idx + 1] + downsample_feat = self.downsample_convs[idx](feat_low) + hidden_states = self.pan_blocks[idx]( + torch.concat([downsample_feat, feat_high.to(downsample_feat.device)], dim=1) + ) + fpn_states.append(hidden_states) + if not return_dict: + return (fpn_states[-1], encoder_states, all_attentions, fpn_states) + return OmDetTurboEncoderOutput( + last_hidden_state=fpn_states[-1], + hidden_states=encoder_states, + attentions=all_attentions, + extracted_states=fpn_states, + ) + + +class OmDetTurboMLPWithDropout(nn.Module): + def __init__(self, config): + super().__init__() + self.linear1 = nn.Linear(config.class_embed_dim, config.task_encoder_hidden_dim) + self.activation = ACT2FN[config.decoder_activation] + self.dropout = nn.Dropout(config.decoder_dropout) + self.linear2 = nn.Linear(config.task_encoder_hidden_dim, config.class_embed_dim) + + def forward(self, x): + return self.linear2(self.dropout(self.activation(self.linear1(x)))) + + +class OmDetTurboMLP(nn.Module): + """Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + hidden_layers_dims = [hidden_dim] * (num_layers - 1) + layers_dims = [input_dim] + hidden_layers_dims + [output_dim] + self.layers = nn.ModuleList( + [nn.Linear(in_dim, out_dim) for in_dim, out_dim in zip(layers_dims[:-1], layers_dims[1:])] + ) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + + +class OmDetTurboResidualLayer(nn.Module): + """ + A residual connection followed by a layer norm. + """ + + def __init__(self, config): + super().__init__() + self.norm1 = nn.LayerNorm(config.class_embed_dim, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.decoder_dropout) + + def forward(self, x, y): + return self.norm1(x + self.dropout(y)) + + +class OmDetTurboTaskEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.mlp = OmDetTurboMLPWithDropout(config) + self.res1 = OmDetTurboResidualLayer(config) + + def forward(self, x): + mlp_out = self.mlp(x) + x = self.res1(x, mlp_out) + return x + + +class OmDetTurboDeformableTransformerDecoderLayer(nn.Module): + """ + A single layer of the Deformable Transformer Decoder. + """ + + def __init__(self, config): + super().__init__() + # self attention + self.self_attn = OmDetTurboMultiheadAttention( + config, + hidden_size=config.decoder_hidden_dim, + num_attention_heads=config.decoder_num_heads, + dropout=config.decoder_dropout, + ) + self.dropout1 = nn.Dropout(config.decoder_dropout) + self.norm1 = nn.LayerNorm(config.decoder_hidden_dim, eps=config.layer_norm_eps) + + # cross attention + self.cross_attn = OmDetTurboMultiscaleDeformableAttention( + config, num_heads=config.decoder_num_heads, n_points=config.decoder_num_points + ) + self.dropout2 = nn.Dropout(config.decoder_dropout) + self.norm2 = nn.LayerNorm(config.decoder_hidden_dim, eps=config.layer_norm_eps) + + # feed forward network + self.linear1 = nn.Linear(config.decoder_hidden_dim, config.decoder_dim_feedforward) + self.act = ACT2FN[config.decoder_activation] + self.dropout3 = nn.Dropout(config.decoder_dropout) + self.linear2 = nn.Linear(config.decoder_dim_feedforward, config.decoder_hidden_dim) + self.dropout4 = nn.Dropout(config.decoder_dropout) + self.norm3 = nn.LayerNorm(config.decoder_hidden_dim, eps=config.layer_norm_eps) + + self.output_attentions = config.output_attentions + self.output_hidden_states = config.output_hidden_states + + @staticmethod + def with_pos_embed(tensor, pos): + return tensor if pos is None else tensor + pos + + def forward( + self, + decoder_embeddings, + task_features, + reference_points, + vision_features, + vision_shapes, + vision_shapes_list, + level_start_index=None, + attention_mask=None, + padding_mask=None, + query_position=None, + output_attentions=None, + output_hidden_states=None, + ): + output_attentions = output_attentions if output_attentions is not None else self.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states + + origin_embedding_len = decoder_embeddings.shape[1] + + # self attention + query = key = self.with_pos_embed(decoder_embeddings, query_position) + # combine task_features with query, key, value + task_features = task_features.transpose(0, 1) + query = torch.cat((query, task_features), dim=1) + key = torch.cat((key, task_features), dim=1) + decoder_embeddings = torch.cat((decoder_embeddings, task_features), dim=1) + + outputs = self.self_attn( + query, + key, + decoder_embeddings, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + context, self_attention = outputs if output_attentions else (outputs[0], None) + decoder_embeddings = decoder_embeddings + self.dropout1(context) + decoder_embeddings = self.norm1(decoder_embeddings) + + task_features = decoder_embeddings[:, origin_embedding_len:, :].transpose(0, 1) + decoder_embeddings = decoder_embeddings[:, :origin_embedding_len, :] + + # cross attention + hidden_states = self.with_pos_embed(decoder_embeddings, query_position) + reference_points = reference_points.unsqueeze(2) + outputs, cross_attention = self.cross_attn( + hidden_states=hidden_states, + attention_mask=padding_mask, + encoder_hidden_states=vision_features, + reference_points=reference_points, + spatial_shapes=vision_shapes, + spatial_shapes_list=vision_shapes_list, + level_start_index=level_start_index, + ) + decoder_embeddings = decoder_embeddings + self.dropout2(outputs) + residual = self.norm2(decoder_embeddings) + + # feed forward network + decoder_embeddings = self.linear2(self.dropout3(self.act(self.linear1(residual)))) + decoder_embeddings = residual + self.dropout4(decoder_embeddings) + decoder_embeddings = self.norm3(decoder_embeddings) + + return ( + decoder_embeddings, + task_features, + self_attention if output_attentions else None, + cross_attention if output_attentions else None, + ) + + +class OmDetTurboPreTrainedModel(PreTrainedModel): + config_class = OmDetTurboConfig + base_model_prefix = "model" + main_input_name = "pixel_values" + + def _init_weights(self, module): + def linear_init_(module_to_init): + bound = 1 / math.sqrt(module_to_init.weight.shape[0]) + nn.init.uniform_(module_to_init.weight, -bound, bound) + if hasattr(module_to_init, "bias") and module_to_init.bias is not None: + nn.init.uniform_(module_to_init.bias, -bound, bound) + + if isinstance(module, OmDetTurboEncoderLayer): + linear_init_(module.fc1) + linear_init_(module.fc2) + elif isinstance(module, OmDetTurboDecoder): + nn.init.constant_(module.encoder_bbox_head.layers[-1].weight, 0.0) + nn.init.constant_(module.encoder_bbox_head.layers[-1].bias, 0.0) + for mlp in module.decoder_bbox_head: + nn.init.constant_(mlp.layers[-1].weight, 0.0) + nn.init.constant_(mlp.layers[-1].bias, 0.0) + linear_init_(module.encoder_vision_features[0]) + nn.init.xavier_uniform_(module.encoder_vision_features[0].weight) + if module.learn_initial_query: + nn.init.xavier_uniform_(module.tgt_embed.weight) + nn.init.xavier_uniform_(module.query_position_head.layers[0].weight) + nn.init.xavier_uniform_(module.query_position_head.layers[1].weight) + for layer in module.channel_projection_layers: + nn.init.xavier_uniform_(layer[0].weight) + elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): + module.weight.data.normal_(mean=0.0, std=self.config.init_std) + if module.bias is not None: + module.bias.data.zero_() + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, OmDetTurboDecoder): + module.gradient_checkpointing = value + + @staticmethod + def _get_cache_key_at_index(input_ids, attention_mask, index): + input_ids = input_ids[index] + input_mask = attention_mask[index] + cache_key = tuple(input_ids[input_mask != 0].tolist()) + return cache_key + + def get_cached_class_embeddings(self, classes_input_ids, classes_attention_mask): + not_cached_index = [] + not_cached_classes = [] + total_embeddings = [] + for idx, _ in enumerate(classes_input_ids): + cache_key = self._get_cache_key_at_index(classes_input_ids, classes_attention_mask, idx) + if self.language_cache_class.has(cache_key): + total_embeddings.append(self.language_cache_class.get(cache_key)) + else: + total_embeddings.append(None) + not_cached_index.append(idx) + not_cached_classes.append(cache_key) + + if not_cached_classes: + not_cached_classes_ids = torch.stack([classes_input_ids[idx] for idx in not_cached_index]) + embeddings = self.language_backbone(not_cached_classes_ids, encode_type="class") + for idx, emb in enumerate(embeddings): + idx_to_put = not_cached_index[idx] + total_embeddings[idx_to_put] = emb + self.language_cache_class.put(not_cached_classes[idx], emb) + + total_class_embs = torch.stack(total_embeddings).to(self.device) + return total_class_embs + + def get_cached_task_embeddings(self, tasks_input_ids, tasks_attention_mask): + not_cached_index = [] + not_cached_tasks = [] + total_task_features = [] + total_task_masks = [] + for idx, _ in enumerate(tasks_input_ids): + cache_key = self._get_cache_key_at_index(tasks_input_ids, tasks_attention_mask, idx) + if self.language_cache_prompt.has(cache_key): + task_feature, task_mask = self.language_cache_prompt.get(cache_key) + total_task_features.append(task_feature) + total_task_masks.append(task_mask) + else: + total_task_features.append(None) + total_task_masks.append(None) + not_cached_index.append(idx) + not_cached_tasks.append(cache_key) + + if not_cached_tasks: + not_cached_index_ids = torch.stack([tasks_input_ids[idx] for idx in not_cached_index]) + not_cached_mask = torch.stack([tasks_attention_mask[idx] for idx in not_cached_index]) + embeddings, masks = self.language_backbone(not_cached_index_ids, mask=not_cached_mask, encode_type="task") + + for idx in range(embeddings.shape[1]): + emb = embeddings[:, [idx], :] + idx_to_put = not_cached_index[idx] + cur_mask = torch.unsqueeze(masks[idx], dim=0).to(self.device) + total_task_features[idx_to_put] = emb + total_task_masks[idx_to_put] = cur_mask + self.language_cache_prompt.put(not_cached_tasks[idx], (emb, cur_mask)) + + # pad before concat if needed + max_len = max([task.shape[0] for task in total_task_features]) + for idx, task in enumerate(total_task_features): + if task.shape[0] < max_len: + pad_size = max_len - task.shape[0] + total_task_features[idx] = F.pad(task, (0, 0, 0, 0, 0, pad_size)) + total_task_masks[idx] = F.pad(total_task_masks[idx], (0, pad_size)) + + total_task_features = torch.cat(total_task_features, dim=1).to(self.device) + total_task_masks = torch.cat(total_task_masks, dim=0).to(self.device) + + return total_task_features, total_task_masks + + def get_language_embedding( + self, + classes_input_ids, + classes_attention_mask, + tasks_input_ids, + tasks_attention_mask, + classes_structure, + ): + batched_classes_embeddings = self.get_cached_class_embeddings(classes_input_ids, classes_attention_mask) + # regroup class embeddings using saved structure + max_class_size = torch.max(classes_structure) + class_embeddings_regrouped = [] + start = 0 + for size in classes_structure: + pad_size = max_class_size - size + class_embeddings_regrouped.append( + F.pad(batched_classes_embeddings[start : start + size], (0, 0, 0, pad_size)).unsqueeze(1) + ) + start += size + class_embeddings = torch.cat(class_embeddings_regrouped, dim=1) + + task_embeddings, task_mask = self.get_cached_task_embeddings(tasks_input_ids, tasks_attention_mask) + + return class_embeddings, task_embeddings, task_mask + + +OMDET_TURBO_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`OmDetTurboConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +OMDET_TURBO_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. + + Pixel values can be obtained using [`AutoImageProcessor`]. See [`DetrImageProcessor.__call__`] for + details. + + classes_input_ids (`torch.LongTensor` of shape `(total_classes (>= batch_size), sequence_length)`): + Indices of input classes sequence tokens in the vocabulary of the language model. + Several classes can be provided for each tasks, thus the tokenized classes are flattened + and the structure of the classes is provided in the `classes_structure` argument. + + Indices can be obtained using [`OmDetTurboProcessor`]. See [`OmDetTurboProcessor.__call__`] for + details. + + [What are input IDs?](../glossary#input-ids) + + classes_attention_mask (`torch.BoolTensor` of shape `(total_classes (>= batch_size), num_classes, sequence_length)`): + Attention mask for the classes. This is a binary mask that indicates which tokens should be attended to, + and which should not. + + tasks_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input tasks sequence tokens in the vocabulary of the language model. + + Indices can be obtained using [`OmDetTurboProcessor`]. See [`OmDetTurboProcessor.__call__`] for + details. + + [What are input IDs?](../glossary#input-ids) + + tasks_attention_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)`): + Attention mask for the tasks. This is a binary mask that indicates which tokens should be attended to, + and which should not. + + classes_structure (torch.LongTensor of shape `(batch_size)`): + Structure of the classes. This tensor indicates the number of classes for each task. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + + """ + + +def _cosine_similarity_scaled(a, b, logit_scale): + a = a / a.norm(dim=2, keepdim=True).clamp_min(1e-12) + b = b / b.norm(dim=1, keepdim=True).clamp_min(1e-12) + logit_scale = logit_scale.exp() + logits_per_image = logit_scale * torch.bmm(a, b) + return logits_per_image + + +def get_class_similarity(class_distance_type, cls_feature, class_proj): + logit_scale = torch.tensor(1 / 0.07).log() + if class_distance_type == "cosine": + class_logits = _cosine_similarity_scaled(cls_feature, class_proj, logit_scale) + elif class_distance_type == "dot": + class_logits = torch.bmm(cls_feature, class_proj) + else: + raise Exception("Unknown class_distance_type {}".format(class_distance_type)) + return class_logits + + +def _inverse_sigmoid(x, eps=1e-5): + x = x.clamp(min=0, max=1) + x1 = x.clamp(min=eps) + x2 = (1 - x).clamp(min=eps) + return torch.log(x1 / x2) + + +class OmDetTurboDecoder(OmDetTurboPreTrainedModel): + def __init__(self, config: OmDetTurboConfig): + self.config = config + super().__init__(config) + self.gradient_checkpointing = False + + hidden_dim = config.decoder_hidden_dim + self.num_queries = config.num_queries + self.class_distance_type = config.class_distance_type + self.learn_initial_query = config.learn_initial_query + + # backbone feature projection + self.channel_projection_layers = nn.ModuleList( + nn.Sequential(nn.Conv2d(x, hidden_dim, 1, bias=False), nn.BatchNorm2d(hidden_dim)) + for x in config.vision_features_channels + ) + self.task_encoder = OmDetTurboTaskEncoder(config) + if config.class_embed_dim != hidden_dim: + self.task_project = nn.Linear(config.class_embed_dim, hidden_dim) + + # Transformer module + self.layers = nn.ModuleList( + [OmDetTurboDeformableTransformerDecoderLayer(config) for _ in range(config.decoder_num_layers)] + ) + self.decoder_num_layers = config.decoder_num_layers + # decoder embedding + if self.learn_initial_query: + self.tgt_embed = nn.Embedding(self.num_queries, hidden_dim) + self.query_position_head = OmDetTurboMLP( + input_dim=4, hidden_dim=2 * hidden_dim, output_dim=hidden_dim, num_layers=2 + ) + + # encoder head + self.encoder_vision_features = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim, eps=config.layer_norm_eps) + ) + self.encoder_class_head = nn.Linear(config.class_embed_dim, hidden_dim) + self.encoder_bbox_head = OmDetTurboMLP(input_dim=hidden_dim, hidden_dim=hidden_dim, output_dim=4, num_layers=3) + + # decoder head + self.decoder_class_head = nn.ModuleList( + [nn.Linear(config.class_embed_dim, hidden_dim) for _ in range(config.decoder_num_layers)] + ) + self.decoder_bbox_head = nn.ModuleList( + [OmDetTurboMLP(hidden_dim, hidden_dim, 4, num_layers=3) for _ in range(config.decoder_num_layers)] + ) + + # Initialize weights and apply final processing + self.post_init() + + @lru_cache(maxsize=32) + def generate_anchors(self, spatial_shapes=None, grid_size=0.05, device="cpu", dtype=torch.float32): + # We always generate anchors in float32 to preserve equivalence between + # dynamic and static anchor inference + # Ignore copy + if spatial_shapes is None: + raise ValueError("spatial_shapes must be provided") + + anchors = [] + for level, (height, width) in enumerate(spatial_shapes): + grid_y, grid_x = torch.meshgrid( + torch.arange(end=height, dtype=dtype, device=device), + torch.arange(end=width, dtype=dtype, device=device), + indexing="ij", + ) + grid_xy = torch.stack([grid_x, grid_y], -1) + valid_wh = torch.tensor([width, height], dtype=dtype, device=device) + grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_wh + wh = torch.ones_like(grid_xy, dtype=dtype, device=device) * grid_size * (2.0**level) + anchors.append(torch.concat([grid_xy, wh], -1).reshape(-1, height * width, 4)) + # define the valid range for anchor coordinates + eps = 1e-2 + anchors = torch.concat(anchors, 1) + valid_mask = ((anchors > eps) * (anchors < 1 - eps)).all(-1, keepdim=True) + anchors = torch.log(anchors / (1 - anchors)) + anchors = torch.where(valid_mask, anchors, torch.inf) + + return anchors, valid_mask + + def _get_encoder_input(self, vision_features): + # get projection features + vision_features = [self.channel_projection_layers[i](feat) for i, feat in enumerate(vision_features)] + # get encoder inputs + new_vision_features = [] + new_vision_shapes_list = [] + for feat in vision_features: + height, width = feat.shape[2:] + # [batch_size, channels, height, width] -> [batch_size, height*width, channels] + new_vision_features.append(feat.flatten(2).permute(0, 2, 1)) + # [num_feature_levels, 2] + new_vision_shapes_list.append((height, width)) + + # [batch_size, height*width, channels] + new_vision_features = torch.cat(new_vision_features, 1) + new_vision_shapes = torch.tensor(new_vision_shapes_list, dtype=torch.int64, device=vision_features[0].device) + level_start_index = torch.cat((new_vision_shapes.new_zeros((1,)), new_vision_shapes.prod(1).cumsum(0)[:-1])) + + return new_vision_features, new_vision_shapes, new_vision_shapes_list, level_start_index + + def _get_decoder_input( + self, vision_features, vision_shapes, class_features, denoise_embeddings=None, denoise_bboxes=None + ): + batch_size = len(vision_features) + # prepare input for decoder + anchors, valid_mask = self.generate_anchors( + vision_shapes, device=vision_features.device, dtype=vision_features.dtype + ) + predicted_class_features = self.encoder_vision_features( + torch.where( + valid_mask, + vision_features, + torch.tensor(0.0, dtype=vision_features.dtype, device=vision_features.device), + ) + ) + + original_class_projected = self.encoder_class_head(class_features).permute(1, 2, 0) + encoder_class_similarity = get_class_similarity( + self.class_distance_type, predicted_class_features, original_class_projected + ) + + # dynamic anchors + static content + # (batch_size, height*width, 4) + encoder_outputs_bboxes = self.encoder_bbox_head(predicted_class_features) + anchors + + # query selection + # (batch_size, num_queries) + topk_ind = torch.topk(encoder_class_similarity.max(-1).values, self.num_queries, dim=1).indices.view(-1) + # (batch_size, num_queries) + batch_ind = ( + torch.arange(end=batch_size, dtype=topk_ind.dtype, device=topk_ind.device) + .unsqueeze(-1) + .repeat(1, self.num_queries) + .view(-1) + ) + + reference_points = encoder_outputs_bboxes[batch_ind, topk_ind].view(batch_size, self.num_queries, -1) + encoder_bboxes = reference_points.sigmoid() + if denoise_bboxes is not None: + reference_points = torch.cat([denoise_bboxes, reference_points], 1) + if self.training: + reference_points = reference_points.detach() + encoder_class_similarity = encoder_class_similarity[batch_ind, topk_ind].view(batch_size, self.num_queries, -1) + + if self.learn_initial_query: + embeddings = self.tgt_embed.weight.unsqueeze(0).repeat(batch_size, 1, 1) + else: + embeddings = predicted_class_features[batch_ind, topk_ind].view(batch_size, self.num_queries, -1) + if self.training: + embeddings = embeddings.detach() + if denoise_embeddings is not None: + embeddings = torch.cat([denoise_embeddings, embeddings], 1) + + return embeddings, reference_points, encoder_bboxes, encoder_class_similarity, anchors + + def forward( + self, + vision_features, + class_features, + task_features, + task_mask, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + """ + Args: + vision_features (`torch.FloatTensor`): The sequence of vision features. shape depends on the vision + backbone. + class_features (`torch.FloatTensor`): The sequence of class features of shape + `(class_sequence_length, batch_size, class_embed_dim)`. + task_features (`torch.FloatTensor`): The sequence of task features of shape + `(task_sequence_length, batch_size, decoder_hidden_dim)`. + task_mask (`torch.LongTensor`): The mask for the task features of shape `(batch_size, task_sequence_length)`. + output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention + layers. See `attentions` under returned tensors for more detail. + output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See + `hidden_states` under returned tensors for more detail. + return_dict (`bool`, *optional*): Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain + tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + vision_features, vision_shapes, vision_shapes_list, level_start_index = self._get_encoder_input( + vision_features + ) + + # todo add denoising for training + denoise_embeddings, denoise_bboxes, key_padding_mask = None, None, None + batch_size = task_mask.shape[0] + + # compose attn_mask for vision_emb and task_emb fusion + task_features = self.task_encoder(task_features) + if self.task_project is not None: + task_features = self.task_project(task_features) + src_key_mask = (task_mask == 0).detach() + attn_mask_len = self.num_queries + fusion_size = attn_mask_len + task_features.shape[0] + key_padding_mask = torch.zeros([batch_size, fusion_size], dtype=torch.bool).to(task_features.device) + key_padding_mask[:, attn_mask_len:] = src_key_mask + attention_mask = _prepare_4d_attention_mask(~key_padding_mask, dtype=vision_features.dtype) + decoder_embeddings, reference_points, encoder_bboxes, encoder_class_similarity, init_reference_points = ( + self._get_decoder_input( + vision_features, tuple(vision_shapes_list), class_features, denoise_embeddings, denoise_bboxes + ) + ) + + all_hidden_states = () if output_hidden_states else None + all_attns = () if output_attentions else None + all_self_attns = () if output_attentions else None + all_cross_attns = () if output_attentions else None + predicted_class_features = decoder_embeddings + + if output_hidden_states: + all_hidden_states = all_hidden_states + (predicted_class_features,) + decoder_bboxes = [] + decoder_classes = [] + last_refined_bbox = None + reference_points = reference_points.sigmoid() + for i, layer in enumerate(self.layers): + if self.gradient_checkpointing and self.training: + predicted_class_features, task_features, self_attention, cross_attention = ( + self._gradient_checkpointing_func( + layer.__call__, + predicted_class_features, + task_features, + reference_points, + vision_features, + vision_shapes, + vision_shapes_list, + level_start_index=level_start_index, + attention_mask=attention_mask, + query_position=self.query_position_head(reference_points), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + ) + else: + predicted_class_features, task_features, self_attention, cross_attention = layer( + predicted_class_features, + task_features, + reference_points, + vision_features, + vision_shapes, + vision_shapes_list, + level_start_index=level_start_index, + attention_mask=attention_mask, + query_position=self.query_position_head(reference_points), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + if output_attentions: + all_self_attns = all_self_attns + (self_attention,) + all_cross_attns = all_cross_attns + (cross_attention,) + if output_hidden_states: + all_hidden_states = all_hidden_states + (predicted_class_features,) + + refined_bbox = torch.sigmoid( + self.decoder_bbox_head[i](predicted_class_features) + _inverse_sigmoid(reference_points) + ) + original_class_projected = self.decoder_class_head[i](class_features).permute(1, 2, 0) + if self.training: + decoder_classes.append( + get_class_similarity( + class_distance_type=self.class_distance_type, + cls_feature=predicted_class_features, + class_proj=original_class_projected, + ) + ) + if i == 0: + decoder_bboxes.append(refined_bbox) + else: + decoder_bboxes.append( + torch.sigmoid( + self.decoder_bbox_head[i](predicted_class_features) + _inverse_sigmoid(last_refined_bbox) + ) + ) + elif i == self.decoder_num_layers - 1: + decoder_classes.append( + get_class_similarity(self.class_distance_type, predicted_class_features, original_class_projected) + ) + decoder_bboxes.append(refined_bbox) + break + last_refined_bbox = refined_bbox + reference_points = refined_bbox.detach() if self.training else refined_bbox + if output_attentions: + all_attns += (all_self_attns, all_cross_attns) + + last_hidden_state = predicted_class_features + decoder_bboxes = torch.stack(decoder_bboxes) + decoder_classes = torch.stack(decoder_classes) + + if not return_dict: + return ( + last_hidden_state, + all_hidden_states, + all_attns, + decoder_bboxes, + decoder_classes, + encoder_bboxes, + encoder_class_similarity, + init_reference_points, + reference_points, + ) + + return OmDetTurboDecoderOutput( + last_hidden_state=last_hidden_state, + hidden_states=all_hidden_states, + attentions=all_attns, + decoder_coords=decoder_bboxes, + decoder_classes=decoder_classes, + encoder_coord_logits=encoder_bboxes, + encoder_class_logits=encoder_class_similarity, + init_reference_points=init_reference_points, + intermediate_reference_points=reference_points, + ) + + +@add_start_docstrings( + """ + OmDetTurbo Model (consisting of a vision and a text backbone, and encoder-decoder architecture) outputting + bounding boxes and classes scores for tasks such as COCO detection. + """, + OMDET_TURBO_START_DOCSTRING, +) +class OmDetTurboForObjectDetection(OmDetTurboPreTrainedModel): + def __init__(self, config: OmDetTurboConfig): + super().__init__(config) + self.vision_backbone = OmDetTurboVisionBackbone(config) + self.language_backbone = OmDetTurboLanguageBackbone(config) + self.encoder = OmDetTurboHybridEncoder(config) + self.decoder = OmDetTurboDecoder(config) + self.num_queries = config.num_queries + + self.language_cache_class = OmDetTurboLRUCache(config.cache_size) + self.language_cache_prompt = OmDetTurboLRUCache(config.cache_size) + self.vocab_size = config.text_config.vocab_size + self.post_init() + + def get_input_embeddings(self): + return self.language_backbone.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_backbone.model.set_input_embeddings(value) + + def resize_token_embeddings( + self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None, mean_resizing: bool = True + ) -> nn.Embedding: + model_embeds = self.language_backbone.model.resize_token_embeddings( + new_num_tokens=new_num_tokens, pad_to_multiple_of=pad_to_multiple_of, mean_resizing=mean_resizing + ) + self.config.text_config.vocab_size = model_embeds.num_embeddings + self.vocab_size = model_embeds.num_embeddings + return model_embeds + + @add_start_docstrings_to_model_forward(OMDET_TURBO_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=OmDetTurboObjectDetectionOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: torch.FloatTensor, + classes_input_ids: torch.LongTensor, + classes_attention_mask: torch.LongTensor, + tasks_input_ids: torch.LongTensor, + tasks_attention_mask: torch.LongTensor, + classes_structure: torch.LongTensor, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.FloatTensor], OmDetTurboObjectDetectionOutput]: + r""" + Returns: + + Examples: + + ```python + >>> import requests + >>> from PIL import Image + + >>> from transformers import AutoProcessor, OmDetTurboForObjectDetection + + >>> processor = AutoProcessor.from_pretrained("omlab/omdet-turbo-swin-tiny-hf") + >>> model = OmDetTurboForObjectDetection.from_pretrained("omlab/omdet-turbo-swin-tiny-hf") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> classes = ["cat", "remote"] + >>> task = "Detect {}.".format(", ".join(classes)) + >>> inputs = processor(image, text=classes, task=task, return_tensors="pt") + + >>> outputs = model(**inputs) + + >>> # convert outputs (bounding boxes and class logits) + >>> results = processor.post_process_grounded_object_detection( + ... outputs, + ... classes=classes, + ... target_sizes=[image.size[::-1]], + ... score_threshold=0.3, + ... nms_threshold=0.3, + >>> )[0] + >>> for score, class_name, box in zip(results["scores"], results["classes"], results["boxes"]): + ... box = [round(i, 1) for i in box.tolist()] + ... print( + ... f"Detected {class_name} with confidence " + ... f"{round(score.item(), 2)} at location {box}" + ... ) + Detected remote with confidence 0.76 at location [39.9, 71.3, 176.5, 117.9] + Detected cat with confidence 0.72 at location [345.1, 22.5, 639.7, 371.9] + Detected cat with confidence 0.65 at location [12.7, 53.8, 315.5, 475.3] + Detected remote with confidence 0.57 at location [333.4, 75.6, 370.7, 187.0] + ```""" + if labels is not None: + raise NotImplementedError("Training is not implemented yet") + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + loss = None + image_features = self.vision_backbone(pixel_values) + encoder_outputs = self.encoder( + image_features, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + class_features, task_features, task_mask = self.get_language_embedding( + classes_input_ids, + classes_attention_mask, + tasks_input_ids, + tasks_attention_mask, + classes_structure, + ) + encoder_extracted_states = encoder_outputs.extracted_states if return_dict else encoder_outputs[-1] + decoder_outputs = self.decoder( + encoder_extracted_states, + class_features, + task_features, + task_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + return tuple( + output + for output in [ + loss, + decoder_outputs[3][-1], + decoder_outputs[4][-1], + decoder_outputs[7], + decoder_outputs[8], + decoder_outputs[5], + decoder_outputs[6], + encoder_outputs[-1], + decoder_outputs[1], + decoder_outputs[2], + encoder_outputs[1], + encoder_outputs[2], + classes_structure, + ] + if output is not None + ) + + return OmDetTurboObjectDetectionOutput( + loss=loss, + decoder_coord_logits=decoder_outputs.decoder_coords[-1], + decoder_class_logits=decoder_outputs.decoder_classes[-1], + init_reference_points=decoder_outputs.init_reference_points, + intermediate_reference_points=decoder_outputs.intermediate_reference_points, + encoder_coord_logits=decoder_outputs.encoder_coord_logits, + encoder_class_logits=decoder_outputs.encoder_class_logits, + encoder_extracted_states=encoder_outputs.extracted_states, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + classes_structure=classes_structure, + ) + + +__all__ = ["OmDetTurboForObjectDetection", "OmDetTurboPreTrainedModel"] diff --git a/docs/transformers/src/transformers/models/omdet_turbo/processing_omdet_turbo.py b/docs/transformers/src/transformers/models/omdet_turbo/processing_omdet_turbo.py new file mode 100644 index 0000000000000000000000000000000000000000..3f5c51779b9d6b26be3aabd6fd1725dc992518a8 --- /dev/null +++ b/docs/transformers/src/transformers/models/omdet_turbo/processing_omdet_turbo.py @@ -0,0 +1,415 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Processor class for OmDet-Turbo. +""" + +import warnings +from typing import TYPE_CHECKING, List, Optional, Tuple, Union + +from ...feature_extraction_utils import BatchFeature +from ...image_transforms import center_to_corners_format +from ...image_utils import ImageInput +from ...processing_utils import ProcessingKwargs, ProcessorMixin, TextKwargs, Unpack +from ...tokenization_utils_base import PreTokenizedInput, TextInput +from ...utils import ( + TensorType, + is_torch_available, + is_torchvision_available, +) +from ...utils.deprecation import deprecate_kwarg +from ...utils.import_utils import requires + + +if TYPE_CHECKING: + from .modeling_omdet_turbo import OmDetTurboObjectDetectionOutput + + +class OmDetTurboTextKwargs(TextKwargs, total=False): + task: Optional[Union[str, List[str], TextInput, PreTokenizedInput]] + + +if is_torch_available(): + import torch + + +if is_torchvision_available(): + from torchvision.ops.boxes import batched_nms + + +class OmDetTurboProcessorKwargs(ProcessingKwargs, total=False): + text_kwargs: OmDetTurboTextKwargs + _defaults = { + "text_kwargs": { + "add_special_tokens": True, + "padding": "max_length", + "truncation": True, + "max_length": 77, + "stride": 0, + "return_overflowing_tokens": False, + "return_special_tokens_mask": False, + "return_offsets_mapping": False, + "return_token_type_ids": False, + "return_length": False, + "verbose": True, + "task": None, + }, + "images_kwargs": {}, + } + + +class DictWithDeprecationWarning(dict): + message = ( + "The `classes` key is deprecated for `OmDetTurboProcessor.post_process_grounded_object_detection` " + "output dict and will be removed in a 4.51.0 version. Please use `text_labels` instead." + ) + + def __getitem__(self, key): + if key == "classes": + warnings.warn(self.message, FutureWarning) + return super().__getitem__("text_labels") + return super().__getitem__(key) + + def get(self, key, *args, **kwargs): + if key == "classes": + warnings.warn(self.message, FutureWarning) + return super().get("text_labels", *args, **kwargs) + return super().get(key, *args, **kwargs) + + +def clip_boxes(box, box_size: Tuple[int, int]): + """ + Clip the boxes by limiting x coordinates to the range [0, width] + and y coordinates to the range [0, height]. + + Args: + box (Tensor): The box to be clipped. + box_size (height, width): The clipping box's size. + """ + assert torch.isfinite(box).all(), "Box tensor contains infinite or NaN!" + height, width = box_size + x1 = box[:, 0].clamp(min=0, max=width) + y1 = box[:, 1].clamp(min=0, max=height) + x2 = box[:, 2].clamp(min=0, max=width) + y2 = box[:, 3].clamp(min=0, max=height) + box = torch.stack((x1, y1, x2, y2), dim=-1) + + return box + + +def compute_score(boxes): + """ + Compute logit scores per class for each box (proposal) and an array of class indices + corresponding to each proposal, flattened across the proposal_num. + The indices in `classes` will later be used to filter and match the predicted classes + with the input class names. + """ + num_classes = boxes.shape[2] + proposal_num = boxes.shape[1] + scores = torch.sigmoid(boxes) + classes = torch.arange(num_classes, device=boxes.device).unsqueeze(0).repeat(proposal_num, 1).flatten(0, 1) + return scores, classes + + +def _post_process_boxes_for_image( + boxes: "torch.Tensor", + scores: "torch.Tensor", + labels: "torch.Tensor", + image_num_classes: int, + image_size: Tuple[int, int], + threshold: float, + nms_threshold: float, + max_num_det: Optional[int] = None, +) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]: + """ + Filter predicted results using given thresholds and NMS. + + Args: + boxes (`torch.Tensor`): + A Tensor of predicted class-specific or class-agnostic boxes for the image. + Shape (num_queries, max_num_classes_in_batch * 4) if doing class-specific regression, + or (num_queries, 4) if doing class-agnostic regression. + scores (`torch.Tensor` of shape (num_queries, max_num_classes_in_batch + 1)): + A Tensor of predicted class scores for the image. + labels (`torch.Tensor` of shape (num_queries * (max_num_classes_in_batch + 1),)): + A Tensor of predicted labels for the image. + image_num_classes (`int`): + The number of classes queried for detection on the image. + image_size (`Tuple[int, int]`): + A tuple of (height, width) for the image. + threshold (`float`): + Only return detections with a confidence score exceeding this threshold. + nms_threshold (`float`): + The threshold to use for box non-maximum suppression. Value in [0, 1]. + max_num_det (`int`, *optional*): + The maximum number of detections to return. Default is None. + + Returns: + Tuple: A tuple with the following: + "boxes" (Tensor): A tensor of shape (num_filtered_objects, 4), containing the predicted boxes in (x1, y1, x2, y2) format. + "scores" (Tensor): A tensor of shape (num_filtered_objects,), containing the predicted confidence scores for each detection. + "labels" (Tensor): A tensor of ids, where each id is the predicted class id for the corresponding detection + """ + + # Filter by max number of detections + proposal_num = len(boxes) if max_num_det is None else max_num_det + scores_per_image, topk_indices = scores.flatten(0, 1).topk(proposal_num, sorted=False) + labels_per_image = labels[topk_indices] + boxes_per_image = boxes.view(-1, 1, 4).repeat(1, scores.shape[1], 1).view(-1, 4) + boxes_per_image = boxes_per_image[topk_indices] + + # Convert and scale boxes to original image size + boxes_per_image = center_to_corners_format(boxes_per_image) + boxes_per_image = boxes_per_image * torch.tensor(image_size[::-1]).repeat(2).to(boxes_per_image.device) + + # Filtering by confidence score + filter_mask = scores_per_image > threshold # R x K + score_keep = filter_mask.nonzero(as_tuple=False).view(-1) + boxes_per_image = boxes_per_image[score_keep] + scores_per_image = scores_per_image[score_keep] + labels_per_image = labels_per_image[score_keep] + + # Ensure we did not overflow to non existing classes + filter_classes_mask = labels_per_image < image_num_classes + classes_keep = filter_classes_mask.nonzero(as_tuple=False).view(-1) + boxes_per_image = boxes_per_image[classes_keep] + scores_per_image = scores_per_image[classes_keep] + labels_per_image = labels_per_image[classes_keep] + + # NMS + keep = batched_nms(boxes_per_image, scores_per_image, labels_per_image, nms_threshold) + boxes_per_image = boxes_per_image[keep] + scores_per_image = scores_per_image[keep] + labels_per_image = labels_per_image[keep] + + # Clip to image size + boxes_per_image = clip_boxes(boxes_per_image, image_size) + + return boxes_per_image, scores_per_image, labels_per_image + + +@requires(backends=("vision", "torchvision")) +class OmDetTurboProcessor(ProcessorMixin): + r""" + Constructs a OmDet-Turbo processor which wraps a Deformable DETR image processor and an AutoTokenizer into a + single processor. + + [`OmDetTurboProcessor`] offers all the functionalities of [`DetrImageProcessor`] and + [`AutoTokenizer`]. See the docstring of [`~OmDetTurboProcessor.__call__`] and [`~OmDetTurboProcessor.decode`] + for more information. + + Args: + image_processor (`DetrImageProcessor`): + An instance of [`DetrImageProcessor`]. The image processor is a required input. + tokenizer (`AutoTokenizer`): + An instance of ['PreTrainedTokenizer`]. The tokenizer is a required input. + """ + + attributes = ["image_processor", "tokenizer"] + image_processor_class = ("DetrImageProcessor", "DetrImageProcessorFast") + tokenizer_class = "AutoTokenizer" + + def __init__(self, image_processor, tokenizer): + super().__init__(image_processor, tokenizer) + + def __call__( + self, + images: ImageInput = None, + text: Union[List[str], List[List[str]]] = None, + audio=None, + videos=None, + **kwargs: Unpack[OmDetTurboProcessorKwargs], + ) -> BatchFeature: + """ + This method uses [*DetrImageProcessor.__call__] method to prepare image(s) for the model, and + [CLIPTokenizerFast.__call__] to prepare text for the model. + + Please refer to the docstring of the above two methods for more information. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. + text (`Union[str, List[str], List[List[str]]]`): + The classes used to limit the scope of the open vocabulary detection. Expects a list of strings or a list + of list of strings. Batched classes can be of different lengths. + Examples: ["cat", "dog", "bird"], [["cat", "dog", "bird"], ["hat", "person"], ["car"]] + Kwargs: + task (`Union[str, List[str], TextInput, PreTokenizedInput]`): + The grounded text used to guide open vocabulary detection. Expects a single string or a list of strings. + Examples: "Detect a cat, a dog, and a bird.",[ "Detect everything.", "Detect trees and flowers."] + When not provided, the default task is "Detect [class1], [class2], [class3]" etc. + ... + """ + if images is None or text is None: + raise ValueError("You have to specify both `images` and `text`") + + output_kwargs = self._merge_kwargs( + OmDetTurboProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + + if isinstance(text, str): + text = text.strip(" ").split(",") + + if not (len(text) and isinstance(text[0], (list, tuple))): + text = [text] + + task = output_kwargs["text_kwargs"].pop("task", None) + if task is None: + task = ["Detect {}.".format(", ".join(text_single)) for text_single in text] + elif not isinstance(task, (list, tuple)): + task = [task] + + encoding_image_processor = self.image_processor(images, **output_kwargs["images_kwargs"]) + tasks_encoding = self.tokenizer(text=task, **output_kwargs["text_kwargs"]) + + classes = text + + classes_structure = torch.tensor([len(class_single) for class_single in classes], dtype=torch.long) + classes_flattened = [class_single for class_batch in classes for class_single in class_batch] + classes_encoding = self.tokenizer(text=classes_flattened, **output_kwargs["text_kwargs"]) + + encoding = BatchFeature() + encoding.update({f"tasks_{key}": value for key, value in tasks_encoding.items()}) + encoding.update({f"classes_{key}": value for key, value in classes_encoding.items()}) + encoding.update({"classes_structure": classes_structure}) + encoding.update(encoding_image_processor) + + return encoding + + # Copied from transformers.models.blip.processing_blip.BlipProcessor.batch_decode with BertTokenizerFast->PreTrainedTokenizer + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + # Copied from transformers.models.blip.processing_blip.BlipProcessor.decode with BertTokenizerFast->PreTrainedTokenizer + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + def _get_default_image_size(self) -> Tuple[int, int]: + height = ( + self.image_processor.size["height"] + if "height" in self.image_processor.size + else self.image_processor.size["shortest_edge"] + ) + width = ( + self.image_processor.size["width"] + if "width" in self.image_processor.size + else self.image_processor.size["longest_edge"] + ) + return height, width + + @deprecate_kwarg("score_threshold", new_name="threshold", version="4.51.0") + @deprecate_kwarg("classes", new_name="text_labels", version="4.51.0") + def post_process_grounded_object_detection( + self, + outputs: "OmDetTurboObjectDetectionOutput", + text_labels: Optional[Union[List[str], List[List[str]]]] = None, + threshold: float = 0.3, + nms_threshold: float = 0.5, + target_sizes: Optional[Union[TensorType, List[Tuple]]] = None, + max_num_det: Optional[int] = None, + ): + """ + Converts the raw output of [`OmDetTurboForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y, + bottom_right_x, bottom_right_y) format and get the associated text class. + + Args: + outputs ([`OmDetTurboObjectDetectionOutput`]): + Raw outputs of the model. + text_labels (Union[List[str], List[List[str]]], *optional*): + The input classes names. If not provided, `text_labels` will be set to `None` in `outputs`. + threshold (float, defaults to 0.3): + Only return detections with a confidence score exceeding this threshold. + nms_threshold (float, defaults to 0.5): + The threshold to use for box non-maximum suppression. Value in [0, 1]. + target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*): + Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size + `(height, width)` of each image in the batch. If unset, predictions will not be resized. + max_num_det (`int`, *optional*): + The maximum number of detections to return. + Returns: + `List[Dict]`: A list of dictionaries, each dictionary containing the scores, classes and boxes for an image + in the batch as predicted by the model. + """ + + batch_size = len(outputs.decoder_coord_logits) + + # Inputs consistency check for target sizes + if target_sizes is None: + height, width = self._get_default_image_size() + target_sizes = [(height, width)] * batch_size + + if any(len(image_size) != 2 for image_size in target_sizes): + raise ValueError( + "Each element of target_sizes must contain the size (height, width) of each image of the batch" + ) + + if len(target_sizes) != batch_size: + raise ValueError("Make sure that you pass in as many target sizes as output sequences") + + # Inputs consistency check for text labels + if text_labels is not None and isinstance(text_labels[0], str): + text_labels = [text_labels] + + if text_labels is not None and len(text_labels) != batch_size: + raise ValueError("Make sure that you pass in as many classes group as output sequences") + + # Convert target_sizes to list for easier handling + if isinstance(target_sizes, torch.Tensor): + target_sizes = target_sizes.tolist() + + batch_boxes = outputs.decoder_coord_logits + batch_logits = outputs.decoder_class_logits + batch_num_classes = outputs.classes_structure + + batch_scores, batch_labels = compute_score(batch_logits) + + results = [] + for boxes, scores, image_size, image_num_classes in zip( + batch_boxes, batch_scores, target_sizes, batch_num_classes + ): + boxes, scores, labels = _post_process_boxes_for_image( + boxes=boxes, + scores=scores, + labels=batch_labels, + image_num_classes=image_num_classes, + image_size=image_size, + threshold=threshold, + nms_threshold=nms_threshold, + max_num_det=max_num_det, + ) + result = DictWithDeprecationWarning( + {"boxes": boxes, "scores": scores, "labels": labels, "text_labels": None} + ) + results.append(result) + + # Add text labels + if text_labels is not None: + for result, image_text_labels in zip(results, text_labels): + result["text_labels"] = [image_text_labels[idx] for idx in result["labels"]] + + return results + + +__all__ = ["OmDetTurboProcessor"] diff --git a/docs/transformers/src/transformers/models/oneformer/__init__.py b/docs/transformers/src/transformers/models/oneformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..195b56e9fbd042e1170a5dc2fabdf5349af59ecf --- /dev/null +++ b/docs/transformers/src/transformers/models/oneformer/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_oneformer import * + from .image_processing_oneformer import * + from .modeling_oneformer import * + from .processing_oneformer import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/oneformer/configuration_oneformer.py b/docs/transformers/src/transformers/models/oneformer/configuration_oneformer.py new file mode 100644 index 0000000000000000000000000000000000000000..2f577d0d4d62fa9a3e686f726292e1ea8007954b --- /dev/null +++ b/docs/transformers/src/transformers/models/oneformer/configuration_oneformer.py @@ -0,0 +1,277 @@ +# coding=utf-8 +# Copyright 2022 SHI Labs and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""OneFormer model configuration""" + +from typing import Dict, Optional + +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ...utils.backbone_utils import verify_backbone_config_arguments +from ..auto import CONFIG_MAPPING + + +logger = logging.get_logger(__name__) + + +class OneFormerConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`OneFormerModel`]. It is used to instantiate a + OneFormer model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the OneFormer + [shi-labs/oneformer_ade20k_swin_tiny](https://huggingface.co/shi-labs/oneformer_ade20k_swin_tiny) architecture + trained on [ADE20k-150](https://huggingface.co/datasets/scene_parse_150). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + backbone_config (`PretrainedConfig`, *optional*, defaults to `SwinConfig`): + The configuration of the backbone model. + backbone (`str`, *optional*): + Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this + will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone` + is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights. + use_pretrained_backbone (`bool`, *optional*, defaults to `False`): + Whether to use pretrained weights for the backbone. + use_timm_backbone (`bool`, *optional*, defaults to `False`): + Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers + library. + backbone_kwargs (`dict`, *optional*): + Keyword arguments to be passed to AutoBackbone when loading from a checkpoint + e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set. + ignore_value (`int`, *optional*, defaults to 255): + Values to be ignored in GT label while calculating loss. + num_queries (`int`, *optional*, defaults to 150): + Number of object queries. + no_object_weight (`float`, *optional*, defaults to 0.1): + Weight for no-object class predictions. + class_weight (`float`, *optional*, defaults to 2.0): + Weight for Classification CE loss. + mask_weight (`float`, *optional*, defaults to 5.0): + Weight for binary CE loss. + dice_weight (`float`, *optional*, defaults to 5.0): + Weight for dice loss. + contrastive_weight (`float`, *optional*, defaults to 0.5): + Weight for contrastive loss. + contrastive_temperature (`float`, *optional*, defaults to 0.07): + Initial value for scaling the contrastive logits. + train_num_points (`int`, *optional*, defaults to 12544): + Number of points to sample while calculating losses on mask predictions. + oversample_ratio (`float`, *optional*, defaults to 3.0): + Ratio to decide how many points to oversample. + importance_sample_ratio (`float`, *optional*, defaults to 0.75): + Ratio of points that are sampled via importance sampling. + init_std (`float`, *optional*, defaults to 0.02): + Standard deviation for normal intialization. + init_xavier_std (`float`, *optional*, defaults to 1.0): + Standard deviation for xavier uniform initialization. + layer_norm_eps (`float`, *optional*, defaults to 1e-05): + Epsilon for layer normalization. + is_training (`bool`, *optional*, defaults to `False`): + Whether to run in training or inference mode. + use_auxiliary_loss (`bool`, *optional*, defaults to `True`): + Whether to calculate loss using intermediate predictions from transformer decoder. + output_auxiliary_logits (`bool`, *optional*, defaults to `True`): + Whether to return intermediate predictions from transformer decoder. + strides (`list`, *optional*, defaults to `[4, 8, 16, 32]`): + List containing the strides for feature maps in the encoder. + task_seq_len (`int`, *optional*, defaults to 77): + Sequence length for tokenizing text list input. + text_encoder_width (`int`, *optional*, defaults to 256): + Hidden size for text encoder. + text_encoder_context_length (`int`, *optional*, defaults to 77): + Input sequence length for text encoder. + text_encoder_num_layers (`int`, *optional*, defaults to 6): + Number of layers for transformer in text encoder. + text_encoder_vocab_size (`int`, *optional*, defaults to 49408): + Vocabulary size for tokenizer. + text_encoder_proj_layers (`int`, *optional*, defaults to 2): + Number of layers in MLP for project text queries. + text_encoder_n_ctx (`int`, *optional*, defaults to 16): + Number of learnable text context queries. + conv_dim (`int`, *optional*, defaults to 256): + Feature map dimension to map outputs from the backbone. + mask_dim (`int`, *optional*, defaults to 256): + Dimension for feature maps in pixel decoder. + hidden_dim (`int`, *optional*, defaults to 256): + Dimension for hidden states in transformer decoder. + encoder_feedforward_dim (`int`, *optional*, defaults to 1024): + Dimension for FFN layer in pixel decoder. + norm (`str`, *optional*, defaults to `"GN"`): + Type of normalization. + encoder_layers (`int`, *optional*, defaults to 6): + Number of layers in pixel decoder. + decoder_layers (`int`, *optional*, defaults to 10): + Number of layers in transformer decoder. + use_task_norm (`bool`, *optional*, defaults to `True`): + Whether to normalize the task token. + num_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads in transformer layers in the pixel and transformer decoders. + dropout (`float`, *optional*, defaults to 0.1): + Dropout probability for pixel and transformer decoders. + dim_feedforward (`int`, *optional*, defaults to 2048): + Dimension for FFN layer in transformer decoder. + pre_norm (`bool`, *optional*, defaults to `False`): + Whether to normalize hidden states before attention layers in transformer decoder. + enforce_input_proj (`bool`, *optional*, defaults to `False`): + Whether to project hidden states in transformer decoder. + query_dec_layers (`int`, *optional*, defaults to 2): + Number of layers in query transformer. + common_stride (`int`, *optional*, defaults to 4): + Common stride used for features in pixel decoder. + + Examples: + ```python + >>> from transformers import OneFormerConfig, OneFormerModel + + >>> # Initializing a OneFormer shi-labs/oneformer_ade20k_swin_tiny configuration + >>> configuration = OneFormerConfig() + >>> # Initializing a model (with random weights) from the shi-labs/oneformer_ade20k_swin_tiny style configuration + >>> model = OneFormerModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "oneformer" + attribute_map = {"hidden_size": "hidden_dim"} + + def __init__( + self, + backbone_config: Optional[Dict] = None, + backbone: Optional[str] = None, + use_pretrained_backbone: bool = False, + use_timm_backbone: bool = False, + backbone_kwargs: Optional[Dict] = None, + ignore_value: int = 255, + num_queries: int = 150, + no_object_weight: int = 0.1, + class_weight: float = 2.0, + mask_weight: float = 5.0, + dice_weight: float = 5.0, + contrastive_weight: float = 0.5, + contrastive_temperature: float = 0.07, + train_num_points: int = 12544, + oversample_ratio: float = 3.0, + importance_sample_ratio: float = 0.75, + init_std: float = 0.02, + init_xavier_std: float = 1.0, + layer_norm_eps: float = 1e-05, + is_training: bool = False, + use_auxiliary_loss: bool = True, + output_auxiliary_logits: bool = True, + strides: Optional[list] = [4, 8, 16, 32], + task_seq_len: int = 77, + text_encoder_width: int = 256, + text_encoder_context_length: int = 77, + text_encoder_num_layers: int = 6, + text_encoder_vocab_size: int = 49408, + text_encoder_proj_layers: int = 2, + text_encoder_n_ctx: int = 16, + conv_dim: int = 256, + mask_dim: int = 256, + hidden_dim: int = 256, + encoder_feedforward_dim: int = 1024, + norm: str = "GN", + encoder_layers: int = 6, + decoder_layers: int = 10, + use_task_norm: bool = True, + num_attention_heads: int = 8, + dropout: float = 0.1, + dim_feedforward: int = 2048, + pre_norm: bool = False, + enforce_input_proj: bool = False, + query_dec_layers: int = 2, + common_stride: int = 4, + **kwargs, + ): + if backbone_config is None and backbone is None: + logger.info("`backbone_config` is unset. Initializing the config with the default `Swin` backbone.") + backbone_config = CONFIG_MAPPING["swin"]( + image_size=224, + num_channels=3, + patch_size=4, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + drop_path_rate=0.3, + use_absolute_embeddings=False, + out_features=["stage1", "stage2", "stage3", "stage4"], + ) + elif isinstance(backbone_config, dict): + backbone_model_type = backbone_config.get("model_type") + config_class = CONFIG_MAPPING[backbone_model_type] + backbone_config = config_class.from_dict(backbone_config) + + verify_backbone_config_arguments( + use_timm_backbone=use_timm_backbone, + use_pretrained_backbone=use_pretrained_backbone, + backbone=backbone, + backbone_config=backbone_config, + backbone_kwargs=backbone_kwargs, + ) + + self.backbone_config = backbone_config + self.backbone = backbone + self.use_pretrained_backbone = use_pretrained_backbone + self.use_timm_backbone = use_timm_backbone + self.backbone_kwargs = backbone_kwargs + self.ignore_value = ignore_value + self.num_queries = num_queries + self.no_object_weight = no_object_weight + self.class_weight = class_weight + self.mask_weight = mask_weight + self.dice_weight = dice_weight + self.contrastive_weight = contrastive_weight + self.contrastive_temperature = contrastive_temperature + self.train_num_points = train_num_points + self.oversample_ratio = oversample_ratio + self.importance_sample_ratio = importance_sample_ratio + self.init_std = init_std + self.init_xavier_std = init_xavier_std + self.layer_norm_eps = layer_norm_eps + self.is_training = is_training + self.use_auxiliary_loss = use_auxiliary_loss + self.output_auxiliary_logits = output_auxiliary_logits + self.strides = strides + self.task_seq_len = task_seq_len + self.text_encoder_width = text_encoder_width + self.text_encoder_context_length = text_encoder_context_length + self.text_encoder_num_layers = text_encoder_num_layers + self.text_encoder_vocab_size = text_encoder_vocab_size + self.text_encoder_proj_layers = text_encoder_proj_layers + self.text_encoder_n_ctx = text_encoder_n_ctx + self.conv_dim = conv_dim + self.mask_dim = mask_dim + self.hidden_dim = hidden_dim + self.encoder_feedforward_dim = encoder_feedforward_dim + self.norm = norm + self.encoder_layers = encoder_layers + self.decoder_layers = decoder_layers + self.use_task_norm = use_task_norm + self.num_attention_heads = num_attention_heads + self.dropout = dropout + self.dim_feedforward = dim_feedforward + self.pre_norm = pre_norm + self.enforce_input_proj = enforce_input_proj + self.query_dec_layers = query_dec_layers + self.common_stride = common_stride + self.num_hidden_layers = decoder_layers + + super().__init__(**kwargs) + + +__all__ = ["OneFormerConfig"] diff --git a/docs/transformers/src/transformers/models/oneformer/convert_to_hf_oneformer.py b/docs/transformers/src/transformers/models/oneformer/convert_to_hf_oneformer.py new file mode 100644 index 0000000000000000000000000000000000000000..960634f9f4e89696e020052cc7ef84a48e54ea1a --- /dev/null +++ b/docs/transformers/src/transformers/models/oneformer/convert_to_hf_oneformer.py @@ -0,0 +1,1191 @@ +# coding=utf-8 +# Copyright 2022 SHI Labs and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Convert OneFormer checkpoints from the original repository. URL: https://github.com/SHI-Labs/OneFormer""" + +import os +import sys +from argparse import ArgumentParser +from dataclasses import dataclass +from pathlib import Path +from pprint import pformat +from typing import Any, Dict, Iterator, List, Set, Tuple + +import requests +import torch +import torchvision.transforms as T +from PIL import Image +from torch import Tensor, nn + + +try: + from detectron2.checkpoint import DetectionCheckpointer + from detectron2.config import get_cfg + from detectron2.data import MetadataCatalog + from detectron2.projects.deeplab import add_deeplab_config +except ImportError: + pass +from transformers import CLIPTokenizer, DinatConfig, SwinConfig +from transformers.models.oneformer.image_processing_oneformer import OneFormerImageProcessor +from transformers.models.oneformer.modeling_oneformer import ( + OneFormerConfig, + OneFormerForUniversalSegmentation, + OneFormerForUniversalSegmentationOutput, + OneFormerModel, + OneFormerModelOutput, +) +from transformers.models.oneformer.processing_oneformer import OneFormerProcessor +from transformers.utils import logging + + +StateDict = Dict[str, Tensor] + +logging.set_verbosity_info() +logger = logging.get_logger() + +torch.manual_seed(0) + + +class TrackedStateDict: + def __init__(self, to_track: Dict): + """This class "tracks" a python dictionary by keeping track of which item is accessed. + + Args: + to_track (Dict): The dictionary we wish to track + """ + self.to_track = to_track + self._seen: Set[str] = set() + + def __getitem__(self, key: str) -> Any: + return self.to_track[key] + + def __setitem__(self, key: str, item: Any): + self._seen.add(key) + self.to_track[key] = item + + def diff(self) -> List[str]: + """This method returns a set difference between the keys in the tracked state dict and the one we have access so far. + This is an effective method to check if we have update all the keys + + Returns: + List[str]: List of keys not yet updated + """ + return set(self.to_track.keys()) - self._seen + + def copy(self) -> Dict: + # proxy the call to the internal dictionary + return self.to_track.copy() + + +# Image to verify the result +def prepare_img(): + url = "https://praeclarumjj3.github.io/files/coco.jpeg" + img_data = requests.get(url, stream=True).raw + im = Image.open(img_data) + return im + + +@dataclass +class Args: + """Fake command line arguments needed by oneformer/detectron2 implementation""" + + config_file: str + + +def setup_cfg(args: Args): + # load config from file and command-line arguments + cfg = get_cfg() + add_deeplab_config(cfg) + add_common_config(cfg) + add_oneformer_config(cfg) + add_swin_config(cfg) + add_dinat_config(cfg) + cfg.merge_from_file(args.config_file) + cfg.freeze() + return cfg + + +class OriginalOneFormerConfigToOursConverter: + def __call__(self, original_config: object, is_swin: bool) -> OneFormerConfig: + model = original_config.MODEL + + dataset_catalog = MetadataCatalog.get(original_config.DATASETS.TEST_PANOPTIC[0]) + id2label = dict(enumerate(dataset_catalog.stuff_classes)) + label2id = {label: idx for idx, label in id2label.items()} + + if is_swin: + if model.SWIN.EMBED_DIM == 96: + backbone_config = SwinConfig.from_pretrained( + "microsoft/swin-tiny-patch4-window7-224", + drop_path_rate=model.SWIN.DROP_PATH_RATE, + out_features=["stage1", "stage2", "stage3", "stage4"], + ) + elif model.SWIN.EMBED_DIM == 192: + backbone_config = SwinConfig.from_pretrained( + "microsoft/swin-large-patch4-window12-384", + drop_path_rate=model.SWIN.DROP_PATH_RATE, + out_features=["stage1", "stage2", "stage3", "stage4"], + ) + else: + raise ValueError(f"embed dim {model.SWIN.EMBED_DIM} not supported for Swin!") + else: + backbone_config = DinatConfig.from_pretrained( + "shi-labs/dinat-large-11x11-in22k-in1k-384", + dilations=model.DiNAT.DILATIONS, + kernel_size=model.DiNAT.KERNEL_SIZE, + out_features=["stage1", "stage2", "stage3", "stage4"], + ) + + config: OneFormerConfig = OneFormerConfig( + backbone_config=backbone_config, + output_attentions=True, + output_hidden_states=True, + return_dict=True, + ignore_value=model.SEM_SEG_HEAD.IGNORE_VALUE, + num_classes=model.SEM_SEG_HEAD.NUM_CLASSES, + num_queries=model.ONE_FORMER.NUM_OBJECT_QUERIES, + no_object_weight=model.ONE_FORMER.NO_OBJECT_WEIGHT, + class_weight=model.ONE_FORMER.CLASS_WEIGHT, + mask_weight=model.ONE_FORMER.MASK_WEIGHT, + dice_weight=model.ONE_FORMER.DICE_WEIGHT, + contrastive_weight=model.ONE_FORMER.CONTRASTIVE_WEIGHT, + contrastive_temperature=model.ONE_FORMER.CONTRASTIVE_TEMPERATURE, + train_num_points=model.ONE_FORMER.TRAIN_NUM_POINTS, + oversample_ratio=model.ONE_FORMER.OVERSAMPLE_RATIO, + importance_sample_ratio=model.ONE_FORMER.IMPORTANCE_SAMPLE_RATIO, + init_std=0.02, + init_xavier_std=1.0, + layer_norm_eps=1e-05, + is_training=False, + use_auxiliary_loss=model.ONE_FORMER.DEEP_SUPERVISION, + output_auxiliary_logits=True, + strides=[4, 8, 16, 32], + task_seq_len=original_config.INPUT.TASK_SEQ_LEN, + max_seq_len=original_config.INPUT.MAX_SEQ_LEN, + text_encoder_width=model.TEXT_ENCODER.WIDTH, + text_encoder_context_length=model.TEXT_ENCODER.CONTEXT_LENGTH, + text_encoder_num_layers=model.TEXT_ENCODER.NUM_LAYERS, + text_encoder_vocab_size=model.TEXT_ENCODER.VOCAB_SIZE, + text_encoder_proj_layers=model.TEXT_ENCODER.PROJ_NUM_LAYERS, + text_encoder_n_ctx=model.TEXT_ENCODER.N_CTX, + conv_dim=model.SEM_SEG_HEAD.CONVS_DIM, + mask_dim=model.SEM_SEG_HEAD.MASK_DIM, + hidden_dim=model.ONE_FORMER.HIDDEN_DIM, + norm=model.SEM_SEG_HEAD.NORM, + encoder_layers=model.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS, + encoder_feedforward_dim=1024, + decoder_layers=model.ONE_FORMER.DEC_LAYERS, + use_task_norm=model.ONE_FORMER.USE_TASK_NORM, + num_attention_heads=model.ONE_FORMER.NHEADS, + dropout=model.ONE_FORMER.DROPOUT, + dim_feedforward=model.ONE_FORMER.DIM_FEEDFORWARD, + pre_norm=model.ONE_FORMER.PRE_NORM, + enforce_input_proj=model.ONE_FORMER.ENFORCE_INPUT_PROJ, + query_dec_layers=model.ONE_FORMER.CLASS_DEC_LAYERS, + common_stride=model.SEM_SEG_HEAD.COMMON_STRIDE, + id2label=id2label, + label2id=label2id, + ) + + return config + + +class OriginalOneFormerConfigToProcessorConverter: + def __call__(self, original_config: object, model_repo: str) -> OneFormerProcessor: + model = original_config.MODEL + model_input = original_config.INPUT + dataset_catalog = MetadataCatalog.get(original_config.DATASETS.TEST_PANOPTIC[0]) + + if "ade20k" in model_repo: + class_info_file = "ade20k_panoptic.json" + elif "coco" in model_repo: + class_info_file = "coco_panoptic.json" + elif "cityscapes" in model_repo: + class_info_file = "cityscapes_panoptic.json" + else: + raise ValueError("Invalid Dataset!") + + image_processor = OneFormerImageProcessor( + image_mean=(torch.tensor(model.PIXEL_MEAN) / 255).tolist(), + image_std=(torch.tensor(model.PIXEL_STD) / 255).tolist(), + size=model_input.MIN_SIZE_TEST, + max_size=model_input.MAX_SIZE_TEST, + num_labels=model.SEM_SEG_HEAD.NUM_CLASSES, + ignore_index=dataset_catalog.ignore_label, + class_info_file=class_info_file, + ) + + tokenizer = CLIPTokenizer.from_pretrained(model_repo) + + return OneFormerProcessor( + image_processor=image_processor, + tokenizer=tokenizer, + task_seq_length=original_config.INPUT.TASK_SEQ_LEN, + max_seq_length=original_config.INPUT.MAX_SEQ_LEN, + ) + + +class OriginalOneFormerCheckpointToOursConverter: + def __init__(self, original_model: nn.Module, config: OneFormerConfig): + self.original_model = original_model + self.config = config + + def pop_all(self, renamed_keys: List[Tuple[str, str]], dst_state_dict: StateDict, src_state_dict: StateDict): + for src_key, dst_key in renamed_keys: + dst_state_dict[dst_key] = src_state_dict.pop(src_key) + + # Swin Backbone + def replace_swin_backbone(self, dst_state_dict: StateDict, src_state_dict: StateDict, config: OneFormerConfig): + dst_prefix: str = "pixel_level_module.encoder" + src_prefix: str = "backbone" + + renamed_keys = [ + ( + f"{src_prefix}.patch_embed.proj.weight", + f"{dst_prefix}.embeddings.patch_embeddings.projection.weight", + ), + (f"{src_prefix}.patch_embed.proj.bias", f"{dst_prefix}.embeddings.patch_embeddings.projection.bias"), + (f"{src_prefix}.patch_embed.norm.weight", f"{dst_prefix}.embeddings.norm.weight"), + (f"{src_prefix}.patch_embed.norm.bias", f"{dst_prefix}.embeddings.norm.bias"), + ] + num_layers = len(config.backbone_config.depths) + for layer_idx in range(num_layers): + for block_idx in range(config.backbone_config.depths[layer_idx]): + renamed_keys.extend( + [ # src, dst + ( + f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.norm1.weight", + f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.layernorm_before.weight", + ), + ( + f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.norm1.bias", + f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.layernorm_before.bias", + ), + ( + f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.relative_position_bias_table", + f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.relative_position_bias_table", + ), + ] + ) + # now we need to handle the attentions + # read in weights + bias of input projection layer of cross-attention + + src_att_weight = src_state_dict[f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.qkv.weight"] + src_att_bias = src_state_dict[f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.qkv.bias"] + + size = src_att_weight.shape[0] + offset = size // 3 + dst_state_dict[ + f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.query.weight" + ] = src_att_weight[:offset, :] + dst_state_dict[ + f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.query.bias" + ] = src_att_bias[:offset] + + dst_state_dict[ + f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.key.weight" + ] = src_att_weight[offset : offset * 2, :] + dst_state_dict[ + f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.key.bias" + ] = src_att_bias[offset : offset * 2] + + dst_state_dict[ + f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.value.weight" + ] = src_att_weight[-offset:, :] + dst_state_dict[ + f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.value.bias" + ] = src_att_bias[-offset:] + + # let's pop them + src_state_dict.pop(f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.qkv.weight") + src_state_dict.pop(f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.qkv.bias") + # proj + renamed_keys.extend( + [ + ( + f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.proj.weight", + f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.output.dense.weight", + ), + ( + f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.proj.bias", + f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.output.dense.bias", + ), + ] + ) + + # second norm + renamed_keys.extend( + [ + ( + f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.norm2.weight", + f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.layernorm_after.weight", + ), + ( + f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.norm2.bias", + f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.layernorm_after.bias", + ), + ] + ) + + # mlp + renamed_keys.extend( + [ + ( + f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.mlp.fc1.weight", + f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.intermediate.dense.weight", + ), + ( + f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.mlp.fc1.bias", + f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.intermediate.dense.bias", + ), + ( + f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.mlp.fc2.weight", + f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.output.dense.weight", + ), + ( + f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.mlp.fc2.bias", + f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.output.dense.bias", + ), + ] + ) + + renamed_keys.extend( + [ + ( + f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.relative_position_index", + f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.relative_position_index", + ) + ] + ) + + if layer_idx < num_layers - 1: + # patch merging + renamed_keys.extend( + [ + ( + f"{src_prefix}.layers.{layer_idx}.downsample.reduction.weight", + f"{dst_prefix}.encoder.layers.{layer_idx}.downsample.reduction.weight", + ), + ( + f"{src_prefix}.layers.{layer_idx}.downsample.norm.weight", + f"{dst_prefix}.encoder.layers.{layer_idx}.downsample.norm.weight", + ), + ( + f"{src_prefix}.layers.{layer_idx}.downsample.norm.bias", + f"{dst_prefix}.encoder.layers.{layer_idx}.downsample.norm.bias", + ), + ] + ) + + # hidden states norms + renamed_keys.extend( + [ + ( + f"{src_prefix}.norm{layer_idx}.weight", + f"{dst_prefix}.hidden_states_norms.stage{layer_idx + 1}.weight", + ), + ( + f"{src_prefix}.norm{layer_idx}.bias", + f"{dst_prefix}.hidden_states_norms.stage{layer_idx + 1}.bias", + ), + ] + ) + + self.pop_all(renamed_keys, dst_state_dict, src_state_dict) + + # Dinat Backbone + def replace_dinat_backbone(self, dst_state_dict: StateDict, src_state_dict: StateDict, config: OneFormerConfig): + dst_prefix: str = "pixel_level_module.encoder" + src_prefix: str = "backbone" + + def rename_keys_for_weight_bias(src_prefix: str, dst_prefix: str): + return [ + (f"{src_prefix}.weight", f"{dst_prefix}.weight"), + (f"{src_prefix}.bias", f"{dst_prefix}.bias"), + ] + + renamed_keys = rename_keys_for_weight_bias(f"{src_prefix}.patch_embed.norm", f"{dst_prefix}.embeddings.norm") + + for i in range(2): + renamed_keys.extend( + rename_keys_for_weight_bias( + f"{src_prefix}.patch_embed.proj.{i}", + f"{dst_prefix}.embeddings.patch_embeddings.projection.{i}", + ) + ) + + num_layers = len(config.backbone_config.depths) + for layer_idx in range(num_layers): + for block_idx in range(config.backbone_config.depths[layer_idx]): + renamed_keys.extend( + rename_keys_for_weight_bias( + f"{src_prefix}.levels.{layer_idx}.blocks.{block_idx}.norm1", + f"{dst_prefix}.encoder.levels.{layer_idx}.layers.{block_idx}.layernorm_before", + ) + ) + + renamed_keys.extend( + rename_keys_for_weight_bias( + f"{src_prefix}.levels.{layer_idx}.blocks.{block_idx}.norm2", + f"{dst_prefix}.encoder.levels.{layer_idx}.layers.{block_idx}.layernorm_after", + ) + ) + + renamed_keys.extend( + [ # src, dst + ( + f"{src_prefix}.levels.{layer_idx}.blocks.{block_idx}.attn.rpb", + f"{dst_prefix}.encoder.levels.{layer_idx}.layers.{block_idx}.attention.self.rpb", + ), + ] + ) + # now we need to handle the attentions + # read in weights + bias of input projection layer of cross-attention + + src_att_weight = src_state_dict[f"{src_prefix}.levels.{layer_idx}.blocks.{block_idx}.attn.qkv.weight"] + src_att_bias = src_state_dict[f"{src_prefix}.levels.{layer_idx}.blocks.{block_idx}.attn.qkv.bias"] + + size = src_att_weight.shape[0] + offset = size // 3 + dst_state_dict[ + f"{dst_prefix}.encoder.levels.{layer_idx}.layers.{block_idx}.attention.self.query.weight" + ] = src_att_weight[:offset, :] + dst_state_dict[ + f"{dst_prefix}.encoder.levels.{layer_idx}.layers.{block_idx}.attention.self.query.bias" + ] = src_att_bias[:offset] + + dst_state_dict[ + f"{dst_prefix}.encoder.levels.{layer_idx}.layers.{block_idx}.attention.self.key.weight" + ] = src_att_weight[offset : offset * 2, :] + dst_state_dict[ + f"{dst_prefix}.encoder.levels.{layer_idx}.layers.{block_idx}.attention.self.key.bias" + ] = src_att_bias[offset : offset * 2] + + dst_state_dict[ + f"{dst_prefix}.encoder.levels.{layer_idx}.layers.{block_idx}.attention.self.value.weight" + ] = src_att_weight[-offset:, :] + dst_state_dict[ + f"{dst_prefix}.encoder.levels.{layer_idx}.layers.{block_idx}.attention.self.value.bias" + ] = src_att_bias[-offset:] + + # let's pop them + src_state_dict.pop(f"{src_prefix}.levels.{layer_idx}.blocks.{block_idx}.attn.qkv.weight") + src_state_dict.pop(f"{src_prefix}.levels.{layer_idx}.blocks.{block_idx}.attn.qkv.bias") + # proj + + renamed_keys.extend( + rename_keys_for_weight_bias( + f"{src_prefix}.levels.{layer_idx}.blocks.{block_idx}.attn.proj", + f"{dst_prefix}.encoder.levels.{layer_idx}.layers.{block_idx}.attention.output.dense", + ) + ) + + # mlp + renamed_keys.extend( + rename_keys_for_weight_bias( + f"{src_prefix}.levels.{layer_idx}.blocks.{block_idx}.mlp.fc1", + f"{dst_prefix}.encoder.levels.{layer_idx}.layers.{block_idx}.intermediate.dense", + ) + ) + + renamed_keys.extend( + rename_keys_for_weight_bias( + f"{src_prefix}.levels.{layer_idx}.blocks.{block_idx}.mlp.fc2", + f"{dst_prefix}.encoder.levels.{layer_idx}.layers.{block_idx}.output.dense", + ) + ) + + if layer_idx < num_layers - 1: + # patch merging + renamed_keys.extend( + [ + ( + f"{src_prefix}.levels.{layer_idx}.downsample.reduction.weight", + f"{dst_prefix}.encoder.levels.{layer_idx}.downsample.reduction.weight", + ), + ( + f"{src_prefix}.levels.{layer_idx}.downsample.norm.weight", + f"{dst_prefix}.encoder.levels.{layer_idx}.downsample.norm.weight", + ), + ( + f"{src_prefix}.levels.{layer_idx}.downsample.norm.bias", + f"{dst_prefix}.encoder.levels.{layer_idx}.downsample.norm.bias", + ), + ] + ) + + # hidden states norms + renamed_keys.extend( + [ + ( + f"{src_prefix}.norm{layer_idx}.weight", + f"{dst_prefix}.hidden_states_norms.stage{layer_idx + 1}.weight", + ), + ( + f"{src_prefix}.norm{layer_idx}.bias", + f"{dst_prefix}.hidden_states_norms.stage{layer_idx + 1}.bias", + ), + ] + ) + + self.pop_all(renamed_keys, dst_state_dict, src_state_dict) + + # Backbone + Pixel Decoder + def replace_pixel_module(self, dst_state_dict: StateDict, src_state_dict: StateDict, is_swin: bool): + dst_prefix: str = "pixel_level_module.decoder" + src_prefix: str = "sem_seg_head.pixel_decoder" + + if is_swin: + self.replace_swin_backbone(dst_state_dict, src_state_dict, self.config) + else: + self.replace_dinat_backbone(dst_state_dict, src_state_dict, self.config) + + def rename_keys_for_weight_bias(src_prefix: str, dst_prefix: str): + return [ + (f"{src_prefix}.weight", f"{dst_prefix}.weight"), + (f"{src_prefix}.bias", f"{dst_prefix}.bias"), + ] + + def rename_keys_for_self_attn(src_prefix: str, dst_prefix: str): + self_attn_keys = [] + self_attn_keys.extend( + rename_keys_for_weight_bias(f"{src_prefix}.attention_weights", f"{dst_prefix}.attention_weights") + ) + self_attn_keys.extend( + rename_keys_for_weight_bias(f"{src_prefix}.output_proj", f"{dst_prefix}.output_proj") + ) + self_attn_keys.extend( + rename_keys_for_weight_bias(f"{src_prefix}.sampling_offsets", f"{dst_prefix}.sampling_offsets") + ) + self_attn_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.value_proj", f"{dst_prefix}.value_proj")) + + return self_attn_keys + + def rename_keys_for_encoder_layer(src_prefix: str, dst_prefix: str): + encoder_keys = [] + encoder_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.linear1", f"{dst_prefix}.fc1")) + encoder_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.linear2", f"{dst_prefix}.fc2")) + encoder_keys.extend( + rename_keys_for_weight_bias(f"{src_prefix}.norm1", f"{dst_prefix}.self_attn_layer_norm") + ) + encoder_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.norm2", f"{dst_prefix}.final_layer_norm")) + encoder_keys.extend(rename_keys_for_self_attn(f"{src_prefix}.self_attn", f"{dst_prefix}.self_attn")) + + return encoder_keys + + # convolution layer for final features + renamed_keys = [ + (f"{src_prefix}.adapter_1.weight", f"{dst_prefix}.adapter_1.0.weight"), + (f"{src_prefix}.adapter_1.norm.weight", f"{dst_prefix}.adapter_1.1.weight"), + (f"{src_prefix}.adapter_1.norm.bias", f"{dst_prefix}.adapter_1.1.bias"), + ] + + renamed_keys.extend( + [ + (f"{src_prefix}.layer_1.weight", f"{dst_prefix}.layer_1.0.weight"), + (f"{src_prefix}.layer_1.norm.weight", f"{dst_prefix}.layer_1.1.weight"), + (f"{src_prefix}.layer_1.norm.bias", f"{dst_prefix}.layer_1.1.bias"), + ] + ) + + # proj layers + for i in range(3): + for j in range(2): + renamed_keys.extend( + [ + (f"{src_prefix}.input_proj.{i}.{j}.weight", f"{dst_prefix}.input_projections.{i}.{j}.weight"), + (f"{src_prefix}.input_proj.{i}.{j}.bias", f"{dst_prefix}.input_projections.{i}.{j}.bias"), + ] + ) + + renamed_keys.extend([(f"{src_prefix}.transformer.level_embed", f"{dst_prefix}.level_embed")]) + + # layers + for layer_idx in range(self.config.encoder_layers): + renamed_keys.extend( + rename_keys_for_encoder_layer( + f"{src_prefix}.transformer.encoder.layers.{layer_idx}", f"{dst_prefix}.encoder.layers.{layer_idx}" + ) + ) + + # proj + renamed_keys.extend( + [ + (f"{src_prefix}.mask_features.weight", f"{dst_prefix}.mask_projection.weight"), + (f"{src_prefix}.mask_features.bias", f"{dst_prefix}.mask_projection.bias"), + ] + ) + + self.pop_all(renamed_keys, dst_state_dict, src_state_dict) + + # Transformer Decoder + def replace_keys_qkv_transformer_decoder(self, dst_state_dict: StateDict, src_state_dict: StateDict): + dst_prefix: str = "transformer_module.decoder.layers" + src_prefix: str = "sem_seg_head.predictor" + for i in range(self.config.decoder_layers - 1): + # read in weights + bias of input projection layer of self-attention + in_proj_weight = src_state_dict.pop( + f"{src_prefix}.transformer_self_attention_layers.{i}.self_attn.in_proj_weight" + ) + in_proj_bias = src_state_dict.pop( + f"{src_prefix}.transformer_self_attention_layers.{i}.self_attn.in_proj_bias" + ) + # next, add query, keys and values (in that order) to the state dict + dst_state_dict[f"{dst_prefix}.{i}.self_attn.self_attn.q_proj.weight"] = in_proj_weight[:256, :] + dst_state_dict[f"{dst_prefix}.{i}.self_attn.self_attn.q_proj.bias"] = in_proj_bias[:256] + dst_state_dict[f"{dst_prefix}.{i}.self_attn.self_attn.k_proj.weight"] = in_proj_weight[256:512, :] + dst_state_dict[f"{dst_prefix}.{i}.self_attn.self_attn.k_proj.bias"] = in_proj_bias[256:512] + dst_state_dict[f"{dst_prefix}.{i}.self_attn.self_attn.v_proj.weight"] = in_proj_weight[-256:, :] + dst_state_dict[f"{dst_prefix}.{i}.self_attn.self_attn.v_proj.bias"] = in_proj_bias[-256:] + + def replace_transformer_module(self, dst_state_dict: StateDict, src_state_dict: StateDict): + dst_prefix: str = "transformer_module" + src_prefix: str = "sem_seg_head.predictor" + + def rename_keys_for_weight_bias(src_prefix: str, dst_prefix: str): + return [ + (f"{src_prefix}.weight", f"{dst_prefix}.weight"), + (f"{src_prefix}.bias", f"{dst_prefix}.bias"), + ] + + def rename_keys_for_attn(src_prefix: str, dst_prefix: str): + attn_keys = [ + (f"{src_prefix}.in_proj_bias", f"{dst_prefix}.in_proj_bias"), + (f"{src_prefix}.in_proj_weight", f"{dst_prefix}.in_proj_weight"), + ] + attn_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.out_proj", f"{dst_prefix}.out_proj")) + + return attn_keys + + def rename_keys_for_self_attn(src_prefix: str, dst_prefix: str): + attn_keys = [] + attn_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.out_proj", f"{dst_prefix}.out_proj")) + + return attn_keys + + def rename_keys_for_query_transformer_layer(src_prefix: str, dst_prefix: str): + query_transformer_layer_keys = [] + + query_transformer_layer_keys.extend( + rename_keys_for_weight_bias(f"{src_prefix}.linear1", f"{dst_prefix}.linear1") + ) + query_transformer_layer_keys.extend( + rename_keys_for_weight_bias(f"{src_prefix}.linear2", f"{dst_prefix}.linear2") + ) + query_transformer_layer_keys.extend( + rename_keys_for_weight_bias(f"{src_prefix}.norm1", f"{dst_prefix}.norm1") + ) + query_transformer_layer_keys.extend( + rename_keys_for_weight_bias(f"{src_prefix}.norm2", f"{dst_prefix}.norm2") + ) + query_transformer_layer_keys.extend( + rename_keys_for_weight_bias(f"{src_prefix}.norm3", f"{dst_prefix}.norm3") + ) + + query_transformer_layer_keys.extend( + rename_keys_for_attn(f"{src_prefix}.self_attn", f"{dst_prefix}.self_attn") + ) + + query_transformer_layer_keys.extend( + rename_keys_for_attn(f"{src_prefix}.multihead_attn", f"{dst_prefix}.multihead_attn") + ) + + return query_transformer_layer_keys + + def rename_keys_for_cross_attn_layer(src_prefix: str, dst_prefix: str): + cross_attn_layer_keys = [] + + cross_attn_layer_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.norm", f"{dst_prefix}.norm")) + cross_attn_layer_keys.extend( + rename_keys_for_attn(f"{src_prefix}.multihead_attn", f"{dst_prefix}.multihead_attn") + ) + + return cross_attn_layer_keys + + def rename_keys_for_self_attn_layer(src_prefix: str, dst_prefix: str): + self_attn_layer_keys = [] + + self_attn_layer_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.norm", f"{dst_prefix}.norm")) + self_attn_layer_keys.extend( + rename_keys_for_self_attn(f"{src_prefix}.self_attn", f"{dst_prefix}.self_attn") + ) + + return self_attn_layer_keys + + def rename_keys_for_ffn_layer(src_prefix: str, dst_prefix: str): + ffn_layer_keys = [] + + ffn_layer_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.linear1", f"{dst_prefix}.linear1")) + ffn_layer_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.linear2", f"{dst_prefix}.linear2")) + ffn_layer_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.norm", f"{dst_prefix}.norm")) + + return ffn_layer_keys + + def rename_keys_for_transformer_decoder_layer(src_prefix: str, dst_prefix: str, idx: int): + transformer_decoder_layer_keys = [] + + transformer_decoder_layer_keys.extend( + rename_keys_for_cross_attn_layer( + f"{src_prefix}.transformer_cross_attention_layers.{idx}", f"{dst_prefix}.{idx}.cross_attn" + ) + ) + + transformer_decoder_layer_keys.extend( + rename_keys_for_self_attn_layer( + f"{src_prefix}.transformer_self_attention_layers.{idx}", f"{dst_prefix}.{idx}.self_attn" + ) + ) + + transformer_decoder_layer_keys.extend( + rename_keys_for_ffn_layer(f"{src_prefix}.transformer_ffn_layers.{idx}", f"{dst_prefix}.{idx}.ffn") + ) + + return transformer_decoder_layer_keys + + # positional embedding for object queries + renamed_keys = [ + (f"{src_prefix}.query_embed.weight", f"{dst_prefix}.queries_embedder.weight"), + (f"{src_prefix}.level_embed.weight", f"{dst_prefix}.level_embed.weight"), + ] + + # norm + renamed_keys.extend( + rename_keys_for_weight_bias(f"{src_prefix}.decoder_norm", f"{dst_prefix}.decoder.decoder_norm") + ) + + # proj + renamed_keys.extend( + rename_keys_for_weight_bias( + f"{src_prefix}.class_input_proj", f"{dst_prefix}.decoder.query_input_projection" + ) + ) + + renamed_keys.extend( + rename_keys_for_weight_bias(f"{src_prefix}.class_embed", f"{dst_prefix}.decoder.class_embed") + ) + + for i in range(3): + renamed_keys.extend( + rename_keys_for_weight_bias( + f"{src_prefix}.mask_embed.layers.{i}", f"{dst_prefix}.decoder.mask_embed.layers.{i}.0" + ) + ) + + # norm + renamed_keys.extend( + rename_keys_for_weight_bias( + f"{src_prefix}.class_transformer.decoder.norm", f"{dst_prefix}.decoder.query_transformer.decoder.norm" + ) + ) + + # transformer to update queries with task tokens + for i in range(self.config.query_dec_layers): + renamed_keys.extend( + rename_keys_for_query_transformer_layer( + f"{src_prefix}.class_transformer.decoder.layers.{i}", + f"{dst_prefix}.decoder.query_transformer.decoder.layers.{i}", + ) + ) + + # decoder layers + for i in range(self.config.decoder_layers - 1): + renamed_keys.extend( + rename_keys_for_transformer_decoder_layer( + f"{src_prefix}", + f"{dst_prefix}.decoder.layers", + i, + ) + ) + + self.pop_all(renamed_keys, dst_state_dict, src_state_dict) + self.replace_keys_qkv_transformer_decoder(dst_state_dict, src_state_dict) + + def replace_task_mlp(self, dst_state_dict: StateDict, src_state_dict: StateDict): + dst_prefix: str = "task_encoder" + src_prefix: str = "task_mlp" + + def rename_keys_for_weight_bias(src_prefix: str, dst_prefix: str): + return [ + (f"{src_prefix}.weight", f"{dst_prefix}.weight"), + (f"{src_prefix}.bias", f"{dst_prefix}.bias"), + ] + + renamed_keys = [] + + for i in range(2): + renamed_keys.extend( + rename_keys_for_weight_bias(f"{src_prefix}.layers.{i}", f"{dst_prefix}.task_mlp.layers.{i}.0") + ) + + self.pop_all(renamed_keys, dst_state_dict, src_state_dict) + + def replace_text_projector(self, dst_state_dict: StateDict, src_state_dict: StateDict): + dst_prefix: str = "text_mapper.text_projector" + src_prefix: str = "text_projector" + + def rename_keys_for_weight_bias(src_prefix: str, dst_prefix: str): + return [ + (f"{src_prefix}.weight", f"{dst_prefix}.weight"), + (f"{src_prefix}.bias", f"{dst_prefix}.bias"), + ] + + renamed_keys = [] + + for i in range(self.config.text_encoder_config["text_encoder_proj_layers"]): + renamed_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.layers.{i}", f"{dst_prefix}.{i}.0")) + + self.pop_all(renamed_keys, dst_state_dict, src_state_dict) + + def replace_text_mapper(self, dst_state_dict: StateDict, src_state_dict: StateDict): + dst_prefix: str = "text_mapper.text_encoder" + src_prefix: str = "text_encoder" + + self.replace_text_projector(dst_state_dict, src_state_dict) + + def rename_keys_for_weight_bias(src_prefix: str, dst_prefix: str): + return [ + (f"{src_prefix}.weight", f"{dst_prefix}.weight"), + (f"{src_prefix}.bias", f"{dst_prefix}.bias"), + ] + + def rename_keys_for_attn(src_prefix: str, dst_prefix: str): + attn_keys = [ + (f"{src_prefix}.in_proj_bias", f"{dst_prefix}.in_proj_bias"), + (f"{src_prefix}.in_proj_weight", f"{dst_prefix}.in_proj_weight"), + ] + attn_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.out_proj", f"{dst_prefix}.out_proj")) + + return attn_keys + + def rename_keys_for_layer(src_prefix: str, dst_prefix: str): + resblock_keys = [] + + resblock_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.mlp.c_fc", f"{dst_prefix}.mlp.fc1")) + resblock_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.mlp.c_proj", f"{dst_prefix}.mlp.fc2")) + resblock_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.ln_1", f"{dst_prefix}.layer_norm1")) + resblock_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.ln_2", f"{dst_prefix}.layer_norm2")) + resblock_keys.extend(rename_keys_for_attn(f"{src_prefix}.attn", f"{dst_prefix}.self_attn")) + + return resblock_keys + + renamed_keys = [ + ("prompt_ctx.weight", "text_mapper.prompt_ctx.weight"), + ] + + renamed_keys.extend( + [ + (f"{src_prefix}.positional_embedding", f"{dst_prefix}.positional_embedding"), + (f"{src_prefix}.token_embedding.weight", f"{dst_prefix}.token_embedding.weight"), + ] + ) + + renamed_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.ln_final", f"{dst_prefix}.ln_final")) + + for i in range(self.config.text_encoder_config["text_encoder_num_layers"]): + renamed_keys.extend( + rename_keys_for_layer( + f"{src_prefix}.transformer.resblocks.{i}", f"{dst_prefix}.transformer.layers.{i}" + ) + ) + + self.pop_all(renamed_keys, dst_state_dict, src_state_dict) + + def convert(self, oneformer: OneFormerModel, is_swin: bool) -> OneFormerModel: + dst_state_dict = TrackedStateDict(oneformer.state_dict()) + src_state_dict = self.original_model.state_dict() + + self.replace_pixel_module(dst_state_dict, src_state_dict, is_swin) + self.replace_transformer_module(dst_state_dict, src_state_dict) + self.replace_task_mlp(dst_state_dict, src_state_dict) + if self.config.is_training: + self.replace_text_mapper(dst_state_dict, src_state_dict) + + logger.info(f"Missed keys are {pformat(dst_state_dict.diff())}") + logger.info(f"Not copied keys are {pformat(src_state_dict.keys())}") + logger.info("🙌 Done") + + oneformer.load_state_dict(dst_state_dict) + + return oneformer + + @staticmethod + def using_dirs(checkpoints_dir: Path, config_dir: Path) -> Iterator[Tuple[object, Path, Path]]: + checkpoints: List[Path] = checkpoints_dir.glob("**/*.pth") + + for checkpoint in checkpoints: + logger.info(f"💪 Converting {checkpoint.stem}") + # find associated config file + config: Path = config_dir / f"{checkpoint.stem}.yaml" + + yield config, checkpoint + + +def post_process_sem_seg_output(outputs: OneFormerForUniversalSegmentationOutput, target_size: Tuple[int, int]): + # class_queries_logits has shape [BATCH, QUERIES, CLASSES + 1] + class_queries_logits = outputs.class_queries_logits + # masks_queries_logits has shape [BATCH, QUERIES, HEIGHT, WIDTH] + masks_queries_logits = outputs.masks_queries_logits + if target_size is not None: + masks_queries_logits = torch.nn.functional.interpolate( + masks_queries_logits, + size=target_size, + mode="bilinear", + align_corners=False, + ) + # remove the null class `[..., :-1]` + masks_classes = class_queries_logits.softmax(dim=-1)[..., :-1] + # mask probs has shape [BATCH, QUERIES, HEIGHT, WIDTH] + masks_probs = masks_queries_logits.sigmoid() + # now we want to sum over the queries, + # $ out_{c,h,w} = \sum_q p_{q,c} * m_{q,h,w} $ + # where $ softmax(p) \in R^{q, c} $ is the mask classes + # and $ sigmoid(m) \in R^{q, h, w}$ is the mask probabilities + # b(atch)q(uery)c(lasses), b(atch)q(uery)h(eight)w(idth) + segmentation = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs) + + return segmentation + + +def test( + original_model, + our_model: OneFormerForUniversalSegmentation, + processor: OneFormerProcessor, + model_repo: str, +): + def _preprocess_text(text_list=None, max_length=77): + if text_list is None: + raise ValueError("tokens cannot be None.") + + tokens = tokenizer(text_list, padding="max_length", max_length=max_length, truncation=True) + + attention_masks, input_ids = tokens["attention_mask"], tokens["input_ids"] + + token_inputs = [] + for attn_mask, input_id in zip(attention_masks, input_ids): + token = torch.tensor(attn_mask) * torch.tensor(input_id) + token_inputs.append(token.unsqueeze(0)) + + token_inputs = torch.cat(token_inputs, dim=0) + return token_inputs + + with torch.no_grad(): + tokenizer = CLIPTokenizer.from_pretrained(model_repo) + original_model = original_model.eval() + our_model = our_model.eval() + + im = prepare_img() + + tr = T.Compose( + [ + T.Resize((640, 640)), + T.ToTensor(), + T.Normalize( + mean=torch.tensor([123.675, 116.280, 103.530]) / 255.0, + std=torch.tensor([58.395, 57.120, 57.375]) / 255.0, + ), + ], + ) + + x = tr(im).unsqueeze(0) + + task_input = ["the task is semantic"] + task_token = _preprocess_text(task_input, max_length=processor.task_seq_length) + + original_model_backbone_features = original_model.backbone(x.clone()) + + our_model_output: OneFormerModelOutput = our_model.model(x.clone(), task_token, output_hidden_states=True) + + for original_model_feature, our_model_feature in zip( + original_model_backbone_features.values(), our_model_output.encoder_hidden_states + ): + assert torch.allclose(original_model_feature, our_model_feature, atol=3e-3), ( + "The backbone features are not the same." + ) + mask_features, _, multi_scale_features, _, _ = original_model.sem_seg_head.pixel_decoder.forward_features( + original_model_backbone_features + ) + + original_pixel_decoder_features = [] + original_pixel_decoder_features.append(mask_features) + for i in range(len(multi_scale_features)): + original_pixel_decoder_features.append(multi_scale_features[i]) + + for original_model_feature, our_model_feature in zip( + original_pixel_decoder_features, our_model_output.pixel_decoder_hidden_states + ): + assert torch.allclose(original_model_feature, our_model_feature, atol=3e-4), ( + "The pixel decoder feature are not the same" + ) + + tr_complete = T.Compose( + [ + T.Resize((640, 640)), + T.ToTensor(), + ], + ) + + y = (tr_complete(im) * 255.0).to(torch.int).float() + + # let's test the full model + original_model_out = original_model([{"image": y.clone(), "task": "The task is semantic"}]) + + original_segmentation = original_model_out[0]["sem_seg"] + + our_model_out: OneFormerForUniversalSegmentationOutput = our_model( + x.clone(), task_token, output_hidden_states=True + ) + + our_segmentation = post_process_sem_seg_output(our_model_out, target_size=(640, 640))[0] + + assert torch.allclose(original_segmentation, our_segmentation, atol=1e-3), ( + "The segmentation image is not the same." + ) + + logger.info("✅ Test passed!") + + +def get_name(checkpoint_file: Path): + model_name_raw: str = checkpoint_file.stem + + backbone = "swin" if "swin" in model_name_raw else "dinat" + dataset = "" + if "coco" in model_name_raw: + dataset = "coco" + elif "ade20k" in model_name_raw: + dataset = "ade20k" + elif "cityscapes" in model_name_raw: + dataset = "cityscapes" + else: + raise ValueError( + f"{model_name_raw} must be wrong since we didn't find 'coco' or 'ade20k' or 'cityscapes' in it " + ) + + backbone_types = ["tiny", "large"] + + backbone_type = list(filter(lambda x: x in model_name_raw, backbone_types))[0] + + model_name = f"oneformer_{dataset}_{backbone}_{backbone_type}" + + return model_name + + +if __name__ == "__main__": + parser = ArgumentParser( + description=( + "Command line to convert the original oneformer models (with swin backbone) to transformers" + " implementation." + ) + ) + + parser.add_argument( + "--checkpoints_dir", + type=Path, + help=( + "A directory containing the model's checkpoints. The directory has to have the following structure:" + " structure: //.pth; where name must follow the" + " following nomenclature nomenclature: oneformer___" + ), + ) + parser.add_argument( + "--configs_dir", + type=Path, + help=( + "A directory containing the model's configs, see detectron2 doc. The directory has to have the following" + " structure: //.yaml; where name must follow the" + " following nomenclature nomenclature: oneformer___" + ), + ) + parser.add_argument( + "--pytorch_dump_folder_path", + required=True, + type=Path, + help="Path to the folder to output PyTorch models.", + ) + parser.add_argument( + "--oneformer_dir", + required=True, + type=Path, + help=( + "A path to OneFormer's original implementation directory. You can download from here: " + "https://github.com/SHI-Labs/OneFormer" + ), + ) + + args = parser.parse_args() + + checkpoints_dir: Path = args.checkpoints_dir + config_dir: Path = args.configs_dir + save_directory: Path = args.pytorch_dump_folder_path + oneformer_dir: Path = args.oneformer_dir + # append the path to the parents to oneformer dir + sys.path.append(str(oneformer_dir.parent)) + # and import what's needed + from OneFormer.oneformer import add_common_config, add_dinat_config, add_oneformer_config, add_swin_config + from OneFormer.oneformer.oneformer_model import OneFormer as OriginalOneFormer + + if not save_directory.exists(): + save_directory.mkdir(parents=True) + + for config_file, checkpoint_file in OriginalOneFormerCheckpointToOursConverter.using_dirs( + checkpoints_dir, config_dir + ): + processor = OriginalOneFormerConfigToProcessorConverter()( + setup_cfg(Args(config_file=config_file)), os.path.join("shi-labs", config_file.stem) + ) + + original_config = setup_cfg(Args(config_file=config_file)) + oneformer_kwargs = OriginalOneFormer.from_config(original_config) + + original_model = OriginalOneFormer(**oneformer_kwargs).eval() + + DetectionCheckpointer(original_model).load(str(checkpoint_file)) + + is_swin = "swin" in config_file.stem + + config: OneFormerConfig = OriginalOneFormerConfigToOursConverter()(original_config, is_swin) + + oneformer = OneFormerModel(config=config).eval() + + converter = OriginalOneFormerCheckpointToOursConverter(original_model, config) + + oneformer = converter.convert(oneformer, is_swin) + + oneformer_for_universal_segmentation = OneFormerForUniversalSegmentation(config=config).eval() + + oneformer_for_universal_segmentation.model = oneformer + + test( + original_model, + oneformer_for_universal_segmentation, + processor, + os.path.join("shi-labs", config_file.stem), + ) + + model_name = get_name(checkpoint_file) + logger.info(f"🪄 Saving {model_name}") + + processor.save_pretrained(save_directory / model_name) + oneformer_for_universal_segmentation.save_pretrained(save_directory / model_name) + + processor.push_to_hub( + repo_id=os.path.join("shi-labs", config_file.stem), + commit_message="Add configs", + use_temp_dir=True, + ) + oneformer_for_universal_segmentation.push_to_hub( + repo_id=os.path.join("shi-labs", config_file.stem), + commit_message="Add model", + use_temp_dir=True, + ) diff --git a/docs/transformers/src/transformers/models/oneformer/image_processing_oneformer.py b/docs/transformers/src/transformers/models/oneformer/image_processing_oneformer.py new file mode 100644 index 0000000000000000000000000000000000000000..956bd3e7e2fa7af482e79ec8d785b853757a1f1e --- /dev/null +++ b/docs/transformers/src/transformers/models/oneformer/image_processing_oneformer.py @@ -0,0 +1,1356 @@ +# coding=utf-8 +# Copyright 2022 SHI Labs and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for OneFormer.""" + +import json +import os +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union + +import numpy as np +from huggingface_hub import hf_hub_download +from huggingface_hub.utils import RepositoryNotFoundError + +from ...image_processing_utils import INIT_SERVICE_KWARGS, BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import ( + PaddingMode, + get_resize_output_image_size, + pad, + rescale, + resize, + to_channel_dimension_format, +) +from ...image_utils import ( + ChannelDimension, + ImageInput, + PILImageResampling, + get_image_size, + infer_channel_dimension_format, + is_scaled_image, + make_list_of_images, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...utils import ( + IMAGENET_DEFAULT_MEAN, + IMAGENET_DEFAULT_STD, + TensorType, + filter_out_non_signature_kwargs, + is_torch_available, + is_torch_tensor, + logging, +) +from ...utils.deprecation import deprecate_kwarg + + +logger = logging.get_logger(__name__) + + +if is_torch_available(): + import torch + from torch import nn + + +# Copied from transformers.models.detr.image_processing_detr.max_across_indices +def max_across_indices(values: Iterable[Any]) -> List[Any]: + """ + Return the maximum value across all indices of an iterable of values. + """ + return [max(values_i) for values_i in zip(*values)] + + +# Copied from transformers.models.detr.image_processing_detr.get_max_height_width +def get_max_height_width( + images: List[np.ndarray], input_data_format: Optional[Union[str, ChannelDimension]] = None +) -> List[int]: + """ + Get the maximum height and width across all images in a batch. + """ + if input_data_format is None: + input_data_format = infer_channel_dimension_format(images[0]) + + if input_data_format == ChannelDimension.FIRST: + _, max_height, max_width = max_across_indices([img.shape for img in images]) + elif input_data_format == ChannelDimension.LAST: + max_height, max_width, _ = max_across_indices([img.shape for img in images]) + else: + raise ValueError(f"Invalid channel dimension format: {input_data_format}") + return (max_height, max_width) + + +# Copied from transformers.models.detr.image_processing_detr.make_pixel_mask +def make_pixel_mask( + image: np.ndarray, output_size: Tuple[int, int], input_data_format: Optional[Union[str, ChannelDimension]] = None +) -> np.ndarray: + """ + Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding. + + Args: + image (`np.ndarray`): + Image to make the pixel mask for. + output_size (`Tuple[int, int]`): + Output size of the mask. + """ + input_height, input_width = get_image_size(image, channel_dim=input_data_format) + mask = np.zeros(output_size, dtype=np.int64) + mask[:input_height, :input_width] = 1 + return mask + + +# Copied from transformers.models.detr.image_processing_detr.binary_mask_to_rle +def binary_mask_to_rle(mask): + """ + Converts given binary mask of shape `(height, width)` to the run-length encoding (RLE) format. + + Args: + mask (`torch.Tensor` or `numpy.array`): + A binary mask tensor of shape `(height, width)` where 0 denotes background and 1 denotes the target + segment_id or class_id. + Returns: + `List`: Run-length encoded list of the binary mask. Refer to COCO API for more information about the RLE + format. + """ + if is_torch_tensor(mask): + mask = mask.numpy() + + pixels = mask.flatten() + pixels = np.concatenate([[0], pixels, [0]]) + runs = np.where(pixels[1:] != pixels[:-1])[0] + 1 + runs[1::2] -= runs[::2] + return list(runs) + + +# Copied from transformers.models.detr.image_processing_detr.convert_segmentation_to_rle +def convert_segmentation_to_rle(segmentation): + """ + Converts given segmentation map of shape `(height, width)` to the run-length encoding (RLE) format. + + Args: + segmentation (`torch.Tensor` or `numpy.array`): + A segmentation map of shape `(height, width)` where each value denotes a segment or class id. + Returns: + `List[List]`: A list of lists, where each list is the run-length encoding of a segment / class id. + """ + segment_ids = torch.unique(segmentation) + + run_length_encodings = [] + for idx in segment_ids: + mask = torch.where(segmentation == idx, 1, 0) + rle = binary_mask_to_rle(mask) + run_length_encodings.append(rle) + + return run_length_encodings + + +# Copied from transformers.models.detr.image_processing_detr.remove_low_and_no_objects +def remove_low_and_no_objects(masks, scores, labels, object_mask_threshold, num_labels): + """ + Binarize the given masks using `object_mask_threshold`, it returns the associated values of `masks`, `scores` and + `labels`. + + Args: + masks (`torch.Tensor`): + A tensor of shape `(num_queries, height, width)`. + scores (`torch.Tensor`): + A tensor of shape `(num_queries)`. + labels (`torch.Tensor`): + A tensor of shape `(num_queries)`. + object_mask_threshold (`float`): + A number between 0 and 1 used to binarize the masks. + Raises: + `ValueError`: Raised when the first dimension doesn't match in all input tensors. + Returns: + `Tuple[`torch.Tensor`, `torch.Tensor`, `torch.Tensor`]`: The `masks`, `scores` and `labels` without the region + < `object_mask_threshold`. + """ + if not (masks.shape[0] == scores.shape[0] == labels.shape[0]): + raise ValueError("mask, scores and labels must have the same shape!") + + to_keep = labels.ne(num_labels) & (scores > object_mask_threshold) + + return masks[to_keep], scores[to_keep], labels[to_keep] + + +# Copied from transformers.models.detr.image_processing_detr.check_segment_validity +def check_segment_validity(mask_labels, mask_probs, k, mask_threshold=0.5, overlap_mask_area_threshold=0.8): + # Get the mask associated with the k class + mask_k = mask_labels == k + mask_k_area = mask_k.sum() + + # Compute the area of all the stuff in query k + original_area = (mask_probs[k] >= mask_threshold).sum() + mask_exists = mask_k_area > 0 and original_area > 0 + + # Eliminate disconnected tiny segments + if mask_exists: + area_ratio = mask_k_area / original_area + if not area_ratio.item() > overlap_mask_area_threshold: + mask_exists = False + + return mask_exists, mask_k + + +# Copied from transformers.models.detr.image_processing_detr.compute_segments +def compute_segments( + mask_probs, + pred_scores, + pred_labels, + mask_threshold: float = 0.5, + overlap_mask_area_threshold: float = 0.8, + label_ids_to_fuse: Optional[Set[int]] = None, + target_size: Tuple[int, int] = None, +): + height = mask_probs.shape[1] if target_size is None else target_size[0] + width = mask_probs.shape[2] if target_size is None else target_size[1] + + segmentation = torch.zeros((height, width), dtype=torch.int32, device=mask_probs.device) + segments: List[Dict] = [] + + if target_size is not None: + mask_probs = nn.functional.interpolate( + mask_probs.unsqueeze(0), size=target_size, mode="bilinear", align_corners=False + )[0] + + current_segment_id = 0 + + # Weigh each mask by its prediction score + mask_probs *= pred_scores.view(-1, 1, 1) + mask_labels = mask_probs.argmax(0) # [height, width] + + # Keep track of instances of each class + stuff_memory_list: Dict[str, int] = {} + for k in range(pred_labels.shape[0]): + pred_class = pred_labels[k].item() + should_fuse = pred_class in label_ids_to_fuse + + # Check if mask exists and large enough to be a segment + mask_exists, mask_k = check_segment_validity( + mask_labels, mask_probs, k, mask_threshold, overlap_mask_area_threshold + ) + + if mask_exists: + if pred_class in stuff_memory_list: + current_segment_id = stuff_memory_list[pred_class] + else: + current_segment_id += 1 + + # Add current object segment to final segmentation map + segmentation[mask_k] = current_segment_id + segment_score = round(pred_scores[k].item(), 6) + segments.append( + { + "id": current_segment_id, + "label_id": pred_class, + "was_fused": should_fuse, + "score": segment_score, + } + ) + if should_fuse: + stuff_memory_list[pred_class] = current_segment_id + + return segmentation, segments + + +# Copied from transformers.models.maskformer.image_processing_maskformer.convert_segmentation_map_to_binary_masks +def convert_segmentation_map_to_binary_masks( + segmentation_map: "np.ndarray", + instance_id_to_semantic_id: Optional[Dict[int, int]] = None, + ignore_index: Optional[int] = None, + do_reduce_labels: bool = False, +): + if do_reduce_labels and ignore_index is None: + raise ValueError("If `do_reduce_labels` is True, `ignore_index` must be provided.") + + if do_reduce_labels: + segmentation_map = np.where(segmentation_map == 0, ignore_index, segmentation_map - 1) + + # Get unique ids (class or instance ids based on input) + all_labels = np.unique(segmentation_map) + + # Drop background label if applicable + if ignore_index is not None: + all_labels = all_labels[all_labels != ignore_index] + + # Generate a binary mask for each object instance + binary_masks = [(segmentation_map == i) for i in all_labels] + + # Stack the binary masks + if binary_masks: + binary_masks = np.stack(binary_masks, axis=0) + else: + binary_masks = np.zeros((0, *segmentation_map.shape)) + + # Convert instance ids to class ids + if instance_id_to_semantic_id is not None: + labels = np.zeros(all_labels.shape[0]) + + for label in all_labels: + class_id = instance_id_to_semantic_id[label + 1 if do_reduce_labels else label] + labels[all_labels == label] = class_id - 1 if do_reduce_labels else class_id + else: + labels = all_labels + + return binary_masks.astype(np.float32), labels.astype(np.int64) + + +def get_oneformer_resize_output_image_size( + image: np.ndarray, + size: Union[int, Tuple[int, int], List[int], Tuple[int]], + max_size: Optional[int] = None, + default_to_square: bool = True, + input_data_format: Optional[Union[str, ChannelDimension]] = None, +) -> tuple: + """ + Computes the output size given the desired size. + + Args: + image (`np.ndarray`): + The input image. + size (`int` or `Tuple[int, int]` or `List[int]` or `Tuple[int]`): + The size of the output image. + max_size (`int`, *optional*): + The maximum size of the output image. + default_to_square (`bool`, *optional*, defaults to `True`): + Whether to default to square if no size is provided. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If unset, will use the inferred format from the input. + + Returns: + `Tuple[int, int]`: The output size. + """ + output_size = get_resize_output_image_size( + input_image=image, + size=size, + default_to_square=default_to_square, + max_size=max_size, + input_data_format=input_data_format, + ) + return output_size + + +def prepare_metadata(class_info): + metadata = {} + class_names = [] + thing_ids = [] + for key, info in class_info.items(): + metadata[key] = info["name"] + class_names.append(info["name"]) + if info["isthing"]: + thing_ids.append(int(key)) + metadata["thing_ids"] = thing_ids + metadata["class_names"] = class_names + return metadata + + +def load_metadata(repo_id, class_info_file): + fname = os.path.join("" if repo_id is None else repo_id, class_info_file) + + if not os.path.exists(fname) or not os.path.isfile(fname): + if repo_id is None: + raise ValueError(f"Could not file {fname} locally. repo_id must be defined if loading from the hub") + # We try downloading from a dataset by default for backward compatibility + try: + fname = hf_hub_download(repo_id, class_info_file, repo_type="dataset") + except RepositoryNotFoundError: + fname = hf_hub_download(repo_id, class_info_file) + + with open(fname, "r") as f: + class_info = json.load(f) + + return class_info + + +class OneFormerImageProcessor(BaseImageProcessor): + r""" + Constructs a OneFormer image processor. The image processor can be used to prepare image(s), task input(s) and + optional text inputs and targets for the model. + + This image processor inherits from [`BaseImageProcessor`] which contains most of the main methods. Users should + refer to this superclass for more information regarding those methods. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the input to a certain `size`. + size (`int`, *optional*, defaults to 800): + Resize the input to the given size. Only has an effect if `do_resize` is set to `True`. If size is a + sequence like `(width, height)`, output size will be matched to this. If size is an int, smaller edge of + the image will be matched to this number. i.e, if `height > width`, then image will be rescaled to `(size * + height / width, size)`. + resample (`int`, *optional*, defaults to `Resampling.BILINEAR`): + An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`, + `PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`, + `PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set + to `True`. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the input to a certain `scale`. + rescale_factor (`float`, *optional*, defaults to `1/ 255`): + Rescale the input by the given factor. Only has an effect if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether or not to normalize the input with mean and standard deviation. + image_mean (`int`, *optional*, defaults to `[0.485, 0.456, 0.406]`): + The sequence of means for each channel, to be used when normalizing images. Defaults to the ImageNet mean. + image_std (`int`, *optional*, defaults to `[0.229, 0.224, 0.225]`): + The sequence of standard deviations for each channel, to be used when normalizing images. Defaults to the + ImageNet std. + ignore_index (`int`, *optional*): + Label to be assigned to background pixels in segmentation maps. If provided, segmentation map pixels + denoted with 0 (background) will be replaced with `ignore_index`. + do_reduce_labels (`bool`, *optional*, defaults to `False`): + Whether or not to decrement all label values of segmentation maps by 1. Usually used for datasets where 0 + is used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k). + The background label will be replaced by `ignore_index`. + repo_path (`str`, *optional*, defaults to `"shi-labs/oneformer_demo"`): + Path to hub repo or local directory containing the JSON file with class information for the dataset. + If unset, will look for `class_info_file` in the current working directory. + class_info_file (`str`, *optional*): + JSON file containing class information for the dataset. See `shi-labs/oneformer_demo/cityscapes_panoptic.json` for an example. + num_text (`int`, *optional*): + Number of text entries in the text input list. + num_labels (`int`, *optional*): + The number of labels in the segmentation map. + """ + + model_input_names = ["pixel_values", "pixel_mask", "task_inputs"] + + @deprecate_kwarg("reduce_labels", new_name="do_reduce_labels", version="4.44.0") + @deprecate_kwarg("max_size", version="4.27.0", warn_if_greater_or_equal_version=True) + @filter_out_non_signature_kwargs(extra=["max_size", "metadata", *INIT_SERVICE_KWARGS]) + def __init__( + self, + do_resize: bool = True, + size: Dict[str, int] = None, + resample: PILImageResampling = PILImageResampling.BILINEAR, + do_rescale: bool = True, + rescale_factor: float = 1 / 255, + do_normalize: bool = True, + image_mean: Union[float, List[float]] = None, + image_std: Union[float, List[float]] = None, + ignore_index: Optional[int] = None, + do_reduce_labels: bool = False, + repo_path: Optional[str] = "shi-labs/oneformer_demo", + class_info_file: Optional[str] = None, + num_text: Optional[int] = None, + num_labels: Optional[int] = None, + **kwargs, + ): + super().__init__(**kwargs) + + # Deprecated, backward compatibility + self._max_size = kwargs.pop("max_size", 1333) + + size = size if size is not None else {"shortest_edge": 800, "longest_edge": self._max_size} + size = get_size_dict(size, max_size=self._max_size, default_to_square=False) + + if class_info_file is None: + raise ValueError("You must provide a `class_info_file`") + + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN + self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD + self.ignore_index = ignore_index + self.do_reduce_labels = do_reduce_labels + self.class_info_file = class_info_file + self.repo_path = repo_path + self.metadata = prepare_metadata(load_metadata(repo_path, class_info_file)) + self.num_text = num_text + self.num_labels = num_labels + + @classmethod + def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs): + """ + Overrides the `from_dict` method from the base class to save support of deprecated `reduce_labels` in old configs + """ + image_processor_dict = image_processor_dict.copy() + if "reduce_labels" in image_processor_dict: + image_processor_dict["do_reduce_labels"] = image_processor_dict.pop("reduce_labels") + return super().from_dict(image_processor_dict, **kwargs) + + # Copied from transformers.models.maskformer.image_processing_maskformer.MaskFormerImageProcessor.to_dict + def to_dict(self) -> Dict[str, Any]: + """ + Serializes this instance to a Python dictionary. This method calls the superclass method and then removes the + `_max_size` attribute from the dictionary. + """ + image_processor_dict = super().to_dict() + image_processor_dict.pop("_max_size", None) + return image_processor_dict + + @deprecate_kwarg("max_size", version="4.27.0", warn_if_greater_or_equal_version=True) + @filter_out_non_signature_kwargs(extra=["max_size"]) + def resize( + self, + image: np.ndarray, + size: Dict[str, int], + resample: PILImageResampling = PILImageResampling.BILINEAR, + data_format=None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize the image to the given size. Size can be min_size (scalar) or `(height, width)` tuple. If size is an + int, smaller edge of the image will be matched to this number. + """ + + # Deprecated, backward compatibility + max_size = kwargs.pop("max_size", None) + + size = get_size_dict(size, max_size=max_size, default_to_square=False) + if "shortest_edge" in size and "longest_edge" in size: + size, max_size = size["shortest_edge"], size["longest_edge"] + elif "height" in size and "width" in size: + size = (size["height"], size["width"]) + max_size = None + else: + raise ValueError( + "Size must contain 'height' and 'width' keys or 'shortest_edge' and 'longest_edge' keys. Got" + f" {size.keys()}." + ) + size = get_oneformer_resize_output_image_size( + image=image, size=size, max_size=max_size, default_to_square=False, input_data_format=input_data_format + ) + image = resize( + image, size=size, resample=resample, data_format=data_format, input_data_format=input_data_format + ) + return image + + # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.rescale + def rescale( + self, + image: np.ndarray, + rescale_factor: float, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> np.ndarray: + """ + Rescale the image by the given factor. image = image * rescale_factor. + + Args: + image (`np.ndarray`): + Image to rescale. + rescale_factor (`float`): + The value to use for rescaling. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the output image. If unset, the channel dimension format of the input + image is used. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + input_data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the input image. If unset, is inferred from the input image. Can be + one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + """ + return rescale(image, rescale_factor, data_format=data_format, input_data_format=input_data_format) + + # Copied from transformers.models.maskformer.image_processing_maskformer.MaskFormerImageProcessor.convert_segmentation_map_to_binary_masks + def convert_segmentation_map_to_binary_masks( + self, + segmentation_map: "np.ndarray", + instance_id_to_semantic_id: Optional[Dict[int, int]] = None, + ignore_index: Optional[int] = None, + do_reduce_labels: bool = False, + ): + do_reduce_labels = do_reduce_labels if do_reduce_labels is not None else self.do_reduce_labels + ignore_index = ignore_index if ignore_index is not None else self.ignore_index + return convert_segmentation_map_to_binary_masks( + segmentation_map=segmentation_map, + instance_id_to_semantic_id=instance_id_to_semantic_id, + ignore_index=ignore_index, + do_reduce_labels=do_reduce_labels, + ) + + def __call__(self, images, task_inputs=None, segmentation_maps=None, **kwargs) -> BatchFeature: + return self.preprocess(images, task_inputs=task_inputs, segmentation_maps=segmentation_maps, **kwargs) + + def _preprocess( + self, + image: ImageInput, + do_resize: Optional[bool] = None, + size: Dict[str, int] = None, + resample: PILImageResampling = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + if do_resize: + image = self.resize(image, size=size, resample=resample, input_data_format=input_data_format) + if do_rescale: + image = self.rescale(image, rescale_factor=rescale_factor, input_data_format=input_data_format) + if do_normalize: + image = self.normalize(image, mean=image_mean, std=image_std, input_data_format=input_data_format) + return image + + def _preprocess_image( + self, + image: ImageInput, + do_resize: Optional[bool] = None, + size: Dict[str, int] = None, + resample: PILImageResampling = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> np.ndarray: + """Preprocesses a single image.""" + # All transformations expect numpy arrays. + image = to_numpy_array(image) + if do_rescale and is_scaled_image(image): + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + if input_data_format is None: + input_data_format = infer_channel_dimension_format(image) + image = self._preprocess( + image=image, + do_resize=do_resize, + size=size, + resample=resample, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + input_data_format=input_data_format, + ) + if data_format is not None: + image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + return image + + def _preprocess_mask( + self, + segmentation_map: ImageInput, + do_resize: Optional[bool] = None, + size: Dict[str, int] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> np.ndarray: + """Preprocesses a single mask.""" + segmentation_map = to_numpy_array(segmentation_map) + # Add channel dimension if missing - needed for certain transformations + if segmentation_map.ndim == 2: + added_channel_dim = True + segmentation_map = segmentation_map[None, ...] + input_data_format = ChannelDimension.FIRST + else: + added_channel_dim = False + if input_data_format is None: + input_data_format = infer_channel_dimension_format(segmentation_map, num_channels=1) + # TODO: (Amy) + # Remork segmentation map processing to include reducing labels and resizing which doesn't + # drop segment IDs > 255. + segmentation_map = self._preprocess( + image=segmentation_map, + do_resize=do_resize, + resample=PILImageResampling.NEAREST, + size=size, + do_rescale=False, + do_normalize=False, + input_data_format=input_data_format, + ) + # Remove extra channel dimension if added for processing + if added_channel_dim: + segmentation_map = segmentation_map.squeeze(0) + return segmentation_map + + @filter_out_non_signature_kwargs() + def preprocess( + self, + images: ImageInput, + task_inputs: Optional[List[str]] = None, + segmentation_maps: Optional[ImageInput] = None, + instance_id_to_semantic_id: Optional[Dict[int, int]] = None, + do_resize: Optional[bool] = None, + size: Optional[Dict[str, int]] = None, + resample: PILImageResampling = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + ignore_index: Optional[int] = None, + do_reduce_labels: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> BatchFeature: + if task_inputs is None: + # Default value + task_inputs = ["panoptic"] + + do_resize = do_resize if do_resize is not None else self.do_resize + size = size if size is not None else self.size + size = get_size_dict(size, default_to_square=False, max_size=self._max_size) + resample = resample if resample is not None else self.resample + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + ignore_index = ignore_index if ignore_index is not None else self.ignore_index + do_reduce_labels = do_reduce_labels if do_reduce_labels is not None else self.do_reduce_labels + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_resize=do_resize, + size=size, + resample=resample, + ) + + if segmentation_maps is not None and not valid_images(segmentation_maps): + raise ValueError( + "Invalid segmentation map type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + images = make_list_of_images(images) + if segmentation_maps is not None: + segmentation_maps = make_list_of_images(segmentation_maps, expected_ndims=2) + + if segmentation_maps is not None and len(images) != len(segmentation_maps): + raise ValueError("Images and segmentation maps must have the same length.") + + images = [ + self._preprocess_image( + image, + do_resize=do_resize, + size=size, + resample=resample, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + data_format=data_format, + input_data_format=input_data_format, + ) + for image in images + ] + + if segmentation_maps is not None: + segmentation_maps = [ + self._preprocess_mask(segmentation_map, do_resize, size, input_data_format=input_data_format) + for segmentation_map in segmentation_maps + ] + encoded_inputs = self.encode_inputs( + images, + task_inputs, + segmentation_maps, + instance_id_to_semantic_id, + ignore_index, + do_reduce_labels, + return_tensors, + input_data_format=data_format, + ) + return encoded_inputs + + # Copied from transformers.models.vilt.image_processing_vilt.ViltImageProcessor._pad_image + def _pad_image( + self, + image: np.ndarray, + output_size: Tuple[int, int], + constant_values: Union[float, Iterable[float]] = 0, + data_format: Optional[ChannelDimension] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> np.ndarray: + """ + Pad an image with zeros to the given size. + """ + input_height, input_width = get_image_size(image, channel_dim=input_data_format) + output_height, output_width = output_size + + pad_bottom = output_height - input_height + pad_right = output_width - input_width + padding = ((0, pad_bottom), (0, pad_right)) + padded_image = pad( + image, + padding, + mode=PaddingMode.CONSTANT, + constant_values=constant_values, + data_format=data_format, + input_data_format=input_data_format, + ) + return padded_image + + # Copied from transformers.models.vilt.image_processing_vilt.ViltImageProcessor.pad + def pad( + self, + images: List[np.ndarray], + constant_values: Union[float, Iterable[float]] = 0, + return_pixel_mask: bool = True, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Optional[ChannelDimension] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> BatchFeature: + """ + Pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width + in the batch and optionally returns their corresponding pixel mask. + + Args: + image (`np.ndarray`): + Image to pad. + constant_values (`float` or `Iterable[float]`, *optional*): + The value to use for the padding if `mode` is `"constant"`. + return_pixel_mask (`bool`, *optional*, defaults to `True`): + Whether to return a pixel mask. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be the same as the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + pad_size = get_max_height_width(images, input_data_format=input_data_format) + + padded_images = [ + self._pad_image( + image, + pad_size, + constant_values=constant_values, + data_format=data_format, + input_data_format=input_data_format, + ) + for image in images + ] + data = {"pixel_values": padded_images} + + if return_pixel_mask: + masks = [ + make_pixel_mask(image=image, output_size=pad_size, input_data_format=input_data_format) + for image in images + ] + data["pixel_mask"] = masks + + return BatchFeature(data=data, tensor_type=return_tensors) + + def get_semantic_annotations(self, label, num_class_obj): + annotation_classes = label["classes"] + annotation_masks = label["masks"] + + texts = ["a semantic photo"] * self.num_text + classes = [] + masks = [] + + for idx in range(len(annotation_classes)): + class_id = annotation_classes[idx] + mask = annotation_masks[idx] + if not np.all(mask is False): + if class_id not in classes: + cls_name = self.metadata[str(class_id)] + classes.append(class_id) + masks.append(mask) + num_class_obj[cls_name] += 1 + else: + idx = classes.index(class_id) + masks[idx] += mask + masks[idx] = np.clip(masks[idx], 0, 1) + + num = 0 + for i, cls_name in enumerate(self.metadata["class_names"]): + if num_class_obj[cls_name] > 0: + for _ in range(num_class_obj[cls_name]): + if num >= len(texts): + break + texts[num] = f"a photo with a {cls_name}" + num += 1 + + classes = np.array(classes) + masks = np.array(masks) + return classes, masks, texts + + def get_instance_annotations(self, label, num_class_obj): + annotation_classes = label["classes"] + annotation_masks = label["masks"] + + texts = ["an instance photo"] * self.num_text + classes = [] + masks = [] + + for idx in range(len(annotation_classes)): + class_id = annotation_classes[idx] + mask = annotation_masks[idx] + + if class_id in self.metadata["thing_ids"]: + if not np.all(mask is False): + cls_name = self.metadata[str(class_id)] + classes.append(class_id) + masks.append(mask) + num_class_obj[cls_name] += 1 + + num = 0 + for i, cls_name in enumerate(self.metadata["class_names"]): + if num_class_obj[cls_name] > 0: + for _ in range(num_class_obj[cls_name]): + if num >= len(texts): + break + texts[num] = f"a photo with a {cls_name}" + num += 1 + + classes = np.array(classes) + masks = np.array(masks) + return classes, masks, texts + + def get_panoptic_annotations(self, label, num_class_obj): + annotation_classes = label["classes"] + annotation_masks = label["masks"] + + texts = ["an panoptic photo"] * self.num_text + classes = [] + masks = [] + + for idx in range(len(annotation_classes)): + class_id = annotation_classes[idx] + mask = annotation_masks[idx].data + if not np.all(mask is False): + cls_name = self.metadata[str(class_id)] + classes.append(class_id) + masks.append(mask) + num_class_obj[cls_name] += 1 + + num = 0 + for i, cls_name in enumerate(self.metadata["class_names"]): + if num_class_obj[cls_name] > 0: + for _ in range(num_class_obj[cls_name]): + if num >= len(texts): + break + texts[num] = f"a photo with a {cls_name}" + num += 1 + + classes = np.array(classes) + masks = np.array(masks) + return classes, masks, texts + + def encode_inputs( + self, + pixel_values_list: List[ImageInput], + task_inputs: List[str], + segmentation_maps: ImageInput = None, + instance_id_to_semantic_id: Optional[Union[List[Dict[int, int]], Dict[int, int]]] = None, + ignore_index: Optional[int] = None, + do_reduce_labels: bool = False, + return_tensors: Optional[Union[str, TensorType]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + """ + Pad images up to the largest image in a batch and create a corresponding `pixel_mask`. + + OneFormer addresses semantic segmentation with a mask classification paradigm, thus input segmentation maps + will be converted to lists of binary masks and their respective labels. Let's see an example, assuming + `segmentation_maps = [[2,6,7,9]]`, the output will contain `mask_labels = + [[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,1]]` (four binary masks) and `class_labels = [2,6,7,9]`, the labels for + each mask. + + Args: + pixel_values_list (`List[ImageInput]`): + List of images (pixel values) to be padded. Each image should be a tensor of shape `(channels, height, + width)`. + + task_inputs (`List[str]`): + List of task values. + + segmentation_maps (`ImageInput`, *optional*): + The corresponding semantic segmentation maps with the pixel-wise annotations. + + (`bool`, *optional*, defaults to `True`): + Whether or not to pad images up to the largest image in a batch and create a pixel mask. + + If left to the default, will return a pixel mask that is: + + - 1 for pixels that are real (i.e. **not masked**), + - 0 for pixels that are padding (i.e. **masked**). + + instance_id_to_semantic_id (`List[Dict[int, int]]` or `Dict[int, int]`, *optional*): + A mapping between object instance ids and class ids. If passed, `segmentation_maps` is treated as an + instance segmentation map where each pixel represents an instance id. Can be provided as a single + dictionary with a global/dataset-level mapping or as a list of dictionaries (one per image), to map + instance ids in each image separately. + + return_tensors (`str` or [`~file_utils.TensorType`], *optional*): + If set, will return tensors instead of NumPy arrays. If set to `'pt'`, return PyTorch `torch.Tensor` + objects. + + input_data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred from the input + image. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **pixel_values** -- Pixel values to be fed to a model. + - **pixel_mask** -- Pixel mask to be fed to a model (when `=True` or if `pixel_mask` is in + `self.model_input_names`). + - **mask_labels** -- Optional list of mask labels of shape `(labels, height, width)` to be fed to a model + (when `annotations` are provided). + - **class_labels** -- Optional list of class labels of shape `(labels)` to be fed to a model (when + `annotations` are provided). They identify the labels of `mask_labels`, e.g. the label of + `mask_labels[i][j]` if `class_labels[i][j]`. + - **text_inputs** -- Optional list of text string entries to be fed to a model (when `annotations` are + provided). They identify the binary masks present in the image. + """ + ignore_index = self.ignore_index if ignore_index is None else ignore_index + do_reduce_labels = self.do_reduce_labels if do_reduce_labels is None else do_reduce_labels + pixel_values_list = [to_numpy_array(pixel_values) for pixel_values in pixel_values_list] + + if input_data_format is None: + input_data_format = infer_channel_dimension_format(pixel_values_list[0]) + + pad_size = get_max_height_width(pixel_values_list, input_data_format=input_data_format) + encoded_inputs = self.pad( + pixel_values_list, return_tensors=return_tensors, input_data_format=input_data_format + ) + + annotations = None + if segmentation_maps is not None: + segmentation_maps = map(np.array, segmentation_maps) + annotations = [] + for idx, segmentation_map in enumerate(segmentation_maps): + # Use instance2class_id mapping per image + if isinstance(instance_id_to_semantic_id, list): + instance_id = instance_id_to_semantic_id[idx] + else: + instance_id = instance_id_to_semantic_id + # Use instance2class_id mapping per image + masks, classes = self.convert_segmentation_map_to_binary_masks( + segmentation_map, instance_id, ignore_index=ignore_index, do_reduce_labels=do_reduce_labels + ) + annotations.append({"masks": masks, "classes": classes}) + + if annotations is not None: + mask_labels = [] + class_labels = [] + text_inputs = [] + + num_class_obj = {} + for cls_name in self.metadata["class_names"]: + num_class_obj[cls_name] = 0 + + for i, label in enumerate(annotations): + task = task_inputs[i] + if task == "semantic": + classes, masks, texts = self.get_semantic_annotations(label, num_class_obj) + elif task == "instance": + classes, masks, texts = self.get_instance_annotations(label, num_class_obj) + elif task == "panoptic": + classes, masks, texts = self.get_panoptic_annotations(label, num_class_obj) + else: + raise ValueError(f"{task} was not expected, expected `semantic`, `instance` or `panoptic`") + + # we cannot batch them since they don't share a common class size + masks = [mask[None, ...] for mask in masks] + masks = [ + self._pad_image(image=mask, output_size=pad_size, constant_values=ignore_index) for mask in masks + ] + masks = np.concatenate(masks, axis=0) + mask_labels.append(torch.from_numpy(masks)) + class_labels.append(torch.from_numpy(classes).long()) + text_inputs.append(texts) + + encoded_inputs["mask_labels"] = mask_labels + encoded_inputs["class_labels"] = class_labels + encoded_inputs["text_inputs"] = text_inputs + + # This needs to be tokenized before sending to the model. + encoded_inputs["task_inputs"] = [f"the task is {task_input}" for task_input in task_inputs] + + return encoded_inputs + + # Copied from transformers.models.maskformer.image_processing_maskformer.MaskFormerImageProcessor.post_process_semantic_segmentation + def post_process_semantic_segmentation( + self, outputs, target_sizes: Optional[List[Tuple[int, int]]] = None + ) -> "torch.Tensor": + """ + Converts the output of [`MaskFormerForInstanceSegmentation`] into semantic segmentation maps. Only supports + PyTorch. + + Args: + outputs ([`MaskFormerForInstanceSegmentation`]): + Raw outputs of the model. + target_sizes (`List[Tuple[int, int]]`, *optional*): + List of length (batch_size), where each list item (`Tuple[int, int]]`) corresponds to the requested + final size (height, width) of each prediction. If left to None, predictions will not be resized. + Returns: + `List[torch.Tensor]`: + A list of length `batch_size`, where each item is a semantic segmentation map of shape (height, width) + corresponding to the target_sizes entry (if `target_sizes` is specified). Each entry of each + `torch.Tensor` correspond to a semantic class id. + """ + class_queries_logits = outputs.class_queries_logits # [batch_size, num_queries, num_classes+1] + masks_queries_logits = outputs.masks_queries_logits # [batch_size, num_queries, height, width] + + # Remove the null class `[..., :-1]` + masks_classes = class_queries_logits.softmax(dim=-1)[..., :-1] + masks_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width] + + # Semantic segmentation logits of shape (batch_size, num_classes, height, width) + segmentation = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs) + batch_size = class_queries_logits.shape[0] + + # Resize logits and compute semantic segmentation maps + if target_sizes is not None: + if batch_size != len(target_sizes): + raise ValueError( + "Make sure that you pass in as many target sizes as the batch dimension of the logits" + ) + + semantic_segmentation = [] + for idx in range(batch_size): + resized_logits = torch.nn.functional.interpolate( + segmentation[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False + ) + semantic_map = resized_logits[0].argmax(dim=0) + semantic_segmentation.append(semantic_map) + else: + semantic_segmentation = segmentation.argmax(dim=1) + semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])] + + return semantic_segmentation + + def post_process_instance_segmentation( + self, + outputs, + task_type: str = "instance", + is_demo: bool = True, + threshold: float = 0.5, + mask_threshold: float = 0.5, + overlap_mask_area_threshold: float = 0.8, + target_sizes: Optional[List[Tuple[int, int]]] = None, + return_coco_annotation: Optional[bool] = False, + ): + """ + Converts the output of [`OneFormerForUniversalSegmentationOutput`] into image instance segmentation + predictions. Only supports PyTorch. + + Args: + outputs ([`OneFormerForUniversalSegmentationOutput`]): + The outputs from [`OneFormerForUniversalSegmentationOutput`]. + task_type (`str`, *optional*, defaults to "instance"): + The post processing depends on the task token input. If the `task_type` is "panoptic", we need to + ignore the stuff predictions. + is_demo (`bool`, *optional)*, defaults to `True`): + Whether the model is in demo mode. If true, use threshold to predict final masks. + threshold (`float`, *optional*, defaults to 0.5): + The probability score threshold to keep predicted instance masks. + mask_threshold (`float`, *optional*, defaults to 0.5): + Threshold to use when turning the predicted masks into binary values. + overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8): + The overlap mask area threshold to merge or discard small disconnected parts within each binary + instance mask. + target_sizes (`List[Tuple]`, *optional*): + List of length (batch_size), where each list item (`Tuple[int, int]]`) corresponds to the requested + final size (height, width) of each prediction in batch. If left to None, predictions will not be + resized. + return_coco_annotation (`bool`, *optional)*, defaults to `False`): + Whether to return predictions in COCO format. + + Returns: + `List[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys: + - **segmentation** -- a tensor of shape `(height, width)` where each pixel represents a `segment_id`, set + to `None` if no mask if found above `threshold`. If `target_sizes` is specified, segmentation is resized + to the corresponding `target_sizes` entry. + - **segments_info** -- A dictionary that contains additional information on each segment. + - **id** -- an integer representing the `segment_id`. + - **label_id** -- An integer representing the label / semantic class id corresponding to `segment_id`. + - **was_fused** -- a boolean, `True` if `label_id` was in `label_ids_to_fuse`, `False` otherwise. + Multiple instances of the same class / label were fused and assigned a single `segment_id`. + - **score** -- Prediction score of segment with `segment_id`. + """ + class_queries_logits = outputs.class_queries_logits # [batch_size, num_queries, num_classes+1] + masks_queries_logits = outputs.masks_queries_logits # [batch_size, num_queries, height, width] + + device = masks_queries_logits.device + batch_size = class_queries_logits.shape[0] + num_queries = class_queries_logits.shape[1] + num_classes = class_queries_logits.shape[-1] - 1 + + # Loop over items in batch size + results: List[Dict[str, torch.Tensor]] = [] + + for i in range(batch_size): + # [Q, K] + scores = torch.nn.functional.softmax(class_queries_logits[i], dim=-1)[:, :-1] + labels = torch.arange(num_classes, device=device).unsqueeze(0).repeat(num_queries, 1).flatten(0, 1) + + # scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.num_queries, sorted=False) + scores_per_image, topk_indices = scores.flatten(0, 1).topk(num_queries, sorted=False) + labels_per_image = labels[topk_indices] + + topk_indices = torch.div(topk_indices, num_classes, rounding_mode="floor") + # mask_pred = mask_pred.unsqueeze(1).repeat(1, self.sem_seg_head.num_classes, 1).flatten(0, 1) + mask_pred = masks_queries_logits[i][topk_indices] + + # Only consider scores with confidence over [threshold] for demo + if is_demo: + keep = scores_per_image > threshold + scores_per_image = scores_per_image[keep] + labels_per_image = labels_per_image[keep] + mask_pred = mask_pred[keep] + + # if this is panoptic segmentation, we only keep the "thing" classes + if task_type == "panoptic": + keep = torch.zeros_like(scores_per_image).bool() + for j, lab in enumerate(labels_per_image): + keep[j] = lab in self.metadata["thing_ids"] + + scores_per_image = scores_per_image[keep] + labels_per_image = labels_per_image[keep] + mask_pred = mask_pred[keep] + + if mask_pred.shape[0] <= 0: + height, width = target_sizes[i] if target_sizes is not None else mask_pred.shape[1:] + segmentation = torch.zeros((height, width)) - 1 + results.append({"segmentation": segmentation, "segments_info": []}) + continue + + if "ade20k" in self.class_info_file and not is_demo and "instance" in task_type: + for j in range(labels_per_image.shape[0]): + labels_per_image[j] = self.metadata["thing_ids"].index(labels_per_image[j].item()) + + # Get segmentation map and segment information of batch item + target_size = target_sizes[i] if target_sizes is not None else None + segmentation, segments = compute_segments( + mask_pred, + scores_per_image, + labels_per_image, + mask_threshold, + overlap_mask_area_threshold, + set(), + target_size, + ) + + # Return segmentation map in run-length encoding (RLE) format + if return_coco_annotation: + segmentation = convert_segmentation_to_rle(segmentation) + + results.append({"segmentation": segmentation, "segments_info": segments}) + return results + + # Copied from transformers.models.maskformer.image_processing_maskformer.MaskFormerImageProcessor.post_process_panoptic_segmentation + def post_process_panoptic_segmentation( + self, + outputs, + threshold: float = 0.5, + mask_threshold: float = 0.5, + overlap_mask_area_threshold: float = 0.8, + label_ids_to_fuse: Optional[Set[int]] = None, + target_sizes: Optional[List[Tuple[int, int]]] = None, + ) -> List[Dict]: + """ + Converts the output of [`MaskFormerForInstanceSegmentationOutput`] into image panoptic segmentation + predictions. Only supports PyTorch. + + Args: + outputs ([`MaskFormerForInstanceSegmentationOutput`]): + The outputs from [`MaskFormerForInstanceSegmentation`]. + threshold (`float`, *optional*, defaults to 0.5): + The probability score threshold to keep predicted instance masks. + mask_threshold (`float`, *optional*, defaults to 0.5): + Threshold to use when turning the predicted masks into binary values. + overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8): + The overlap mask area threshold to merge or discard small disconnected parts within each binary + instance mask. + label_ids_to_fuse (`Set[int]`, *optional*): + The labels in this state will have all their instances be fused together. For instance we could say + there can only be one sky in an image, but several persons, so the label ID for sky would be in that + set, but not the one for person. + target_sizes (`List[Tuple]`, *optional*): + List of length (batch_size), where each list item (`Tuple[int, int]]`) corresponds to the requested + final size (height, width) of each prediction in batch. If left to None, predictions will not be + resized. + + Returns: + `List[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys: + - **segmentation** -- a tensor of shape `(height, width)` where each pixel represents a `segment_id`, set + to `None` if no mask if found above `threshold`. If `target_sizes` is specified, segmentation is resized + to the corresponding `target_sizes` entry. + - **segments_info** -- A dictionary that contains additional information on each segment. + - **id** -- an integer representing the `segment_id`. + - **label_id** -- An integer representing the label / semantic class id corresponding to `segment_id`. + - **was_fused** -- a boolean, `True` if `label_id` was in `label_ids_to_fuse`, `False` otherwise. + Multiple instances of the same class / label were fused and assigned a single `segment_id`. + - **score** -- Prediction score of segment with `segment_id`. + """ + + if label_ids_to_fuse is None: + logger.warning("`label_ids_to_fuse` unset. No instance will be fused.") + label_ids_to_fuse = set() + + class_queries_logits = outputs.class_queries_logits # [batch_size, num_queries, num_classes+1] + masks_queries_logits = outputs.masks_queries_logits # [batch_size, num_queries, height, width] + + batch_size = class_queries_logits.shape[0] + num_labels = class_queries_logits.shape[-1] - 1 + + mask_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width] + + # Predicted label and score of each query (batch_size, num_queries) + pred_scores, pred_labels = nn.functional.softmax(class_queries_logits, dim=-1).max(-1) + + # Loop over items in batch size + results: List[Dict[str, TensorType]] = [] + + for i in range(batch_size): + mask_probs_item, pred_scores_item, pred_labels_item = remove_low_and_no_objects( + mask_probs[i], pred_scores[i], pred_labels[i], threshold, num_labels + ) + + # No mask found + if mask_probs_item.shape[0] <= 0: + height, width = target_sizes[i] if target_sizes is not None else mask_probs_item.shape[1:] + segmentation = torch.zeros((height, width)) - 1 + results.append({"segmentation": segmentation, "segments_info": []}) + continue + + # Get segmentation map and segment information of batch item + target_size = target_sizes[i] if target_sizes is not None else None + segmentation, segments = compute_segments( + mask_probs=mask_probs_item, + pred_scores=pred_scores_item, + pred_labels=pred_labels_item, + mask_threshold=mask_threshold, + overlap_mask_area_threshold=overlap_mask_area_threshold, + label_ids_to_fuse=label_ids_to_fuse, + target_size=target_size, + ) + + results.append({"segmentation": segmentation, "segments_info": segments}) + return results + + +__all__ = ["OneFormerImageProcessor"] diff --git a/docs/transformers/src/transformers/models/oneformer/modeling_oneformer.py b/docs/transformers/src/transformers/models/oneformer/modeling_oneformer.py new file mode 100644 index 0000000000000000000000000000000000000000..7c3ecb3611b3cfce06f477815569e9b86458b8b2 --- /dev/null +++ b/docs/transformers/src/transformers/models/oneformer/modeling_oneformer.py @@ -0,0 +1,3262 @@ +# coding=utf-8 +# Copyright 2022 SHI Labs and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch OneFormer model.""" + +import copy +import math +import warnings +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from torch import Tensor, nn +from torch.cuda.amp import autocast + +from ...activations import ACT2FN +from ...modeling_outputs import BaseModelOutput +from ...modeling_utils import PreTrainedModel +from ...utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_accelerate_available, + is_scipy_available, + logging, + replace_return_docstrings, + requires_backends, +) +from ...utils.backbone_utils import load_backbone +from .configuration_oneformer import OneFormerConfig + + +if is_accelerate_available(): + from accelerate import PartialState + from accelerate.utils import reduce + +logger = logging.get_logger(__name__) + + +_CONFIG_FOR_DOC = "OneFormerConfig" +_CHECKPOINT_FOR_DOC = "shi-labs/oneformer_ade20k_swin_tiny" + + +if is_scipy_available(): + from scipy.optimize import linear_sum_assignment + + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +def multi_scale_deformable_attention( + value: Tensor, + value_spatial_shapes: Union[Tensor, List[Tuple]], + sampling_locations: Tensor, + attention_weights: Tensor, +) -> Tensor: + batch_size, _, num_heads, hidden_dim = value.shape + _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape + value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1) + sampling_grids = 2 * sampling_locations - 1 + sampling_value_list = [] + for level_id, (height, width) in enumerate(value_spatial_shapes): + # batch_size, height*width, num_heads, hidden_dim + # -> batch_size, height*width, num_heads*hidden_dim + # -> batch_size, num_heads*hidden_dim, height*width + # -> batch_size*num_heads, hidden_dim, height, width + value_l_ = ( + value_list[level_id].flatten(2).transpose(1, 2).reshape(batch_size * num_heads, hidden_dim, height, width) + ) + # batch_size, num_queries, num_heads, num_points, 2 + # -> batch_size, num_heads, num_queries, num_points, 2 + # -> batch_size*num_heads, num_queries, num_points, 2 + sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1) + # batch_size*num_heads, hidden_dim, num_queries, num_points + sampling_value_l_ = nn.functional.grid_sample( + value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False + ) + sampling_value_list.append(sampling_value_l_) + # (batch_size, num_queries, num_heads, num_levels, num_points) + # -> (batch_size, num_heads, num_queries, num_levels, num_points) + # -> (batch_size, num_heads, 1, num_queries, num_levels*num_points) + attention_weights = attention_weights.transpose(1, 2).reshape( + batch_size * num_heads, 1, num_queries, num_levels * num_points + ) + output = ( + (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights) + .sum(-1) + .view(batch_size, num_heads * hidden_dim, num_queries) + ) + return output.transpose(1, 2).contiguous() + + +# Copied from transformers.models.maskformer.modeling_maskformer.dice_loss +def dice_loss(inputs: Tensor, labels: Tensor, num_masks: int) -> Tensor: + r""" + Compute the DICE loss, similar to generalized IOU for masks as follows: + + $$ \mathcal{L}_{\text{dice}(x, y) = 1 - \frac{2 * x \cap y }{x \cup y + 1}} $$ + + In practice, since `labels` is a binary mask, (only 0s and 1s), dice can be computed as follow + + $$ \mathcal{L}_{\text{dice}(x, y) = 1 - \frac{2 * x * y }{x + y + 1}} $$ + + Args: + inputs (`torch.Tensor`): + A tensor representing a mask. + labels (`torch.Tensor`): + A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs + (0 for the negative class and 1 for the positive class). + num_masks (`int`): + The number of masks present in the current batch, used for normalization. + + Returns: + `torch.Tensor`: The computed loss. + """ + probs = inputs.sigmoid().flatten(1) + numerator = 2 * (probs * labels).sum(-1) + denominator = probs.sum(-1) + labels.sum(-1) + loss = 1 - (numerator + 1) / (denominator + 1) + loss = loss.sum() / num_masks + return loss + + +# Copied from transformers.models.mask2former.modeling_mask2former.sigmoid_cross_entropy_loss +def sigmoid_cross_entropy_loss(inputs: torch.Tensor, labels: torch.Tensor, num_masks: int) -> torch.Tensor: + r""" + Args: + inputs (`torch.Tensor`): + A float tensor of arbitrary shape. + labels (`torch.Tensor`): + A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs + (0 for the negative class and 1 for the positive class). + + Returns: + loss (`torch.Tensor`): The computed loss. + """ + criterion = nn.BCEWithLogitsLoss(reduction="none") + cross_entropy_loss = criterion(inputs, labels) + + loss = cross_entropy_loss.mean(1).sum() / num_masks + return loss + + +# Copied from transformers.models.maskformer.modeling_maskformer.pair_wise_dice_loss +def pair_wise_dice_loss(inputs: Tensor, labels: Tensor) -> Tensor: + """ + A pair wise version of the dice loss, see `dice_loss` for usage. + + Args: + inputs (`torch.Tensor`): + A tensor representing a mask + labels (`torch.Tensor`): + A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs + (0 for the negative class and 1 for the positive class). + + Returns: + `torch.Tensor`: The computed loss between each pairs. + """ + inputs = inputs.sigmoid().flatten(1) + numerator = 2 * torch.matmul(inputs, labels.T) + # using broadcasting to get a [num_queries, NUM_CLASSES] matrix + denominator = inputs.sum(-1)[:, None] + labels.sum(-1)[None, :] + loss = 1 - (numerator + 1) / (denominator + 1) + return loss + + +# Copied from transformers.models.mask2former.modeling_mask2former.pair_wise_sigmoid_cross_entropy_loss +def pair_wise_sigmoid_cross_entropy_loss(inputs: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + r""" + A pair wise version of the cross entropy loss, see `sigmoid_cross_entropy_loss` for usage. + + Args: + inputs (`torch.Tensor`): + A tensor representing a mask. + labels (`torch.Tensor`): + A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs + (0 for the negative class and 1 for the positive class). + + Returns: + loss (`torch.Tensor`): The computed loss between each pairs. + """ + + height_and_width = inputs.shape[1] + + criterion = nn.BCEWithLogitsLoss(reduction="none") + cross_entropy_loss_pos = criterion(inputs, torch.ones_like(inputs)) + cross_entropy_loss_neg = criterion(inputs, torch.zeros_like(inputs)) + + loss_pos = torch.matmul(cross_entropy_loss_pos / height_and_width, labels.T) + loss_neg = torch.matmul(cross_entropy_loss_neg / height_and_width, (1 - labels).T) + loss = loss_pos + loss_neg + return loss + + +# Copied from transformers.models.mask2former.modeling_mask2former.sample_point +def sample_point( + input_features: torch.Tensor, point_coordinates: torch.Tensor, add_dim=False, **kwargs +) -> torch.Tensor: + """ + A wrapper around `torch.nn.functional.grid_sample` to support 3D point_coordinates tensors. + + Args: + input_features (`torch.Tensor` of shape (batch_size, channels, height, width)): + A tensor that contains features map on a height * width grid + point_coordinates (`torch.Tensor` of shape (batch_size, num_points, 2) or (batch_size, grid_height, grid_width,: + 2)): + A tensor that contains [0, 1] * [0, 1] normalized point coordinates + add_dim (`bool`): + boolean value to keep track of added dimension + + Returns: + point_features (`torch.Tensor` of shape (batch_size, channels, num_points) or (batch_size, channels, + height_grid, width_grid): + A tensor that contains features for points in `point_coordinates`. + """ + if point_coordinates.dim() == 3: + add_dim = True + point_coordinates = point_coordinates.unsqueeze(2) + + # use nn.function.grid_sample to get features for points in `point_coordinates` via bilinear interpolation + point_features = torch.nn.functional.grid_sample(input_features, 2.0 * point_coordinates - 1.0, **kwargs) + if add_dim: + point_features = point_features.squeeze(3) + + return point_features + + +# Refactored from https://github.com/SHI-Labs/OneFormer/blob/33ebb56ed34f970a30ae103e786c0cb64c653d9a/oneformer/modeling/matcher.py#L93 +class OneFormerHungarianMatcher(nn.Module): + def __init__( + self, cost_class: float = 1.0, cost_mask: float = 1.0, cost_dice: float = 1.0, num_points: int = 12544 + ): + """This class computes an assignment between the labels and the predictions of the network. + + For efficiency reasons, the labels don't include the no_object. Because of this, in general, there are more + predictions than labels. In this case, we do a 1-to-1 matching of the best predictions, while the others are + un-matched (and thus treated as non-objects). + + Params: + cost_class (float, *optional*, defaults to 1.0): + This is the relative weight of the classification error in the matching cost. + cost_mask (float, *optional*, defaults to 1.0): + This is the relative weight of the sigmoid ce loss of the binary mask in the matching cost. + cost_dice (float, *optional*, defaults to 1.0): + This is the relative weight of the dice loss of the binary mask in the matching cost + num_points (int, *optional*, defaults to 12544): + Number of points to be sampled for dice and mask loss matching cost. + """ + super().__init__() + if cost_class == 0 and cost_mask == 0 and cost_dice == 0: + raise ValueError("All costs cant be 0") + self.cost_class = cost_class + self.cost_mask = cost_mask + self.cost_dice = cost_dice + self.num_points = num_points + + @torch.no_grad() + def forward(self, masks_queries_logits, class_queries_logits, mask_labels, class_labels) -> List[Tuple[Tensor]]: + """Performs the matching + + Params: + masks_queries_logits (`torch.Tensor`): + A tensor` of dim `batch_size, num_queries, num_labels` with the + classification logits. + class_queries_logits (`torch.Tensor`): + A tensor` of dim `batch_size, num_queries, height, width` with the + predicted masks. + + class_labels (`torch.Tensor`): + A tensor` of dim `num_target_boxes` (where num_target_boxes is the number + of ground-truth objects in the target) containing the class labels. + mask_labels (`torch.Tensor`): + A tensor` of dim `num_target_boxes, height, width` containing the target + masks. + + Returns: + `List[Tuple[Tensor]]`: A list of size batch_size, containing tuples of (index_i, index_j) where: + - index_i is the indices of the selected predictions (in order) + - index_j is the indices of the corresponding selected labels (in order) + For each batch element, it holds: + len(index_i) = len(index_j) = min(num_queries, num_targets). + """ + indices: List[Tuple[np.array]] = [] + + num_queries = class_queries_logits.shape[1] + + preds_masks = masks_queries_logits + preds_probs = class_queries_logits + # iterate through batch size + for pred_probs, pred_mask, target_mask, labels in zip(preds_probs, preds_masks, mask_labels, class_labels): + pred_probs = pred_probs.softmax(-1) + # Compute the classification cost. Contrary to the loss, we don't use the NLL, + # but approximate it in 1 - proba[target class]. + # The 1 is a constant that doesn't change the matching, it can be ommitted. + cost_class = -pred_probs[:, labels] + + pred_mask = pred_mask[:, None] + target_mask = target_mask[:, None].to(pred_mask.device) + + # all masks share the same set of points for efficient matching! + point_coords = torch.rand(1, self.num_points, 2, device=pred_mask.device) + + # get ground truth labels + target_mask = sample_point( + target_mask, + point_coords.repeat(target_mask.shape[0], 1, 1), + align_corners=False, + ).squeeze(1) + + pred_mask = sample_point( + pred_mask, + point_coords.repeat(pred_mask.shape[0], 1, 1), + align_corners=False, + ).squeeze(1) + + with autocast(enabled=False): + pred_mask = pred_mask.float() + target_mask = target_mask.float() + + # compute the sigmoid ce loss + cost_mask = pair_wise_sigmoid_cross_entropy_loss(pred_mask, target_mask) + # Compute the dice loss + cost_dice = pair_wise_dice_loss(pred_mask, target_mask) + # final cost matrix + cost_matrix = self.cost_mask * cost_mask + self.cost_class * cost_class + self.cost_dice * cost_dice + cost_matrix = cost_matrix.reshape(num_queries, -1).cpu() + # do the assigmented using the hungarian algorithm in scipy + assigned_indices: Tuple[np.array] = linear_sum_assignment(cost_matrix.cpu()) + indices.append(assigned_indices) + + # It could be stacked in one tensor + matched_indices = [ + (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices + ] + return matched_indices + + +class OneFormerLoss(nn.Module): + def __init__( + self, + num_classes: int, + matcher: OneFormerHungarianMatcher, + weight_dict: Dict[str, float], + eos_coef: float, + num_points: int, + oversample_ratio: float, + importance_sample_ratio: float, + contrastive_temperature: Optional[float] = None, + ): + """ + This class computes the losses using the class predictions, mask predictions and the contrastive queries. + + Oneformer calculates the classification CE loss on the class predictions. Mask predictions are used for + calculating the binary CE loss and dice loss. The contrastive queries are used for calculating the contrastive + loss. + + Args: + num_labels (`int`): + The number of classes. + matcher (`OneFormerHungarianMatcher`): + A torch module that computes the assigments between the predictions and labels. + weight_dict (`Dict[str, float]`): + A dictionary of weights to be applied to the different losses. + eos_coef (`float`): + Weight to apply to the null class. + num_points (`int`): + Number of points to be sampled for dice and mask loss calculations. + oversample_ratio (`float`): + Required for pointwise loss calculation. + importance_sample_ratio (`float`): + Required for pointwise loss calculation. + contrastive_temperature (`float`): + Temperature for scaling the contrastive logits. + """ + requires_backends(self, ["scipy"]) + super().__init__() + self.num_classes = num_classes + self.matcher = matcher + self.weight_dict = weight_dict + self.eos_coef = eos_coef + empty_weight = torch.ones(self.num_classes + 1) + empty_weight[-1] = self.eos_coef + self.register_buffer("empty_weight", empty_weight) + + # pointwise mask loss parameters + self.num_points = num_points + self.oversample_ratio = oversample_ratio + self.importance_sample_ratio = importance_sample_ratio + self.contrastive_temperature = contrastive_temperature + if self.contrastive_temperature is not None: + self.logit_scale = nn.Parameter(torch.tensor(np.log(1 / contrastive_temperature))) + + def _max_by_axis(self, the_list: List[List[int]]) -> List[int]: + maxes = the_list[0] + for sublist in the_list[1:]: + for index, item in enumerate(sublist): + maxes[index] = max(maxes[index], item) + return maxes + + def _pad_images_to_max_in_batch(self, tensors: List[Tensor]) -> Tuple[Tensor, Tensor]: + # get the maximum size in the batch + max_size = self._max_by_axis([list(tensor.shape) for tensor in tensors]) + batch_size = len(tensors) + # compute finel size + batch_shape = [batch_size] + max_size + b, _, h, w = batch_shape + # get metadata + dtype = tensors[0].dtype + device = tensors[0].device + padded_tensors = torch.zeros(batch_shape, dtype=dtype, device=device) + padding_masks = torch.ones((b, h, w), dtype=torch.bool, device=device) + # pad the tensors to the size of the biggest one + for tensor, padded_tensor, padding_mask in zip(tensors, padded_tensors, padding_masks): + padded_tensor[: tensor.shape[0], : tensor.shape[1], : tensor.shape[2]].copy_(tensor) + padding_mask[: tensor.shape[1], : tensor.shape[2]] = False + + return padded_tensors, padding_masks + + def loss_contrastive(self, contrastive_queries_logits: Tensor, text_queries: Tensor): + """Compute the query-text contrastive loss. + + Args: + contrastive_queries_logits (`torch.Tensor`): + A tensor of shape `batch_size, num_queries, hidden_dim` + text_queries (`torch.Tensor`): + A tensor of shape `batch_size, num_queries, hidden_dim` + Returns: + `Dict[str, Tensor]`: A dict of `torch.Tensor` containing the following key: + - **loss_contrastive** -- The query-text contrastive loss computed using task-guided queries + and text queries derived from input text list. + """ + + image_queries = contrastive_queries_logits.float() + + # [batch_size, hidden_dim] + image_queries = nn.functional.normalize(image_queries.flatten(1), dim=-1) + text_queries = nn.functional.normalize(text_queries.flatten(1), dim=-1) + + logit_scale = torch.clamp(self.logit_scale.exp(), max=100) + + logits_per_text = torch.matmul(text_queries, image_queries.t()) * logit_scale + logits_per_img = logits_per_text.t() + + loss_img = nn.functional.cross_entropy( + logits_per_img, torch.arange(len(logits_per_img), device=logits_per_text.device) + ) + loss_text = nn.functional.cross_entropy( + logits_per_text, torch.arange(len(logits_per_text), device=logits_per_text.device) + ) + + loss_contrastive = loss_img + loss_text + + losses = {"loss_contrastive": loss_contrastive} + return losses + + def loss_labels( + self, class_queries_logits: Tensor, class_labels: List[Tensor], indices: Tuple[np.array] + ) -> Dict[str, Tensor]: + """Compute the losses related to the labels using cross entropy. + + Args: + class_queries_logits (`torch.Tensor`): + A tensor of shape `batch_size, num_queries, num_labels` + class_labels (`List[torch.Tensor]`): + List of class labels of shape `(labels)`. + indices (`Tuple[np.array])`: + The indices computed by the Hungarian matcher. + + Returns: + `Dict[str, Tensor]`: A dict of `torch.Tensor` containing the following key: + - **loss_cross_entropy** -- The loss computed using cross entropy on the predicted and ground truth labels. + """ + pred_logits = class_queries_logits + batch_size, num_queries, _ = pred_logits.shape + criterion = nn.CrossEntropyLoss(weight=self.empty_weight) + idx = self._get_predictions_permutation_indices(indices) + + # shape = (batch_size, num_queries) + target_classes_o = torch.cat([target[j] for target, (_, j) in zip(class_labels, indices)]) + # shape = (batch_size, num_queries) + target_classes = torch.full( + (batch_size, num_queries), fill_value=self.num_classes, dtype=torch.int64, device=pred_logits.device + ) + target_classes[idx] = target_classes_o + # permute pred_logits (batch_size, num_queries, num_labels) -> (batch_size, num_labels, num_queries) + pred_logits_transposed = pred_logits.transpose(1, 2) + loss_ce = criterion(pred_logits_transposed, target_classes) + losses = {"loss_cross_entropy": loss_ce} + return losses + + def loss_masks( + self, masks_queries_logits: Tensor, mask_labels: List[Tensor], indices: Tuple[np.array], num_masks: int + ) -> Dict[str, Tensor]: + """Compute the losses related to the masks using focal and dice loss. + + Args: + masks_queries_logits (`torch.Tensor`): + A tensor of shape `batch_size, num_queries, height, width` + mask_labels (`torch.Tensor`): + List of mask labels of shape `(labels, height, width)`. + indices (`Tuple[np.array])`: + The indices computed by the Hungarian matcher. + num_masks (`int)`: + The number of masks, used for normalization. + + Returns: + `Dict[str, Tensor]`: A dict of `torch.Tensor` containing two keys: + - **loss_mask** -- The loss computed using sigmoid ce loss on the predicted and ground truth masks. + - **loss_dice** -- The loss computed using dice loss on the predicted on the predicted and ground truth + masks. + """ + src_idx = self._get_predictions_permutation_indices(indices) + tgt_idx = self._get_targets_permutation_indices(indices) + # shape (batch_size * num_queries, height, width) + pred_masks = masks_queries_logits[src_idx] + # shape (batch_size, num_queries, height, width) + # pad all and stack the targets to the num_labels dimension + # upsample predictions to the target size, we have to add one dim to use interpolate + target_masks, _ = self._pad_images_to_max_in_batch(mask_labels) + target_masks = target_masks[tgt_idx] + + pred_masks = pred_masks[:, None] + target_masks = target_masks[:, None] + + with torch.no_grad(): + # sample point_coords + point_coords = self.sample_points_using_uncertainty( + pred_masks, + self.calculate_uncertainty, + self.num_points, + self.oversample_ratio, + self.importance_sample_ratio, + ) + # get ground-truth labels + point_labels = sample_point(target_masks, point_coords, align_corners=False).squeeze(1) + + point_logits = sample_point(pred_masks, point_coords, align_corners=False).squeeze(1) + + losses = { + "loss_mask": sigmoid_cross_entropy_loss(point_logits, point_labels, num_masks), + "loss_dice": dice_loss(point_logits, point_labels, num_masks), + } + + del pred_masks + del target_masks + return losses + + # Copied from transformers.models.mask2former.modeling_mask2former.Mask2FormerLoss.calculate_uncertainty + def calculate_uncertainty(self, logits: torch.Tensor) -> torch.Tensor: + """ + In Mask2Former paper, uncertainty is estimated as L1 distance between 0.0 and the logit prediction in 'logits' + for the foreground class in `classes`. + + Args: + logits (`torch.Tensor`): + A tensor of shape (R, 1, ...) for class-specific or class-agnostic, where R is the total number of predicted masks in all images and C is: + the number of foreground classes. The values are logits. + + Returns: + scores (`torch.Tensor`): A tensor of shape (R, 1, ...) that contains uncertainty scores with the most + uncertain locations having the highest uncertainty score. + """ + uncertainty_scores = -(torch.abs(logits)) + return uncertainty_scores + + # Copied from transformers.models.mask2former.modeling_mask2former.Mask2FormerLoss.sample_points_using_uncertainty + def sample_points_using_uncertainty( + self, + logits: torch.Tensor, + uncertainty_function, + num_points: int, + oversample_ratio: int, + importance_sample_ratio: float, + ) -> torch.Tensor: + """ + This function is meant for sampling points in [0, 1] * [0, 1] coordinate space based on their uncertainty. The + uncertainty is calculated for each point using the passed `uncertainty function` that takes points logit + prediction as input. + + Args: + logits (`float`): + Logit predictions for P points. + uncertainty_function: + A function that takes logit predictions for P points and returns their uncertainties. + num_points (`int`): + The number of points P to sample. + oversample_ratio (`int`): + Oversampling parameter. + importance_sample_ratio (`float`): + Ratio of points that are sampled via importance sampling. + + Returns: + point_coordinates (`torch.Tensor`): + Coordinates for P sampled points. + """ + + num_boxes = logits.shape[0] + num_points_sampled = int(num_points * oversample_ratio) + + # Get random point coordinates + point_coordinates = torch.rand(num_boxes, num_points_sampled, 2, device=logits.device) + # Get sampled prediction value for the point coordinates + point_logits = sample_point(logits, point_coordinates, align_corners=False) + # Calculate the uncertainties based on the sampled prediction values of the points + point_uncertainties = uncertainty_function(point_logits) + + num_uncertain_points = int(importance_sample_ratio * num_points) + num_random_points = num_points - num_uncertain_points + + idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1] + shift = num_points_sampled * torch.arange(num_boxes, dtype=torch.long, device=logits.device) + idx += shift[:, None] + point_coordinates = point_coordinates.view(-1, 2)[idx.view(-1), :].view(num_boxes, num_uncertain_points, 2) + + if num_random_points > 0: + point_coordinates = torch.cat( + [point_coordinates, torch.rand(num_boxes, num_random_points, 2, device=logits.device)], + dim=1, + ) + return point_coordinates + + def _get_predictions_permutation_indices(self, indices): + # permute predictions following indices + batch_indices = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) + predictions_indices = torch.cat([src for (src, _) in indices]) + return batch_indices, predictions_indices + + def _get_targets_permutation_indices(self, indices): + # permute labels following indices + batch_indices = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) + target_indices = torch.cat([tgt for (_, tgt) in indices]) + return batch_indices, target_indices + + def forward( + self, + masks_queries_logits: Tensor, + class_queries_logits: Tensor, + contrastive_queries_logits: Tensor, + mask_labels: List[Tensor], + class_labels: List[Tensor], + text_queries: Tensor, + auxiliary_predictions: Optional[Dict[str, Tensor]] = None, + calculate_contrastive_loss: bool = True, + ) -> Dict[str, Tensor]: + """ + This performs the loss computation. + + Args: + masks_queries_logits (`torch.Tensor`): + A tensor of shape `batch_size, num_queries, height, width` + class_queries_logits (`torch.Tensor`): + A tensor of shape `batch_size, num_queries, num_labels` + contrastive_queries_logits (`torch.Tensor`): + A tensor of shape `batch_size, num_queries, hidden_dim` + mask_labels (`torch.Tensor`): + List of mask labels of shape `(labels, height, width)`. + class_labels (`List[torch.Tensor]`): + List of class labels of shape `(labels)`. + text_queries (`torch.Tensor`): + A tensor of shape `batch_size, num_queries, hidden_dim` + auxiliary_predictions (`Dict[str, torch.Tensor]`, *optional*): + if `use_auxiliary_loss` was set to `true` in [`OneFormerConfig`], then it contains the logits from the + inner layers of the Detr's Decoder. + calculate_contrastive_loss (`bool`, *optional*, defaults to `True`): + Whether or not to calculate the contrastive loss. + + Returns: + `Dict[str, Tensor]`: A dict of `torch.Tensor` containing two keys: + - **loss_cross_entropy** -- The loss computed using cross entropy on the predicted and ground truth labels. + - **loss_mask** -- The loss computed using sigmoid ce loss on the predicted and ground truth masks. + - **loss_dice** -- The loss computed using dice loss on the predicted on the predicted and ground truth + masks. + - **loss_contrastive** -- The query-text contrstive loss computed using object and text queries. + if `use_auxiliary_loss` was set to `true` in [`OneFormerConfig`], the dictionary contains addional losses + for each auxiliary predictions. + """ + + # retrieve the matching between the outputs of the last layer and the labels + indices = self.matcher(masks_queries_logits, class_queries_logits, mask_labels, class_labels) + # compute the average number of target masks for normalization purposes + num_masks = self.get_num_masks(class_labels, device=class_labels[0].device) + # get all the losses + losses: Dict[str, Tensor] = { + **self.loss_masks(masks_queries_logits, mask_labels, indices, num_masks), + **self.loss_labels(class_queries_logits, class_labels, indices), + } + if calculate_contrastive_loss: + losses = {**losses, **self.loss_contrastive(contrastive_queries_logits, text_queries)} + + # in case of auxiliary losses, we repeat this process with the output of each intermediate layer. + if auxiliary_predictions is not None: + for idx, aux_outputs in enumerate(auxiliary_predictions): + masks_queries_logits = aux_outputs["masks_queries_logits"] + class_queries_logits = aux_outputs["class_queries_logits"] + loss_dict = self.forward( + masks_queries_logits, + class_queries_logits, + None, + mask_labels, + class_labels, + None, + calculate_contrastive_loss=False, + ) + loss_dict = {f"{key}_{idx}": value for key, value in loss_dict.items()} + losses.update(loss_dict) + + return losses + + def get_num_masks(self, class_labels: torch.Tensor, device: torch.device) -> torch.Tensor: + """ + Computes the average number of target masks across the batch, for normalization purposes. + """ + num_masks = sum([len(classes) for classes in class_labels]) + num_masks = torch.as_tensor([num_masks], dtype=torch.float, device=device) + world_size = 1 + if is_accelerate_available(): + if PartialState._shared_state != {}: + num_masks = reduce(num_masks) + world_size = PartialState().num_processes + + num_masks = torch.clamp(num_masks / world_size, min=1) + return num_masks + + +@dataclass +class OneFormerTransformerDecoderOutput(BaseModelOutput): + """ + Base class for outputs of the Transformer decoder. This class adds attributes for class predictions, mask + predictions and contrastive logits to BaseModelOutputWithCrossAttentions. + + Args: + object_logits (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_dim)`): + Queries representation for the region proposals. + contrastive_logits (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_dim)`): + Queries representation for the contrastive loss. + prediction_masks (`torch.FloatTensor` of shape `(batch_size, num_queries, height, width)`): + Mask predictions from last layer of the transformer decoder. + prediction_class (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes+1)`): + Class predictions from last layer of the transformer decoder. + auxiliary_predictions (Tuple of Dict of `str, torch.FloatTensor`, *optional*): + Tuple of class and mask predictions from each layer of the transformer decoder. + """ + + object_queries: Optional[torch.FloatTensor] = None + contrastive_logits: Optional[torch.FloatTensor] = None + prediction_masks: Optional[torch.FloatTensor] = None + prediction_class: Optional[torch.FloatTensor] = None + auxiliary_predictions: Optional[Tuple[Dict[str, torch.FloatTensor]]] = None + + +@dataclass +# Copied from transformers.models.mask2former.modeling_mask2former.Mask2FormerPixelDecoderOutput with Mask2->One +class OneFormerPixelDecoderOutput(ModelOutput): + """ + OneFormer's pixel decoder module output, practically a Multi-Scale Deformable Attention based decoder. It returns + the mask features and the multiscale features. + + Args: + multi_scale_features (`tuple(torch.FloatTensor)`): + Tuple of multi-scale features of scales [1/8, 1/16, 1/32] and shape `(batch_size, num_channels, height, + width)`from the Multi-Scale Deformable Attenntion based Pixel Decoder. + mask_features (`torch.FloatTensor`): + Tensor of shape `(batch_size, num_channels, height, width)`, 1/4 scale features from the last Pixel Decoder + Layer. + attentions (`tuple(torch.FloatTensor)`, *optional*): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights from pixel decoder. Returned when `output_attentions=True` is passed + or when `config.output_attentions=True` + """ + + multi_scale_features: Tuple[torch.FloatTensor] = None + mask_features: Optional[torch.FloatTensor] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class OneFormerPixelLevelModuleOutput(ModelOutput): + """ + OneFormer's pixel level module output. It returns both the last and (optionally) the hidden states from the + `encoder` and `decoder`. By default, the `encoder` is a Swin/Dinat Backbone and the `decoder` is a Multi-Scale + Deformable Attention based decoder. + + Args: + encoder_features (List of `(torch.FloatTensor)`): + List of `torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`. Hidden-states (also + called feature maps) of the model at the output of each stage. + decoder_features (List of `(torch.FloatTensor)`): + List of `torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`. Hidden-states (also + called feature maps) of the model at the output of each stage. + decoder_last_feature (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)): + 1/4 scale features from the last Pixel Decoder Layer. + """ + + encoder_features: List[torch.FloatTensor] = None + decoder_features: List[torch.FloatTensor] = None + decoder_last_feature: Optional[torch.FloatTensor] = None + + +@dataclass +class OneFormerModelOutput(ModelOutput): + """ + Class for outputs of [`OneFormerModel`]. This class returns all the needed hidden states to compute the logits. + + Args: + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the encoder + model at the output of each stage. + pixel_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the pixel + decoder model at the output of each stage. + transformer_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, sequence_length, hidden_size)`. Hidden-states (also called feature maps) of the + transformer decoder at the output of each stage. + transformer_decoder_object_queries (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_dim)`) + Output object queries from the last layer in the transformer decoder. + transformer_decoder_contrastive_queries (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_dim)`) + Contrastive queries from the transformer decoder. + transformer_decoder_mask_predictions (`torch.FloatTensor` of shape `(batch_size, num_queries, height, width)`) + Mask Predictions from the last layer in the transformer decoder. + transformer_decoder_class_predictions (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes+1)`): + Class Predictions from the last layer in the transformer decoder. + transformer_decoder_auxiliary_predictions (Tuple of Dict of `str, torch.FloatTensor`, *optional*): + Tuple of class and mask predictions from each layer of the transformer decoder. + text_queries (`torch.FloatTensor`, *optional* of shape `(batch_size, num_queries, hidden_dim)`) + Text queries derived from the input text list used for calculating contrastive loss during training. + task_token (`torch.FloatTensor` of shape `(batch_size, hidden_dim)`) + 1D task token to condition the queries. + attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tuple(torch.FloatTensor)` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Self and Cross Attentions weights from transformer decoder. + """ + + encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + pixel_decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + transformer_decoder_hidden_states: Optional[torch.FloatTensor] = None + transformer_decoder_object_queries: Optional[torch.FloatTensor] = None + transformer_decoder_contrastive_queries: Optional[torch.FloatTensor] = None + transformer_decoder_mask_predictions: Optional[torch.FloatTensor] = None + transformer_decoder_class_predictions: Optional[torch.FloatTensor] = None + transformer_decoder_auxiliary_predictions: Optional[Tuple[Dict[str, torch.FloatTensor]]] = None + text_queries: Optional[torch.FloatTensor] = None + task_token: Optional[torch.FloatTensor] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class OneFormerForUniversalSegmentationOutput(ModelOutput): + """ + Class for outputs of [`OneFormerForUniversalSegmentationOutput`]. + + This output can be directly passed to [`~OneFormerImageProcessor.post_process_semantic_segmentation`] or + [`~OneFormerImageProcessor.post_process_instance_segmentation`] or + [`~OneFormerImageProcessor.post_process_panoptic_segmentation`] depending on the task. Please, see + [`~OneFormerImageProcessor] for details regarding usage. + + Args: + loss (`torch.Tensor`, *optional*): + The computed loss, returned when labels are present. + class_queries_logits (`torch.FloatTensor`): + A tensor of shape `(batch_size, num_queries, num_labels + 1)` representing the proposed classes for each + query. Note the `+ 1` is needed because we incorporate the null class. + masks_queries_logits (`torch.FloatTensor`): + A tensor of shape `(batch_size, num_queries, height, width)` representing the proposed masks for each + query. + auxiliary_predictions (List of Dict of `str, torch.FloatTensor`, *optional*): + List of class and mask predictions from each layer of the transformer decoder. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the encoder + model at the output of each stage. + pixel_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the pixel + decoder model at the output of each stage. + transformer_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, sequence_length, hidden_size)`. Hidden-states (also called feature maps) of the + transformer decoder at the output of each stage. + transformer_decoder_object_queries (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_dim)`) + Output object queries from the last layer in the transformer decoder. + transformer_decoder_contrastive_queries (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_dim)`) + Contrastive queries from the transformer decoder. + transformer_decoder_mask_predictions (`torch.FloatTensor` of shape `(batch_size, num_queries, height, width)`) + Mask Predictions from the last layer in the transformer decoder. + transformer_decoder_class_predictions (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes+1)`): + Class Predictions from the last layer in the transformer decoder. + transformer_decoder_auxiliary_predictions (List of Dict of `str, torch.FloatTensor`, *optional*): + List of class and mask predictions from each layer of the transformer decoder. + text_queries (`torch.FloatTensor`, *optional* of shape `(batch_size, num_queries, hidden_dim)`) + Text queries derived from the input text list used for calculating contrastive loss during training. + task_token (`torch.FloatTensor` of shape `(batch_size, hidden_dim)`) + 1D task token to condition the queries. + attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tuple(torch.FloatTensor)` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Self and Cross Attentions weights from transformer decoder. + """ + + loss: Optional[torch.FloatTensor] = None + class_queries_logits: Optional[torch.FloatTensor] = None + masks_queries_logits: Optional[torch.FloatTensor] = None + auxiliary_predictions: List[Dict[str, torch.FloatTensor]] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + pixel_decoder_hidden_states: Optional[List[torch.FloatTensor]] = None + transformer_decoder_hidden_states: Optional[torch.FloatTensor] = None + transformer_decoder_object_queries: Optional[torch.FloatTensor] = None + transformer_decoder_contrastive_queries: Optional[torch.FloatTensor] = None + transformer_decoder_mask_predictions: Optional[torch.FloatTensor] = None + transformer_decoder_class_predictions: Optional[torch.FloatTensor] = None + transformer_decoder_auxiliary_predictions: Optional[List[Dict[str, torch.FloatTensor]]] = None + text_queries: Optional[torch.FloatTensor] = None + task_token: Optional[torch.FloatTensor] = None + attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + + +# Modified from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrFrozenBatchNorm2d with DeformableDetr->OneFormerPixelDecoder +class OneFormerPixelDecoderFrozenBatchNorm2d(nn.Module): + """ + BatchNorm2d where the batch statistics and the affine parameters are fixed. + + Copy-paste from torchvision.misc.ops with added eps before rqsrt, without which any other models than + torchvision.models.resnet[18,34,50,101] produce nans. + """ + + def __init__(self, n): + super().__init__() + self.register_buffer("weight", torch.ones(n)) + self.register_buffer("bias", torch.zeros(n)) + self.register_buffer("running_mean", torch.zeros(n)) + self.register_buffer("running_var", torch.ones(n)) + + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + num_batches_tracked_key = prefix + "num_batches_tracked" + if num_batches_tracked_key in state_dict: + del state_dict[num_batches_tracked_key] + + super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) + + def forward(self, x): + weight = self.weight.reshape(1, -1, 1, 1) + bias = self.bias.reshape(1, -1, 1, 1) + running_var = self.running_var.reshape(1, -1, 1, 1) + running_mean = self.running_mean.reshape(1, -1, 1, 1) + epsilon = 1e-5 + scale = weight * (running_var + epsilon).rsqrt() + bias = bias - running_mean * scale + return x * scale + bias + + +# Modified from transformers.models.detr.modeling_deformable_detr.DeformableDetrMultiscaleDeformableAttention with DeformableDetr->OneFormerPixelDecoderEncoder +class OneFormerPixelDecoderEncoderMultiscaleDeformableAttention(nn.Module): + """ + Multiscale deformable attention as proposed in Deformable DETR. + """ + + def __init__(self, embed_dim: int, num_heads: int, n_levels: int, n_points: int): + super().__init__() + if embed_dim % num_heads != 0: + raise ValueError( + f"embed_dim (d_model) must be divisible by num_heads, but got {embed_dim} and {num_heads}" + ) + dim_per_head = embed_dim // num_heads + # check if dim_per_head is power of 2 + if not ((dim_per_head & (dim_per_head - 1) == 0) and dim_per_head != 0): + warnings.warn( + "You'd better set embed_dim (d_model) in DeformableDetrMultiscaleDeformableAttention to make the" + " dimension of each attention head a power of 2 which is more efficient in the authors' CUDA" + " implementation." + ) + + self.im2col_step = 128 + + self.d_model = embed_dim + self.n_levels = n_levels + self.n_heads = num_heads + self.n_points = n_points + + self.sampling_offsets = nn.Linear(embed_dim, num_heads * n_levels * n_points * 2) + self.attention_weights = nn.Linear(embed_dim, num_heads * n_levels * n_points) + self.value_proj = nn.Linear(embed_dim, embed_dim) + self.output_proj = nn.Linear(embed_dim, embed_dim) + + def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]): + return tensor if position_embeddings is None else tensor + position_embeddings + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states=None, + encoder_attention_mask=None, + position_embeddings: Optional[torch.Tensor] = None, + reference_points=None, + spatial_shapes=None, + level_start_index=None, + output_attentions: bool = False, + ): + # add position embeddings to the hidden states before projecting to queries and keys + if position_embeddings is not None: + hidden_states = self.with_pos_embed(hidden_states, position_embeddings) + + batch_size, num_queries, _ = hidden_states.shape + batch_size, sequence_length, _ = encoder_hidden_states.shape + if (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() != sequence_length: + raise ValueError( + "Make sure to align the spatial shapes with the sequence length of the encoder hidden states" + ) + + value = self.value_proj(encoder_hidden_states) + if attention_mask is not None: + # we invert the attention_mask + value = value.masked_fill(attention_mask[..., None], float(0)) + value = value.view(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads) + sampling_offsets = self.sampling_offsets(hidden_states).view( + batch_size, num_queries, self.n_heads, self.n_levels, self.n_points, 2 + ) + attention_weights = self.attention_weights(hidden_states).view( + batch_size, num_queries, self.n_heads, self.n_levels * self.n_points + ) + attention_weights = nn.functional.softmax(attention_weights, -1).view( + batch_size, num_queries, self.n_heads, self.n_levels, self.n_points + ) + # batch_size, num_queries, n_heads, n_levels, n_points, 2 + if reference_points.shape[-1] == 2: + offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1) + sampling_locations = ( + reference_points[:, :, None, :, None, :] + + sampling_offsets / offset_normalizer[None, None, None, :, None, :] + ) + elif reference_points.shape[-1] == 4: + sampling_locations = ( + reference_points[:, :, None, :, None, :2] + + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 + ) + else: + raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}") + # PyTorch implementation + output = multi_scale_deformable_attention(value, spatial_shapes, sampling_locations, attention_weights) + output = self.output_proj(output) + + return output, attention_weights + + +class OneFormerPixelDecoderEncoderLayer(nn.Module): + def __init__(self, config: OneFormerConfig): + super().__init__() + self.embed_dim = config.conv_dim + self.self_attn = OneFormerPixelDecoderEncoderMultiscaleDeformableAttention( + embed_dim=self.embed_dim, + num_heads=config.num_attention_heads, + n_levels=3, + n_points=4, + ) + + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.dropout = config.dropout + self.activation_fn = nn.functional.relu + self.activation_dropout = config.dropout + self.fc1 = nn.Linear(self.embed_dim, config.encoder_feedforward_dim) + self.fc2 = nn.Linear(config.encoder_feedforward_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + self.is_training = config.is_training + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + position_embeddings: Optional[torch.Tensor] = None, + reference_points=None, + spatial_shapes=None, + level_start_index=None, + output_attentions: bool = False, + ): + """ + Args: + hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Input to the layer. + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Attention mask. + position_embeddings (`torch.FloatTensor`, *optional*): + Position embeddings, to be added to `hidden_states`. + reference_points (`torch.FloatTensor`, *optional*): + Reference points. + spatial_shapes (`torch.LongTensor`, *optional*): + Spatial shapes of the backbone feature maps. + level_start_index (`torch.LongTensor`, *optional*): + Level start index. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + # Apply Multi-scale Deformable Attention Module on the multi-scale feature maps. + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + position_embeddings=position_embeddings, + reference_points=reference_points, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + output_attentions=output_attentions, + ) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.is_training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.is_training) + + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.is_training) + + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + if self.is_training: + if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +# Modified from from transformers.models.detr.modeling_deformable_detr.DeformableDetrEncoder with DeformableDetrEncoder->OneFormerPixelDecoderEncoderOnly +class OneFormerPixelDecoderEncoderOnly(nn.Module): + """ + Transformer encoder consisting of *config.encoder_layers* deformable attention layers. Each layer is a + [`OneFormerPixelDecoderEncoderLayer`]. + + The encoder updates the flattened multi-scale feature maps through multiple deformable attention layers. + + Args: + config: OneFormerConfig + """ + + def __init__(self, config: OneFormerConfig): + super().__init__() + + self.config = config + self.dropout = config.dropout + self.layers = nn.ModuleList([OneFormerPixelDecoderEncoderLayer(config) for _ in range(config.encoder_layers)]) + + @staticmethod + def get_reference_points(spatial_shapes, valid_ratios, device): + """ + Get reference points for each feature map. Used in decoder. + + Args: + spatial_shapes (`torch.LongTensor` of shape `(num_feature_levels, 2)`): + Spatial shapes of each feature map. + valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`): + Valid ratios of each feature map. + device (`torch.device`): + Device on which to create the tensors. + Returns: + `torch.FloatTensor` of shape `(batch_size, num_queries, num_feature_levels, 2)` + """ + reference_points_list = [] + for lvl, (height, width) in enumerate(spatial_shapes): + ref_y, ref_x = torch.meshgrid( + torch.linspace(0.5, height - 0.5, height, dtype=valid_ratios.dtype, device=device), + torch.linspace(0.5, width - 0.5, width, dtype=valid_ratios.dtype, device=device), + ) + ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * height) + ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * width) + ref = torch.stack((ref_x, ref_y), -1) + reference_points_list.append(ref) + reference_points = torch.cat(reference_points_list, 1) + reference_points = reference_points[:, :, None] * valid_ratios[:, None] + return reference_points + + def forward( + self, + inputs_embeds=None, + attention_mask=None, + position_embeddings=None, + spatial_shapes=None, + level_start_index=None, + valid_ratios=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Flattened feature map (output of the backbone + projection layer) that is passed to the encoder. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding pixel features. Mask values selected in `[0, 1]`: + - 1 for pixel features that are real (i.e. **not masked**), + - 0 for pixel features that are padding (i.e. **masked**). + [What are attention masks?](../glossary#attention-mask) + position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Position embeddings that are added to the queries and keys in each self-attention layer. + spatial_shapes (`torch.LongTensor` of shape `(num_feature_levels, 2)`): + Spatial shapes of each feature map. + level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`): + Starting index of each feature map. + valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`): + Ratio of valid area in each feature level. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + hidden_states = inputs_embeds + reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=inputs_embeds.device) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + for i, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + position_embeddings=position_embeddings, + reference_points=reference_points, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +# Modified from from transformers.models.mask2former.modeling_mask2former.Mask2FormerPixelDecoder with Mask2->One +class OneFormerPixelDecoder(nn.Module): + def __init__(self, config: OneFormerConfig, feature_channels): + super().__init__() + + self.config = config + + # positional encoding + self.position_embedding = OneFormerSinePositionEmbedding(num_pos_feats=config.conv_dim // 2, normalize=True) + self.num_feature_levels = 3 + transformer_in_channels = feature_channels[-self.num_feature_levels :] + self.transformer_feature_strides = config.strides[-self.num_feature_levels :] + self.feature_channels = feature_channels + self.level_embed = nn.Parameter(torch.Tensor(self.num_feature_levels, config.conv_dim)) + + # Create input projection layers + if self.num_feature_levels > 1: + input_projections_list = [] + for in_channels in transformer_in_channels[::-1]: + input_projections_list.append( + nn.Sequential( + nn.Conv2d(in_channels, config.conv_dim, kernel_size=1), + nn.GroupNorm(32, config.conv_dim), + ) + ) + self.input_projections = nn.ModuleList(input_projections_list) + else: + self.input_projections = nn.ModuleList( + [ + nn.Sequential( + nn.Conv2d(transformer_in_channels[-1], config.conv_dim, kernel_size=1), + nn.GroupNorm(32, config.conv_dim), + ) + ] + ) + + self.encoder = OneFormerPixelDecoderEncoderOnly(config) + + self.mask_projection = nn.Conv2d( + config.conv_dim, + config.mask_dim, + kernel_size=1, + stride=1, + padding=0, + ) + + self.common_stride = config.common_stride + + # extra fpn levels + stride = min(self.transformer_feature_strides) + self.num_fpn_levels = int(np.log2(stride) - np.log2(self.common_stride)) + + lateral_convs = [] + output_convs = [] + + for idx, in_channels in enumerate(self.feature_channels[: self.num_fpn_levels]): + lateral_conv = nn.Sequential( + nn.Conv2d( + in_channels, + config.conv_dim, + kernel_size=1, + bias=False, + ), + nn.GroupNorm(32, config.conv_dim), + ) + output_conv = nn.Sequential( + nn.Conv2d( + config.conv_dim, + config.conv_dim, + kernel_size=3, + stride=1, + padding=1, + bias=False, + ), + nn.GroupNorm(32, config.conv_dim), + nn.ReLU(), + ) + self.add_module("adapter_{}".format(idx + 1), lateral_conv) + self.add_module("layer_{}".format(idx + 1), output_conv) + + lateral_convs.append(lateral_conv) + output_convs.append(output_conv) + # Place convs into top-down order (from low to high resolution) + # to make the top-down computation in forward clearer. + self.lateral_convs = lateral_convs[::-1] + self.output_convs = output_convs[::-1] + + def get_valid_ratio(self, mask, dtype=torch.float32): + """Get the valid ratio of all feature maps.""" + + _, height, width = mask.shape + valid_height = torch.sum(~mask[:, :, 0], 1) + valid_width = torch.sum(~mask[:, 0, :], 1) + valid_ratio_heigth = valid_height.to(dtype) / height + valid_ratio_width = valid_width.to(dtype) / width + valid_ratio = torch.stack([valid_ratio_width, valid_ratio_heigth], -1) + return valid_ratio + + def forward( + self, + features, + encoder_outputs=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + # Then, apply 1x1 convolution to reduce the channel dimension to d_model (256 by default) + sources = [] + position_embeddings_list = [] + for level, source in enumerate(features[::-1][: self.num_feature_levels]): + sources.append(self.input_projections[level](source)) + position_embeddings_list.append(self.position_embedding(source)) + + masks = [torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) for x in sources] + + # Prepare encoder inputs (by flattening) + source_flatten = [] + mask_flatten = [] + lvl_pos_embed_flatten = [] + spatial_shapes = [] + for level, (source, mask, pos_embed) in enumerate(zip(sources, masks, position_embeddings_list)): + batch_size, num_channels, height, width = source.shape + spatial_shape = (height, width) + spatial_shapes.append(spatial_shape) + source = source.flatten(2).transpose(1, 2) + mask = mask.flatten(1) + pos_embed = pos_embed.flatten(2).transpose(1, 2) + lvl_pos_embed = pos_embed + self.level_embed[level].view(1, 1, -1) + lvl_pos_embed_flatten.append(lvl_pos_embed) + source_flatten.append(source) + mask_flatten.append(mask) + source_flatten = torch.cat(source_flatten, 1) + mask_flatten = torch.cat(mask_flatten, 1) + lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) + spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=source_flatten.device) + level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) + valid_ratios = torch.stack([self.get_valid_ratio(m, dtype=source_flatten.dtype) for m in masks], 1) + + # Fourth, sent source_flatten + mask_flatten + lvl_pos_embed_flatten (backbone + proj layer output) through encoder + # Also provide spatial_shapes, level_start_index and valid_ratios + if encoder_outputs is None: + encoder_outputs = self.encoder( + inputs_embeds=source_flatten, + attention_mask=mask_flatten, + position_embeddings=lvl_pos_embed_flatten, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + valid_ratios=valid_ratios, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + y = encoder_outputs.last_hidden_state + bs = y.shape[0] + + split_size_or_sections = [None] * self.num_feature_levels + for i in range(self.num_feature_levels): + if i < self.num_feature_levels - 1: + split_size_or_sections[i] = level_start_index[i + 1] - level_start_index[i] + else: + split_size_or_sections[i] = y.shape[1] - level_start_index[i] + y = torch.split(y, split_size_or_sections, dim=1) + + out = [] + multi_scale_features = [] + num_cur_levels = 0 + for i, z in enumerate(y): + out.append(z.transpose(1, 2).view(bs, -1, spatial_shapes[i][0], spatial_shapes[i][1])) + + # append `out` with extra FPN levels + # Reverse feature maps into top-down order (from low to high resolution) + for idx, feats in enumerate(features[: self.num_fpn_levels][::-1]): + lateral_conv = self.lateral_convs[idx] + output_conv = self.output_convs[idx] + cur_fpn = lateral_conv(feats) + # Following FPN implementation, we use nearest upsampling here + y = cur_fpn + nn.functional.interpolate( + out[-1], size=cur_fpn.shape[-2:], mode="bilinear", align_corners=False + ) + y = output_conv(y) + out.append(y) + + for o in out: + if num_cur_levels < self.num_feature_levels: + multi_scale_features.append(o) + num_cur_levels += 1 + + return OneFormerPixelDecoderOutput( + mask_features=self.mask_projection(out[-1]), + multi_scale_features=multi_scale_features, + attentions=encoder_outputs.attentions, + ) + + +# Modified from from transformers.models.mask2former.modeling_mask2former.Mask2FormerPixelLevelModule with Mask2->One +class OneFormerPixelLevelModule(nn.Module): + def __init__(self, config: OneFormerConfig): + """ + Pixel Level Module proposed in [Masked-attention Mask Transformer for Universal Image + Segmentation](https://arxiv.org/abs/2112.01527). It runs the input image through a backbone and a pixel + decoder, generating multi-scale feature maps and pixel embeddings. + + Args: + config ([`OneFormerConfig`]): + The configuration used to instantiate this model. + """ + super().__init__() + self.encoder = load_backbone(config) + self.decoder = OneFormerPixelDecoder(config, feature_channels=self.encoder.channels) + + def forward(self, pixel_values: Tensor, output_hidden_states: bool = False) -> OneFormerPixelLevelModuleOutput: + features: List[Tensor] = self.encoder(pixel_values).feature_maps + decoder_output: OneFormerPixelDecoderOutput = self.decoder(features, output_hidden_states=output_hidden_states) + return OneFormerPixelLevelModuleOutput( + encoder_features=tuple(features), + decoder_features=decoder_output.multi_scale_features, + decoder_last_feature=decoder_output.mask_features, + ) + + +# Modified from transformers.models.detr.modeling_detr.DetrAttention with Detr->OneFormer +class OneFormerAttention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Here, we add position embeddings to the queries and + keys (as explained in the DETR paper). + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + if self.head_dim * num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int): + return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]): + return tensor if position_embeddings is None else tensor + position_embeddings + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_embeddings: Optional[torch.Tensor] = None, + key_value_states: Optional[torch.Tensor] = None, + key_value_position_embeddings: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + hidden_states = hidden_states.permute(1, 0, 2) if hidden_states is not None else None + position_embeddings = position_embeddings.permute(1, 0, 2) if position_embeddings is not None else None + key_value_states = key_value_states.permute(1, 0, 2) if key_value_states is not None else None + key_value_position_embeddings = ( + key_value_position_embeddings.permute(1, 0, 2) if key_value_position_embeddings is not None else None + ) + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + batch_size, target_len, embed_dim = hidden_states.size() + + # add position embeddings to the hidden states before projecting to queries and keys + if position_embeddings is not None: + hidden_states_original = hidden_states + hidden_states = self.with_pos_embed(hidden_states, position_embeddings) + + # add key-value position embeddings to the key value states + if key_value_position_embeddings is not None: + key_value_states_original = key_value_states + key_value_states = self.with_pos_embed(key_value_states, key_value_position_embeddings) + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + if is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, batch_size) + value_states = self._shape(self.v_proj(key_value_states_original), -1, batch_size) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, batch_size) + value_states = self._shape(self.v_proj(hidden_states_original), -1, batch_size) + + proj_shape = (batch_size * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, target_len, batch_size).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + source_len = key_states.size(1) + + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (batch_size * self.num_heads, target_len, source_len): + raise ValueError( + f"Attention weights should be of size {(batch_size * self.num_heads, target_len, source_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (batch_size * self.num_heads, target_len, source_len): + raise ValueError( + f"Attention mask should be of size {(target_len, batch_size * self.num_heads, source_len)}, but is" + f" {attention_mask.size()}" + ) + attn_weights += attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attn_weights = attn_weights_reshaped.view(batch_size * self.num_heads, target_len, source_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (batch_size * self.num_heads, target_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(batch_size, self.num_heads, target_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(batch_size, self.num_heads, target_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(batch_size, target_len, embed_dim) + + attn_output = self.out_proj(attn_output).permute(1, 0, 2) + + return attn_output, attn_weights_reshaped + + +class OneFormerTransformerDecoderSelfAttentionLayer(nn.Module): + def __init__( + self, embed_dim, num_heads, dropout=0.0, activation="relu", normalize_before=False, layer_norm_eps=1e-05 + ): + super().__init__() + self.self_attn = OneFormerAttention(embed_dim=embed_dim, num_heads=num_heads, dropout=dropout, is_decoder=True) + + self.norm = nn.LayerNorm(embed_dim, eps=layer_norm_eps) + self.dropout = nn.Dropout(dropout) + + self.activation = ACT2FN[activation] + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post( + self, + output, + output_mask: Optional[Tensor] = None, + output_key_padding_mask: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + output2, attention_weights = self.self_attn( + hidden_states=output, position_embeddings=query_pos, attention_mask=output_mask, output_attentions=True + ) + output = output + self.dropout(output2) + output = self.norm(output) + + return output, attention_weights + + def forward_pre( + self, + output, + output_mask: Optional[Tensor] = None, + output_key_padding_mask: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + output2 = self.norm(output) + output2, attention_weights = self.self_attn( + hidden_states=output2, position_embeddings=query_pos, attention_mask=output_mask, output_attentions=True + ) + output = output + self.dropout(output2) + + return output, attention_weights + + def forward( + self, + output, + output_mask: Optional[Tensor] = None, + output_key_padding_mask: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + if self.normalize_before: + return self.forward_pre(output, output_mask, output_key_padding_mask, query_pos) + return self.forward_post(output, output_mask, output_key_padding_mask, query_pos) + + +class OneFormerTransformerDecoderCrossAttentionLayer(nn.Module): + def __init__( + self, embed_dim, num_heads, dropout=0.0, activation="relu", normalize_before=False, layer_norm_eps=1e-05 + ): + super().__init__() + self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout) + + self.norm = nn.LayerNorm(embed_dim, eps=layer_norm_eps) + self.dropout = nn.Dropout(dropout) + + self.activation = ACT2FN[activation] + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post( + self, + output, + memory, + memory_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + output2, attention_weights = self.multihead_attn( + query=self.with_pos_embed(output, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, + attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask, + ) + output = output + self.dropout(output2) + output = self.norm(output) + + return output, attention_weights + + def forward_pre( + self, + output, + memory, + memory_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + output2 = self.norm(output) + output2, attention_weights = self.multihead_attn( + query=self.with_pos_embed(output2, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, + attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask, + ) + output = output + self.dropout(output2) + + return output, attention_weights + + def forward( + self, + output, + memory, + memory_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + if self.normalize_before: + return self.forward_pre(output, memory, memory_mask, memory_key_padding_mask, pos, query_pos) + return self.forward_post(output, memory, memory_mask, memory_key_padding_mask, pos, query_pos) + + +class OneFormerTransformerDecoderFFNLayer(nn.Module): + def __init__( + self, + d_model, + dim_feedforward=2048, + dropout=0.0, + activation="relu", + normalize_before=False, + layer_norm_eps=1e-05, + ): + super().__init__() + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm = nn.LayerNorm(d_model, eps=layer_norm_eps) + + self.activation = ACT2FN[activation] + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post(self, output): + output2 = self.linear2(self.dropout(self.activation(self.linear1(output)))) + output = output + self.dropout(output2) + output = self.norm(output) + return output + + def forward_pre(self, output): + output2 = self.norm(output) + output2 = self.linear2(self.dropout(self.activation(self.linear1(output2)))) + output = output + self.dropout(output2) + return output + + def forward(self, output): + if self.normalize_before: + return self.forward_pre(output) + return self.forward_post(output) + + +class OneFormerMLPPredictionHead(nn.Module): + def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int = 3): + """ + A classic Multi Layer Perceptron (MLP). + + Args: + input_dim (`int`): + The input dimensions. + hidden_dim (`int`): + The hidden dimensions. + output_dim (`int`): + The output dimensions. + num_layers (int, *optional*, defaults to 3): + The number of layers. + """ + super().__init__() + in_dims = [input_dim] + [hidden_dim] * (num_layers - 1) + out_dims = [hidden_dim] * (num_layers - 1) + [output_dim] + + layers = [] + for i, (in_dim, out_dim) in enumerate(zip(in_dims, out_dims)): + layers.append( + PredictionBlock(in_dim, out_dim, activation=nn.ReLU() if i < num_layers - 1 else nn.Identity()) + ) + + self.layers = nn.Sequential(*layers) + + def forward(self, input: Tensor) -> Tensor: + return self.layers(input) + + +# refactored from original implementation +class OneFormerTransformerDecoderLayer(nn.Module): + def __init__(self, config: OneFormerConfig): + super().__init__() + self.embed_dim = config.hidden_dim + self.num_feature_levels = 3 + + self.cross_attn = OneFormerTransformerDecoderCrossAttentionLayer( + embed_dim=self.embed_dim, + num_heads=config.num_attention_heads, + dropout=0.0, + normalize_before=config.pre_norm, + layer_norm_eps=config.layer_norm_eps, + ) + + self.self_attn = OneFormerTransformerDecoderSelfAttentionLayer( + embed_dim=self.embed_dim, + num_heads=config.num_attention_heads, + dropout=0.0, + normalize_before=config.pre_norm, + layer_norm_eps=config.layer_norm_eps, + ) + + self.ffn = OneFormerTransformerDecoderFFNLayer( + d_model=self.embed_dim, + dim_feedforward=config.dim_feedforward, + dropout=0.0, + normalize_before=config.pre_norm, + layer_norm_eps=config.layer_norm_eps, + ) + + def forward( + self, + index: int, + output: torch.Tensor, + multi_stage_features: List[torch.Tensor], + multi_stage_positional_embeddings: List[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + query_embeddings: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ): + """ + Args: + index (`int`): index of the layer in the Transformer decoder. + output (`torch.FloatTensor`): the object queries of shape `(N, batch, hidden_dim)` + multi_stage_features (`List[torch.Tensor]`): the multi-scale features from the pixel decoder. + multi_stage_positional_embeddings (`List[torch.Tensor]`): + positional embeddings for the multi_stage_features + attention_mask (`torch.FloatTensor`): attention mask for the masked cross attention layer + query_embeddings (`torch.FloatTensor`, *optional*): + position embeddings that are added to the queries and keys in the self-attention layer. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + + level_index = index % self.num_feature_levels + attention_mask[torch.where(attention_mask.sum(-1) == attention_mask.shape[-1])] = False + + # Masked Cross Attention + output, cross_attn_weights = self.cross_attn( + output, + multi_stage_features[level_index], + memory_mask=attention_mask, + memory_key_padding_mask=None, # here we do not apply masking on padded region + pos=multi_stage_positional_embeddings[level_index], + query_pos=query_embeddings, + ) + + # Self Attention + output, self_attn_weights = self.self_attn( + output, + output_mask=None, + output_key_padding_mask=None, + query_pos=query_embeddings, + ) + + # Fully Connected + output = self.ffn(output) + + outputs = (output,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + return outputs + + +class OneFormerTransformerDecoderQueryTransformerDecoder(nn.Module): + def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): + super().__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + self.return_intermediate = return_intermediate + + def forward( + self, + output, + memory, + output_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + output_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + intermediate = [] + + for layer in self.layers: + output = layer( + output, + memory, + output_mask=output_mask, + memory_mask=memory_mask, + output_key_padding_mask=output_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + pos=pos, + query_pos=query_pos, + ) + if self.return_intermediate: + intermediate.append(self.norm(output)) + + if self.norm is not None: + output = self.norm(output) + if self.return_intermediate: + intermediate.pop() + intermediate.append(output) + + if self.return_intermediate: + return torch.stack(intermediate) + + return output.unsqueeze(0) + + +class OneFormerTransformerDecoderQueryTransformerDecoderLayer(nn.Module): + def __init__( + self, + d_model, + nhead, + dim_feedforward=2048, + dropout=0.1, + activation="relu", + normalize_before=False, + layer_norm_eps=1e-05, + ): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps) + self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps) + self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation = ACT2FN[activation] + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post( + self, + output, + memory, + output_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + output_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + q = k = self.with_pos_embed(output, query_pos) + output2 = self.self_attn(q, k, value=output, attn_mask=output_mask, key_padding_mask=output_key_padding_mask) + output2 = output2[0] + output = output + self.dropout1(output2) + output = self.norm1(output) + output2 = self.multihead_attn( + query=self.with_pos_embed(output, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, + attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask, + ) + output2 = output2[0] + output = output + self.dropout2(output2) + output = self.norm2(output) + output2 = self.linear2(self.dropout(self.activation(self.linear1(output)))) + output = output + self.dropout3(output2) + output = self.norm3(output) + return output + + def forward_pre( + self, + output, + memory, + output_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + output_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + output2 = self.norm1(output) + q = k = self.with_pos_embed(output2, query_pos) + output2 = self.self_attn(q, k, value=output2, attn_mask=output_mask, key_padding_mask=output_key_padding_mask) + output2 = output2[0] + output = output + self.dropout1(output2) + output2 = self.norm2(output) + output2 = self.multihead_attn( + query=self.with_pos_embed(output2, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, + attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask, + ) + output2 = output2[0] + output = output + self.dropout2(output2) + output2 = self.norm3(output) + output2 = self.linear2(self.dropout(self.activation(self.linear1(output2)))) + output = output + self.dropout3(output2) + return output + + def forward( + self, + output, + memory, + output_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + output_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + if self.normalize_before: + return self.forward_pre( + output, + memory, + output_mask, + memory_mask, + output_key_padding_mask, + memory_key_padding_mask, + pos, + query_pos, + ) + return self.forward_post( + output, + memory, + output_mask, + memory_mask, + output_key_padding_mask, + memory_key_padding_mask, + pos, + query_pos, + ) + + +class OneFormerTransformerDecoderQueryTransformer(nn.Module): + def __init__( + self, + d_model=512, + nhead=8, + num_decoder_layers=6, + dim_feedforward=2048, + dropout=0.1, + activation="relu", + normalize_before=False, + return_intermediate_dec=False, + layer_norm_eps=1e-05, + ): + super().__init__() + + decoder_layer = OneFormerTransformerDecoderQueryTransformerDecoderLayer( + d_model, nhead, dim_feedforward, dropout, activation, normalize_before, layer_norm_eps + ) + decoder_norm = nn.LayerNorm(d_model, eps=layer_norm_eps) + self.decoder = OneFormerTransformerDecoderQueryTransformerDecoder( + decoder_layer, + num_decoder_layers, + decoder_norm, + return_intermediate=return_intermediate_dec, + ) + + self.d_model = d_model + self.nhead = nhead + + def forward(self, src, mask, query_embed, pos_embed, task_token=None): + batch_size = src.shape[0] + src = src.flatten(2).permute(2, 0, 1) + pos_embed = pos_embed.flatten(2).permute(2, 0, 1) + query_embed = query_embed.unsqueeze(1).repeat(1, batch_size, 1) + if mask is not None: + mask = mask.flatten(1) + + if task_token is None: + queries = torch.zeros_like(query_embed) + else: + queries = task_token.repeat(query_embed.shape[0], 1, 1) + + queries = self.decoder(queries, src, memory_key_padding_mask=mask, pos=pos_embed, query_pos=query_embed) + return queries.transpose(1, 2) + + +class OneFormerTransformerDecoder(nn.Module): + """ + Transformer decoder + """ + + def __init__(self, in_channels: int, config: OneFormerConfig): + super().__init__() + self.config = config + + self.dropout = config.dropout + self.num_heads = config.num_attention_heads + self.is_training = config.is_training + self.use_task_norm = config.use_task_norm + self.use_auxiliary_loss = config.use_auxiliary_loss + + self.query_transformer = OneFormerTransformerDecoderQueryTransformer( + d_model=config.hidden_dim, + dropout=config.dropout, + nhead=config.num_attention_heads, + dim_feedforward=config.dim_feedforward, + num_decoder_layers=config.query_dec_layers, + normalize_before=config.pre_norm, + return_intermediate_dec=False, + layer_norm_eps=config.layer_norm_eps, + ) + + self.decoder_norm = nn.LayerNorm(config.hidden_dim, eps=config.layer_norm_eps) + + self.num_feature_levels = 3 + + self.layers = nn.ModuleList( + [OneFormerTransformerDecoderLayer(config) for _ in range(config.decoder_layers - 1)] + ) + + self.query_input_projection = nn.Conv2d(in_channels, config.hidden_dim, kernel_size=1) + + self.class_embed = nn.Linear(config.hidden_dim, config.num_labels + 1) + self.mask_embed = OneFormerMLPPredictionHead( + config.hidden_dim, + config.hidden_dim, + config.mask_dim, + 3, + ) + + def forward( + self, + task_token=None, + multi_stage_features=None, + multi_stage_positional_embeddings=None, + mask_features=None, + query_features=None, + query_embeddings=None, + query_embedder=None, + size_list=None, + output_attentions=None, + ): + if self.use_task_norm: + task_token = self.decoder_norm(task_token) + + object_queries = self.query_transformer( + query_features, + None, + query_embedder.weight[:-1], + self.query_input_projection(mask_features), + task_token if self.use_task_norm else None, + ) + + object_queries = object_queries[0].permute(1, 0, 2) + + queries = torch.cat([object_queries, task_token], dim=0) + + output = queries.clone() + + intermediate_class_predictions = [] + intermediate_mask_predictions = [] + + # prediction heads on learnable query features + outputs_class, outputs_mask, attention_mask = self.forward_prediction_heads( + output, mask_features, attention_mask_target_size=size_list[0] + ) + intermediate_class_predictions.append(outputs_class) + intermediate_mask_predictions.append(outputs_mask) + + attentions = () + + for index, layer in enumerate(self.layers): + layer_outputs = layer( + index=index, + output=output, + multi_stage_features=multi_stage_features, + multi_stage_positional_embeddings=multi_stage_positional_embeddings, + attention_mask=attention_mask, + query_embeddings=query_embeddings, + output_attentions=output_attentions, + ) + + output = layer_outputs[0] + attentions += (layer_outputs[1:],) + + outputs_class, outputs_mask, attention_mask = self.forward_prediction_heads( + output, mask_features, attention_mask_target_size=size_list[(index + 1) % self.num_feature_levels] + ) + intermediate_class_predictions.append(outputs_class) + intermediate_mask_predictions.append(outputs_mask) + + if not len(intermediate_mask_predictions) == len(self.layers) + 1: + raise ValueError( + "Intermediate predictions in the transformer decoder must have the same number of elements as number" + " of layers" + ) + + object_queries = layer_outputs[0].permute(1, 0, 2) + + contrastive_logits = queries.permute(1, 0, 2) + + return OneFormerTransformerDecoderOutput( + object_queries=object_queries, + contrastive_logits=contrastive_logits, + prediction_masks=intermediate_mask_predictions[-1], + prediction_class=intermediate_class_predictions[-1], + auxiliary_predictions=self._get_aux_predictions( + intermediate_class_predictions, intermediate_mask_predictions + ) + if self.use_auxiliary_loss + else None, + attentions=attentions, + ) + + def forward_prediction_heads(self, output, mask_features, attention_mask_target_size): + decoder_output = self.decoder_norm(output) + decoder_output = decoder_output.transpose(0, 1) + outputs_class = self.class_embed(decoder_output) + mask_embed = self.mask_embed(decoder_output) + outputs_mask = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features) + + attention_mask = nn.functional.interpolate( + outputs_mask, size=attention_mask_target_size, mode="bilinear", align_corners=False + ) + + # must use bool type + # If a BoolTensor is provided, positions with ``True`` are not allowed to attend while ``False`` values will be unchanged. + attention_mask = ( + attention_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, self.num_heads, 1, 1).flatten(0, 1) < 0.5 + ).bool() + attention_mask = attention_mask.detach() + + return outputs_class, outputs_mask, attention_mask + + @torch.jit.unused + def _get_aux_predictions(self, outputs_class, outputs_seg_masks): + # this is a workaround to make torchscript happy, as torchscript + # doesn't support dictionary with non-homogeneous values, such + # as a dict having both a Tensor and a list. + aux_list = [ + {"class_queries_logits": a, "masks_queries_logits": b} + for a, b in zip(outputs_class[:-1], outputs_seg_masks[:-1]) + ] + return tuple(aux_list) + + +class OneFormerTransformerModule(nn.Module): + """ + The OneFormer's transformer module. + """ + + def __init__(self, in_features: int, config: OneFormerConfig): + super().__init__() + hidden_dim = config.hidden_dim + self.num_feature_levels = 3 + self.position_embedder = OneFormerSinePositionEmbedding(num_pos_feats=hidden_dim // 2, normalize=True) + self.queries_embedder = nn.Embedding(config.num_queries, hidden_dim) + self.input_projections = [] + + for _ in range(self.num_feature_levels): + if in_features != hidden_dim or config.enforce_input_proj: + self.input_projections.append(nn.Conv2d(in_features, hidden_dim, kernel_size=1)) + else: + self.input_projections.append(nn.Sequential()) + + self.decoder = OneFormerTransformerDecoder(in_channels=in_features, config=config) + self.level_embed = nn.Embedding(self.num_feature_levels, hidden_dim) + + def forward( + self, + multi_scale_features: List[Tensor], + mask_features: Tensor, + task_token: Tensor, + output_attentions: bool = False, + ) -> OneFormerTransformerDecoderOutput: + if not len(multi_scale_features) == self.num_feature_levels: + raise ValueError( + f"Number of elements in multi_scale_features ({len(multi_scale_features)}) and num_feature_levels" + f" ({self.num_feature_levels}) do not match!" + ) + multi_stage_features = [] + multi_stage_positional_embeddings = [] + size_list = [] + + for i in range(self.num_feature_levels): + size_list.append(multi_scale_features[i].shape[-2:]) + multi_stage_positional_embeddings.append(self.position_embedder(multi_scale_features[i], None).flatten(2)) + multi_stage_features.append( + self.input_projections[i](multi_scale_features[i]).flatten(2) + + self.level_embed.weight[i][None, :, None] + ) + + # flatten NxCxHxW to HWxNxC + multi_stage_positional_embeddings[-1] = multi_stage_positional_embeddings[-1].permute(2, 0, 1) + multi_stage_features[-1] = multi_stage_features[-1].permute(2, 0, 1) + + _, batch_size, _ = multi_stage_features[0].shape + + # QxNxC + query_embeddings = self.queries_embedder.weight.unsqueeze(1).repeat(1, batch_size, 1) + task_token = task_token.unsqueeze(0) + + query_features = self.position_embedder(mask_features, None) + + return self.decoder( + task_token=task_token, + multi_stage_features=multi_stage_features, + multi_stage_positional_embeddings=multi_stage_positional_embeddings, + mask_features=mask_features, + query_features=query_features, + query_embeddings=query_embeddings, + query_embedder=self.queries_embedder, + size_list=size_list, + output_attentions=output_attentions, + ) + + +# Copied from transformers.models.maskformer.modeling_maskformer.MaskFormerSinePositionEmbedding with Mask->One +class OneFormerSinePositionEmbedding(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one used by the Attention is all you + need paper, generalized to work on images. + """ + + def __init__( + self, num_pos_feats: int = 64, temperature: int = 10000, normalize: bool = False, scale: Optional[float] = None + ): + super().__init__() + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + self.scale = 2 * math.pi if scale is None else scale + + def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor: + if mask is None: + mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) + not_mask = (~mask).to(x.dtype) + y_embed = not_mask.cumsum(1) + x_embed = not_mask.cumsum(2) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.int64, device=x.device).type_as(x) + dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + +# Copied from transformers.models.maskformer.modeling_maskformer.PredictionBlock +class PredictionBlock(nn.Module): + def __init__(self, in_dim: int, out_dim: int, activation: nn.Module) -> None: + super().__init__() + self.layers = [nn.Linear(in_dim, out_dim), activation] + # Maintain submodule indexing as if part of a Sequential block + for i, layer in enumerate(self.layers): + self.add_module(str(i), layer) + + def forward(self, input: Tensor) -> Tensor: + hidden_state = input + for layer in self.layers: + hidden_state = layer(hidden_state) + return hidden_state + + +class OneFormerTextMapperAttention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim**-0.5 + + self.q_proj = nn.Linear(dim, dim, bias=qkv_bias) + self.k_proj = nn.Linear(dim, dim, bias=qkv_bias) + self.v_proj = nn.Linear(dim, dim, bias=qkv_bias) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, q, k, v): + batch_size, q_sequence_length, num_channels = q.shape + if not k.shape == v.shape: + raise ValueError(f"keys ({list(k.shape)}) and values ({list(v.shape)}) have different shapes!") + batch_size, k_sequence_length, num_channels = k.shape + q = self.q_proj(q).reshape(batch_size, q_sequence_length, self.num_heads, num_channels // self.num_heads) + k = self.k_proj(k).reshape(batch_size, k_sequence_length, self.num_heads, num_channels // self.num_heads) + v = self.v_proj(v).reshape(batch_size, k_sequence_length, self.num_heads, num_channels // self.num_heads) + + attn = torch.einsum("bnkc,bmkc->bknm", q, k) * self.scale + + attn = attn.softmax(dim=-1) + + output = torch.einsum("bknm,bmkc->bnkc", attn, v).reshape(batch_size, q_sequence_length, num_channels) + + output = self.proj(output) + output = self.proj_drop(output) + return output + + +class OneFormerTextTransformerDecoderLayer(nn.Module): + def __init__( + self, + d_model, + nhead, + dropout=0.1, + layer_norm_eps=1e-05, + ): + super().__init__() + self.self_attn = OneFormerTextMapperAttention(d_model, nhead, proj_drop=dropout) + self.cross_attn = OneFormerTextMapperAttention(d_model, nhead, proj_drop=dropout) + + self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps) + self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps) + self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps) + self.dropout = nn.Dropout(dropout) + + self.mlp = nn.Sequential( + nn.Linear(d_model, d_model * 4), nn.GELU(), nn.Dropout(dropout), nn.Linear(d_model * 4, d_model) + ) + + def forward(self, hidden_state, mem): + q = k = v = self.norm1(hidden_state) + hidden_state = hidden_state + self.self_attn(q, k, v) + q = self.norm2(hidden_state) + hidden_state = hidden_state + self.cross_attn(q, mem, mem) + hidden_state = hidden_state + self.dropout(self.mlp(self.norm3(hidden_state))) + return hidden_state + + +class OneFormerTextContextDecoder(nn.Module): + def __init__( + self, + transformer_width=256, + transformer_heads=4, + transformer_layers=6, + visual_dim=1024, + dropout=0.1, + layer_norm_eps=1e-05, + **kwargs, + ): + super().__init__() + + self.memory_proj = nn.Sequential( + nn.LayerNorm(visual_dim, eps=layer_norm_eps), + nn.Linear(visual_dim, transformer_width), + nn.LayerNorm(transformer_width, eps=layer_norm_eps), + ) + + self.text_proj = nn.Sequential( + nn.LayerNorm(visual_dim, eps=layer_norm_eps), + nn.Linear(visual_dim, transformer_width), + ) + + self.decoder = nn.ModuleList( + [ + OneFormerTextTransformerDecoderLayer(transformer_width, transformer_heads, dropout, layer_norm_eps) + for _ in range(transformer_layers) + ] + ) + + self.out_proj = nn.Sequential( + nn.LayerNorm(transformer_width, eps=layer_norm_eps), nn.Linear(transformer_width, visual_dim) + ) + + def forward(self, text, visual): + visual = self.memory_proj(visual) + hidden_state = self.text_proj(text) + + for layer in self.decoder: + hidden_state = layer(hidden_state, visual) + + return self.out_proj(hidden_state) + + +class OneFormerTextMLP(nn.Module): + def __init__( + self, + hidden_size: Optional[int] = None, + intermediate_size: Optional[int] = None, + output_size: Optional[int] = None, + ): + super().__init__() + self.activation_fn = ACT2FN["quick_gelu"] + hidden_size = hidden_size + intermediate_size = intermediate_size + output_size = output_size + self.fc1 = nn.Linear(hidden_size, intermediate_size) + self.fc2 = nn.Linear(intermediate_size, output_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class OneFormerTextTransformerLayer(nn.Module): + def __init__(self, width: int, heads: int, attn_mask: torch.Tensor, layer_norm_eps=1e-05): + super().__init__() + self.self_attn = nn.MultiheadAttention(width, heads) + self.layer_norm1 = nn.LayerNorm(width, eps=layer_norm_eps) + self.mlp = OneFormerTextMLP(width, width * 4, width) + self.layer_norm2 = nn.LayerNorm(width, eps=layer_norm_eps) + self.attn_mask = attn_mask + + def forward( + self, + hidden_states: torch.Tensor, + key_padding_mask: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.self_attn( + hidden_states, + hidden_states, + hidden_states, + need_weights=False, + key_padding_mask=key_padding_mask, + )[0] + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class OneFormerTextTransformer(nn.Module): + def __init__( + self, + width: int, + layers: int, + heads: int, + attn_mask: Optional[torch.Tensor] = None, + use_checkpoint=False, + layer_norm_eps=1e-05, + ): + super().__init__() + self.width = width + self.num_layers = layers + self.layers = nn.Sequential( + *[OneFormerTextTransformerLayer(width, heads, attn_mask, layer_norm_eps) for _ in range(layers)] + ) + self.use_checkpoint = use_checkpoint + + def forward(self, hidden_states: torch.Tensor): + for layer in self.layers: + if self.use_checkpoint: + hidden_states = self._gradient_checkpointing_func(layer, hidden_states) + else: + hidden_states = layer(hidden_states) + return hidden_states + + +class OneFormerTextEncoder(nn.Module): + def __init__( + self, + context_length: int, + width: int, + layers: int, + vocab_size, + use_checkpoint=False, + layer_norm_eps=1e-05, + ): + super().__init__() + heads = width // 64 + self.context_length = context_length + self.width = width + self.transformer = OneFormerTextTransformer( + width=width, + layers=layers, + heads=heads, + attn_mask=self.build_attention_mask(), + use_checkpoint=use_checkpoint, + layer_norm_eps=layer_norm_eps, + ) + + self.positional_embedding = nn.Parameter(torch.empty(self.context_length, width)) + self.ln_final = nn.LayerNorm(width, eps=layer_norm_eps) + self.token_embedding = nn.Embedding(vocab_size, width) + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + def forward(self, text): + hidden_state = self.token_embedding(text) + hidden_state = hidden_state + self.positional_embedding + hidden_state = hidden_state.permute(1, 0, 2) + hidden_state = self.transformer(hidden_state) + hidden_state = hidden_state.permute(1, 0, 2) + hidden_state = self.ln_final(hidden_state) + hidden_state = hidden_state[torch.arange(hidden_state.shape[0]), text.argmax(dim=-1)] + + return hidden_state + + +class OneFormerTextMapper(nn.Module): + def __init__(self, config: OneFormerConfig): + super().__init__() + self.text_encoder = OneFormerTextEncoder( + context_length=config.text_encoder_context_length, + width=config.text_encoder_width, + layers=config.text_encoder_num_layers, + vocab_size=config.text_encoder_vocab_size, + layer_norm_eps=config.layer_norm_eps, + ) + + self.text_projector = OneFormerMLPPredictionHead( + config.text_encoder_width, + config.hidden_dim, + config.hidden_dim, + config.text_encoder_proj_layers, + ) + if config.text_encoder_n_ctx > 0: + self.prompt_ctx = nn.Embedding( + config.text_encoder_n_ctx, + config.text_encoder_width, + ) + else: + self.prompt_ctx = None + + def forward( + self, + inputs: Tensor, + ) -> Tensor: + text_queries = self.encode_text(inputs) + + return text_queries + + def encode_text(self, text): + if text.ndim is None: + raise ValueError("text must not be NoneType") + if text.ndim not in [2, 3]: + raise ValueError("Number of dimensions in text must be 2 or 3") + squeeze_dim = False + num_text = 1 + if text.ndim == 3: + num_text = text.shape[1] + batch_size, num_text, hidden_dim = text.shape + text = text.reshape(batch_size * num_text, hidden_dim) + squeeze_dim = True + + # [batch_size, num_channels] + encoded_text = self.text_encoder(text) + + text_queries = self.text_projector(encoded_text) + + if squeeze_dim: + _, hidden_dim = text_queries.shape + text_queries = text_queries.reshape(batch_size, num_text, hidden_dim) + if self.prompt_ctx is not None: + text_queries_ctx = self.prompt_ctx.weight.unsqueeze(0).repeat(text_queries.shape[0], 1, 1) + text_queries = torch.cat([text_queries, text_queries_ctx], dim=1) + + return text_queries + + +class OneFormerTaskModel(nn.Module): + def __init__(self, config: OneFormerConfig): + super().__init__() + self.task_mlp = OneFormerMLPPredictionHead( + config.task_seq_len, + config.hidden_dim, + config.hidden_dim, + 2, + ) + + def forward(self, inputs: Tensor) -> Tensor: + task_tokens = self.task_mlp(inputs) + return task_tokens + + +ONEFORMER_START_DOCSTRING = r""" + This model is a PyTorch [nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use it as a + regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. + + Parameters: + config ([`OneFormerConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +ONEFORMER_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`OneFormerProcessor`]. See + [`OneFormerProcessor.__call__`] for details. + task_inputs (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Task inputs. Task inputs can be obtained using [`AutoImageProcessor`]. See [`OneFormerProcessor.__call__`] + for details. + pixel_mask (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*): + Mask to avoid performing attention on padding pixel values. Mask values selected in `[0, 1]`: + + - 1 for pixels that are real (i.e. **not masked**), + - 0 for pixels that are padding (i.e. **masked**). + + [What are attention masks?](../glossary#attention-mask) + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of Detr's decoder attention layers. + return_dict (`bool`, *optional*): + Whether or not to return a [`~OneFormerModelOutput`] instead of a plain tuple. +""" + + +class OneFormerPreTrainedModel(PreTrainedModel): + config_class = OneFormerConfig + base_model_prefix = "model" + main_input_name = "pixel_values" + + def _init_weights(self, module: nn.Module): + xavier_std = self.config.init_xavier_std + std = self.config.init_std + if isinstance(module, OneFormerTransformerModule): + if module.input_projections is not None: + for input_projection in module.input_projections: + if not isinstance(input_projection, nn.Sequential): + nn.init.xavier_uniform_(input_projection.weight, gain=xavier_std) + nn.init.constant_(input_projection.bias, 0) + elif isinstance(module, OneFormerTransformerDecoder): + nn.init.xavier_uniform_(module.query_input_projection.weight, gain=xavier_std) + nn.init.constant_(module.query_input_projection.bias, 0) + module.query_input_projection._is_hf_initialized = True + elif isinstance(module, OneFormerPixelDecoderEncoderMultiscaleDeformableAttention): + nn.init.constant_(module.sampling_offsets.weight.data, 0.0) + thetas = torch.arange(module.n_heads, dtype=torch.int64).float() * (2.0 * math.pi / module.n_heads) + grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) + grid_init = ( + (grid_init / grid_init.abs().max(-1, keepdim=True)[0]) + .view(module.n_heads, 1, 1, 2) + .repeat(1, module.n_levels, module.n_points, 1) + ) + for i in range(module.n_points): + grid_init[:, :, i, :] *= i + 1 + with torch.no_grad(): + module.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) + nn.init.constant_(module.attention_weights.weight.data, 0.0) + nn.init.constant_(module.attention_weights.bias.data, 0.0) + nn.init.xavier_uniform_(module.value_proj.weight.data) + nn.init.constant_(module.value_proj.bias.data, 0.0) + nn.init.xavier_uniform_(module.output_proj.weight.data) + nn.init.constant_(module.output_proj.bias.data, 0.0) + elif isinstance(module, OneFormerPixelDecoderEncoderOnly): + for p in module.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + elif isinstance(module, OneFormerPixelDecoder): + for p in module.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + nn.init.normal_(module.level_embed, std=0) + elif isinstance(module, OneFormerTransformerDecoderSelfAttentionLayer): + for p in module.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p, gain=xavier_std) + elif isinstance(module, OneFormerTransformerDecoderCrossAttentionLayer): + for p in module.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p, gain=xavier_std) + elif isinstance(module, OneFormerTransformerDecoderFFNLayer): + for p in module.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p, gain=xavier_std) + elif isinstance(module, OneFormerTransformerDecoderQueryTransformer): + for p in module.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p, gain=xavier_std) + elif isinstance(module, OneFormerPixelLevelModule): + for submodule in module.modules(): + if isinstance(submodule, (nn.Conv2d, nn.Linear)): + submodule.weight.data.normal_(mean=0.0, std=std) + if submodule.bias is not None: + submodule.bias.data.zero_() + elif isinstance(module, OneFormerTextContextDecoder): + for submodule in module.modules(): + if isinstance(submodule, nn.Linear): + nn.init.trunc_normal_(submodule.weight, std=0.02) + if isinstance(submodule, nn.Linear) and submodule.bias is not None: + nn.init.constant_(submodule.bias, 0) + elif isinstance(submodule, nn.LayerNorm): + nn.init.constant_(submodule.bias, 0) + nn.init.constant_(submodule.weight, 1.0) + elif isinstance(module, OneFormerTextTransformer): + proj_std = (module.width**-0.5) * ((2 * module.num_layers) ** -0.5) + attn_std = module.width**-0.5 + fc_std = (2 * module.width) ** -0.5 + for layer in module.layers: + nn.init.normal_(layer.self_attn.in_proj_weight, std=attn_std) + nn.init.normal_(layer.self_attn.out_proj.weight, std=proj_std) + nn.init.normal_(layer.mlp.fc1.weight, std=fc_std) + nn.init.normal_(layer.mlp.fc2.weight, std=proj_std) + elif isinstance(module, OneFormerTextEncoder): + nn.init.normal_(module.token_embedding.weight, std=0.02) + nn.init.normal_(module.positional_embedding, std=0.01) + if hasattr(module, "reference_points"): + nn.init.xavier_uniform_(module.reference_points.weight.data, gain=1.0) + nn.init.constant_(module.reference_points.bias.data, 0.0) + elif isinstance(module, OneFormerTaskModel): + for submodule in module.modules(): + if isinstance(module, OneFormerMLPPredictionHead): + for submodule in module.modules(): + if isinstance(submodule, nn.Linear): + nn.init.xavier_uniform_(submodule.weight, gain=xavier_std) + nn.init.constant_(submodule.bias, 0) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, nn.MultiheadAttention): + module.in_proj_weight.data.normal_(mean=0.0, std=std) + module.in_proj_bias.data.zero_() + elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +@add_start_docstrings( + "The bare OneFormer Model outputting raw hidden-states without any specific head on top.", + ONEFORMER_START_DOCSTRING, +) +class OneFormerModel(OneFormerPreTrainedModel): + main_input_name = ["pixel_values", "task_inputs"] + + def __init__(self, config: OneFormerConfig): + super().__init__(config) + self.pixel_level_module = OneFormerPixelLevelModule(config) + self.transformer_module = OneFormerTransformerModule(in_features=config.conv_dim, config=config) + self.task_encoder = OneFormerTaskModel(config) + self.is_training = config.is_training + + if self.is_training: + self.text_mapper = OneFormerTextMapper(config) + else: + self.text_mapper = None + + self.post_init() + + @add_start_docstrings_to_model_forward(ONEFORMER_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=OneFormerModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: Tensor, + task_inputs: Tensor, + text_inputs: Optional[Tensor] = None, + pixel_mask: Optional[Tensor] = None, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> OneFormerModelOutput: + r""" + Returns: + `OneFormerModelOutput` + Example: + + ```python + >>> import torch + >>> from PIL import Image + >>> import requests + >>> from transformers import OneFormerProcessor, OneFormerModel + + >>> # download texting image + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> # load processor for preprocessing the inputs + >>> processor = OneFormerProcessor.from_pretrained("shi-labs/oneformer_ade20k_swin_tiny") + >>> model = OneFormerModel.from_pretrained("shi-labs/oneformer_ade20k_swin_tiny") + >>> inputs = processor(image, ["semantic"], return_tensors="pt") + + >>> with torch.no_grad(): + ... outputs = model(**inputs) + + >>> mask_predictions = outputs.transformer_decoder_mask_predictions + >>> class_predictions = outputs.transformer_decoder_class_predictions + + >>> f"👉 Mask Predictions Shape: {list(mask_predictions.shape)}, Class Predictions Shape: {list(class_predictions.shape)}" + '👉 Mask Predictions Shape: [1, 150, 128, 171], Class Predictions Shape: [1, 150, 151]' + ```""" + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size, _, height, width = pixel_values.shape + + if pixel_mask is None: + pixel_mask = torch.ones((batch_size, height, width), device=pixel_values.device) + + pixel_level_module_output = self.pixel_level_module(pixel_values, output_hidden_states) + + multi_scale_features = pixel_level_module_output.decoder_features + mask_features = pixel_level_module_output.decoder_last_feature + + task_token = self.task_encoder(task_inputs.to(self.dtype)) + + if self.is_training: + text_queries = self.text_mapper(text_inputs) + else: + text_queries = None + + transformer_module_output = self.transformer_module( + multi_scale_features=multi_scale_features, + mask_features=mask_features, + task_token=task_token, + output_attentions=output_attentions, + ) + + queries = transformer_module_output.object_queries + + encoder_hidden_states = None + pixel_decoder_hidden_states = None + transformer_decoder_hidden_states = None + + if output_hidden_states: + encoder_hidden_states = pixel_level_module_output.encoder_features + pixel_decoder_hidden_states = (pixel_level_module_output.decoder_last_feature,) + for f in pixel_level_module_output.decoder_features: + pixel_decoder_hidden_states += (f,) + transformer_decoder_hidden_states = transformer_module_output.auxiliary_predictions + + output = OneFormerModelOutput( + encoder_hidden_states=encoder_hidden_states, + pixel_decoder_hidden_states=pixel_decoder_hidden_states, + transformer_decoder_hidden_states=transformer_decoder_hidden_states, + transformer_decoder_object_queries=queries, + transformer_decoder_contrastive_queries=transformer_module_output.contrastive_logits, + transformer_decoder_mask_predictions=transformer_module_output.prediction_masks, + transformer_decoder_class_predictions=transformer_module_output.prediction_class, + transformer_decoder_auxiliary_predictions=transformer_module_output.auxiliary_predictions, + text_queries=text_queries, + task_token=task_token, + attentions=transformer_module_output.attentions, + ) + + if not return_dict: + output = tuple(v for v in output.values()) + + return output + + +@add_start_docstrings( + "OneFormer Model for instance, semantic and panoptic image segmentation.", + ONEFORMER_START_DOCSTRING, +) +class OneFormerForUniversalSegmentation(OneFormerPreTrainedModel): + main_input_name = ["pixel_values", "task_inputs"] + + def __init__(self, config: OneFormerConfig): + super().__init__(config) + self.model = OneFormerModel(config) + + self.matcher = OneFormerHungarianMatcher( + cost_class=config.class_weight, + cost_dice=config.dice_weight, + cost_mask=config.mask_weight, + num_points=config.train_num_points, + ) + + self.weight_dict: Dict[str, float] = { + "loss_cross_entropy": config.class_weight, + "loss_mask": config.mask_weight, + "loss_dice": config.dice_weight, + "loss_contrastive": config.contrastive_weight, + } + + self.criterion = OneFormerLoss( + num_classes=config.num_labels, + matcher=self.matcher, + weight_dict=self.weight_dict, + eos_coef=config.no_object_weight, + num_points=config.train_num_points, + oversample_ratio=config.oversample_ratio, + importance_sample_ratio=config.importance_sample_ratio, + contrastive_temperature=config.contrastive_temperature, + ) + + self.post_init() + + def get_loss_dict( + self, + masks_queries_logits: Tensor, + class_queries_logits: Tensor, + contrastive_queries_logits: Tensor, + mask_labels: Tensor, + class_labels: Tensor, + text_queries: Tensor, + auxiliary_predictions: Dict[str, Tensor], + calculate_contrastive_loss: bool, + ) -> Dict[str, Tensor]: + loss_dict: Dict[str, Tensor] = self.criterion( + masks_queries_logits=masks_queries_logits, + class_queries_logits=class_queries_logits, + contrastive_queries_logits=contrastive_queries_logits, + mask_labels=mask_labels, + class_labels=class_labels, + text_queries=text_queries, + auxiliary_predictions=auxiliary_predictions, + calculate_contrastive_loss=calculate_contrastive_loss, + ) + + # weight each loss by `self.weight_dict[]` including auxiliary losses + for key, weight in self.weight_dict.items(): + for loss_key, loss in loss_dict.items(): + if key in loss_key: + loss *= weight + + return loss_dict + + def get_loss(self, loss_dict: Dict[str, Tensor]) -> Tensor: + return sum(loss_dict.values()) + + @add_start_docstrings_to_model_forward(ONEFORMER_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=OneFormerForUniversalSegmentationOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: Tensor, + task_inputs: Tensor, + text_inputs: Optional[Tensor] = None, + mask_labels: Optional[List[Tensor]] = None, + class_labels: Optional[List[Tensor]] = None, + pixel_mask: Optional[Tensor] = None, + output_auxiliary_logits: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> OneFormerForUniversalSegmentationOutput: + r""" + text_inputs (`List[torch.Tensor]`, *optional*): + Tensor fof shape `(num_queries, sequence_length)` to be fed to a model + mask_labels (`List[torch.Tensor]`, *optional*): + List of mask labels of shape `(num_labels, height, width)` to be fed to a model + class_labels (`List[torch.LongTensor]`, *optional*): + list of target class labels of shape `(num_labels, height, width)` to be fed to a model. They identify the + labels of `mask_labels`, e.g. the label of `mask_labels[i][j]` if `class_labels[i][j]`. + + Returns: + `OneFormerUniversalSegmentationOutput` + Example: + + Universal segmentation example: + + ```python + >>> from transformers import OneFormerProcessor, OneFormerForUniversalSegmentation + >>> from PIL import Image + >>> import requests + >>> import torch + + >>> # load OneFormer fine-tuned on ADE20k for universal segmentation + >>> processor = OneFormerProcessor.from_pretrained("shi-labs/oneformer_ade20k_swin_tiny") + >>> model = OneFormerForUniversalSegmentation.from_pretrained("shi-labs/oneformer_ade20k_swin_tiny") + + >>> url = ( + ... "https://huggingface.co/datasets/hf-internal-testing/fixtures_ade20k/resolve/main/ADE_val_00000001.jpg" + ... ) + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> # Semantic Segmentation + >>> inputs = processor(image, ["semantic"], return_tensors="pt") + + >>> with torch.no_grad(): + ... outputs = model(**inputs) + >>> # model predicts class_queries_logits of shape `(batch_size, num_queries)` + >>> # and masks_queries_logits of shape `(batch_size, num_queries, height, width)` + >>> class_queries_logits = outputs.class_queries_logits + >>> masks_queries_logits = outputs.masks_queries_logits + + >>> # you can pass them to processor for semantic postprocessing + >>> predicted_semantic_map = processor.post_process_semantic_segmentation( + ... outputs, target_sizes=[(image.height, image.width)] + ... )[0] + >>> f"👉 Semantic Predictions Shape: {list(predicted_semantic_map.shape)}" + '👉 Semantic Predictions Shape: [512, 683]' + + >>> # Instance Segmentation + >>> inputs = processor(image, ["instance"], return_tensors="pt") + + >>> with torch.no_grad(): + ... outputs = model(**inputs) + >>> # model predicts class_queries_logits of shape `(batch_size, num_queries)` + >>> # and masks_queries_logits of shape `(batch_size, num_queries, height, width)` + >>> class_queries_logits = outputs.class_queries_logits + >>> masks_queries_logits = outputs.masks_queries_logits + + >>> # you can pass them to processor for instance postprocessing + >>> predicted_instance_map = processor.post_process_instance_segmentation( + ... outputs, target_sizes=[(image.height, image.width)] + ... )[0]["segmentation"] + >>> f"👉 Instance Predictions Shape: {list(predicted_instance_map.shape)}" + '👉 Instance Predictions Shape: [512, 683]' + + >>> # Panoptic Segmentation + >>> inputs = processor(image, ["panoptic"], return_tensors="pt") + + >>> with torch.no_grad(): + ... outputs = model(**inputs) + >>> # model predicts class_queries_logits of shape `(batch_size, num_queries)` + >>> # and masks_queries_logits of shape `(batch_size, num_queries, height, width)` + >>> class_queries_logits = outputs.class_queries_logits + >>> masks_queries_logits = outputs.masks_queries_logits + + >>> # you can pass them to processor for panoptic postprocessing + >>> predicted_panoptic_map = processor.post_process_panoptic_segmentation( + ... outputs, target_sizes=[(image.height, image.width)] + ... )[0]["segmentation"] + >>> f"👉 Panoptic Predictions Shape: {list(predicted_panoptic_map.shape)}" + '👉 Panoptic Predictions Shape: [512, 683]' + ``` + """ + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + pixel_values=pixel_values, + task_inputs=task_inputs, + text_inputs=text_inputs, + pixel_mask=pixel_mask, + output_hidden_states=output_hidden_states or self.config.use_auxiliary_loss, + output_attentions=output_attentions, + return_dict=True, + ) + + loss, loss_dict, auxiliary_predictions = None, None, None + + class_queries_logits = outputs.transformer_decoder_class_predictions + masks_queries_logits = outputs.transformer_decoder_mask_predictions + contrastive_queries_logits = outputs.transformer_decoder_contrastive_queries + auxiliary_predictions = outputs.transformer_decoder_auxiliary_predictions + text_queries = outputs.text_queries + + if mask_labels is not None and class_labels is not None: + loss_dict: Dict[str, Tensor] = self.get_loss_dict( + masks_queries_logits=masks_queries_logits, + class_queries_logits=class_queries_logits, + contrastive_queries_logits=contrastive_queries_logits, + mask_labels=mask_labels, + class_labels=class_labels, + text_queries=text_queries, + auxiliary_predictions=auxiliary_predictions, + calculate_contrastive_loss=self.config.contrastive_temperature is not None, + ) + loss = self.get_loss(loss_dict) + + output_auxiliary_logits = ( + self.config.output_auxiliary_logits if output_auxiliary_logits is None else output_auxiliary_logits + ) + if not output_auxiliary_logits: + auxiliary_predictions = None + + output = OneFormerForUniversalSegmentationOutput( + class_queries_logits=class_queries_logits, + masks_queries_logits=masks_queries_logits, + auxiliary_predictions=auxiliary_predictions, + loss=loss, + **outputs, + ) + + if not return_dict: + output = tuple(v for v in output.values()) + if loss is not None: + output = (loss) + output + return output + + +__all__ = ["OneFormerForUniversalSegmentation", "OneFormerModel", "OneFormerPreTrainedModel"] diff --git a/docs/transformers/src/transformers/models/oneformer/processing_oneformer.py b/docs/transformers/src/transformers/models/oneformer/processing_oneformer.py new file mode 100644 index 0000000000000000000000000000000000000000..d3e02f50d8170833c54cd77f85669c7511227c03 --- /dev/null +++ b/docs/transformers/src/transformers/models/oneformer/processing_oneformer.py @@ -0,0 +1,207 @@ +# coding=utf-8 +# Copyright 2022 SHI Labs and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Image/Text processor class for OneFormer +""" + +from typing import List + +from ...processing_utils import ProcessorMixin +from ...utils import is_torch_available + + +if is_torch_available(): + import torch + + +class OneFormerProcessor(ProcessorMixin): + r""" + Constructs an OneFormer processor which wraps [`OneFormerImageProcessor`] and + [`CLIPTokenizer`]/[`CLIPTokenizerFast`] into a single processor that inherits both the image processor and + tokenizer functionalities. + + Args: + image_processor ([`OneFormerImageProcessor`]): + The image processor is a required input. + tokenizer ([`CLIPTokenizer`, `CLIPTokenizerFast`]): + The tokenizer is a required input. + max_seq_len (`int`, *optional*, defaults to 77)): + Sequence length for input text list. + task_seq_len (`int`, *optional*, defaults to 77): + Sequence length for input task token. + """ + + attributes = ["image_processor", "tokenizer"] + image_processor_class = "OneFormerImageProcessor" + tokenizer_class = ("CLIPTokenizer", "CLIPTokenizerFast") + + def __init__( + self, image_processor=None, tokenizer=None, max_seq_length: int = 77, task_seq_length: int = 77, **kwargs + ): + if image_processor is None: + raise ValueError("You need to specify an `image_processor`.") + if tokenizer is None: + raise ValueError("You need to specify a `tokenizer`.") + + self.max_seq_length = max_seq_length + self.task_seq_length = task_seq_length + + super().__init__(image_processor, tokenizer) + + def _preprocess_text(self, text_list=None, max_length=77): + if text_list is None: + raise ValueError("tokens cannot be None.") + + tokens = self.tokenizer(text_list, padding="max_length", max_length=max_length, truncation=True) + + attention_masks, input_ids = tokens["attention_mask"], tokens["input_ids"] + + token_inputs = [] + for attn_mask, input_id in zip(attention_masks, input_ids): + token = torch.tensor(attn_mask) * torch.tensor(input_id) + token_inputs.append(token.unsqueeze(0)) + + token_inputs = torch.cat(token_inputs, dim=0) + return token_inputs + + def __call__(self, images=None, task_inputs=None, segmentation_maps=None, **kwargs): + """ + Main method to prepare for the model one or several task input(s) and image(s). This method forwards the + `task_inputs` and `kwargs` arguments to CLIPTokenizer's [`~CLIPTokenizer.__call__`] if `task_inputs` is not + `None` to encode. To prepare the image(s), this method forwards the `images` and `kwargs` arguments to + OneFormerImageProcessor's [`~OneFormerImageProcessor.__call__`] if `images` is not `None`. Please refer to the + docstring of the above two methods for more information. + + Args: + task_inputs (`str`, `List[str]`): + The sequence or batch of task_inputs sequences to be encoded. Each sequence can be a string or a list + of strings of the template "the task is {task}". + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, + `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + segmentation_maps (`ImageInput`, *optional*): + The corresponding semantic segmentation maps with the pixel-wise annotations. + + (`bool`, *optional*, defaults to `True`): + Whether or not to pad images up to the largest image in a batch and create a pixel mask. + + If left to the default, will return a pixel mask that is: + + - 1 for pixels that are real (i.e. **not masked**), + - 0 for pixels that are padding (i.e. **masked**). + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + - **task_inputs** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + """ + + if task_inputs is None: + raise ValueError("You have to specify the task_input. Found None.") + elif images is None: + raise ValueError("You have to specify the image. Found None.") + + if not all(task in ["semantic", "instance", "panoptic"] for task in task_inputs): + raise ValueError("task_inputs must be semantic, instance, or panoptic.") + + encoded_inputs = self.image_processor(images, task_inputs, segmentation_maps, **kwargs) + + if isinstance(task_inputs, str): + task_inputs = [task_inputs] + + if isinstance(task_inputs, List) and all(isinstance(task_input, str) for task_input in task_inputs): + task_token_inputs = [] + for task in task_inputs: + task_input = f"the task is {task}" + task_token_inputs.append(task_input) + encoded_inputs["task_inputs"] = self._preprocess_text(task_token_inputs, max_length=self.task_seq_length) + else: + raise TypeError("Task Inputs should be a string or a list of strings.") + + if hasattr(encoded_inputs, "text_inputs"): + texts_list = encoded_inputs.text_inputs + + text_inputs = [] + for texts in texts_list: + text_input_list = self._preprocess_text(texts, max_length=self.max_seq_length) + text_inputs.append(text_input_list.unsqueeze(0)) + + encoded_inputs["text_inputs"] = torch.cat(text_inputs, dim=0) + + return encoded_inputs + + def encode_inputs(self, images=None, task_inputs=None, segmentation_maps=None, **kwargs): + """ + This method forwards all its arguments to [`OneFormerImageProcessor.encode_inputs`] and then tokenizes the + task_inputs. Please refer to the docstring of this method for more information. + """ + + if task_inputs is None: + raise ValueError("You have to specify the task_input. Found None.") + elif images is None: + raise ValueError("You have to specify the image. Found None.") + + if not all(task in ["semantic", "instance", "panoptic"] for task in task_inputs): + raise ValueError("task_inputs must be semantic, instance, or panoptic.") + + encoded_inputs = self.image_processor.encode_inputs(images, task_inputs, segmentation_maps, **kwargs) + + if isinstance(task_inputs, str): + task_inputs = [task_inputs] + + if isinstance(task_inputs, List) and all(isinstance(task_input, str) for task_input in task_inputs): + task_token_inputs = [] + for task in task_inputs: + task_input = f"the task is {task}" + task_token_inputs.append(task_input) + encoded_inputs["task_inputs"] = self._preprocess_text(task_token_inputs, max_length=self.task_seq_length) + else: + raise TypeError("Task Inputs should be a string or a list of strings.") + + if hasattr(encoded_inputs, "text_inputs"): + texts_list = encoded_inputs.text_inputs + + text_inputs = [] + for texts in texts_list: + text_input_list = self._preprocess_text(texts, max_length=self.max_seq_length) + text_inputs.append(text_input_list.unsqueeze(0)) + + encoded_inputs["text_inputs"] = torch.cat(text_inputs, dim=0) + + return encoded_inputs + + def post_process_semantic_segmentation(self, *args, **kwargs): + """ + This method forwards all its arguments to [`OneFormerImageProcessor.post_process_semantic_segmentation`]. + Please refer to the docstring of this method for more information. + """ + return self.image_processor.post_process_semantic_segmentation(*args, **kwargs) + + def post_process_instance_segmentation(self, *args, **kwargs): + """ + This method forwards all its arguments to [`OneFormerImageProcessor.post_process_instance_segmentation`]. + Please refer to the docstring of this method for more information. + """ + return self.image_processor.post_process_instance_segmentation(*args, **kwargs) + + def post_process_panoptic_segmentation(self, *args, **kwargs): + """ + This method forwards all its arguments to [`OneFormerImageProcessor.post_process_panoptic_segmentation`]. + Please refer to the docstring of this method for more information. + """ + return self.image_processor.post_process_panoptic_segmentation(*args, **kwargs) + + +__all__ = ["OneFormerProcessor"] diff --git a/docs/transformers/src/transformers/models/openai/__init__.py b/docs/transformers/src/transformers/models/openai/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a07b0ab669f3a8b386e0fc99e99f5d5a780ef9c5 --- /dev/null +++ b/docs/transformers/src/transformers/models/openai/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_openai import * + from .modeling_openai import * + from .modeling_tf_openai import * + from .tokenization_openai import * + from .tokenization_openai_fast import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/openai/configuration_openai.py b/docs/transformers/src/transformers/models/openai/configuration_openai.py new file mode 100644 index 0000000000000000000000000000000000000000..b4f2fae9d304bc63b6e69a3e456141531a644bfd --- /dev/null +++ b/docs/transformers/src/transformers/models/openai/configuration_openai.py @@ -0,0 +1,156 @@ +# coding=utf-8 +# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""OpenAI GPT configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class OpenAIGPTConfig(PretrainedConfig): + """ + This is the configuration class to store the configuration of a [`OpenAIGPTModel`] or a [`TFOpenAIGPTModel`]. It is + used to instantiate a GPT model according to the specified arguments, defining the model architecture. + Instantiating a configuration with the defaults will yield a similar configuration to that of the GPT + [openai-community/openai-gpt](https://huggingface.co/openai-community/openai-gpt) architecture from OpenAI. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 40478): + Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`OpenAIGPTModel`] or [`TFOpenAIGPTModel`]. + n_positions (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + n_embd (`int`, *optional*, defaults to 768): + Dimensionality of the embeddings and hidden states. + n_layer (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + n_head (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + afn (`str` or `Callable`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + resid_pdrop (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + embd_pdrop (`int`, *optional*, defaults to 0.1): + The dropout ratio for the embeddings. + attn_pdrop (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention. + layer_norm_epsilon (`float`, *optional*, defaults to 1e-05): + The epsilon to use in the layer normalization layers + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + summary_type (`str`, *optional*, defaults to `"cls_index"`): + Argument used when doing sequence summary, used in the models [`OpenAIGPTDoubleHeadsModel`] and + [`OpenAIGPTDoubleHeadsModel`]. + + Has to be one of the following options: + + - `"last"`: Take the last token hidden state (like XLNet). + - `"first"`: Take the first token hidden state (like BERT). + - `"mean"`: Take the mean of all tokens hidden states. + - `"cls_index"`: Supply a Tensor of classification token position (like GPT/GPT-2). + - `"attn"`: Not implemented now, use multi-head attention. + summary_use_proj (`bool`, *optional*, defaults to `True`): + Argument used when doing sequence summary, used in the models [`OpenAIGPTDoubleHeadsModel`] and + [`OpenAIGPTDoubleHeadsModel`]. + + Whether or not to add a projection after the vector extraction. + summary_activation (`str`, *optional*): + Argument used when doing sequence summary, used in the models [`OpenAIGPTDoubleHeadsModel`] and + [`OpenAIGPTDoubleHeadsModel`]. + + Pass `"tanh"` for a tanh activation to the output, any other value will result in no activation. + summary_proj_to_labels (`bool`, *optional*, defaults to `True`): + Argument used when doing sequence summary, used in the models [`OpenAIGPTDoubleHeadsModel`] and + [`OpenAIGPTDoubleHeadsModel`]. + + Whether the projection outputs should have `config.num_labels` or `config.hidden_size` classes. + summary_first_dropout (`float`, *optional*, defaults to 0.1): + Argument used when doing sequence summary, used in the models [`OpenAIGPTDoubleHeadsModel`] and + [`OpenAIGPTDoubleHeadsModel`]. + + The dropout ratio to be used after the projection and activation. + + + Examples: + + ```python + >>> from transformers import OpenAIGPTConfig, OpenAIGPTModel + + >>> # Initializing a GPT configuration + >>> configuration = OpenAIGPTConfig() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = OpenAIGPTModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "openai-gpt" + attribute_map = { + "max_position_embeddings": "n_positions", + "hidden_size": "n_embd", + "num_attention_heads": "n_head", + "num_hidden_layers": "n_layer", + } + + def __init__( + self, + vocab_size=40478, + n_positions=512, + n_embd=768, + n_layer=12, + n_head=12, + afn="gelu", + resid_pdrop=0.1, + embd_pdrop=0.1, + attn_pdrop=0.1, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + summary_type="cls_index", + summary_use_proj=True, + summary_activation=None, + summary_proj_to_labels=True, + summary_first_dropout=0.1, + **kwargs, + ): + self.vocab_size = vocab_size + self.n_positions = n_positions + self.n_embd = n_embd + self.n_layer = n_layer + self.n_head = n_head + self.afn = afn + self.resid_pdrop = resid_pdrop + self.embd_pdrop = embd_pdrop + self.attn_pdrop = attn_pdrop + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_range = initializer_range + self.summary_type = summary_type + self.summary_use_proj = summary_use_proj + self.summary_activation = summary_activation + self.summary_first_dropout = summary_first_dropout + self.summary_proj_to_labels = summary_proj_to_labels + super().__init__(**kwargs) + + +__all__ = ["OpenAIGPTConfig"] diff --git a/docs/transformers/src/transformers/models/openai/convert_openai_original_tf_checkpoint_to_pytorch.py b/docs/transformers/src/transformers/models/openai/convert_openai_original_tf_checkpoint_to_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..3d5218c204262f639a0b862c4106a3a04dc27d0b --- /dev/null +++ b/docs/transformers/src/transformers/models/openai/convert_openai_original_tf_checkpoint_to_pytorch.py @@ -0,0 +1,74 @@ +# coding=utf-8 +# Copyright 2018 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert OpenAI GPT checkpoint.""" + +import argparse + +import torch + +from transformers import OpenAIGPTConfig, OpenAIGPTModel, load_tf_weights_in_openai_gpt +from transformers.utils import CONFIG_NAME, WEIGHTS_NAME, logging + + +logging.set_verbosity_info() + + +def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path): + # Construct model + if openai_config_file == "": + config = OpenAIGPTConfig() + else: + config = OpenAIGPTConfig.from_json_file(openai_config_file) + model = OpenAIGPTModel(config) + + # Load weights from numpy + load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path) + + # Save pytorch-model + pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME + pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME + print(f"Save PyTorch model to {pytorch_weights_dump_path}") + torch.save(model.state_dict(), pytorch_weights_dump_path) + print(f"Save configuration file to {pytorch_config_dump_path}") + with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: + f.write(config.to_json_string()) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--openai_checkpoint_folder_path", + default=None, + type=str, + required=True, + help="Path to the TensorFlow checkpoint path.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model." + ) + parser.add_argument( + "--openai_config_file", + default="", + type=str, + help=( + "An optional config json file corresponding to the pre-trained OpenAI model. \n" + "This specifies the model architecture." + ), + ) + args = parser.parse_args() + convert_openai_checkpoint_to_pytorch( + args.openai_checkpoint_folder_path, args.openai_config_file, args.pytorch_dump_folder_path + ) diff --git a/docs/transformers/src/transformers/models/openai/modeling_openai.py b/docs/transformers/src/transformers/models/openai/modeling_openai.py new file mode 100644 index 0000000000000000000000000000000000000000..244ad6d50a2e9bcb3a9ef8899808627db26dcefb --- /dev/null +++ b/docs/transformers/src/transformers/models/openai/modeling_openai.py @@ -0,0 +1,967 @@ +# coding=utf-8 +# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch OpenAI GPT model.""" + +import json +import math +import os +from dataclasses import dataclass +from typing import Any, Callable, Dict, Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import gelu_new, get_activation, silu +from ...generation import GenerationMixin +from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer +from ...utils import ( + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from .configuration_openai import OpenAIGPTConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "openai-community/openai-gpt" +_CONFIG_FOR_DOC = "OpenAIGPTConfig" + + +def load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path): + """Load tf pre-trained weights in a pytorch model (from NumPy arrays here)""" + import re + + import numpy as np + + if ".ckpt" in openai_checkpoint_folder_path: + openai_checkpoint_folder_path = os.path.dirname(openai_checkpoint_folder_path) + + logger.info(f"Loading weights from {openai_checkpoint_folder_path}") + + with open(openai_checkpoint_folder_path + "/parameters_names.json", "r", encoding="utf-8") as names_handle: + names = json.load(names_handle) + with open(openai_checkpoint_folder_path + "/params_shapes.json", "r", encoding="utf-8") as shapes_handle: + shapes = json.load(shapes_handle) + offsets = np.cumsum([np.prod(shape) for shape in shapes]) + init_params = [np.load(openai_checkpoint_folder_path + f"/params_{n}.npy") for n in range(10)] + init_params = np.split(np.concatenate(init_params, 0), offsets)[:-1] + init_params = [param.reshape(shape) for param, shape in zip(init_params, shapes)] + + # This was used when we had a single embedding matrix for positions and tokens + # init_params[0] = np.concatenate([init_params[1], init_params[0]], 0) + # del init_params[1] + init_params = [arr.squeeze() for arr in init_params] + + # Check that the token and position embeddings weight dimensions map those of the init parameters. + if model.tokens_embed.weight.shape != init_params[1].shape: + raise ValueError( + f"tokens_embed.weight.shape: {model.tokens_embed.weight.shape} does not match init_param[1].shape:" + f" {init_params[1].shape}" + ) + + if model.positions_embed.weight.shape != init_params[0].shape: + raise ValueError( + f"positions_embed.weight.shape: {model.positions_embed.weight.shape} does not match init_param[0].shape:" + f" {init_params[0].shape}" + ) + + model.tokens_embed.weight.data = torch.from_numpy(init_params[1]) + model.positions_embed.weight.data = torch.from_numpy(init_params[0]) + names.pop(0) + # Pop position and token embedding arrays + init_params.pop(0) + init_params.pop(0) + + for name, array in zip(names, init_params): # names[1:n_transfer], init_params[1:n_transfer]): + name = name[6:] # skip "model/" + if name[-2:] != ":0": + raise ValueError(f"Layer {name} does not end with :0") + name = name[:-2] + name = name.split("/") + pointer = model + for m_name in name: + if re.fullmatch(r"[A-Za-z]+\d+", m_name): + scope_names = re.split(r"(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "g": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "b": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "w": + pointer = getattr(pointer, "weight") + else: + pointer = getattr(pointer, scope_names[0]) + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + + # Ensure that the pointer and array have compatible shapes. + if pointer.shape != array.shape: + raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") + + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array) + return model + + +ACT_FNS = {"relu": nn.ReLU(), "silu": silu, "gelu": gelu_new, "swish": silu} + + +class Attention(nn.Module): + def __init__(self, nx, n_positions, config, scale=False): + super().__init__() + n_state = nx # in Attention: n_state=768 (nx=n_embd) + # [switch nx => n_state from Block to Attention to keep identical to TF implementation] + if n_state % config.n_head != 0: + raise ValueError(f"Attention n_state shape: {n_state} must be divisible by config.n_head {config.n_head}") + self.register_buffer( + "bias", + torch.tril(torch.ones(n_positions, n_positions)).view(1, 1, n_positions, n_positions), + persistent=False, + ) + self.n_head = config.n_head + self.split_size = n_state + self.scale = scale + + self.c_attn = Conv1D(n_state * 3, nx) + self.c_proj = Conv1D(n_state, nx) + self.attn_dropout = nn.Dropout(config.attn_pdrop) + self.resid_dropout = nn.Dropout(config.resid_pdrop) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.n_head, self.split_size // self.n_head, self.pruned_heads + ) + index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)]) + # Prune conv1d layers + self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1) + self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0) + # Update hyper params + self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads)) + self.n_head = self.n_head - len(heads) + self.pruned_heads = self.pruned_heads.union(heads) + + def _attn(self, q, k, v, attention_mask=None, head_mask=None, output_attentions=False): + w = torch.matmul(q, k) + if self.scale: + w = w / math.sqrt(v.size(-1)) + # w = w * self.bias + -1e9 * (1 - self.bias) # TF implementation method: mask_attn_weights + # XD: self.b may be larger than w, so we need to crop it + b = self.bias[:, :, : w.size(-2), : w.size(-1)] + w = w * b + -1e4 * (1 - b) + + if attention_mask is not None: + # Apply the attention mask + w = w + attention_mask + + w = nn.functional.softmax(w, dim=-1) + w = self.attn_dropout(w) + + # Mask heads if we want to + if head_mask is not None: + w = w * head_mask + + outputs = [torch.matmul(w, v)] + if output_attentions: + outputs.append(w) + return outputs + + def merge_heads(self, x): + x = x.permute(0, 2, 1, 3).contiguous() + new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),) + return x.view(*new_x_shape) # in Tensorflow implementation: fct merge_states + + def split_heads(self, x, k=False): + new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head) + x = x.view(*new_x_shape) # in Tensorflow implementation: fct split_states + if k: + return x.permute(0, 2, 3, 1) + else: + return x.permute(0, 2, 1, 3) + + def forward(self, x, attention_mask=None, head_mask=None, output_attentions=False): + x = self.c_attn(x) + query, key, value = x.split(self.split_size, dim=2) + query = self.split_heads(query) + key = self.split_heads(key, k=True) + value = self.split_heads(value) + + attn_outputs = self._attn(query, key, value, attention_mask, head_mask, output_attentions) + a = attn_outputs[0] + + a = self.merge_heads(a) + a = self.c_proj(a) + a = self.resid_dropout(a) + + outputs = [a] + attn_outputs[1:] + return outputs # a, (attentions) + + +class MLP(nn.Module): + def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd) + super().__init__() + nx = config.n_embd + self.c_fc = Conv1D(n_state, nx) + self.c_proj = Conv1D(nx, n_state) + self.act = ACT_FNS[config.afn] + self.dropout = nn.Dropout(config.resid_pdrop) + + def forward(self, x): + h = self.act(self.c_fc(x)) + h2 = self.c_proj(h) + return self.dropout(h2) + + +class Block(nn.Module): + def __init__(self, n_positions, config, scale=False): + super().__init__() + nx = config.n_embd + self.attn = Attention(nx, n_positions, config, scale) + self.ln_1 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon) + self.mlp = MLP(4 * nx, config) + self.ln_2 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon) + + def forward(self, x, attention_mask=None, head_mask=None, output_attentions=False): + attn_outputs = self.attn( + x, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + ) + a = attn_outputs[0] + + n = self.ln_1(x + a) + m = self.mlp(n) + h = self.ln_2(n + m) + + outputs = [h] + attn_outputs[1:] + return outputs + + +# Copied from transformers.models.xlm.modeling_xlm.XLMSequenceSummary with XLM->OpenAIGPT +class OpenAIGPTSequenceSummary(nn.Module): + r""" + Compute a single vector summary of a sequence hidden states. + + Args: + config ([`OpenAIGPTConfig`]): + The config used by the model. Relevant arguments in the config class of the model are (refer to the actual + config class of your model for the default values it uses): + + - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are: + + - `"last"` -- Take the last token hidden state (like XLNet) + - `"first"` -- Take the first token hidden state (like Bert) + - `"mean"` -- Take the mean of all tokens hidden states + - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2) + - `"attn"` -- Not implemented now, use multi-head attention + + - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction. + - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes + (otherwise to `config.hidden_size`). + - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output, + another string or `None` will add no activation. + - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation. + - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation. + """ + + def __init__(self, config: OpenAIGPTConfig): + super().__init__() + + self.summary_type = getattr(config, "summary_type", "last") + if self.summary_type == "attn": + # We should use a standard multi-head attention module with absolute positional embedding for that. + # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276 + # We can probably just use the multi-head attention module of PyTorch >=1.1.0 + raise NotImplementedError + + self.summary = nn.Identity() + if hasattr(config, "summary_use_proj") and config.summary_use_proj: + if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0: + num_classes = config.num_labels + else: + num_classes = config.hidden_size + self.summary = nn.Linear(config.hidden_size, num_classes) + + activation_string = getattr(config, "summary_activation", None) + self.activation: Callable = get_activation(activation_string) if activation_string else nn.Identity() + + self.first_dropout = nn.Identity() + if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0: + self.first_dropout = nn.Dropout(config.summary_first_dropout) + + self.last_dropout = nn.Identity() + if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0: + self.last_dropout = nn.Dropout(config.summary_last_dropout) + + def forward( + self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None + ) -> torch.FloatTensor: + """ + Compute a single vector summary of a sequence hidden states. + + Args: + hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`): + The hidden states of the last layer. + cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*): + Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token. + + Returns: + `torch.FloatTensor`: The summary of the sequence hidden states. + """ + if self.summary_type == "last": + output = hidden_states[:, -1] + elif self.summary_type == "first": + output = hidden_states[:, 0] + elif self.summary_type == "mean": + output = hidden_states.mean(dim=1) + elif self.summary_type == "cls_index": + if cls_index is None: + cls_index = torch.full_like( + hidden_states[..., :1, :], + hidden_states.shape[-2] - 1, + dtype=torch.long, + ) + else: + cls_index = cls_index.unsqueeze(-1).unsqueeze(-1) + cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),)) + # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states + output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size) + elif self.summary_type == "attn": + raise NotImplementedError + + output = self.first_dropout(output) + output = self.summary(output) + output = self.activation(output) + output = self.last_dropout(output) + + return output + + +class OpenAIGPTPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = OpenAIGPTConfig + load_tf_weights = load_tf_weights_in_openai_gpt + base_model_prefix = "transformer" + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, (nn.Linear, Conv1D)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +@dataclass +class OpenAIGPTDoubleHeadsModelOutput(ModelOutput): + """ + Base class for outputs of models predicting if two sentences are consecutive or not. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss. + mc_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mc_labels` is provided): + Multiple choice classification loss. + logits (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + mc_logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`): + Prediction scores of the multiple choice classification head (scores for each choice before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + mc_loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + mc_logits: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +OPENAI_GPT_START_DOCSTRING = r""" + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`OpenAIGPTConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +OPENAI_GPT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare OpenAI GPT transformer model outputting raw hidden-states without any specific head on top.", + OPENAI_GPT_START_DOCSTRING, +) +class OpenAIGPTModel(OpenAIGPTPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.tokens_embed = nn.Embedding(config.vocab_size, config.n_embd) + self.positions_embed = nn.Embedding(config.n_positions, config.n_embd) + self.drop = nn.Dropout(config.embd_pdrop) + self.h = nn.ModuleList([Block(config.n_positions, config, scale=True) for _ in range(config.n_layer)]) + + self.register_buffer("position_ids", torch.arange(config.n_positions), persistent=False) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.tokens_embed + + def set_input_embeddings(self, new_embeddings): + self.tokens_embed = new_embeddings + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} + """ + for layer, heads in heads_to_prune.items(): + self.h[layer].attn.prune_heads(heads) + + @add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], BaseModelOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if position_ids is None: + # Code is different from when we had a single embedding matrix from position and token embeddings + position_ids = self.position_ids[None, : input_shape[-1]] + + # Attention mask. + if attention_mask is not None: + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and the dtype's smallest value for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + + # Prepare head mask if needed + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if inputs_embeds is None: + inputs_embeds = self.tokens_embed(input_ids) + position_embeds = self.positions_embed(position_ids) + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) + token_type_embeds = self.tokens_embed(token_type_ids) + else: + token_type_embeds = 0 + hidden_states = inputs_embeds + position_embeds + token_type_embeds + hidden_states = self.drop(hidden_states) + + output_shape = input_shape + (hidden_states.size(-1),) + + all_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + for i, block in enumerate(self.h): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + outputs = block(hidden_states, attention_mask, head_mask[i], output_attentions=output_attentions) + hidden_states = outputs[0] + if output_attentions: + all_attentions = all_attentions + (outputs[1],) + + hidden_states = hidden_states.view(*output_shape) + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) + + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + ) + + +@add_start_docstrings( + """ + OpenAI GPT Model transformer with a language modeling head on top (linear layer with weights tied to the input + embeddings). + """, + OPENAI_GPT_START_DOCSTRING, +) +class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.transformer = OpenAIGPTModel(config) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + @add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=CausalLMOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple[torch.Tensor], CausalLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + lm_logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # Flatten the tokens + loss = self.loss_function( + lm_logits, + labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutput( + loss=loss, + logits=lm_logits, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> Dict[str, Any]: + # Overwritten -- old model with reduced inputs + return {"input_ids": input_ids} + + +@add_start_docstrings( + """ +OpenAI GPT Model transformer with a language modeling and a multiple-choice classification head on top e.g. for +RocStories/SWAG tasks. The two heads are two linear layers. The language modeling head has its weights tied to the +input embeddings, the classification head takes as input the input of a specified classification token index in the +input sequence). +""", + OPENAI_GPT_START_DOCSTRING, +) +class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + + config.num_labels = 1 + self.transformer = OpenAIGPTModel(config) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + self.multiple_choice_head = OpenAIGPTSequenceSummary(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + @add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=OpenAIGPTDoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + mc_token_ids: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + mc_labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], OpenAIGPTDoubleHeadsModelOutput]: + r""" + mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input): + Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) - + 1]`. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-1, 0, ..., config.vocab_size]` All labels set to `-100` are + ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + mc_labels (`torch.LongTensor` of shape `(batch_size)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` + where *num_choices* is the size of the second dimension of the input tensors. (see *input_ids* above) + + Return: + + Examples: + + ```python + >>> from transformers import AutoTokenizer, OpenAIGPTDoubleHeadsModel + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/openai-gpt") + >>> model = OpenAIGPTDoubleHeadsModel.from_pretrained("openai-community/openai-gpt") + >>> tokenizer.add_special_tokens( + ... {"cls_token": "[CLS]"} + ... ) # Add a [CLS] to the vocabulary (we should train it also!) + >>> model.resize_token_embeddings(len(tokenizer)) + + >>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"] + >>> input_ids = torch.tensor([tokenizer.encode(s) for s in choices]).unsqueeze(0) # Batch size 1, 2 choices + >>> mc_token_ids = torch.tensor([input_ids.size(-1) - 1, input_ids.size(-1) - 1]).unsqueeze(0) # Batch size 1 + + >>> outputs = model(input_ids, mc_token_ids=mc_token_ids) + >>> lm_logits = outputs.logits + >>> mc_logits = outputs.mc_logits + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + + lm_logits = self.lm_head(hidden_states) + mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1) + + lm_loss, mc_loss = None, None + if mc_labels is not None: + loss_fct = CrossEntropyLoss() + mc_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)) + if labels is not None: + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + loss_fct = CrossEntropyLoss() + lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + if not return_dict: + output = (lm_logits, mc_logits) + transformer_outputs[1:] + if mc_loss is not None: + output = (mc_loss,) + output + return ((lm_loss,) + output) if lm_loss is not None else output + + return OpenAIGPTDoubleHeadsModelOutput( + loss=lm_loss, + mc_loss=mc_loss, + logits=lm_logits, + mc_logits=mc_logits, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ + The Original OpenAI GPT Model transformer with a sequence classification head on top (linear layer). + [`OpenAIGPTForSequenceClassification`] uses the last token in order to do the classification, as other causal + models (e.g. GPT-2) do. Since it does classification on the last token, it requires to know the position of the + last token. If a `pad_token_id` is defined in the configuration, it finds the last token that is not a padding + token in each row. If no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since + it cannot guess the padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take + the last value in each row of the batch). + """, + OPENAI_GPT_START_DOCSTRING, +) +class OpenAIGPTForSequenceClassification(OpenAIGPTPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.transformer = OpenAIGPTModel(config) + self.score = nn.Linear(config.n_embd, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=SequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size, sequence_length = input_ids.shape[:2] + else: + batch_size, sequence_length = inputs_embeds.shape[:2] + + # Ensure the batch size is > 1 if there is no padding. + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) + last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) + else: + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=pooled_logits, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +__all__ = [ + "OpenAIGPTDoubleHeadsModel", + "OpenAIGPTForSequenceClassification", + "OpenAIGPTLMHeadModel", + "OpenAIGPTModel", + "OpenAIGPTPreTrainedModel", + "load_tf_weights_in_openai_gpt", +] diff --git a/docs/transformers/src/transformers/models/openai/modeling_tf_openai.py b/docs/transformers/src/transformers/models/openai/modeling_tf_openai.py new file mode 100644 index 0000000000000000000000000000000000000000..3856711d10625495db9591736306b9f689898182 --- /dev/null +++ b/docs/transformers/src/transformers/models/openai/modeling_tf_openai.py @@ -0,0 +1,937 @@ +# coding=utf-8 +# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TF 2.0 OpenAI GPT model.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import tensorflow as tf + +from ...activations_tf import get_tf_activation +from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput, TFSequenceClassifierOutput +from ...modeling_tf_utils import ( + TFCausalLanguageModelingLoss, + TFConv1D, + TFModelInputType, + TFPreTrainedModel, + TFSequenceClassificationLoss, + TFSequenceSummary, + TFSharedEmbeddings, + get_initializer, + keras, + keras_serializable, + unpack_inputs, +) +from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax +from ...utils import ( + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from .configuration_openai import OpenAIGPTConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "openai-community/openai-gpt" +_CONFIG_FOR_DOC = "OpenAIGPTConfig" + + +class TFAttention(keras.layers.Layer): + def __init__(self, nx, config, scale=False, **kwargs): + super().__init__(**kwargs) + + n_state = nx # in Attention: n_state=768 (nx=n_embd) + # [switch nx => n_state from Block to Attention to keep identical to TF implementation] + assert n_state % config.n_head == 0, ( + f"Hidden dimension {n_state} not dividable by number of heads {config.n_head}" + ) + self.n_head = config.n_head + self.split_size = n_state + self.scale = scale + self.output_attentions = config.output_attentions + + self.c_attn = TFConv1D(n_state * 3, nx, initializer_range=config.initializer_range, name="c_attn") + self.c_proj = TFConv1D(n_state, nx, initializer_range=config.initializer_range, name="c_proj") + self.attn_dropout = keras.layers.Dropout(config.attn_pdrop) + self.resid_dropout = keras.layers.Dropout(config.resid_pdrop) + self.n_state = n_state + self.pruned_heads = set() + + def prune_heads(self, heads): + pass + + @staticmethod + def causal_attention_mask(nd, ns): + """ + 1's in the lower triangle, counting from the lower right corner. Same as tf.matrix_band_part(tf.ones([nd, ns]), + -1, ns-nd), but doesn't produce garbage on TPUs. + """ + i = tf.range(nd)[:, None] + j = tf.range(ns) + m = i >= j - ns + nd + return m + + def _attn(self, q, k, v, attention_mask, head_mask, output_attentions, training=False): + # q, k, v have shape [batch, heads, sequence, features] + w = tf.matmul(q, k, transpose_b=True) + if self.scale: + dk = tf.cast(shape_list(k)[-1], dtype=w.dtype) # scale attention_scores + w = w / tf.math.sqrt(dk) + + # w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst. + _, _, nd, ns = shape_list(w) + b = tf.cast(self.causal_attention_mask(nd, ns), dtype=w.dtype) + b = tf.reshape(b, [1, 1, nd, ns]) + w = w * b - 1e4 * (1 - b) + + if attention_mask is not None: + # Apply the attention mask + attention_mask = tf.cast(attention_mask, dtype=w.dtype) + w = w + attention_mask + + w = stable_softmax(w, axis=-1) + w = self.attn_dropout(w, training=training) + + # Mask heads if we want to + if head_mask is not None: + w = w * head_mask + + outputs = [tf.matmul(w, v)] + if output_attentions: + outputs.append(w) + return outputs + + def merge_heads(self, x): + x = tf.transpose(x, [0, 2, 1, 3]) + x_shape = shape_list(x) + new_x_shape = x_shape[:-2] + [x_shape[-2] * x_shape[-1]] + return tf.reshape(x, new_x_shape) + + def split_heads(self, x): + x_shape = shape_list(x) + new_x_shape = x_shape[:-1] + [self.n_head, x_shape[-1] // self.n_head] + x = tf.reshape(x, new_x_shape) + return tf.transpose(x, (0, 2, 1, 3)) # (batch, head, seq_length, head_features) + + def call(self, x, attention_mask, head_mask, output_attentions, training=False): + x = self.c_attn(x) + query, key, value = tf.split(x, 3, axis=2) + query = self.split_heads(query) + key = self.split_heads(key) + value = self.split_heads(value) + + attn_outputs = self._attn(query, key, value, attention_mask, head_mask, output_attentions, training=training) + a = attn_outputs[0] + + a = self.merge_heads(a) + a = self.c_proj(a) + a = self.resid_dropout(a, training=training) + + outputs = [a] + attn_outputs[1:] + return outputs # a, (attentions) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "c_attn", None) is not None: + with tf.name_scope(self.c_attn.name): + self.c_attn.build([None, None, self.n_state * 3]) + if getattr(self, "c_proj", None) is not None: + with tf.name_scope(self.c_proj.name): + self.c_proj.build([None, None, self.n_state]) + + +class TFMLP(keras.layers.Layer): + def __init__(self, n_state, config, **kwargs): + super().__init__(**kwargs) + nx = config.n_embd + self.c_fc = TFConv1D(n_state, nx, initializer_range=config.initializer_range, name="c_fc") + self.c_proj = TFConv1D(nx, n_state, initializer_range=config.initializer_range, name="c_proj") + self.act = get_tf_activation("gelu") + self.dropout = keras.layers.Dropout(config.resid_pdrop) + self.nx = nx + self.n_state = n_state + + def call(self, x, training=False): + h = self.act(self.c_fc(x)) + h2 = self.c_proj(h) + h2 = self.dropout(h2, training=training) + return h2 + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "c_fc", None) is not None: + with tf.name_scope(self.c_fc.name): + self.c_fc.build([None, None, self.n_state]) + if getattr(self, "c_proj", None) is not None: + with tf.name_scope(self.c_proj.name): + self.c_proj.build([None, None, self.nx]) + + +class TFBlock(keras.layers.Layer): + def __init__(self, config, scale=False, **kwargs): + super().__init__(**kwargs) + nx = config.n_embd + self.attn = TFAttention(nx, config, scale, name="attn") + self.ln_1 = keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_1") + self.mlp = TFMLP(4 * nx, config, name="mlp") + self.ln_2 = keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_2") + self.nx = nx + + def call(self, x, attention_mask, head_mask, output_attentions, training=False): + output_attn = self.attn(x, attention_mask, head_mask, output_attentions, training=training) + a = output_attn[0] # output_attn: a, (attentions) + + n = self.ln_1(x + a) + m = self.mlp(n, training=training) + h = self.ln_2(n + m) + + outputs = [h] + output_attn[1:] + return outputs # x, (attentions) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "attn", None) is not None: + with tf.name_scope(self.attn.name): + self.attn.build(None) + if getattr(self, "ln_1", None) is not None: + with tf.name_scope(self.ln_1.name): + self.ln_1.build([None, None, self.nx]) + if getattr(self, "mlp", None) is not None: + with tf.name_scope(self.mlp.name): + self.mlp.build(None) + if getattr(self, "ln_2", None) is not None: + with tf.name_scope(self.ln_2.name): + self.ln_2.build([None, None, self.nx]) + + +@keras_serializable +class TFOpenAIGPTMainLayer(keras.layers.Layer): + config_class = OpenAIGPTConfig + + def __init__(self, config, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + self.config = config + self.output_hidden_states = config.output_hidden_states + self.output_attentions = config.output_attentions + self.return_dict = config.use_return_dict + self.num_hidden_layers = config.n_layer + self.n_embd = config.n_embd + self.n_positions = config.n_positions + self.initializer_range = config.initializer_range + + self.tokens_embed = TFSharedEmbeddings( + config.vocab_size, config.n_embd, initializer_range=config.initializer_range, name="tokens_embed" + ) + self.drop = keras.layers.Dropout(config.embd_pdrop) + self.h = [TFBlock(config, scale=True, name=f"h_._{i}") for i in range(config.n_layer)] + + def build(self, input_shape=None): + with tf.name_scope("positions_embed"): + self.positions_embed = self.add_weight( + name="embeddings", + shape=[self.n_positions, self.n_embd], + initializer=get_initializer(self.initializer_range), + ) + + if self.built: + return + self.built = True + if getattr(self, "tokens_embed", None) is not None: + with tf.name_scope(self.tokens_embed.name): + self.tokens_embed.build(None) + if getattr(self, "h", None) is not None: + for layer in self.h: + with tf.name_scope(layer.name): + layer.build(None) + + def get_input_embeddings(self): + return self.tokens_embed + + def set_input_embeddings(self, value): + self.tokens_embed.weight = value + self.tokens_embed.vocab_size = shape_list(value)[0] + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} + """ + raise NotImplementedError + + @unpack_inputs + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: Optional[bool] = False, + ) -> Union[Tuple, TFBaseModelOutput]: + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = shape_list(input_ids) + input_ids = tf.reshape(input_ids, [-1, input_shape[-1]]) + elif inputs_embeds is not None: + input_shape = shape_list(inputs_embeds)[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if position_ids is None: + position_ids = tf.expand_dims(tf.range(input_shape[-1]), axis=0) + + if attention_mask is not None: + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask = tf.reshape(attention_mask, (input_shape[0], 1, 1, input_shape[1])) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + + one_cst = tf.constant(1.0) + attention_mask = tf.cast(attention_mask, dtype=one_cst.dtype) + attention_mask = tf.multiply(tf.subtract(one_cst, attention_mask), tf.constant(-10000.0)) + else: + attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + if head_mask is not None: + raise NotImplementedError + else: + head_mask = [None] * self.num_hidden_layers + # head_mask = tf.constant([0] * self.num_hidden_layers) + + position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]]) + + if inputs_embeds is None: + check_embeddings_within_bounds(input_ids, self.config.vocab_size) + inputs_embeds = self.tokens_embed(input_ids, mode="embedding") + position_embeds = tf.gather(self.positions_embed, position_ids) + if token_type_ids is not None: + token_type_ids = tf.reshape(token_type_ids, [-1, shape_list(token_type_ids)[-1]]) + check_embeddings_within_bounds(token_type_ids, self.config.vocab_size, "token_type_ids") + token_type_embeds = self.tokens_embed(token_type_ids, mode="embedding") + else: + token_type_embeds = 0 + hidden_states = inputs_embeds + position_embeds + token_type_embeds + hidden_states = self.drop(hidden_states, training=training) + + output_shape = input_shape + [shape_list(hidden_states)[-1]] + + all_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + for i, block in enumerate(self.h): + if output_hidden_states: + all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),) + + outputs = block( + hidden_states, + attention_mask, + head_mask[i], + output_attentions, + training=training, + ) + hidden_states = outputs[0] + if output_attentions: + all_attentions = all_attentions + (outputs[1],) + + hidden_states = tf.reshape(hidden_states, output_shape) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if output_attentions: + # let the number of heads free (-1) so we can extract attention even after head pruning + attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:] + all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) + + return TFBaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + ) + + +class TFOpenAIGPTPreTrainedModel(TFPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = OpenAIGPTConfig + base_model_prefix = "transformer" + + +@dataclass +class TFOpenAIGPTDoubleHeadsModelOutput(ModelOutput): + """ + Base class for outputs of models predicting if two sentences are consecutive or not. + + Args: + logits (`tf.Tensor` of shape `(batch_size, num_choices, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + mc_logits (`tf.Tensor` of shape `(batch_size, num_choices)`): + Prediction scores of the multiple choice classification head (scores for each choice before SoftMax). + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + logits: Optional[tf.Tensor] = None + mc_logits: Optional[tf.Tensor] = None + hidden_states: Tuple[tf.Tensor] | None = None + attentions: Tuple[tf.Tensor] | None = None + + +OPENAI_GPT_START_DOCSTRING = r""" + + This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it + as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and + behavior. + + + + TensorFlow models and layers in `transformers` accept two formats as input: + + - having all inputs as keyword arguments (like PyTorch models), or + - having all inputs as a list, tuple or dict in the first positional argument. + + The reason the second format is supported is that Keras methods prefer this format when passing inputs to models + and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just + pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second + format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with + the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first + positional argument: + + - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` + - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: + `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` + - a dictionary with one or several input Tensors associated to the input names given in the docstring: + `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` + + Note that when creating models and layers with + [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry + about any of this, as you can just pass inputs like you would to any other Python function! + + + + Parameters: + config ([`OpenAIGPTConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +OPENAI_GPT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`Numpy array` or `tf.Tensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and + [`PreTrainedTokenizer.encode`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length)`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`tf.Tensor` or `Numpy array` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the + config will be used instead. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. This argument can be used only in eager mode, in graph mode the value in the config will be + used instead. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in + eager mode, in graph mode the value will always be set to True. + training (`bool`, *optional*, defaults to `False`): + Whether or not to use the model in training mode (some modules like dropout modules have different + behaviors between training and evaluation). +""" + + +@add_start_docstrings( + "The bare OpenAI GPT transformer model outputting raw hidden-states without any specific head on top.", + OPENAI_GPT_START_DOCSTRING, +) +class TFOpenAIGPTModel(TFOpenAIGPTPreTrainedModel): + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.transformer = TFOpenAIGPTMainLayer(config, name="transformer") + + @unpack_inputs + @add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFBaseModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: Optional[bool] = False, + ) -> Union[Tuple, TFBaseModelOutput]: + outputs = self.transformer( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "transformer", None) is not None: + with tf.name_scope(self.transformer.name): + self.transformer.build(None) + + +@add_start_docstrings( + """ + OpenAI GPT Model transformer with a language modeling head on top (linear layer with weights tied to the input + embeddings). + """, + OPENAI_GPT_START_DOCSTRING, +) +class TFOpenAIGPTLMHeadModel(TFOpenAIGPTPreTrainedModel, TFCausalLanguageModelingLoss): + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.transformer = TFOpenAIGPTMainLayer(config, name="transformer") + # OpenAIGPT does not have past caching features + self.supports_xla_generation = False + + def get_output_embeddings(self): + return self.get_input_embeddings() + + def set_output_embeddings(self, value): + self.set_input_embeddings(value) + + @unpack_inputs + @add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFCausalLMOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: np.ndarray | tf.Tensor | None = None, + training: Optional[bool] = False, + ) -> Union[Tuple, TFCausalLMOutput]: + r""" + labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the cross entropy classification loss. Indices should be in `[0, ..., + config.vocab_size - 1]`. + """ + + transformer_outputs = self.transformer( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + hidden_states = transformer_outputs[0] + + logits = self.transformer.tokens_embed(hidden_states, mode="linear") + + loss = None + if labels is not None: + # shift labels to the left and cut last logit token + shifted_logits = logits[:, :-1] + labels = labels[:, 1:] + loss = self.hf_compute_loss(labels, shifted_logits) + + if not return_dict: + output = (logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return TFCausalLMOutput( + loss=loss, + logits=logits, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + def prepare_inputs_for_generation(self, inputs, **kwargs): + return {"input_ids": inputs} + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "transformer", None) is not None: + with tf.name_scope(self.transformer.name): + self.transformer.build(None) + + +@add_start_docstrings( + """ + OpenAI GPT Model transformer with a language modeling and a multiple-choice classification head on top e.g. for + RocStories/SWAG tasks. The two heads are two linear layers. The language modeling head has its weights tied to the + input embeddings, the classification head takes as input the input of a specified classification token index in the + input sequence). + """, + OPENAI_GPT_START_DOCSTRING, +) +class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel): + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + config.num_labels = 1 + self.transformer = TFOpenAIGPTMainLayer(config, name="transformer") + self.multiple_choice_head = TFSequenceSummary( + config, initializer_range=config.initializer_range, name="multiple_choice_head" + ) + + @unpack_inputs + @add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TFOpenAIGPTDoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + mc_token_ids: np.ndarray | tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: Optional[bool] = False, + ) -> Union[Tuple, TFOpenAIGPTDoubleHeadsModelOutput]: + r""" + mc_token_ids (`tf.Tensor` or `Numpy array` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input): + Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) - + 1]`. + + Return: + + Examples: + + ```python + >>> import tensorflow as tf + >>> from transformers import AutoTokenizer, TFOpenAIGPTDoubleHeadsModel + + >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/openai-gpt") + >>> model = TFOpenAIGPTDoubleHeadsModel.from_pretrained("openai-community/openai-gpt") + + >>> # Add a [CLS] to the vocabulary (we should train it also!) + >>> tokenizer.add_special_tokens({"cls_token": "[CLS]"}) + >>> model.resize_token_embeddings(len(tokenizer)) # Update the model embeddings with the new vocabulary size + >>> print(tokenizer.cls_token_id, len(tokenizer)) # The newly token the last token of the vocabulary + + >>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"] + >>> encoding = tokenizer(choices, return_tensors="tf") + >>> inputs = {k: tf.expand_dims(v, 0) for k, v in encoding.items()} + >>> inputs["mc_token_ids"] = tf.constant( + ... [inputs["input_ids"].shape[-1] - 1, inputs["input_ids"].shape[-1] - 1] + ... )[ + ... None, : + ... ] # Batch size 1 + >>> outputs = model(inputs) + >>> lm_prediction_scores, mc_prediction_scores = outputs[:2] + ```""" + + if input_ids is not None: + input_shapes = shape_list(input_ids) + else: + input_shapes = shape_list(inputs_embeds)[:-1] + + seq_length = input_shapes[-1] + flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None + flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None + flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None + flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None + transformer_outputs = self.transformer( + flat_input_ids, + flat_attention_mask, + flat_token_type_ids, + flat_position_ids, + head_mask, + inputs_embeds, + output_attentions, + output_hidden_states, + return_dict=return_dict, + training=training, + ) + hidden_states = transformer_outputs[0] + hidden_states = tf.reshape(hidden_states, input_shapes + shape_list(hidden_states)[-1:]) + if return_dict and output_hidden_states: + # We do this to match the slightly odd PT behaviour - the final hidden state is reshaped to rank 4 when the + # input is rank 3, but all other hidden states remain at rank-3 (with the first 2 dims merged) + all_hidden_states = transformer_outputs.hidden_states[:-1] + (hidden_states,) + else: + all_hidden_states = None + lm_logits = self.transformer.tokens_embed(hidden_states, mode="linear") + mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids, training=training) + mc_logits = tf.squeeze(mc_logits, axis=-1) + + if not return_dict: + return (lm_logits, mc_logits) + transformer_outputs[1:] + + return TFOpenAIGPTDoubleHeadsModelOutput( + logits=lm_logits, + mc_logits=mc_logits, + hidden_states=all_hidden_states, + attentions=transformer_outputs.attentions, + ) + + @property + def input_signature(self): + return { + "input_ids": tf.TensorSpec((None, None, None), tf.int32, name="input_ids"), + "attention_mask": tf.TensorSpec((None, None, None), tf.int32, name="attention_mask"), + "mc_token_ids": tf.TensorSpec((None, None), tf.int32, name="token_type_ids"), + } + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "transformer", None) is not None: + with tf.name_scope(self.transformer.name): + self.transformer.build(None) + if getattr(self, "multiple_choice_head", None) is not None: + with tf.name_scope(self.multiple_choice_head.name): + self.multiple_choice_head.build(None) + + +@add_start_docstrings( + """ + The OpenAI GPT Model transformer with a sequence classification head on top (linear layer). + + [`TFOpenAIGPTForSequenceClassification`] uses the last token in order to do the classification, as other causal + models (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + OPENAI_GPT_START_DOCSTRING, +) +class TFOpenAIGPTForSequenceClassification(TFOpenAIGPTPreTrainedModel, TFSequenceClassificationLoss): + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.num_labels = config.num_labels + self.score = keras.layers.Dense( + config.num_labels, + kernel_initializer=get_initializer(config.initializer_range), + name="score", + use_bias=False, + ) + self.transformer = TFOpenAIGPTMainLayer(config, name="transformer") + self.config = config + + @unpack_inputs + @add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFSequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: np.ndarray | tf.Tensor | None = None, + training: Optional[bool] = False, + ) -> Union[Tuple, TFSequenceClassifierOutput]: + r""" + labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the cross entropy classification loss. Indices should be in `[0, ..., + config.vocab_size - 1]`. + """ + transformer_outputs = self.transformer( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + logits_shape = shape_list(logits) + batch_size = logits_shape[0] + + if self.config.pad_token_id is None: + last_non_pad_token = tf.fill((batch_size,), value=logits_shape[1] - 1) + else: + if input_ids is not None: + token_indices = tf.range(shape_list(input_ids)[-1]) + non_pad_mask = tf.cast(input_ids != self.config.pad_token_id, token_indices.dtype) + last_non_pad_token = tf.reduce_max(token_indices * non_pad_mask, axis=-1) + else: + last_non_pad_token = tf.fill((batch_size,), value=logits_shape[1] - 1) + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + loss = None + + pooled_logits = tf.gather(logits, last_non_pad_token, batch_dims=1, axis=1) + + if labels is not None: + if self.config.pad_token_id is None and logits_shape[0] != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + + loss = self.hf_compute_loss(tf.reshape(labels, [-1]), tf.reshape(pooled_logits, [-1, self.num_labels])) + + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return TFSequenceClassifierOutput( + loss=loss, + logits=pooled_logits, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "score", None) is not None: + with tf.name_scope(self.score.name): + self.score.build([None, None, self.config.n_embd]) + if getattr(self, "transformer", None) is not None: + with tf.name_scope(self.transformer.name): + self.transformer.build(None) + + +__all__ = [ + "TFOpenAIGPTDoubleHeadsModel", + "TFOpenAIGPTForSequenceClassification", + "TFOpenAIGPTLMHeadModel", + "TFOpenAIGPTMainLayer", + "TFOpenAIGPTModel", + "TFOpenAIGPTPreTrainedModel", +] diff --git a/docs/transformers/src/transformers/models/openai/tokenization_openai.py b/docs/transformers/src/transformers/models/openai/tokenization_openai.py new file mode 100644 index 0000000000000000000000000000000000000000..cbfb41fc888fd62d13bc41b9f0674ce54c288d83 --- /dev/null +++ b/docs/transformers/src/transformers/models/openai/tokenization_openai.py @@ -0,0 +1,396 @@ +# coding=utf-8 +# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for OpenAI GPT.""" + +import json +import os +import re +import unicodedata +from typing import Optional, Tuple + +from ...tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace +from ...utils import logging + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = { + "vocab_file": "vocab.json", + "merges_file": "merges.txt", +} + + +# Copied from transformers.models.bert.tokenization_bert.whitespace_tokenize +def whitespace_tokenize(text): + """Runs basic whitespace cleaning and splitting on a piece of text.""" + text = text.strip() + if not text: + return [] + tokens = text.split() + return tokens + + +# Copied from transformers.models.bert.tokenization_bert.BasicTokenizer +class BasicTokenizer: + """ + Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.). + + Args: + do_lower_case (`bool`, *optional*, defaults to `True`): + Whether or not to lowercase the input when tokenizing. + never_split (`Iterable`, *optional*): + Collection of tokens which will never be split during tokenization. Only has an effect when + `do_basic_tokenize=True` + tokenize_chinese_chars (`bool`, *optional*, defaults to `True`): + Whether or not to tokenize Chinese characters. + + This should likely be deactivated for Japanese (see this + [issue](https://github.com/huggingface/transformers/issues/328)). + strip_accents (`bool`, *optional*): + Whether or not to strip all accents. If this option is not specified, then it will be determined by the + value for `lowercase` (as in the original BERT). + do_split_on_punc (`bool`, *optional*, defaults to `True`): + In some instances we want to skip the basic punctuation splitting so that later tokenization can capture + the full context of the words, such as contractions. + """ + + def __init__( + self, + do_lower_case=True, + never_split=None, + tokenize_chinese_chars=True, + strip_accents=None, + do_split_on_punc=True, + ): + if never_split is None: + never_split = [] + self.do_lower_case = do_lower_case + self.never_split = set(never_split) + self.tokenize_chinese_chars = tokenize_chinese_chars + self.strip_accents = strip_accents + self.do_split_on_punc = do_split_on_punc + + def tokenize(self, text, never_split=None): + """ + Basic Tokenization of a piece of text. For sub-word tokenization, see WordPieceTokenizer. + + Args: + never_split (`List[str]`, *optional*) + Kept for backward compatibility purposes. Now implemented directly at the base class level (see + [`PreTrainedTokenizer.tokenize`]) List of token not to split. + """ + # union() returns a new set by concatenating the two sets. + never_split = self.never_split.union(set(never_split)) if never_split else self.never_split + text = self._clean_text(text) + + # This was added on November 1st, 2018 for the multilingual and Chinese + # models. This is also applied to the English models now, but it doesn't + # matter since the English models were not trained on any Chinese data + # and generally don't have any Chinese data in them (there are Chinese + # characters in the vocabulary because Wikipedia does have some Chinese + # words in the English Wikipedia.). + if self.tokenize_chinese_chars: + text = self._tokenize_chinese_chars(text) + # prevents treating the same character with different unicode codepoints as different characters + unicode_normalized_text = unicodedata.normalize("NFC", text) + orig_tokens = whitespace_tokenize(unicode_normalized_text) + split_tokens = [] + for token in orig_tokens: + if token not in never_split: + if self.do_lower_case: + token = token.lower() + if self.strip_accents is not False: + token = self._run_strip_accents(token) + elif self.strip_accents: + token = self._run_strip_accents(token) + split_tokens.extend(self._run_split_on_punc(token, never_split)) + + output_tokens = whitespace_tokenize(" ".join(split_tokens)) + return output_tokens + + def _run_strip_accents(self, text): + """Strips accents from a piece of text.""" + text = unicodedata.normalize("NFD", text) + output = [] + for char in text: + cat = unicodedata.category(char) + if cat == "Mn": + continue + output.append(char) + return "".join(output) + + def _run_split_on_punc(self, text, never_split=None): + """Splits punctuation on a piece of text.""" + if not self.do_split_on_punc or (never_split is not None and text in never_split): + return [text] + chars = list(text) + i = 0 + start_new_word = True + output = [] + while i < len(chars): + char = chars[i] + if _is_punctuation(char): + output.append([char]) + start_new_word = True + else: + if start_new_word: + output.append([]) + start_new_word = False + output[-1].append(char) + i += 1 + + return ["".join(x) for x in output] + + def _tokenize_chinese_chars(self, text): + """Adds whitespace around any CJK character.""" + output = [] + for char in text: + cp = ord(char) + if self._is_chinese_char(cp): + output.append(" ") + output.append(char) + output.append(" ") + else: + output.append(char) + return "".join(output) + + def _is_chinese_char(self, cp): + """Checks whether CP is the codepoint of a CJK character.""" + # This defines a "chinese character" as anything in the CJK Unicode block: + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + # + # Note that the CJK Unicode block is NOT all Japanese and Korean characters, + # despite its name. The modern Korean Hangul alphabet is a different block, + # as is Japanese Hiragana and Katakana. Those alphabets are used to write + # space-separated words, so they are not treated specially and handled + # like the all of the other languages. + if ( + (cp >= 0x4E00 and cp <= 0x9FFF) + or (cp >= 0x3400 and cp <= 0x4DBF) # + or (cp >= 0x20000 and cp <= 0x2A6DF) # + or (cp >= 0x2A700 and cp <= 0x2B73F) # + or (cp >= 0x2B740 and cp <= 0x2B81F) # + or (cp >= 0x2B820 and cp <= 0x2CEAF) # + or (cp >= 0xF900 and cp <= 0xFAFF) + or (cp >= 0x2F800 and cp <= 0x2FA1F) # + ): # + return True + + return False + + def _clean_text(self, text): + """Performs invalid character removal and whitespace cleanup on text.""" + output = [] + for char in text: + cp = ord(char) + if cp == 0 or cp == 0xFFFD or _is_control(char): + continue + if _is_whitespace(char): + output.append(" ") + else: + output.append(char) + return "".join(output) + + +def get_pairs(word): + """ + Return set of symbol pairs in a word. word is represented as tuple of symbols (symbols being variable-length + strings) + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def text_standardize(text): + """ + fixes some issues the spacy tokenizer had on books corpus also does some whitespace standardization + """ + text = text.replace("—", "-") + text = text.replace("–", "-") + text = text.replace("―", "-") + text = text.replace("…", "...") + text = text.replace("´", "'") + text = re.sub(r"""(-+|~+|!+|"+|;+|\?+|\++|,+|\)+|\(+|\\+|\/+|\*+|\[+|\]+|}+|{+|\|+|_+)""", r" \1 ", text) + text = re.sub(r"\s*\n\s*", " \n ", text) + text = re.sub(r"[^\S\n]+", " ", text) + return text.strip() + + +class OpenAIGPTTokenizer(PreTrainedTokenizer): + """ + Construct a GPT Tokenizer. Based on Byte-Pair-Encoding with the following peculiarities: + + - lowercases all inputs, + - uses `SpaCy` tokenizer and `ftfy` for pre-BPE tokenization if they are installed, fallback to BERT's + `BasicTokenizer` if not. + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + merges_file (`str`): + Path to the merges file. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + + def __init__(self, vocab_file, merges_file, unk_token="", **kwargs): + try: + import ftfy + from spacy.lang.en import English + + _nlp = English() + self.nlp = _nlp.tokenizer + self.fix_text = ftfy.fix_text + except ImportError: + logger.warning("ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.") + self.nlp = BasicTokenizer(do_lower_case=True) + self.fix_text = None + + with open(vocab_file, encoding="utf-8") as vocab_handle: + self.encoder = json.load(vocab_handle) + self.decoder = {v: k for k, v in self.encoder.items()} + with open(merges_file, encoding="utf-8") as merges_handle: + merges = merges_handle.read().split("\n")[1:-1] + merges = [tuple(merge.split()) for merge in merges] + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {} + + super().__init__(unk_token=unk_token, **kwargs) + + @property + def do_lower_case(self): + return True + + @property + def vocab_size(self): + return len(self.encoder) + + def get_vocab(self): + return dict(self.encoder, **self.added_tokens_encoder) + + def bpe(self, token): + word = tuple(token[:-1]) + (token[-1] + "",) + if token in self.cache: + return self.cache[token] + pairs = get_pairs(word) + + if not pairs: + return token + "" + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + except ValueError: + new_word.extend(word[i:]) + break + else: + new_word.extend(word[i:j]) + i = j + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = " ".join(word) + if word == "\n ": + word = "\n" + self.cache[token] = word + return word + + def _tokenize(self, text): + """Tokenize a string.""" + split_tokens = [] + if self.fix_text is None: + # Using BERT's BasicTokenizer + text = self.nlp.tokenize(text) + for token in text: + split_tokens.extend(list(self.bpe(token).split(" "))) + else: + # Using SpaCy & ftfy (original tokenization process of OpenAI GPT) + text = self.nlp(text_standardize(self.fix_text(text))) + for token in text: + split_tokens.extend(list(self.bpe(token.text.lower()).split(" "))) + return split_tokens + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.encoder.get(token, self.encoder.get(self.unk_token)) + + def _convert_id_to_token(self, index): + """Converts an id in a token (BPE) using the vocab.""" + return self.decoder.get(index, self.unk_token) + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + out_string = "".join(tokens).replace("", " ").strip() + return out_string + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + merge_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"] + ) + + with open(vocab_file, "w", encoding="utf-8") as f: + f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n") + + index = 0 + with open(merge_file, "w", encoding="utf-8") as writer: + writer.write("#version: 0.2\n") + for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning( + f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive." + " Please check that the tokenizer is not corrupted!" + ) + index = token_index + writer.write(" ".join(bpe_tokens) + "\n") + index += 1 + + return vocab_file, merge_file + + +__all__ = ["OpenAIGPTTokenizer"] diff --git a/docs/transformers/src/transformers/models/openai/tokenization_openai_fast.py b/docs/transformers/src/transformers/models/openai/tokenization_openai_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..c17d7d29b7dd01dc48e960e1ee6f4e26d29aff1b --- /dev/null +++ b/docs/transformers/src/transformers/models/openai/tokenization_openai_fast.py @@ -0,0 +1,66 @@ +# coding=utf-8 +# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Tokenization classes for OpenAI GPT.""" + +from typing import Optional, Tuple + +from ...tokenization_utils_fast import PreTrainedTokenizerFast +from ...utils import logging +from .tokenization_openai import OpenAIGPTTokenizer + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"} + + +class OpenAIGPTTokenizerFast(PreTrainedTokenizerFast): + """ + Construct a "fast" GPT Tokenizer (backed by HuggingFace's *tokenizers* library). Based on Byte-Pair-Encoding with + the following peculiarities: + + - lower case all inputs + - uses BERT's BasicTokenizer for pre-BPE tokenization + + This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should + refer to this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + merges_file (`str`): + Path to the merges file. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + slow_tokenizer_class = OpenAIGPTTokenizer + + def __init__(self, vocab_file=None, merges_file=None, tokenizer_file=None, unk_token="", **kwargs): + super().__init__(vocab_file, merges_file, tokenizer_file=tokenizer_file, unk_token=unk_token, **kwargs) + + @property + def do_lower_case(self): + return True + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + files = self._tokenizer.model.save(save_directory, name=filename_prefix) + return tuple(files) + + +__all__ = ["OpenAIGPTTokenizerFast"] diff --git a/docs/transformers/src/transformers/models/opt/__init__.py b/docs/transformers/src/transformers/models/opt/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d230de5ecadc752b4c0a1784a6295cebbb20ed55 --- /dev/null +++ b/docs/transformers/src/transformers/models/opt/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_opt import * + from .modeling_flax_opt import * + from .modeling_opt import * + from .modeling_tf_opt import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/opt/configuration_opt.py b/docs/transformers/src/transformers/models/opt/configuration_opt.py new file mode 100644 index 0000000000000000000000000000000000000000..8f8838ad9ef18d406bdf4e4dd3c040f609261573 --- /dev/null +++ b/docs/transformers/src/transformers/models/opt/configuration_opt.py @@ -0,0 +1,146 @@ +# coding=utf-8 +# Copyright 2022 The Metaseq Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""OPT model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class OPTConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`OPTModel`]. It is used to instantiate a OPT model + according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the OPT + [facebook/opt-350m](https://huggingface.co/facebook/opt-350m) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50272): + Vocabulary size of the OPT model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`OPTModel`] + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of decoder layers. + ffn_dim (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (often named feed-forward) layer in decoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer decoder. + activation_function (`str` or `function`, *optional*, defaults to `"relu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + do_layer_norm_before (`bool`, *optional*, defaults to `True`): + Whether to perform layer normalization before the attention block. + word_embed_proj_dim (`int`, *optional*): + `word_embed_proj_dim` can be set to down-project word embeddings, *e.g.* `opt-350m`. Defaults to + `hidden_size`. + dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + layerdrop (`float`, *optional*, defaults to 0.0): + The LayerDrop probability. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) for more + details. + init_std (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + enable_bias (`bool`, *optional*, defaults to `True`): + Whether or not if the linear layers in the attention blocks should use the bias term. + layer_norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether or not if the layer norms should have learnable parameters. + + Example: + + ```python + >>> from transformers import OPTConfig, OPTModel + + >>> # Initializing a OPT facebook/opt-large style configuration + >>> configuration = OPTConfig() + + >>> # Initializing a model (with random weights) from the facebook/opt-large style configuration + >>> model = OPTModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "opt" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=50272, + hidden_size=768, + num_hidden_layers=12, + ffn_dim=3072, + max_position_embeddings=2048, + do_layer_norm_before=True, + _remove_final_layer_norm=False, + word_embed_proj_dim=None, + dropout=0.1, + attention_dropout=0.0, + num_attention_heads=12, + activation_function="relu", + layerdrop=0.0, + init_std=0.02, + use_cache=True, + pad_token_id=1, + bos_token_id=2, + eos_token_id=2, + enable_bias=True, + layer_norm_elementwise_affine=True, + **kwargs, + ): + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + **kwargs, + ) + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.num_attention_heads = num_attention_heads + self.word_embed_proj_dim = word_embed_proj_dim if word_embed_proj_dim is not None else hidden_size + self.ffn_dim = ffn_dim + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_function = activation_function + self.init_std = init_std + self.layerdrop = layerdrop + self.use_cache = use_cache + self.do_layer_norm_before = do_layer_norm_before + # We keep these variables at `True` for backward compatibility. + self.enable_bias = enable_bias + self.layer_norm_elementwise_affine = layer_norm_elementwise_affine + + # Note that the only purpose of `_remove_final_layer_norm` is to keep backward compatibility + # with checkpoints that have been fine-tuned before transformers v4.20.1 + # see https://github.com/facebookresearch/metaseq/pull/164 + self._remove_final_layer_norm = _remove_final_layer_norm + + +__all__ = ["OPTConfig"] diff --git a/docs/transformers/src/transformers/models/opt/convert_opt_original_pytorch_checkpoint_to_pytorch.py b/docs/transformers/src/transformers/models/opt/convert_opt_original_pytorch_checkpoint_to_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..9a9b0c306cbe1cff685853a079017701f8af86e4 --- /dev/null +++ b/docs/transformers/src/transformers/models/opt/convert_opt_original_pytorch_checkpoint_to_pytorch.py @@ -0,0 +1,113 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert OPT checkpoint.""" + +import argparse +from pathlib import Path + +import torch + +from transformers import OPTConfig, OPTModel +from transformers.utils import logging + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + + +def load_checkpoint(checkpoint_path): + """Checkpoint path should end in model.pt""" + sd = torch.load(checkpoint_path, map_location="cpu", weights_only=True) + if "model" in sd.keys(): + sd = torch.load(checkpoint_path, map_location="cpu", weights_only=True)["model"] + + # pop unnecessary weights + keys_to_delete = [ + "decoder.version", + "decoder.output_projection.weight", + ] + for key in keys_to_delete: + if key in sd: + sd.pop(key) + + keys_to_rename = { + "decoder.project_in_dim.weight": "decoder.project_in.weight", + "decoder.project_out_dim.weight": "decoder.project_out.weight", + "decoder.layer_norm.weight": "decoder.final_layer_norm.weight", + "decoder.layer_norm.bias": "decoder.final_layer_norm.bias", + } + for old_key, new_key in keys_to_rename.items(): + if old_key in sd: + sd[new_key] = sd.pop(old_key) + + keys = list(sd.keys()) + for key in keys: + if ".qkv_proj." in key: + value = sd[key] + # We split QKV in separate Q,K,V + + q_name = key.replace(".qkv_proj.", ".q_proj.") + k_name = key.replace(".qkv_proj.", ".k_proj.") + v_name = key.replace(".qkv_proj.", ".v_proj.") + + depth = value.shape[0] + assert depth % 3 == 0 + # `SequeuceParallelTransformerBlock` has QKV weight is separated in K,V,Q despite the naming: + # https://cs.github.com/facebookresearch/metaseq/blob/51871bd73cd04c038f239ea2a26db1d7f6b37927/metaseq/modules/sequence_parallel_transformer_layer.py#L97 + k, v, q = torch.split(value, depth // 3, dim=0) + + sd[q_name] = q + sd[k_name] = k + sd[v_name] = v + del sd[key] + + return sd + + +@torch.no_grad() +def convert_opt_checkpoint(checkpoint_path, pytorch_dump_folder_path, config=None): + """ + Copy/paste/tweak model's weights to our BERT structure. + """ + state_dict = load_checkpoint(checkpoint_path) + + if config is not None: + config = OPTConfig.from_pretrained(config) + else: + config = OPTConfig() + + model = OPTModel(config).half().eval() + model.load_state_dict(state_dict) + + # Check results + Path(pytorch_dump_folder_path).mkdir(exist_ok=True) + model.save_pretrained(pytorch_dump_folder_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--fairseq_path", + type=str, + help=( + "path to fairseq checkpoint in correct format. You can find all checkpoints in the correct format here:" + " https://huggingface.co/models?other=opt_metasq" + ), + ) + parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.") + parser.add_argument("--hf_config", default=None, type=str, help="Define HF config.") + args = parser.parse_args() + convert_opt_checkpoint(args.fairseq_path, args.pytorch_dump_folder_path, config=args.hf_config) diff --git a/docs/transformers/src/transformers/models/opt/modeling_flax_opt.py b/docs/transformers/src/transformers/models/opt/modeling_flax_opt.py new file mode 100644 index 0000000000000000000000000000000000000000..fc023bb4ae84a178bb6ab1f2f817663cae0e44d6 --- /dev/null +++ b/docs/transformers/src/transformers/models/opt/modeling_flax_opt.py @@ -0,0 +1,802 @@ +# coding=utf-8 +# Copyright 2022 The Fairseq Authors and The Google Flax Team Authors And The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Flax OPT model.""" + +from functools import partial +from typing import Optional, Tuple + +import flax.linen as nn +import jax +import jax.numpy as jnp +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze +from flax.linen import combine_masks, make_causal_mask +from flax.linen.attention import dot_product_attention_weights +from flax.traverse_util import flatten_dict, unflatten_dict +from jax import lax +from jax.random import PRNGKey + +from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxMaskedLMOutput +from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring +from ...utils import add_start_docstrings, logging +from .configuration_opt import OPTConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "facebook/opt-350m" +_CONFIG_FOR_DOC = "OPTConfig" + + +OPT_START_DOCSTRING = r""" + This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a Flax Linen + [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a + regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior. + + Finally, this model supports inherent JAX features such as: + + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + + Parameters: + config ([`OPTConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. + dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): + The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and + `jax.numpy.bfloat16` (on TPUs). + + This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If + specified all the computation will be performed with the given `dtype`. + + **Note that this only specifies the dtype of the computation and does not influence the dtype of model + parameters.** + + If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and + [`~FlaxPreTrainedModel.to_bf16`]. +""" + +OPT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention with Bart->OPT +class FlaxOPTAttention(nn.Module): + config: OPTConfig + embed_dim: int + num_heads: int + dropout: float = 0.0 + causal: bool = False + bias: bool = True + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self) -> None: + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {self.num_heads})." + ) + + dense = partial( + nn.Dense, + self.embed_dim, + use_bias=self.bias, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.init_std), + ) + + self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense() + self.out_proj = dense() + + self.dropout_layer = nn.Dropout(rate=self.dropout) + + if self.causal: + self.causal_mask = make_causal_mask( + jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool" + ) + + def _split_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim)) + + def _merge_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,)) + + @nn.compact + def _concatenate_to_cache(self, key, value, query, attention_mask): + """ + This function takes projected key, value states from a single input token and concatenates the states to cached + states from previous steps. This function is slightly adapted from the official Flax repository: + https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 + """ + # detect if we're initializing by absence of existing cache data. + is_initialized = self.has_variable("cache", "cached_key") + cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) + cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) + cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) + + if is_initialized: + *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape + # update key, value caches with our new 1d spatial slices + cur_index = cache_index.value + indices = (0,) * len(batch_dims) + (cur_index, 0, 0) + key = lax.dynamic_update_slice(cached_key.value, key, indices) + value = lax.dynamic_update_slice(cached_value.value, value, indices) + cached_key.value = key + cached_value.value = value + num_updated_cache_vectors = query.shape[1] + cache_index.value = cache_index.value + num_updated_cache_vectors + # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements. + pad_mask = jnp.broadcast_to( + jnp.arange(max_length) < cur_index + num_updated_cache_vectors, + tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), + ) + attention_mask = combine_masks(pad_mask, attention_mask) + return key, value, attention_mask + + def __call__( + self, + hidden_states: jnp.ndarray, + key_value_states: Optional[jnp.ndarray] = None, + attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + deterministic: bool = True, + ) -> Tuple[jnp.ndarray]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + batch_size = hidden_states.shape[0] + + # get query proj + query_states = self.q_proj(hidden_states) + # get key, value proj + if is_cross_attention: + # cross_attentions + key_states = self.k_proj(key_value_states) + value_states = self.v_proj(key_value_states) + else: + # self_attention + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = self._split_heads(query_states) + key_states = self._split_heads(key_states) + value_states = self._split_heads(value_states) + + # handle cache prepare causal attention mask + if self.causal: + query_length, key_length = query_states.shape[1], key_states.shape[1] + if self.has_variable("cache", "cached_key"): + mask_shift = self.variables["cache"]["cache_index"] + max_decoder_length = self.variables["cache"]["cached_key"].shape[1] + causal_mask = lax.dynamic_slice( + self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length) + ) + else: + causal_mask = self.causal_mask[:, :, :query_length, :key_length] + causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) + + # combine masks if needed + if attention_mask is not None and self.causal: + attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) + attention_mask = combine_masks(attention_mask, causal_mask) + elif self.causal: + attention_mask = causal_mask + elif attention_mask is not None: + attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) + + # During fast autoregressive decoding, we feed one position at a time, + # and cache the keys and values step by step. + if self.causal and (self.has_variable("cache", "cached_key") or init_cache): + key_states, value_states, attention_mask = self._concatenate_to_cache( + key_states, value_states, query_states, attention_mask + ) + + # Convert the boolean attention mask to an attention bias. + if attention_mask is not None: + # attention mask in the form of attention bias + attention_bias = lax.select( + attention_mask > 0, + jnp.full(attention_mask.shape, 0.0).astype(self.dtype), + jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), + ) + else: + attention_bias = None + + dropout_rng = None + if not deterministic and self.dropout > 0.0: + dropout_rng = self.make_rng("dropout") + + attn_weights = dot_product_attention_weights( + query_states, + key_states, + bias=attention_bias, + dropout_rng=dropout_rng, + dropout_rate=self.dropout, + broadcast_dropout=True, + deterministic=deterministic, + dtype=self.dtype, + precision=None, + ) + + attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) + attn_output = self._merge_heads(attn_output) + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights + + +class FlaxOPTDecoderLayer(nn.Module): + config: OPTConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self) -> None: + self.embed_dim = self.config.hidden_size + self.self_attn = FlaxOPTAttention( + config=self.config, + embed_dim=self.embed_dim, + num_heads=self.config.num_attention_heads, + dropout=self.config.attention_dropout, + causal=True, + dtype=self.dtype, + ) + self.do_layer_norm_before = self.config.do_layer_norm_before + self.dropout_layer = nn.Dropout(rate=self.config.dropout) + self.activation_fn = ACT2FN[self.config.activation_function] + + self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) + self.fc1 = nn.Dense( + self.config.ffn_dim, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.init_std), + ) + self.fc2 = nn.Dense( + self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std) + ) + self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) + + def __call__( + self, + hidden_states: jnp.ndarray, + attention_mask: jnp.ndarray, + init_cache: bool = False, + output_attentions: bool = True, + deterministic: bool = True, + ) -> Tuple[jnp.ndarray]: + residual = hidden_states + + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + init_cache=init_cache, + deterministic=deterministic, + ) + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = residual + hidden_states + # 350m applies layer norm AFTER attention + if not self.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Fully Connected + hidden_states_shape = hidden_states.shape + hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1]) + residual = hidden_states + + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + + hidden_states = self.fc2(hidden_states) + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + + hidden_states = (residual + hidden_states).reshape(hidden_states_shape) + + # 350m applies layer norm AFTER attention + if not self.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +class FlaxOPTDecoderLayerCollection(nn.Module): + config: OPTConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.layers = [ + FlaxOPTDecoderLayer(self.config, name=str(i), dtype=self.dtype) + for i in range(self.config.num_hidden_layers) + ] + self.layerdrop = self.config.layerdrop + + def __call__( + self, + hidden_states, + attention_mask, + deterministic: bool = True, + init_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + ): + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + init_cache=init_cache, + output_attentions=output_attentions, + deterministic=deterministic, + ) + + hidden_states = layer_outputs[0] + if output_attentions: + all_self_attns += (layer_outputs[1],) + + outputs = [hidden_states, all_hidden_states, all_self_attns] + return outputs + + +class FlaxOPTLearnedPositionalEmbedding(nn.Embed): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def setup(self): + self.offset = 2 + self.embedding = self.param( + "embedding", self.embedding_init, (self.num_embeddings + self.offset, self.features), self.param_dtype + ) + + def __call__(self, positions): + """`input_ids_shape` is expected to be [bsz x seqlen].""" + + return super().__call__(positions + self.offset) + + +class FlaxOPTDecoder(nn.Module): + config: OPTConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + offset: int = 2 + + def setup(self): + self.dropout_layer = nn.Dropout(rate=self.config.dropout) + + embed_dim = self.config.hidden_size + self.padding_idx = self.config.pad_token_id + self.max_target_positions = self.config.max_position_embeddings + + self.embed_tokens = nn.Embed( + self.config.vocab_size, + self.config.word_embed_proj_dim, + embedding_init=jax.nn.initializers.normal(self.config.init_std), + dtype=self.dtype, + ) + + self.embed_positions = FlaxOPTLearnedPositionalEmbedding( + self.config.max_position_embeddings, + embed_dim, + embedding_init=jax.nn.initializers.normal(self.config.init_std), + dtype=self.dtype, + ) + + if self.config.word_embed_proj_dim != self.config.hidden_size: + self.project_in = nn.Dense(self.config.hidden_size, use_bias=False) + self.project_out = nn.Dense(self.config.word_embed_proj_dim, use_bias=False) + + else: + self.project_in = None + self.project_out = None + + # Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility + # with checkpoints that have been fine-tuned before transformers v4.20.1 + # see https://github.com/facebookresearch/metaseq/pull/164 + if self.config.do_layer_norm_before and not self.config._remove_final_layer_norm: + self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) + else: + self.final_layer_norm = None + + self.layers = FlaxOPTDecoderLayerCollection(self.config, self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + position_ids, + init_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + ): + input_shape = input_ids.shape + input_ids = input_ids.reshape(-1, input_shape[-1]) + + inputs_embeds = self.embed_tokens(input_ids) + if self.project_in is not None: + inputs_embeds = self.project_in(inputs_embeds) + + positions = self.embed_positions(position_ids) + + hidden_states = inputs_embeds + positions + + hidden_state, all_hidden_states, attentions = self.layers( + hidden_states, + attention_mask, + deterministic=deterministic, + init_cache=init_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + if self.final_layer_norm is not None: + hidden_state = self.final_layer_norm(hidden_state) + + if self.project_out is not None: + hidden_state = self.project_out(hidden_state) + + if output_hidden_states: + all_hidden_states += (hidden_state,) + + outputs = [hidden_state, all_hidden_states, attentions] + + if not return_dict: + return tuple(v for v in outputs if v is not None) + + return FlaxBaseModelOutput( + last_hidden_state=hidden_state, + hidden_states=all_hidden_states, + attentions=attentions, + ) + + +class FlaxOPTPreTrainedModel(FlaxPreTrainedModel): + config_class = OPTConfig + base_model_prefix: str = "model" + module_class: nn.Module = None + + def __init__( + self, + config: OPTConfig, + input_shape: Tuple[int] = (1, 1), + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, + **kwargs, + ): + module = self.module_class(config=config, dtype=dtype, **kwargs) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: + # init input tensors + input_ids = jnp.zeros(input_shape, dtype="i4") + attention_mask = jnp.ones_like(input_ids) + + batch_size, sequence_length = input_ids.shape + position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) + + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + module_init_outputs = self.module.init( + rngs, + input_ids, + attention_mask, + position_ids, + return_dict=False, + ) + + random_params = module_init_outputs["params"] + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + + def init_cache(self, batch_size, max_length): + r""" + Args: + batch_size (`int`): + batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. + max_length (`int`): + maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized + cache. + """ + # init input variables to retrieve cache + input_ids = jnp.ones((batch_size, max_length), dtype="i4") + attention_mask = jnp.ones_like(input_ids, dtype="i4") + position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) + + init_variables = self.module.init( + jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True + ) + return unfreeze(init_variables["cache"]) + + def __call__( + self, + input_ids: jnp.ndarray, + attention_mask: Optional[jnp.ndarray] = None, + position_ids: Optional[jnp.ndarray] = None, + params: dict = None, + past_key_values: dict = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + dropout_rng: PRNGKey = None, + deterministic: bool = True, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + + if position_ids is None: + position_ids = (attention_mask.cumsum(axis=1) * attention_mask) - 1 + + # Handle any PRNG if needed + rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} + + inputs = {"params": params or self.params} + + # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed + # down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be + # changed by FlaxOPTAttention module + if past_key_values: + inputs["cache"] = past_key_values + mutable = ["cache"] + else: + mutable = False + + outputs = self.module.apply( + inputs, + input_ids=jnp.array(input_ids, dtype="i4"), + attention_mask=jnp.array(attention_mask, dtype="i4"), + position_ids=jnp.array(position_ids, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + rngs=rngs, + mutable=mutable, + ) + + # add updated cache to model output + if past_key_values is not None and return_dict: + outputs, past_key_values = outputs + outputs["past_key_values"] = unfreeze(past_key_values["cache"]) + return outputs + elif past_key_values is not None and not return_dict: + outputs, past_key_values = outputs + outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:] + + return outputs + + +class FlaxOPTModule(nn.Module): + config: OPTConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.decoder = FlaxOPTDecoder(self.config, dtype=self.dtype) + + def _get_decoder_module(self): + return self.decoder + + def __call__( + self, + input_ids, + attention_mask, + position_ids, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + init_cache=False, + ): + decoder_outputs = self.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + init_cache=init_cache, + ) + + if not return_dict: + return decoder_outputs + + return FlaxBaseModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + hidden_states=decoder_outputs.hidden_states, + attentions=decoder_outputs.attentions, + ) + + +# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartModel with Bart->OPT +class FlaxOPTModel(FlaxOPTPreTrainedModel): + config: OPTConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + module_class = FlaxOPTModule + + +append_call_sample_docstring(FlaxOPTModel, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutput, _CONFIG_FOR_DOC) + + +@add_start_docstrings( + "The bare OPT Model transformer outputting raw hidden-states without any specific head on top.", + OPT_START_DOCSTRING, +) +class FlaxOPTForCausalLMModule(nn.Module): + config: OPTConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.model = FlaxOPTModule(config=self.config, dtype=self.dtype) + self.lm_head = nn.Dense( + self.config.vocab_size, + use_bias=False, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.init_std), + ) + + def __call__( + self, + input_ids, + attention_mask, + position_ids, + init_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + ): + outputs = self.model( + input_ids, + attention_mask, + position_ids, + init_cache=init_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + hidden_states = outputs[0] + + if self.config.tie_word_embeddings: + shared_embedding = self.model.variables["params"]["decoder"]["embed_tokens"]["embedding"] + lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) + else: + lm_logits = self.lm_head(hidden_states) + + if not return_dict: + return (lm_logits,) + outputs[1:] + + return FlaxMaskedLMOutput( + logits=lm_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + OPT Model with a language modeling head on top (linear layer with weights tied to the input embeddings) e.g for + autoregressive tasks. + """, + OPT_START_DOCSTRING, +) +class FlaxOPTForCausalLM(FlaxOPTPreTrainedModel): + module_class = FlaxOPTForCausalLMModule + + def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None): + # initializing the cache + batch_size, seq_length = input_ids.shape + + past_key_values = self.init_cache(batch_size, max_length) + # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. + # But since the decoder uses a causal mask, those positions are masked anyway. + # Thus, we can create a single static attention_mask here, which is more efficient for compilation + extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") + + if attention_mask is not None: + position_ids = attention_mask.cumsum(axis=1) - 1 + extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0)) + else: + position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) + + return { + "past_key_values": past_key_values, + "attention_mask": extended_attention_mask, + "position_ids": position_ids, + } + + def update_inputs_for_generation(self, model_outputs, model_kwargs): + model_kwargs["past_key_values"] = model_outputs.past_key_values + model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1 + return model_kwargs + + +append_call_sample_docstring( + FlaxOPTForCausalLM, + _CHECKPOINT_FOR_DOC, + FlaxBaseModelOutput, + _CONFIG_FOR_DOC, +) + + +__all__ = ["FlaxOPTForCausalLM", "FlaxOPTModel", "FlaxOPTPreTrainedModel"] diff --git a/docs/transformers/src/transformers/models/opt/modeling_opt.py b/docs/transformers/src/transformers/models/opt/modeling_opt.py new file mode 100644 index 0000000000000000000000000000000000000000..7e097cca311c2b65d4b26fbb27d2f52cca5ae293 --- /dev/null +++ b/docs/transformers/src/transformers/models/opt/modeling_opt.py @@ -0,0 +1,1517 @@ +# coding=utf-8 +# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch OPT model.""" + +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import ( + AttentionMaskConverter, +) +from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, +) +from ...modeling_utils import PreTrainedModel +from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_torch_flex_attn_available, + logging, + replace_return_docstrings, +) +from .configuration_opt import OPTConfig + + +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + +if is_flash_attn_available(): + from ...modeling_flash_attention_utils import _flash_attention_forward + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "facebook/opt-350m" +_CONFIG_FOR_DOC = "OPTConfig" + +# Base model docstring +_EXPECTED_OUTPUT_SHAPE = [1, 8, 1024] + +# SequenceClassification docstring +_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "ArthurZ/opt-350m-dummy-sc" +_SEQ_CLASS_EXPECTED_LOSS = 1.71 +_SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_0'" + + +class OPTLearnedPositionalEmbedding(nn.Embedding): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int): + # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. Other models don't have this hack + self.offset = 2 + super().__init__(num_embeddings + self.offset, embedding_dim) + + def forward( + self, + attention_mask: torch.LongTensor, + past_key_values_length: int = 0, + position_ids: Optional[torch.LongTensor] = None, + ): + """`input_ids_shape` is expected to be [bsz x seqlen].""" + + if position_ids is None: + position_ids = torch.cumsum(attention_mask, dim=1) + position_ids = (position_ids * attention_mask - 1).long() + # cut positions if `past_key_values_length` is > 0 + position_ids = position_ids[:, past_key_values_length:] + + return super().forward(position_ids + self.offset) + + +class OPTAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + config: OPTConfig, + layer_idx: Optional[int] = None, + **kwargs, + ): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.dropout = config.attention_dropout + self.enable_bias = config.enable_bias + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.head_dim = self.embed_dim // self.num_heads + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {self.num_heads})." + ) + self.scaling = self.head_dim**-0.5 + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias) + + def forward( + self, + hidden_states: torch.Tensor, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + # isn't needed in normal attention, but needed in flash attention so to keep the signature same + position_ids: Optional[torch.Tensor] = None, + cache_position: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + """Input shape: Batch x Time x Channel""" + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + + attn_weights = torch.matmul(query_states, key_states.transpose(3, 2)) + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attn_output = torch.matmul(attn_probs, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous() + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned aross GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + attn_output = self.out_proj(attn_output) + + return attn_output, attn_probs, past_key_value + + +class OptFlashAttention2(OPTAttention): + """ + OPT flash attention module. This module inherits from `OPTAttention` as the weights of the module stays untouched. + The only required change would be on the forward pass where it needs to correctly call the public API of flash + attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() + + def forward( + self, + hidden_states: torch.Tensor, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + position_ids: Optional[torch.Tensor] = None, + cache_position: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + bsz, query_length, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim) + + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + + attn_dropout = self.dropout if self.training else 0.0 + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + query_length, + position_ids=position_ids, + dropout=attn_dropout, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) + + attn_weights_reshaped = attn_output.reshape(bsz, query_length, self.num_heads * self.head_dim) + attn_output = self.out_proj(attn_weights_reshaped) + + if not output_attentions: + attn_weights_reshaped = None + + return attn_output, attn_weights_reshaped, past_key_value + + +class OPTSdpaAttention(OPTAttention): + """ + OPT sdpa attention module. This module inherits from `OPTAttention` as the weights of the module stays untouched. + The only required change would be on the forward pass where it needs to correctly call the public API of sdpa + attention and deal with padding tokens in case the input contains any of them. + """ + + def forward( + self, + hidden_states: torch.Tensor, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + position_ids: Optional[torch.Tensor] = None, + cache_position: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions or layer_head_mask is not None: + logger.warning_once( + "OPTModel is using SDPA attention, which currently does not support output_attentions=True." + 'failing back to eager attention. remove warning using attn_implementation="eager".' + ) + + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + cache_position=cache_position, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) + attn_output = self.out_proj(attn_output) + + return attn_output, None, past_key_value + + +OPT_ATTENTION_CLASSES = { + "eager": OPTAttention, + "flash_attention_2": OptFlashAttention2, + "sdpa": OPTSdpaAttention, +} + + +class OPTDecoderLayer(nn.Module): + def __init__(self, config: OPTConfig, layer_idx: Optional[int] = None): + super().__init__() + self.embed_dim = config.hidden_size + + self.self_attn = OPT_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + + self.do_layer_norm_before = config.do_layer_norm_before + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + + self.self_attn_layer_norm = nn.LayerNorm( + self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine + ) + self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim, bias=config.enable_bias) + self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim, bias=config.enable_bias) + self.final_layer_norm = nn.LayerNorm(self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + position_ids: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.Tensor] = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`, *optional*): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence.. + """ + + residual = hidden_states + + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=past_key_value, + position_ids=position_ids, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + cache_position=cache_position, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + # 350m applies layer norm AFTER attention + if not self.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Fully Connected + hidden_states_shape = hidden_states.shape + hidden_states = hidden_states.reshape(-1, hidden_states.size(-1)) + residual = hidden_states + + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + hidden_states = (residual + hidden_states).view(hidden_states_shape) + + # 350m applies layer norm AFTER attention + if not self.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +OPT_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`OPTConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare OPT Model outputting raw hidden-states without any specific head on top.", + OPT_START_DOCSTRING, +) +class OPTPreTrainedModel(PreTrainedModel): + config_class = OPTConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["OPTDecoderLayer"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + + def _init_weights(self, module): + std = self.config.init_std + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + + +OPT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. for padding use -1. + + [What are position IDs?](../glossary#position-ids) + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +class OPTDecoder(OPTPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`OPTDecoderLayer`] + + Args: + config: OPTConfig + """ + + def __init__(self, config: OPTConfig): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.layerdrop + self.padding_idx = config.pad_token_id + self.max_target_positions = config.max_position_embeddings + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.word_embed_proj_dim, self.padding_idx) + self.embed_positions = OPTLearnedPositionalEmbedding(config.max_position_embeddings, config.hidden_size) + + if config.word_embed_proj_dim != config.hidden_size: + self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False) + else: + self.project_out = None + + if config.word_embed_proj_dim != config.hidden_size: + self.project_in = nn.Linear(config.word_embed_proj_dim, config.hidden_size, bias=False) + else: + self.project_in = None + + # Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility + # with checkpoints that have been fine-tuned before transformers v4.20.1 + # see https://github.com/facebookresearch/metaseq/pull/164 + if config.do_layer_norm_before and not config._remove_final_layer_norm: + self.final_layer_norm = nn.LayerNorm( + config.hidden_size, elementwise_affine=config.layer_norm_elementwise_affine + ) + else: + self.final_layer_norm = None + + self.layers = nn.ModuleList([OPTDecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, "BlockMask"], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to place the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + position_ids: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.Tensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. for padding use -1. + + [What are position IDs?](../glossary#position-ids) + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if input_ids is not None: + input_ids = input_ids.view(-1, input_ids.shape[-1]) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + if past_key_values is None: + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.53.0. " + "You should pass an instance of `DynamicCache` instead, e.g. " + "`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`." + ) + + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if attention_mask is None: + seq_length = past_seen_tokens + inputs_embeds.shape[1] + attention_mask = torch.ones(inputs_embeds.shape[0], seq_length, device=inputs_embeds.device) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + # embed positions + if position_ids is None: + # position_ids = cache_position.unsqueeze(0) + position_ids = torch.cumsum(attention_mask, dim=1) + position_ids = (position_ids * attention_mask - 1).long() + # cut positions if `past_seen_tokens` is > 0 + position_ids = position_ids[:, past_seen_tokens:] + + pos_embeds = self.embed_positions(attention_mask, past_seen_tokens, position_ids=position_ids) + + if self.project_in is not None: + inputs_embeds = self.project_in(inputs_embeds) + + hidden_states = inputs_embeds + pos_embeds.to(inputs_embeds.device) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + # check if head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask], ["head_mask"]): + if attn_mask is not None: + if attn_mask.size()[0] != (len(self.layers)): + raise ValueError( + f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: + continue + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + head_mask[idx] if head_mask is not None else None, + None, + output_attentions, + use_cache, + position_ids, + cache_position, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if self.final_layer_norm is not None: + hidden_states = self.final_layer_norm(hidden_states) + + if self.project_out is not None: + hidden_states = self.project_out(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +@add_start_docstrings( + "The bare OPT Model outputting raw hidden-states without any specific head on top.", + OPT_START_DOCSTRING, +) +class OPTModel(OPTPreTrainedModel): + def __init__(self, config: OPTConfig): + super().__init__(config) + self.decoder = OPTDecoder(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.decoder.embed_tokens = value + + def get_decoder(self): + return self.decoder + + @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPast, + config_class=_CONFIG_FOR_DOC, + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + position_ids: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.Tensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + if not return_dict: + return decoder_outputs + + return BaseModelOutputWithPast( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + hidden_states=decoder_outputs.hidden_states, + attentions=decoder_outputs.attentions, + ) + + +class OPTForCausalLM(OPTPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = OPTModel(config) + + # the lm_head weight is automatically tied to the embed tokens weight + self.lm_head = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.model.decoder.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model.decoder = decoder + + def get_decoder(self): + return self.model.decoder + + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + position_ids: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional + tensors are only required when the model is used as a decoder in a Sequence to Sequence model. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. for padding use -1. + + [What are position IDs?](../glossary#position-ids) + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, OPTForCausalLM + + >>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m") + >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious. I'm just a little bit of a weirdo." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + logits = self.lm_head(outputs[0]).contiguous() + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + loss = self.loss_function( + logits, + labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + + +@add_start_docstrings( + """ + The OPT Model transformer with a sequence classification head on top (linear layer). + + [`OPTForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + OPT_START_DOCSTRING, +) +class OPTForSequenceClassification(OPTPreTrainedModel): + def __init__(self, config: OPTConfig): + super().__init__(config) + self.num_labels = config.num_labels + self.model = OPTModel(config) + self.score = nn.Linear(config.word_embed_proj_dim, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION, + output_type=SequenceClassifierOutputWithPast, + config_class=_CONFIG_FOR_DOC, + expected_output=_SEQ_CLASS_EXPECTED_OUTPUT, + expected_loss=_SEQ_CLASS_EXPECTED_LOSS, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size, sequence_length = input_ids.shape[:2] + else: + batch_size, sequence_length = inputs_embeds.shape[:2] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) + last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) + else: + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + def get_input_embeddings(self): + return self.model.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.model.decoder.embed_tokens = value + + +@add_start_docstrings( + """ + The OPT Model transformer with a span classification head on top for extractive question-answering tasks like SQuAD + (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + OPT_START_DOCSTRING, +) +class OPTForQuestionAnswering(OPTPreTrainedModel): + def __init__(self, config: OPTConfig): + super().__init__(config) + self.model = OPTModel(config) + self.qa_outputs = nn.Linear(config.word_embed_proj_dim, 2) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=QuestionAnsweringModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, OPTForQuestionAnswering + >>> import torch + + >>> torch.manual_seed(4) # doctest: +IGNORE_RESULT + >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") + + >>> # note: we are loading a OPTForQuestionAnswering from the hub here, + >>> # so the head will be randomly initialized, hence the predictions will be random + >>> model = OPTForQuestionAnswering.from_pretrained("facebook/opt-350m") + + >>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet" + + >>> inputs = tokenizer(question, text, return_tensors="pt") + >>> with torch.no_grad(): + ... outputs = model(**inputs) + + >>> answer_start_index = outputs.start_logits.argmax() + >>> answer_end_index = outputs.end_logits.argmax() + + >>> answer_offset = len(tokenizer(question)[0]) + + >>> predict_answer_tokens = inputs.input_ids[ + ... 0, answer_offset + answer_start_index : answer_offset + answer_end_index + 1 + ... ] + >>> predicted = tokenizer.decode(predict_answer_tokens) + >>> predicted + ' a nice puppet' + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + + logits = self.qa_outputs(hidden_states) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index).to(logits.device) + end_positions = end_positions.clamp(0, ignored_index).to(logits.device) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + transformer_outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + def get_input_embeddings(self): + return self.model.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.model.decoder.embed_tokens = value + + +__all__ = [ + "OPTForCausalLM", + "OPTModel", + "OPTPreTrainedModel", + "OPTForSequenceClassification", + "OPTForQuestionAnswering", +] diff --git a/docs/transformers/src/transformers/models/opt/modeling_tf_opt.py b/docs/transformers/src/transformers/models/opt/modeling_tf_opt.py new file mode 100644 index 0000000000000000000000000000000000000000..ec5ef7616a3c91d40f1187273363424a05085625 --- /dev/null +++ b/docs/transformers/src/transformers/models/opt/modeling_tf_opt.py @@ -0,0 +1,1097 @@ +# coding=utf-8 +# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TF 2.0 OPT model.""" + +from __future__ import annotations + +from typing import Optional, Tuple, Union + +import numpy as np +import tensorflow as tf + +from ...activations_tf import get_tf_activation +from ...modeling_tf_outputs import TFBaseModelOutputWithPast, TFCausalLMOutputWithPast + +# Public API +from ...modeling_tf_utils import ( + TFCausalLanguageModelingLoss, + TFModelInputType, + TFPreTrainedModel, + TFSharedEmbeddings, + keras, + keras_serializable, + unpack_inputs, +) +from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax +from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from .configuration_opt import OPTConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "facebook/opt-350m" +_CONFIG_FOR_DOC = "OPTConfig" + +# Base model docstring +_EXPECTED_OUTPUT_SHAPE = [1, 8, 1024] + +# Causal LM output +_CAUSAL_LM_EXPECTED_OUTPUT = ( + "Hey, are you conscious? Can you talk to me?\nI'm not conscious. I'm just a little bit of a weirdo." +) + +LARGE_NEGATIVE = -1e8 + + +def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: int = 0): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz = input_ids_shape[0] + tgt_len = input_ids_shape[1] + # We need triu with k = 1 but TF expects known compile-time dims for that, so we hack around it + mask = tf.fill((tgt_len, tgt_len), tf.cast(LARGE_NEGATIVE, tf.float32)) + mask = tf.linalg.band_part(mask, 0, -1) - tf.linalg.band_part(mask, 0, 0) + + if past_key_values_length > 0: + mask = tf.concat([tf.zeros((tgt_len, past_key_values_length)), mask], axis=-1) + + return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1)) + + +# Copied from transformers.models.bart.modeling_tf_bart._expand_mask +def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + src_len = shape_list(mask)[1] + tgt_len = tgt_len if tgt_len is not None else src_len + one_cst = tf.constant(1.0) + mask = tf.cast(mask, dtype=one_cst.dtype) + expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1)) + + return (one_cst - expanded_mask) * LARGE_NEGATIVE + + +class TFOPTLearnedPositionalEmbedding(keras.layers.Embedding): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int, **kwargs): + # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. Other models don't have this hack + self.offset = 2 + super().__init__(num_embeddings + self.offset, embedding_dim, **kwargs) + + def call(self, attention_mask, past_key_values_length: int = 0): + """`input_ids_shape` is expected to be [bsz x seqlen].""" + attention_mask = tf.cast(attention_mask, tf.int64) + + # create positions depending on attention_mask + positions = tf.math.cumsum(attention_mask, axis=1) * attention_mask - 1 + + # cut positions if `past_key_values_length` is > 0 + positions = positions[:, past_key_values_length:] + + return super().call(positions + self.offset) + + +# Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention with Bart->OPT +class TFOPTAttention(keras.layers.Layer): + """Multi-headed attention from "Attention Is All You Need""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + **kwargs, + ): + super().__init__(**kwargs) + self.embed_dim = embed_dim + + self.num_heads = num_heads + self.dropout = keras.layers.Dropout(dropout) + self.head_dim = embed_dim // num_heads + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + + self.k_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="k_proj") + self.q_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="q_proj") + self.v_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="v_proj") + self.out_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="out_proj") + + def _shape(self, tensor: tf.Tensor, seq_len: int, bsz: int): + return tf.transpose(tf.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)), (0, 2, 1, 3)) + + def call( + self, + hidden_states: tf.Tensor, + key_value_states: tf.Tensor | None = None, + past_key_value: Tuple[Tuple[tf.Tensor]] | None = None, + attention_mask: tf.Tensor | None = None, + layer_head_mask: tf.Tensor | None = None, + training: Optional[bool] = False, + ) -> Tuple[tf.Tensor, tf.Tensor | None]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + bsz, tgt_len, embed_dim = shape_list(hidden_states) + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = tf.concat([past_key_value[0], key_states], axis=2) + value_states = tf.concat([past_key_value[1], value_states], axis=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = tf.reshape(self._shape(query_states, tgt_len, bsz), proj_shape) + key_states = tf.reshape(key_states, proj_shape) + value_states = tf.reshape(value_states, proj_shape) + + src_len = shape_list(key_states)[1] + attn_weights = tf.matmul(query_states, key_states, transpose_b=True) + + tf.debugging.assert_equal( + shape_list(attn_weights), + [bsz * self.num_heads, tgt_len, src_len], + message=( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {shape_list(attn_weights)}" + ), + ) + + if attention_mask is not None: + tf.debugging.assert_equal( + shape_list(attention_mask), + [bsz, 1, tgt_len, src_len], + message=( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" + f" {shape_list(attention_mask)}" + ), + ) + + attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype) + attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask + attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) + + attn_weights = stable_softmax(attn_weights, axis=-1) + + if layer_head_mask is not None: + tf.debugging.assert_equal( + shape_list(layer_head_mask), + [self.num_heads], + message=( + f"Head mask for a single layer should be of size {(self.num_heads)}, but is" + f" {shape_list(layer_head_mask)}" + ), + ) + + attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape( + attn_weights, (bsz, self.num_heads, tgt_len, src_len) + ) + attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) + + attn_probs = self.dropout(attn_weights, training=training) + attn_output = tf.matmul(attn_probs, value_states) + + tf.debugging.assert_equal( + shape_list(attn_output), + [bsz * self.num_heads, tgt_len, self.head_dim], + message=( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {shape_list(attn_output)}" + ), + ) + + attn_output = tf.transpose( + tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3) + ) + attn_output = tf.reshape(attn_output, (bsz, tgt_len, embed_dim)) + + attn_output = self.out_proj(attn_output) + attn_weights: tf.Tensor = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + + return attn_output, attn_weights, past_key_value + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "k_proj", None) is not None: + with tf.name_scope(self.k_proj.name): + self.k_proj.build([None, None, self.embed_dim]) + if getattr(self, "q_proj", None) is not None: + with tf.name_scope(self.q_proj.name): + self.q_proj.build([None, None, self.embed_dim]) + if getattr(self, "v_proj", None) is not None: + with tf.name_scope(self.v_proj.name): + self.v_proj.build([None, None, self.embed_dim]) + if getattr(self, "out_proj", None) is not None: + with tf.name_scope(self.out_proj.name): + self.out_proj.build([None, None, self.embed_dim]) + + +class TFOPTDecoderLayer(keras.layers.Layer): + def __init__(self, config: OPTConfig, **kwargs): + super().__init__(**kwargs) + self.do_layer_norm_before = config.do_layer_norm_before + self.embed_dim = config.hidden_size + self.self_attn = TFOPTAttention( + embed_dim=self.embed_dim, + num_heads=config.num_attention_heads, + dropout=config.attention_dropout, + name="self_attn", + is_decoder=True, + ) + self.dropout = keras.layers.Dropout(config.dropout) + self.activation_fn = get_tf_activation(config.activation_function) + + self.self_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm") + self.fc1 = keras.layers.Dense(config.ffn_dim, name="fc1") + self.fc2 = keras.layers.Dense(self.embed_dim, name="fc2") + self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm") + self.config = config + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: np.ndarray | tf.Tensor | None = None, + layer_head_mask: tf.Tensor | None = None, + past_key_value: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None, + training: Optional[bool] = False, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + ) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]: + """ + Args: + hidden_states (`tf.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`tf.Tensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`tf.Tensor`, *optional*): mask for attention heads in a given layer of size + `(decoder_attention_heads,)` + past_key_value (`Tuple(tf.Tensor)`, *optional*): cached past key and value projection states + training (`bool`, *optional*, defaults to `False`): + Whether or not to use the model in training mode (some modules like dropout modules have different + behaviors between training and evaluation). + """ + residual = hidden_states + + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + + # add present self-attn cache to positions 1,2 of present_key_value tuple + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=self_attn_past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + ) + hidden_states = self.dropout(hidden_states, training=training) + hidden_states = residual + hidden_states + + # 350m applies layer norm AFTER attention + if not self.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Fully Connected + residual = hidden_states + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + + hidden_states = self.fc2(hidden_states) + hidden_states = self.dropout(hidden_states, training=training) + hidden_states = residual + hidden_states + + # 350m applies layer norm AFTER attention + if not self.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + return (hidden_states, self_attn_weights, present_key_value) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "self_attn", None) is not None: + with tf.name_scope(self.self_attn.name): + self.self_attn.build(None) + if getattr(self, "self_attn_layer_norm", None) is not None: + with tf.name_scope(self.self_attn_layer_norm.name): + self.self_attn_layer_norm.build([None, None, self.embed_dim]) + if getattr(self, "fc1", None) is not None: + with tf.name_scope(self.fc1.name): + self.fc1.build([None, None, self.embed_dim]) + if getattr(self, "fc2", None) is not None: + with tf.name_scope(self.fc2.name): + self.fc2.build([None, None, self.config.ffn_dim]) + if getattr(self, "final_layer_norm", None) is not None: + with tf.name_scope(self.final_layer_norm.name): + self.final_layer_norm.build([None, None, self.embed_dim]) + + +OPT_START_DOCSTRING = r""" + This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it + as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and + behavior. + + + + TensorFlow models and layers in `transformers` accept two formats as input: + + - having all inputs as keyword arguments (like PyTorch models), or + - having all inputs as a list, tuple or dict in the first positional argument. + + The reason the second format is supported is that Keras methods prefer this format when passing inputs to models + and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just + pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second + format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with + the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first + positional argument: + + - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` + - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: + `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` + - a dictionary with one or several input Tensors associated to the input names given in the docstring: + `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` + + Note that when creating models and layers with + [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry + about any of this, as you can just pass inputs like you would to any other Python function! + + + + Args: + config ([`OPTConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare OPT Model outputting raw hidden-states without any specific head on top.", + OPT_START_DOCSTRING, +) +class TFOPTPreTrainedModel(TFPreTrainedModel): + """ + TFOPT Pretrained Model that inheritates from transformers.TFPreTrainedModel + + Args: + config: OPTConfig + """ + + config_class = OPTConfig + base_model_prefix = "model" + + +OPT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`tf.Tensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`tf.Tensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`Tuple[Tuple[tf.Tensor]]` of length `config.n_layers`) + contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*, defaults to `True`): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). Set to `False` during training, `True` during generation + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the + config will be used instead. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. This argument can be used only in eager mode, in graph mode the value in the config will be + used instead. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in + eager mode, in graph mode the value will always be set to True. + training (`bool`, *optional*, defaults to `False`): + Whether or not to use the model in training mode (some modules like dropout modules have different + behaviors between training and evaluation). +""" + + +@keras_serializable +class TFOPTDecoder(keras.layers.Layer): + config_class = OPTConfig + + def __init__(self, config: OPTConfig, **kwargs): + super().__init__(**kwargs) + self.config = config + self.padding_idx = config.pad_token_id + self.layerdrop = config.layerdrop + num_embeddings = config.max_position_embeddings + self.embed_tokens = TFSharedEmbeddings( + config.vocab_size, config.word_embed_proj_dim, config.pad_token_id, name="embed_tokens" + ) + self.embed_positions = TFOPTLearnedPositionalEmbedding( + num_embeddings, + config.hidden_size, + name="embed_positions", + ) + + # Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility + # with checkpoints that have been fine-tuned before transformers v4.20.1 + # see https://github.com/facebookresearch/metaseq/pull/164 + if config.do_layer_norm_before and not config._remove_final_layer_norm: + self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm") + else: + self.final_layer_norm = None + + if config.word_embed_proj_dim != config.hidden_size: + self.project_out = keras.layers.Dense(config.word_embed_proj_dim, name="project_out", use_bias=False) + self.project_in = keras.layers.Dense(config.hidden_size, name="project_in", use_bias=False) + + else: + self.project_in = None + self.project_out = None + + self.layers = [TFOPTDecoderLayer(config, name=f"layers.{i}") for i in range(config.num_hidden_layers)] + self.dropout = keras.layers.Dropout(config.dropout) + + def get_embed_tokens(self): + return self.embed_tokens + + def set_embed_tokens(self, embed_tokens): + self.embed_tokens = embed_tokens + + def set_input_embeddings(self, new_embeddings): + self.embed_tokens.vocab_size = new_embeddings.shape[0] + self.embed_tokens.weight = new_embeddings + + def get_input_embeddings(self): + return self.embed_tokens + + def _prepare_decoder_attention_mask(self, attention_mask, input_shape, past_key_values_length): + # create causal mask + # # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + _, seq_length = input_shape + tf.debugging.assert_equal( + seq_length + past_key_values_length, + shape_list(attention_mask)[1], + message="Attention mask shape should be (batch_size, seq_length + past_key_values_length)" + f" but is {shape_list(attention_mask)[1]} with input_ids shape {input_shape} and past length" + f" {past_key_values_length}.", + ) + + expanded_attn_mask = _expand_mask(attention_mask, tgt_len=input_shape[-1]) + if seq_length > 1: + combined_attention_mask = ( + _make_causal_mask(input_shape, past_key_values_length=past_key_values_length) + expanded_attn_mask + ) + else: + combined_attention_mask = expanded_attn_mask + + return combined_attention_mask + + @unpack_inputs + def call( + self, + input_ids: TFModelInputType | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: Optional[bool] = False, + ) -> Union[TFBaseModelOutputWithPast, Tuple[tf.Tensor]]: + r""" + Args: + input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`Tuple[Tuple[tf.Tensor]]` of length `config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up + decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`tf.Tensor` of + shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing + `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more + control over how to convert `input_ids` indices into associated vectors than the model's internal + embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + training (`bool`, *optional*, defaults to `False`): + Whether or not to use the model in training mode (some modules like dropout modules have different + behaviors between training and evaluation). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input_shape = shape_list(input_ids) + elif inputs_embeds is not None: + input_shape = shape_list(inputs_embeds)[:-1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + past_key_values_length = shape_list(past_key_values[0][0])[2] if past_key_values is not None else 0 + + if inputs_embeds is None: + check_embeddings_within_bounds(input_ids, self.embed_tokens.vocab_size) + inputs_embeds = self.embed_tokens(input_ids) + + if attention_mask is None: + attention_mask = tf.ones((input_shape[0], input_shape[1] + past_key_values_length), dtype=tf.bool) + else: + tf.debugging.assert_equal( + shape_list(attention_mask)[1], + past_key_values_length + input_shape[1], + message=( + f"The provided attention mask has length {tf.shape(attention_mask)[1]}, but its length should be " + f"{past_key_values_length + input_shape[1]} (sum of the lengths of current and past inputs)" + ), + ) + pos_embeds = self.embed_positions(attention_mask, past_key_values_length) + + attention_mask = self._prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values_length) + + if self.project_in is not None: + inputs_embeds = self.project_in(inputs_embeds) + + hidden_states = inputs_embeds + pos_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + present_key_values = () if use_cache else None + + # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired + for attn_mask_name, attn_mask in [("head_mask", head_mask)]: + if attn_mask is not None: + tf.debugging.assert_equal( + shape_list(attn_mask)[0], + len(self.layers), + message=( + f"The {attn_mask_name} should be specified for {len(self.layers)} layers, but it is for" + f" {shape_list(attn_mask)[0]}." + ), + ) + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + hidden_states, layer_self_attn, present_key_value = decoder_layer( + hidden_states, + attention_mask=attention_mask, + layer_head_mask=head_mask[idx] if head_mask is not None else None, + past_key_value=past_key_value, + ) + + if use_cache: + present_key_values += (present_key_value,) + + if output_attentions: + all_self_attns += (layer_self_attn,) + + if self.final_layer_norm is not None: + hidden_states = self.final_layer_norm(hidden_states) + + if self.project_out is not None: + hidden_states = self.project_out(hidden_states) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple( + v for v in [hidden_states, present_key_values, all_hidden_states, all_self_attns] if v is not None + ) + + else: + return TFBaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=present_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "embed_tokens", None) is not None: + with tf.name_scope(self.embed_tokens.name): + self.embed_tokens.build(None) + if getattr(self, "embed_positions", None) is not None: + with tf.name_scope(self.embed_positions.name): + self.embed_positions.build(None) + if getattr(self, "final_layer_norm", None) is not None: + with tf.name_scope(self.final_layer_norm.name): + self.final_layer_norm.build([None, None, self.config.hidden_size]) + if getattr(self, "project_out", None) is not None: + with tf.name_scope(self.project_out.name): + self.project_out.build([None, None, self.config.hidden_size]) + if getattr(self, "project_in", None) is not None: + with tf.name_scope(self.project_in.name): + self.project_in.build([None, None, self.config.word_embed_proj_dim]) + if getattr(self, "layers", None) is not None: + for layer in self.layers: + with tf.name_scope(layer.name): + layer.build(None) + + +@keras_serializable +class TFOPTMainLayer(keras.layers.Layer): + config_class = OPTConfig + + def __init__(self, config: OPTConfig, **kwargs): + super().__init__(**kwargs) + self.config = config + self.decoder = TFOPTDecoder(config, name="decoder") + + def get_input_embeddings(self): + return self.decoder.embed_tokens + + def set_input_embeddings(self, new_embeddings): + self.decoder.set_input_embeddings(new_embeddings) + + @unpack_inputs + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: Optional[bool] = False, + **kwargs, + ) -> Union[TFBaseModelOutputWithPast, Tuple[tf.Tensor]]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.decoder( + input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + if not return_dict: + return outputs + + return TFBaseModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "decoder", None) is not None: + with tf.name_scope(self.decoder.name): + self.decoder.build(None) + + +@add_start_docstrings( + "The bare TF OPT Model outputting raw hidden-states without any specific head on top.", + OPT_START_DOCSTRING, +) +@keras_serializable +class TFOPTModel(TFOPTPreTrainedModel): + config_class = OPTConfig + + def __init__(self, config: OPTConfig, **kwargs): + super().__init__(config, **kwargs) + self.config = config + self.model = TFOPTMainLayer(config, name="model") + + def get_input_embeddings(self): + return self.model.decoder.embed_tokens + + def set_input_embeddings(self, new_embeddings): + self.model.set_input_embeddings(new_embeddings) + + @unpack_inputs + @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFBaseModelOutputWithPast, + config_class=_CONFIG_FOR_DOC, + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: Optional[bool] = False, + **kwargs, + ) -> Union[TFBaseModelOutputWithPast, Tuple[tf.Tensor]]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + if not return_dict: + return outputs + + return TFBaseModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def serving_output(self, output): + pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None + hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None + attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None + + return TFBaseModelOutputWithPast( + last_hidden_state=output.last_hidden_state, + past_key_values=pkv, + hidden_states=hs, + attentions=attns, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "model", None) is not None: + with tf.name_scope(self.model.name): + self.model.build(None) + + +@add_start_docstrings( + """ + The OPT Model transformer with a language modeling head on top. + """, + OPT_START_DOCSTRING, +) +@keras_serializable +class TFOPTForCausalLM(TFOPTPreTrainedModel, TFCausalLanguageModelingLoss): + config_class = OPTConfig + + def __init__(self, config: OPTConfig, **kwargs): + super().__init__(config, **kwargs) + self.config = config + self.model = TFOPTMainLayer(config, name="model") + + def get_output_embeddings(self): + return self.model.get_input_embeddings() + + def prepare_inputs_for_generation(self, inputs, past_key_values=None, use_cache=None, **kwargs): + attention_mask = kwargs.get("attention_mask", None) + + # only last token for inputs_ids if past is defined in kwargs + if past_key_values: + inputs = tf.expand_dims(inputs[:, -1], -1) + + return { + "input_ids": inputs, + "attention_mask": attention_mask, + "past_key_values": past_key_values, + "use_cache": use_cache, + } + + @unpack_inputs + @replace_return_docstrings(output_type=TFCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFCausalLMOutputWithPast, + config_class=_CONFIG_FOR_DOC, + expected_output=_CAUSAL_LM_EXPECTED_OUTPUT, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + labels: np.ndarray | tf.Tensor | None = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: Optional[bool] = False, + **kwargs, + ) -> Union[TFCausalLMOutputWithPast, Tuple[tf.Tensor]]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional + tensors are only required when the model is used as a decoder in a Sequence to Sequence model. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + logits = self.model.decoder.embed_tokens(outputs[0], mode="linear") + loss = None + if labels is not None: + # shift labels to the left and cut last logit token + shifted_logits = logits[:, :-1] + labels = labels[:, 1:] + loss = self.hf_compute_loss(labels, shifted_logits) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return TFCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def serving_output(self, output): + pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None + hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None + attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None + + return TFCausalLMOutputWithPast( + past_key_values=pkv, + hidden_states=hs, + attentions=attns, + loss=output.loss, + logits=output.logits, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "model", None) is not None: + with tf.name_scope(self.model.name): + self.model.build(None) + + +__all__ = ["TFOPTForCausalLM", "TFOPTModel", "TFOPTPreTrainedModel"] diff --git a/docs/transformers/src/transformers/models/owlv2/__init__.py b/docs/transformers/src/transformers/models/owlv2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6ce329420882f7b02546a85addfca96dc19d35aa --- /dev/null +++ b/docs/transformers/src/transformers/models/owlv2/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_owlv2 import * + from .image_processing_owlv2 import * + from .modeling_owlv2 import * + from .processing_owlv2 import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/owlv2/configuration_owlv2.py b/docs/transformers/src/transformers/models/owlv2/configuration_owlv2.py new file mode 100644 index 0000000000000000000000000000000000000000..bf70afc7ccf527cc87bf8aa1f0b0016799c8bf9b --- /dev/null +++ b/docs/transformers/src/transformers/models/owlv2/configuration_owlv2.py @@ -0,0 +1,289 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""OWLv2 model configuration""" + +from typing import TYPE_CHECKING, Dict + + +if TYPE_CHECKING: + pass + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +# Copied from transformers.models.owlvit.configuration_owlvit.OwlViTTextConfig with OwlViT->Owlv2, owlvit-base-patch32->owlv2-base-patch16, owlvit->owlv2, OWL-ViT->OWLv2 +class Owlv2TextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of an [`Owlv2TextModel`]. It is used to instantiate an + Owlv2 text encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the Owlv2 + [google/owlv2-base-patch16](https://huggingface.co/google/owlv2-base-patch16) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 49408): + Vocabulary size of the OWLv2 text model. Defines the number of different tokens that can be represented + by the `inputs_ids` passed when calling [`Owlv2TextModel`]. + hidden_size (`int`, *optional*, defaults to 512): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 2048): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer encoder. + max_position_embeddings (`int`, *optional*, defaults to 16): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + initializer_factor (`float`, *optional*, defaults to 1.0): + A factor for initializing all weight matrices (should be kept to 1, used internally for initialization + testing). + pad_token_id (`int`, *optional*, defaults to 0): + The id of the padding token in the input sequences. + bos_token_id (`int`, *optional*, defaults to 49406): + The id of the beginning-of-sequence token in the input sequences. + eos_token_id (`int`, *optional*, defaults to 49407): + The id of the end-of-sequence token in the input sequences. + + Example: + + ```python + >>> from transformers import Owlv2TextConfig, Owlv2TextModel + + >>> # Initializing a Owlv2TextModel with google/owlv2-base-patch16 style configuration + >>> configuration = Owlv2TextConfig() + + >>> # Initializing a Owlv2TextConfig from the google/owlv2-base-patch16 style configuration + >>> model = Owlv2TextModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "owlv2_text_model" + base_config_key = "text_config" + + def __init__( + self, + vocab_size=49408, + hidden_size=512, + intermediate_size=2048, + num_hidden_layers=12, + num_attention_heads=8, + max_position_embeddings=16, + hidden_act="quick_gelu", + layer_norm_eps=1e-5, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + pad_token_id=0, + bos_token_id=49406, + eos_token_id=49407, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.max_position_embeddings = max_position_embeddings + self.hidden_act = hidden_act + self.layer_norm_eps = layer_norm_eps + self.attention_dropout = attention_dropout + self.initializer_range = initializer_range + self.initializer_factor = initializer_factor + + +# Copied from transformers.models.owlvit.configuration_owlvit.OwlViTVisionConfig with OwlViT->Owlv2, owlvit-base-patch32->owlv2-base-patch16, owlvit->owlv2, OWL-ViT->OWLv2, 32->16 +class Owlv2VisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of an [`Owlv2VisionModel`]. It is used to instantiate + an OWLv2 image encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the OWLv2 + [google/owlv2-base-patch16](https://huggingface.co/google/owlv2-base-patch16) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + num_channels (`int`, *optional*, defaults to 3): + Number of channels in the input images. + image_size (`int`, *optional*, defaults to 768): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch. + hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + initializer_factor (`float`, *optional*, defaults to 1.0): + A factor for initializing all weight matrices (should be kept to 1, used internally for initialization + testing). + + Example: + + ```python + >>> from transformers import Owlv2VisionConfig, Owlv2VisionModel + + >>> # Initializing a Owlv2VisionModel with google/owlv2-base-patch16 style configuration + >>> configuration = Owlv2VisionConfig() + + >>> # Initializing a Owlv2VisionModel model from the google/owlv2-base-patch16 style configuration + >>> model = Owlv2VisionModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "owlv2_vision_model" + base_config_key = "vision_config" + + def __init__( + self, + hidden_size=768, + intermediate_size=3072, + num_hidden_layers=12, + num_attention_heads=12, + num_channels=3, + image_size=768, + patch_size=16, + hidden_act="quick_gelu", + layer_norm_eps=1e-5, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.image_size = image_size + self.patch_size = patch_size + self.hidden_act = hidden_act + self.layer_norm_eps = layer_norm_eps + self.attention_dropout = attention_dropout + self.initializer_range = initializer_range + self.initializer_factor = initializer_factor + + +# Copied from transformers.models.owlvit.configuration_owlvit.OwlViTConfig with OwlViT->Owlv2, owlvit-base-patch32->owlv2-base-patch16, owlvit->owlv2, OWL-ViT->OWLv2 +class Owlv2Config(PretrainedConfig): + r""" + [`Owlv2Config`] is the configuration class to store the configuration of an [`Owlv2Model`]. It is used to + instantiate an OWLv2 model according to the specified arguments, defining the text model and vision model + configs. Instantiating a configuration with the defaults will yield a similar configuration to that of the OWLv2 + [google/owlv2-base-patch16](https://huggingface.co/google/owlv2-base-patch16) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + text_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`Owlv2TextConfig`]. + vision_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`Owlv2VisionConfig`]. + projection_dim (`int`, *optional*, defaults to 512): + Dimensionality of text and vision projection layers. + logit_scale_init_value (`float`, *optional*, defaults to 2.6592): + The initial value of the *logit_scale* parameter. Default is used as per the original OWLv2 + implementation. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not the model should return a dictionary. If `False`, returns a tuple. + kwargs (*optional*): + Dictionary of keyword arguments. + """ + + model_type = "owlv2" + sub_configs = {"text_config": Owlv2TextConfig, "vision_config": Owlv2VisionConfig} + + def __init__( + self, + text_config=None, + vision_config=None, + projection_dim=512, + logit_scale_init_value=2.6592, + return_dict=True, + **kwargs, + ): + super().__init__(**kwargs) + + if text_config is None: + text_config = {} + logger.info("text_config is None. Initializing the Owlv2TextConfig with default values.") + + if vision_config is None: + vision_config = {} + logger.info("vision_config is None. initializing the Owlv2VisionConfig with default values.") + + self.text_config = Owlv2TextConfig(**text_config) + self.vision_config = Owlv2VisionConfig(**vision_config) + + self.projection_dim = projection_dim + self.logit_scale_init_value = logit_scale_init_value + self.return_dict = return_dict + self.initializer_factor = 1.0 + + @classmethod + def from_text_vision_configs(cls, text_config: Dict, vision_config: Dict, **kwargs): + r""" + Instantiate a [`Owlv2Config`] (or a derived class) from owlv2 text model configuration and owlv2 vision + model configuration. + + Returns: + [`Owlv2Config`]: An instance of a configuration object + """ + config_dict = {} + config_dict["text_config"] = text_config + config_dict["vision_config"] = vision_config + + return cls.from_dict(config_dict, **kwargs) + + +__all__ = ["Owlv2Config", "Owlv2TextConfig", "Owlv2VisionConfig"] diff --git a/docs/transformers/src/transformers/models/owlv2/convert_owlv2_to_hf.py b/docs/transformers/src/transformers/models/owlv2/convert_owlv2_to_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..69665bab1d51a60a9cc2fc50cbf20621af3f87ff --- /dev/null +++ b/docs/transformers/src/transformers/models/owlv2/convert_owlv2_to_hf.py @@ -0,0 +1,422 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert OWLv2 checkpoints from the original repository. + +URL: https://github.com/google-research/scenic/tree/main/scenic/projects/owl_vit""" + +import argparse +import collections +import os + +import jax +import jax.numpy as jnp +import numpy as np +import torch +from flax.training import checkpoints +from huggingface_hub import hf_hub_download +from PIL import Image + +from transformers import ( + CLIPTokenizer, + Owlv2Config, + Owlv2ForObjectDetection, + Owlv2ImageProcessor, + Owlv2Processor, + Owlv2TextConfig, + Owlv2VisionConfig, +) +from transformers.utils import logging + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + + +def get_owlv2_config(model_name): + if "large" in model_name: + image_size = 1008 + patch_size = 14 + vision_hidden_size = 1024 + vision_intermediate_size = 4096 + vision_num_hidden_layers = 24 + vision_num_attention_heads = 16 + projection_dim = 768 + text_hidden_size = 768 + text_intermediate_size = 3072 + text_num_attention_heads = 12 + text_num_hidden_layers = 12 + else: + image_size = 960 + patch_size = 16 + vision_hidden_size = 768 + vision_intermediate_size = 3072 + vision_num_hidden_layers = 12 + vision_num_attention_heads = 12 + projection_dim = 512 + text_hidden_size = 512 + text_intermediate_size = 2048 + text_num_attention_heads = 8 + text_num_hidden_layers = 12 + + vision_config = Owlv2VisionConfig( + patch_size=patch_size, + image_size=image_size, + hidden_size=vision_hidden_size, + num_hidden_layers=vision_num_hidden_layers, + intermediate_size=vision_intermediate_size, + num_attention_heads=vision_num_attention_heads, + ) + text_config = Owlv2TextConfig( + hidden_size=text_hidden_size, + intermediate_size=text_intermediate_size, + num_attention_heads=text_num_attention_heads, + num_hidden_layers=text_num_hidden_layers, + ) + + config = Owlv2Config( + text_config=text_config.to_dict(), + vision_config=vision_config.to_dict(), + projection_dim=projection_dim, + ) + + return config + + +def flatten_nested_dict(params, parent_key="", sep="/"): + items = [] + + for k, v in params.items(): + new_key = parent_key + sep + k if parent_key else k + + if isinstance(v, collections.MutableMapping): + items.extend(flatten_nested_dict(v, new_key, sep=sep).items()) + else: + items.append((new_key, v)) + return dict(items) + + +# here we list all keys to be renamed (original name on the left, our name on the right) +def create_rename_keys(config, model_name): + rename_keys = [] + + # fmt: off + # CLIP vision encoder + rename_keys.append(("backbone/clip/visual/class_embedding", "owlv2.vision_model.embeddings.class_embedding")) + rename_keys.append(("backbone/clip/visual/conv1/kernel", "owlv2.vision_model.embeddings.patch_embedding.weight")) + rename_keys.append(("backbone/clip/visual/positional_embedding", "owlv2.vision_model.embeddings.position_embedding.weight")) + rename_keys.append(("backbone/clip/visual/ln_pre/scale", "owlv2.vision_model.pre_layernorm.weight")) + rename_keys.append(("backbone/clip/visual/ln_pre/bias", "owlv2.vision_model.pre_layernorm.bias")) + + for i in range(config.vision_config.num_hidden_layers): + if "v2" in model_name: + rename_keys.append((f"backbone/clip/visual/transformer/resblocks.{i}/ln_0/scale", f"owlv2.vision_model.encoder.layers.{i}.layer_norm1.weight")) + rename_keys.append((f"backbone/clip/visual/transformer/resblocks.{i}/ln_0/bias", f"owlv2.vision_model.encoder.layers.{i}.layer_norm1.bias")) + rename_keys.append((f"backbone/clip/visual/transformer/resblocks.{i}/ln_1/scale", f"owlv2.vision_model.encoder.layers.{i}.layer_norm2.weight")) + rename_keys.append((f"backbone/clip/visual/transformer/resblocks.{i}/ln_1/bias", f"owlv2.vision_model.encoder.layers.{i}.layer_norm2.bias")) + else: + rename_keys.append((f"backbone/clip/visual/transformer/resblocks.{i}/ln_1/scale", f"owlv2.vision_model.encoder.layers.{i}.layer_norm1.weight")) + rename_keys.append((f"backbone/clip/visual/transformer/resblocks.{i}/ln_1/bias", f"owlv2.vision_model.encoder.layers.{i}.layer_norm1.bias")) + rename_keys.append((f"backbone/clip/visual/transformer/resblocks.{i}/ln_2/scale", f"owlv2.vision_model.encoder.layers.{i}.layer_norm2.weight")) + rename_keys.append((f"backbone/clip/visual/transformer/resblocks.{i}/ln_2/bias", f"owlv2.vision_model.encoder.layers.{i}.layer_norm2.bias")) + rename_keys.append((f"backbone/clip/visual/transformer/resblocks.{i}/mlp/c_fc/kernel", f"owlv2.vision_model.encoder.layers.{i}.mlp.fc1.weight")) + rename_keys.append((f"backbone/clip/visual/transformer/resblocks.{i}/mlp/c_fc/bias", f"owlv2.vision_model.encoder.layers.{i}.mlp.fc1.bias")) + rename_keys.append((f"backbone/clip/visual/transformer/resblocks.{i}/mlp/c_proj/kernel", f"owlv2.vision_model.encoder.layers.{i}.mlp.fc2.weight")) + rename_keys.append((f"backbone/clip/visual/transformer/resblocks.{i}/mlp/c_proj/bias", f"owlv2.vision_model.encoder.layers.{i}.mlp.fc2.bias")) + rename_keys.append((f"backbone/clip/visual/transformer/resblocks.{i}/attn/query/kernel", f"owlv2.vision_model.encoder.layers.{i}.self_attn.q_proj.weight")) + rename_keys.append((f"backbone/clip/visual/transformer/resblocks.{i}/attn/query/bias", f"owlv2.vision_model.encoder.layers.{i}.self_attn.q_proj.bias")) + rename_keys.append((f"backbone/clip/visual/transformer/resblocks.{i}/attn/key/kernel", f"owlv2.vision_model.encoder.layers.{i}.self_attn.k_proj.weight")) + rename_keys.append((f"backbone/clip/visual/transformer/resblocks.{i}/attn/key/bias", f"owlv2.vision_model.encoder.layers.{i}.self_attn.k_proj.bias")) + rename_keys.append((f"backbone/clip/visual/transformer/resblocks.{i}/attn/value/kernel", f"owlv2.vision_model.encoder.layers.{i}.self_attn.v_proj.weight")) + rename_keys.append((f"backbone/clip/visual/transformer/resblocks.{i}/attn/value/bias", f"owlv2.vision_model.encoder.layers.{i}.self_attn.v_proj.bias")) + rename_keys.append((f"backbone/clip/visual/transformer/resblocks.{i}/attn/out/kernel", f"owlv2.vision_model.encoder.layers.{i}.self_attn.out_proj.weight")) + rename_keys.append((f"backbone/clip/visual/transformer/resblocks.{i}/attn/out/bias", f"owlv2.vision_model.encoder.layers.{i}.self_attn.out_proj.bias")) + + rename_keys.append(("backbone/clip/visual/ln_post/scale", "owlv2.vision_model.post_layernorm.weight")) + rename_keys.append(("backbone/clip/visual/ln_post/bias", "owlv2.vision_model.post_layernorm.bias")) + + # CLIP text encoder + rename_keys.append(("backbone/clip/text/token_embedding/embedding", "owlv2.text_model.embeddings.token_embedding.weight")) + rename_keys.append(("backbone/clip/text/positional_embedding", "owlv2.text_model.embeddings.position_embedding.weight")) + + for i in range(config.text_config.num_hidden_layers): + if "v2" in model_name: + rename_keys.append((f"backbone/clip/text/transformer/resblocks.{i}/ln_0/scale", f"owlv2.text_model.encoder.layers.{i}.layer_norm1.weight")) + rename_keys.append((f"backbone/clip/text/transformer/resblocks.{i}/ln_0/bias", f"owlv2.text_model.encoder.layers.{i}.layer_norm1.bias")) + rename_keys.append((f"backbone/clip/text/transformer/resblocks.{i}/ln_1/scale", f"owlv2.text_model.encoder.layers.{i}.layer_norm2.weight")) + rename_keys.append((f"backbone/clip/text/transformer/resblocks.{i}/ln_1/bias", f"owlv2.text_model.encoder.layers.{i}.layer_norm2.bias")) + else: + rename_keys.append((f"backbone/clip/text/transformer/resblocks.{i}/ln_1/scale", f"owlv2.text_model.encoder.layers.{i}.layer_norm1.weight")) + rename_keys.append((f"backbone/clip/text/transformer/resblocks.{i}/ln_1/bias", f"owlv2.text_model.encoder.layers.{i}.layer_norm1.bias")) + rename_keys.append((f"backbone/clip/text/transformer/resblocks.{i}/ln_2/scale", f"owlv2.text_model.encoder.layers.{i}.layer_norm2.weight")) + rename_keys.append((f"backbone/clip/text/transformer/resblocks.{i}/ln_2/bias", f"owlv2.text_model.encoder.layers.{i}.layer_norm2.bias")) + rename_keys.append((f"backbone/clip/text/transformer/resblocks.{i}/mlp/c_fc/kernel", f"owlv2.text_model.encoder.layers.{i}.mlp.fc1.weight")) + rename_keys.append((f"backbone/clip/text/transformer/resblocks.{i}/mlp/c_fc/bias", f"owlv2.text_model.encoder.layers.{i}.mlp.fc1.bias")) + rename_keys.append((f"backbone/clip/text/transformer/resblocks.{i}/mlp/c_proj/kernel", f"owlv2.text_model.encoder.layers.{i}.mlp.fc2.weight")) + rename_keys.append((f"backbone/clip/text/transformer/resblocks.{i}/mlp/c_proj/bias", f"owlv2.text_model.encoder.layers.{i}.mlp.fc2.bias")) + rename_keys.append((f"backbone/clip/text/transformer/resblocks.{i}/attn/query/kernel", f"owlv2.text_model.encoder.layers.{i}.self_attn.q_proj.weight")) + rename_keys.append((f"backbone/clip/text/transformer/resblocks.{i}/attn/query/bias", f"owlv2.text_model.encoder.layers.{i}.self_attn.q_proj.bias")) + rename_keys.append((f"backbone/clip/text/transformer/resblocks.{i}/attn/key/kernel", f"owlv2.text_model.encoder.layers.{i}.self_attn.k_proj.weight")) + rename_keys.append((f"backbone/clip/text/transformer/resblocks.{i}/attn/key/bias", f"owlv2.text_model.encoder.layers.{i}.self_attn.k_proj.bias")) + rename_keys.append((f"backbone/clip/text/transformer/resblocks.{i}/attn/value/kernel", f"owlv2.text_model.encoder.layers.{i}.self_attn.v_proj.weight")) + rename_keys.append((f"backbone/clip/text/transformer/resblocks.{i}/attn/value/bias", f"owlv2.text_model.encoder.layers.{i}.self_attn.v_proj.bias")) + rename_keys.append((f"backbone/clip/text/transformer/resblocks.{i}/attn/out/kernel", f"owlv2.text_model.encoder.layers.{i}.self_attn.out_proj.weight")) + rename_keys.append((f"backbone/clip/text/transformer/resblocks.{i}/attn/out/bias", f"owlv2.text_model.encoder.layers.{i}.self_attn.out_proj.bias")) + + rename_keys.append(("backbone/clip/text/ln_final/scale", "owlv2.text_model.final_layer_norm.weight")) + rename_keys.append(("backbone/clip/text/ln_final/bias", "owlv2.text_model.final_layer_norm.bias")) + + # logit scale + rename_keys.append(("backbone/clip/logit_scale", "owlv2.logit_scale")) + + # projection heads + rename_keys.append(("backbone/clip/text/text_projection/kernel", "owlv2.text_projection.weight")) + + # class and box heads + rename_keys.append(("backbone/merged_class_token/scale", "layer_norm.weight")) + rename_keys.append(("backbone/merged_class_token/bias", "layer_norm.bias")) + rename_keys.append(("class_head/Dense_0/kernel", "class_head.dense0.weight")) + rename_keys.append(("class_head/Dense_0/bias", "class_head.dense0.bias")) + rename_keys.append(("class_head/logit_shift/kernel", "class_head.logit_shift.weight")) + rename_keys.append(("class_head/logit_scale/kernel", "class_head.logit_scale.weight")) + rename_keys.append(("class_head/logit_scale/bias", "class_head.logit_scale.bias")) + rename_keys.append(("class_head/logit_shift/bias", "class_head.logit_shift.bias")) + rename_keys.append(("obj_box_head/Dense_0/kernel", "box_head.dense0.weight")) + rename_keys.append(("obj_box_head/Dense_0/bias", "box_head.dense0.bias")) + rename_keys.append(("obj_box_head/Dense_1/kernel", "box_head.dense1.weight")) + rename_keys.append(("obj_box_head/Dense_1/bias", "box_head.dense1.bias")) + rename_keys.append(("obj_box_head/Dense_2/kernel", "box_head.dense2.weight")) + rename_keys.append(("obj_box_head/Dense_2/bias", "box_head.dense2.bias")) + + # objectness head (only for v2) + if "v2" in model_name: + rename_keys.append(("objectness_head/Dense_0/kernel", "objectness_head.dense0.weight")) + rename_keys.append(("objectness_head/Dense_0/bias", "objectness_head.dense0.bias")) + rename_keys.append(("objectness_head/Dense_1/kernel", "objectness_head.dense1.weight")) + rename_keys.append(("objectness_head/Dense_1/bias", "objectness_head.dense1.bias")) + rename_keys.append(("objectness_head/Dense_2/kernel", "objectness_head.dense2.weight")) + rename_keys.append(("objectness_head/Dense_2/bias", "objectness_head.dense2.bias")) + + # fmt: on + + return rename_keys + + +def rename_and_reshape_key(dct, old, new, config): + val = dct.pop(old) + + if ("out_proj" in new or "v_proj" in new or "k_proj" in new or "q_proj" in new) and "vision" in new: + val = val.reshape(-1, config.vision_config.hidden_size) + if ("out_proj" in new or "v_proj" in new or "k_proj" in new or "q_proj" in new) and "text" in new: + val = val.reshape(-1, config.text_config.hidden_size) + + if "patch_embedding" in new: + print("Reshaping patch embedding... for", new) + val = val.transpose(3, 2, 0, 1) + elif new.endswith("weight") and "position_embedding" not in new and "token_embedding" not in new: + val = val.T + + if new.endswith("bias"): + val = val.reshape(-1) + + dct[new] = torch.from_numpy(np.array(val)) + + +@torch.no_grad() +def convert_owlv2_checkpoint(model_name, checkpoint_path, pytorch_dump_folder_path, push_to_hub, verify_logits): + """ + Copy/paste/tweak model's weights to our OWL-ViT structure. + """ + config = get_owlv2_config(model_name) + + # see available checkpoints at https://github.com/google-research/scenic/tree/main/scenic/projects/owl_vit#pretrained-checkpoints + variables = checkpoints.restore_checkpoint(checkpoint_path, target=None) + variables = variables["params"] if "v2" in model_name else variables["optimizer"]["target"] + flax_params = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, variables) + state_dict = flatten_nested_dict(flax_params) + + # Rename keys + rename_keys = create_rename_keys(config, model_name) + for src, dest in rename_keys: + rename_and_reshape_key(state_dict, src, dest, config) + + # load HuggingFace model + model = Owlv2ForObjectDetection(config) + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) + assert missing_keys == ["owlv2.visual_projection.weight"] + assert unexpected_keys == [] + model.eval() + + # Initialize image processor + size = {"height": config.vision_config.image_size, "width": config.vision_config.image_size} + image_processor = Owlv2ImageProcessor(size=size) + # Initialize tokenizer + tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32", pad_token="!", model_max_length=16) + # Initialize processor + processor = Owlv2Processor(image_processor=image_processor, tokenizer=tokenizer) + + # Verify pixel_values and input_ids + filepath = hf_hub_download(repo_id="nielsr/test-image", filename="owlvit_pixel_values_960.pt", repo_type="dataset") + original_pixel_values = torch.load(filepath, weights_only=True).permute(0, 3, 1, 2) + + filepath = hf_hub_download(repo_id="nielsr/test-image", filename="owlv2_input_ids.pt", repo_type="dataset") + original_input_ids = torch.load(filepath, weights_only=True).squeeze() + + filepath = hf_hub_download(repo_id="adirik/OWL-ViT", repo_type="space", filename="assets/astronaut.png") + image = Image.open(filepath) + texts = [["face", "rocket", "nasa badge", "star-spangled banner"]] + inputs = processor(text=texts, images=image, return_tensors="pt") + + if "large" not in model_name: + assert torch.allclose(inputs.pixel_values, original_pixel_values.float(), atol=1e-6) + assert torch.allclose(inputs.input_ids[:4, :], original_input_ids[:4, :], atol=1e-6) + + with torch.no_grad(): + outputs = model(**inputs) + logits = outputs.logits + pred_boxes = outputs.pred_boxes + objectness_logits = outputs.objectness_logits + + if verify_logits: + if model_name == "owlv2-base-patch16": + expected_logits = torch.tensor( + [[-10.0043, -9.0226, -8.0433], [-12.4569, -14.0380, -12.6153], [-21.0731, -22.2705, -21.8850]] + ) + expected_boxes = torch.tensor( + [[0.0136, 0.0223, 0.0269], [0.0406, 0.0327, 0.0797], [0.0638, 0.1539, 0.1255]] + ) + expected_objectness_logits = torch.tensor( + [[-5.6589, -7.7702, -16.3965]], + ) + elif model_name == "owlv2-base-patch16-finetuned": + expected_logits = torch.tensor( + [[-9.2391, -9.2313, -8.0295], [-14.5498, -16.8450, -14.7166], [-15.1278, -17.3060, -15.7169]], + ) + expected_boxes = torch.tensor( + [[0.0103, 0.0094, 0.0207], [0.0483, 0.0729, 0.1013], [0.0629, 0.1396, 0.1313]] + ) + expected_objectness_logits = torch.tensor( + [[-6.5234, -13.3788, -14.6627]], + ) + elif model_name == "owlv2-base-patch16-ensemble": + expected_logits = torch.tensor( + [[-8.6353, -9.5409, -6.6154], [-7.9442, -9.6151, -6.7117], [-12.4593, -15.3332, -12.1048]] + ) + expected_boxes = torch.tensor( + [[0.0126, 0.0090, 0.0238], [0.0387, 0.0227, 0.0754], [0.0582, 0.1058, 0.1139]] + ) + expected_objectness_logits = torch.tensor( + [[-6.0628, -5.9507, -10.4486]], + ) + elif model_name == "owlv2-large-patch14": + expected_logits = torch.tensor( + [[-12.6662, -11.8384, -12.1880], [-16.0599, -16.5835, -16.9364], [-21.4957, -26.7038, -25.1313]], + ) + expected_boxes = torch.tensor( + [[0.0136, 0.0161, 0.0256], [0.0126, 0.0135, 0.0202], [0.0498, 0.0948, 0.0915]], + ) + expected_objectness_logits = torch.tensor( + [[-6.7196, -9.4590, -13.9472]], + ) + elif model_name == "owlv2-large-patch14-finetuned": + expected_logits = torch.tensor( + [[-9.5413, -9.7130, -7.9762], [-9.5731, -9.7277, -8.2252], [-15.4434, -19.3084, -16.5490]], + ) + expected_boxes = torch.tensor( + [[0.0089, 0.0080, 0.0175], [0.0112, 0.0098, 0.0179], [0.0375, 0.0821, 0.0528]], + ) + expected_objectness_logits = torch.tensor( + [[-6.2655, -6.5845, -11.3105]], + ) + elif model_name == "owlv2-large-patch14-ensemble": + expected_logits = torch.tensor( + [[-12.2037, -12.2070, -11.5371], [-13.4875, -13.8235, -13.1586], [-18.2007, -22.9834, -20.6816]], + ) + expected_boxes = torch.tensor( + [[0.0126, 0.0127, 0.0222], [0.0107, 0.0113, 0.0164], [0.0482, 0.1162, 0.0885]], + ) + expected_objectness_logits = torch.tensor( + [[-7.7572, -8.3637, -13.0334]], + ) + + print("Objectness logits:", objectness_logits[:3, :3]) + print("Logits:", logits[0, :3, :3]) + print("Pred boxes:", pred_boxes[0, :3, :3]) + + assert torch.allclose(logits[0, :3, :3], expected_logits, atol=1e-3) + assert torch.allclose(pred_boxes[0, :3, :3], expected_boxes, atol=1e-3) + assert torch.allclose(objectness_logits[:3, :3], expected_objectness_logits, atol=1e-3) + print("Looks ok!") + else: + print("Model converted without verifying logits") + + if pytorch_dump_folder_path is not None: + print("Saving model and processor locally...") + # Create folder to save model + if not os.path.isdir(pytorch_dump_folder_path): + os.mkdir(pytorch_dump_folder_path) + + model.save_pretrained(pytorch_dump_folder_path) + processor.save_pretrained(pytorch_dump_folder_path) + + if push_to_hub: + print(f"Pushing {model_name} to the hub...") + model.push_to_hub(f"google/{model_name}") + processor.push_to_hub(f"google/{model_name}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + # Required parameters + parser.add_argument( + "--model_name", + default="owlv2-base-patch16", + choices=[ + "owlv2-base-patch16", + "owlv2-base-patch16-finetuned", + "owlv2-base-patch16-ensemble", + "owlv2-large-patch14", + "owlv2-large-patch14-finetuned", + "owlv2-large-patch14-ensemble", + ], + type=str, + help="Name of the Owlv2 model you'd like to convert from FLAX to PyTorch.", + ) + parser.add_argument( + "--checkpoint_path", + default=None, + type=str, + required=True, + help="Path to the original Flax checkpoint.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", + default=None, + type=str, + required=False, + help="Path to the output PyTorch model directory.", + ) + parser.add_argument( + "--verify_logits", + action="store_false", + required=False, + help="Path to the output PyTorch model directory.", + ) + parser.add_argument("--push_to_hub", action="store_true", help="Push model and image preprocessor to the hub") + + args = parser.parse_args() + convert_owlv2_checkpoint( + args.model_name, args.checkpoint_path, args.pytorch_dump_folder_path, args.push_to_hub, args.verify_logits + ) diff --git a/docs/transformers/src/transformers/models/owlv2/image_processing_owlv2.py b/docs/transformers/src/transformers/models/owlv2/image_processing_owlv2.py new file mode 100644 index 0000000000000000000000000000000000000000..56c8075ef0ee085dd81412284d9bb50d65f2b15b --- /dev/null +++ b/docs/transformers/src/transformers/models/owlv2/image_processing_owlv2.py @@ -0,0 +1,635 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for OWLv2.""" + +import warnings +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import ( + center_to_corners_format, + pad, + to_channel_dimension_format, +) +from ...image_utils import ( + OPENAI_CLIP_MEAN, + OPENAI_CLIP_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + get_image_size, + infer_channel_dimension_format, + is_scaled_image, + make_list_of_images, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...utils import ( + TensorType, + filter_out_non_signature_kwargs, + is_scipy_available, + is_torch_available, + is_vision_available, + logging, + requires_backends, +) + + +if is_torch_available(): + import torch + + +if is_vision_available(): + import PIL + +if is_scipy_available(): + from scipy import ndimage as ndi + +if TYPE_CHECKING: + from .modeling_owlv2 import Owlv2ObjectDetectionOutput + +logger = logging.get_logger(__name__) + + +def _scale_boxes(boxes, target_sizes): + """ + Scale batch of bounding boxes to the target sizes. + + Args: + boxes (`torch.Tensor` of shape `(batch_size, num_boxes, 4)`): + Bounding boxes to scale. Each box is expected to be in (x1, y1, x2, y2) format. + target_sizes (`List[Tuple[int, int]]` or `torch.Tensor` of shape `(batch_size, 2)`): + Target sizes to scale the boxes to. Each target size is expected to be in (height, width) format. + + Returns: + `torch.Tensor` of shape `(batch_size, num_boxes, 4)`: Scaled bounding boxes. + """ + + if isinstance(target_sizes, (list, tuple)): + image_height = torch.tensor([i[0] for i in target_sizes]) + image_width = torch.tensor([i[1] for i in target_sizes]) + elif isinstance(target_sizes, torch.Tensor): + image_height, image_width = target_sizes.unbind(1) + else: + raise ValueError("`target_sizes` must be a list, tuple or torch.Tensor") + + # for owlv2 image is padded to max size unlike owlvit, thats why we have to scale boxes to max size + max_size = torch.max(image_height, image_width) + + scale_factor = torch.stack([max_size, max_size, max_size, max_size], dim=1) + scale_factor = scale_factor.unsqueeze(1).to(boxes.device) + boxes = boxes * scale_factor + return boxes + + +# Copied from transformers.models.owlvit.image_processing_owlvit._upcast +def _upcast(t): + # Protects from numerical overflows in multiplications by upcasting to the equivalent higher type + if t.is_floating_point(): + return t if t.dtype in (torch.float32, torch.float64) else t.float() + else: + return t if t.dtype in (torch.int32, torch.int64) else t.int() + + +# Copied from transformers.models.owlvit.image_processing_owlvit.box_area +def box_area(boxes): + """ + Computes the area of a set of bounding boxes, which are specified by its (x1, y1, x2, y2) coordinates. + + Args: + boxes (`torch.FloatTensor` of shape `(number_of_boxes, 4)`): + Boxes for which the area will be computed. They are expected to be in (x1, y1, x2, y2) format with `0 <= x1 + < x2` and `0 <= y1 < y2`. + Returns: + `torch.FloatTensor`: a tensor containing the area for each box. + """ + boxes = _upcast(boxes) + return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) + + +# Copied from transformers.models.owlvit.image_processing_owlvit.box_iou +def box_iou(boxes1, boxes2): + area1 = box_area(boxes1) + area2 = box_area(boxes2) + + left_top = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] + right_bottom = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] + + width_height = (right_bottom - left_top).clamp(min=0) # [N,M,2] + inter = width_height[:, :, 0] * width_height[:, :, 1] # [N,M] + + union = area1[:, None] + area2 - inter + + iou = inter / union + return iou, union + + +def _preprocess_resize_output_shape(image, output_shape): + """Validate resize output shape according to input image. + + Args: + image (`np.ndarray`): + Image to be resized. + output_shape (`iterable`): + Size of the generated output image `(rows, cols[, ...][, dim])`. If `dim` is not provided, the number of + channels is preserved. + + Returns + image (`np.ndarray`): + The input image, but with additional singleton dimensions appended in the case where `len(output_shape) > + input.ndim`. + output_shape (`Tuple`): + The output shape converted to tuple. + + Raises ------ ValueError: + If output_shape length is smaller than the image number of dimensions. + + Notes ----- The input image is reshaped if its number of dimensions is not equal to output_shape_length. + + """ + output_shape = tuple(output_shape) + output_ndim = len(output_shape) + input_shape = image.shape + if output_ndim > image.ndim: + # append dimensions to input_shape + input_shape += (1,) * (output_ndim - image.ndim) + image = np.reshape(image, input_shape) + elif output_ndim == image.ndim - 1: + # multichannel case: append shape of last axis + output_shape = output_shape + (image.shape[-1],) + elif output_ndim < image.ndim: + raise ValueError("output_shape length cannot be smaller than the image number of dimensions") + + return image, output_shape + + +def _clip_warp_output(input_image, output_image): + """Clip output image to range of values of input image. + + Note that this function modifies the values of *output_image* in-place. + + Taken from: + https://github.com/scikit-image/scikit-image/blob/b4b521d6f0a105aabeaa31699949f78453ca3511/skimage/transform/_warps.py#L640. + + Args: + input_image : ndarray + Input image. + output_image : ndarray + Output image, which is modified in-place. + """ + min_val = np.min(input_image) + if np.isnan(min_val): + # NaNs detected, use NaN-safe min/max + min_func = np.nanmin + max_func = np.nanmax + min_val = min_func(input_image) + else: + min_func = np.min + max_func = np.max + max_val = max_func(input_image) + + output_image = np.clip(output_image, min_val, max_val) + + return output_image + + +class Owlv2ImageProcessor(BaseImageProcessor): + r""" + Constructs an OWLv2 image processor. + + Args: + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overriden by `do_rescale` in + the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overriden by `rescale_factor` in the `preprocess` + method. + do_pad (`bool`, *optional*, defaults to `True`): + Whether to pad the image to a square with gray pixels on the bottom and the right. Can be overriden by + `do_pad` in the `preprocess` method. + do_resize (`bool`, *optional*, defaults to `True`): + Controls whether to resize the image's (height, width) dimensions to the specified `size`. Can be overriden + by `do_resize` in the `preprocess` method. + size (`Dict[str, int]` *optional*, defaults to `{"height": 960, "width": 960}`): + Size to resize the image to. Can be overriden by `size` in the `preprocess` method. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`): + Resampling method to use if resizing the image. Can be overriden by `resample` in the `preprocess` method. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess` + method. + image_mean (`float` or `List[float]`, *optional*, defaults to `OPENAI_CLIP_MEAN`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*, defaults to `OPENAI_CLIP_STD`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_pad: bool = True, + do_resize: bool = True, + size: Dict[str, int] = None, + resample: PILImageResampling = PILImageResampling.BILINEAR, + do_normalize: bool = True, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_pad = do_pad + self.do_resize = do_resize + self.size = size if size is not None else {"height": 960, "width": 960} + self.resample = resample + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN + self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD + + def pad( + self, + image: np.array, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + """ + Pad an image to a square with gray pixels on the bottom and the right, as per the original OWLv2 + implementation. + + Args: + image (`np.ndarray`): + Image to pad. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be the same as the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred from the input + image. + """ + height, width = get_image_size(image) + size = max(height, width) + image = pad( + image=image, + padding=((0, size - height), (0, size - width)), + constant_values=0.5, + data_format=data_format, + input_data_format=input_data_format, + ) + + return image + + def resize( + self, + image: np.ndarray, + size: Dict[str, int], + anti_aliasing: bool = True, + anti_aliasing_sigma=None, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize an image as per the original implementation. + + Args: + image (`np.ndarray`): + Image to resize. + size (`Dict[str, int]`): + Dictionary containing the height and width to resize the image to. + anti_aliasing (`bool`, *optional*, defaults to `True`): + Whether to apply anti-aliasing when downsampling the image. + anti_aliasing_sigma (`float`, *optional*, defaults to `None`): + Standard deviation for Gaussian kernel when downsampling the image. If `None`, it will be calculated + automatically. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be the same as the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred from the input + image. + """ + requires_backends(self, "scipy") + + output_shape = (size["height"], size["width"]) + image = to_channel_dimension_format(image, ChannelDimension.LAST) + image, output_shape = _preprocess_resize_output_shape(image, output_shape) + input_shape = image.shape + factors = np.divide(input_shape, output_shape) + + # Translate modes used by np.pad to those used by scipy.ndimage + ndi_mode = "mirror" + cval = 0 + order = 1 + if anti_aliasing: + if anti_aliasing_sigma is None: + anti_aliasing_sigma = np.maximum(0, (factors - 1) / 2) + else: + anti_aliasing_sigma = np.atleast_1d(anti_aliasing_sigma) * np.ones_like(factors) + if np.any(anti_aliasing_sigma < 0): + raise ValueError("Anti-aliasing standard deviation must be greater than or equal to zero") + elif np.any((anti_aliasing_sigma > 0) & (factors <= 1)): + warnings.warn( + "Anti-aliasing standard deviation greater than zero but not down-sampling along all axes" + ) + filtered = ndi.gaussian_filter(image, anti_aliasing_sigma, cval=cval, mode=ndi_mode) + else: + filtered = image + + zoom_factors = [1 / f for f in factors] + out = ndi.zoom(filtered, zoom_factors, order=order, mode=ndi_mode, cval=cval, grid_mode=True) + + image = _clip_warp_output(image, out) + + image = to_channel_dimension_format(image, input_data_format, ChannelDimension.LAST) + image = ( + to_channel_dimension_format(image, data_format, input_data_format) if data_format is not None else image + ) + return image + + @filter_out_non_signature_kwargs() + def preprocess( + self, + images: ImageInput, + do_pad: Optional[bool] = None, + do_resize: Optional[bool] = None, + size: Dict[str, int] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: ChannelDimension = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> PIL.Image.Image: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_pad (`bool`, *optional*, defaults to `self.do_pad`): + Whether to pad the image to a square with gray pixels on the bottom and the right. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Size to resize the image to. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image values between [0 - 1]. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_pad = do_pad if do_pad is not None else self.do_pad + do_resize = do_resize if do_resize is not None else self.do_resize + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + + size = size if size is not None else self.size + size = get_size_dict(size) # for BC + + images = make_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + # Here, pad and resize methods are different from the rest of image processors + # as they don't have any resampling in resize() + # or pad size in pad() (the maximum of (height, width) is taken instead). + # hence, these arguments don't need to be passed in validate_preprocess_arguments. + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + size=size, + ) + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if do_rescale and is_scaled_image(images[0]): + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + if do_rescale: + images = [ + self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + for image in images + ] + + if do_pad: + images = [self.pad(image=image, input_data_format=input_data_format) for image in images] + + if do_resize: + images = [ + self.resize( + image=image, + size=size, + input_data_format=input_data_format, + ) + for image in images + ] + + if do_normalize: + images = [ + self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) + for image in images + ] + + images = [ + to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images + ] + + data = {"pixel_values": images} + return BatchFeature(data=data, tensor_type=return_tensors) + + # Copied from transformers.models.owlvit.image_processing_owlvit.OwlViTImageProcessor.post_process_object_detection with OwlViT->Owlv2 + def post_process_object_detection( + self, + outputs: "Owlv2ObjectDetectionOutput", + threshold: float = 0.1, + target_sizes: Optional[Union[TensorType, List[Tuple]]] = None, + ): + """ + Converts the raw output of [`Owlv2ForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y, + bottom_right_x, bottom_right_y) format. + + Args: + outputs ([`Owlv2ObjectDetectionOutput`]): + Raw outputs of the model. + threshold (`float`, *optional*, defaults to 0.1): + Score threshold to keep object detection predictions. + target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*): + Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size + `(height, width)` of each image in the batch. If unset, predictions will not be resized. + + Returns: + `List[Dict]`: A list of dictionaries, each dictionary containing the following keys: + - "scores": The confidence scores for each predicted box on the image. + - "labels": Indexes of the classes predicted by the model on the image. + - "boxes": Image bounding boxes in (top_left_x, top_left_y, bottom_right_x, bottom_right_y) format. + """ + batch_logits, batch_boxes = outputs.logits, outputs.pred_boxes + batch_size = len(batch_logits) + + if target_sizes is not None and len(target_sizes) != batch_size: + raise ValueError("Make sure that you pass in as many target sizes as images") + + # batch_logits of shape (batch_size, num_queries, num_classes) + batch_class_logits = torch.max(batch_logits, dim=-1) + batch_scores = torch.sigmoid(batch_class_logits.values) + batch_labels = batch_class_logits.indices + + # Convert to [x0, y0, x1, y1] format + batch_boxes = center_to_corners_format(batch_boxes) + + # Convert from relative [0, 1] to absolute [0, height] coordinates + if target_sizes is not None: + batch_boxes = _scale_boxes(batch_boxes, target_sizes) + + results = [] + for scores, labels, boxes in zip(batch_scores, batch_labels, batch_boxes): + keep = scores > threshold + scores = scores[keep] + labels = labels[keep] + boxes = boxes[keep] + results.append({"scores": scores, "labels": labels, "boxes": boxes}) + + return results + + # Copied from transformers.models.owlvit.image_processing_owlvit.OwlViTImageProcessor.post_process_image_guided_detection + def post_process_image_guided_detection(self, outputs, threshold=0.0, nms_threshold=0.3, target_sizes=None): + """ + Converts the output of [`OwlViTForObjectDetection.image_guided_detection`] into the format expected by the COCO + api. + + Args: + outputs ([`OwlViTImageGuidedObjectDetectionOutput`]): + Raw outputs of the model. + threshold (`float`, *optional*, defaults to 0.0): + Minimum confidence threshold to use to filter out predicted boxes. + nms_threshold (`float`, *optional*, defaults to 0.3): + IoU threshold for non-maximum suppression of overlapping boxes. + target_sizes (`torch.Tensor`, *optional*): + Tensor of shape (batch_size, 2) where each entry is the (height, width) of the corresponding image in + the batch. If set, predicted normalized bounding boxes are rescaled to the target sizes. If left to + None, predictions will not be unnormalized. + + Returns: + `List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image + in the batch as predicted by the model. All labels are set to None as + `OwlViTForObjectDetection.image_guided_detection` perform one-shot object detection. + """ + logits, target_boxes = outputs.logits, outputs.target_pred_boxes + + if target_sizes is not None and len(logits) != len(target_sizes): + raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits") + if target_sizes is not None and target_sizes.shape[1] != 2: + raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch") + + probs = torch.max(logits, dim=-1) + scores = torch.sigmoid(probs.values) + + # Convert to [x0, y0, x1, y1] format + target_boxes = center_to_corners_format(target_boxes) + + # Apply non-maximum suppression (NMS) + if nms_threshold < 1.0: + for idx in range(target_boxes.shape[0]): + for i in torch.argsort(-scores[idx]): + if not scores[idx][i]: + continue + + ious = box_iou(target_boxes[idx][i, :].unsqueeze(0), target_boxes[idx])[0][0] + ious[i] = -1.0 # Mask self-IoU. + scores[idx][ious > nms_threshold] = 0.0 + + # Convert from relative [0, 1] to absolute [0, height] coordinates + if target_sizes is not None: + target_boxes = _scale_boxes(target_boxes, target_sizes) + + # Compute box display alphas based on prediction scores + results = [] + alphas = torch.zeros_like(scores) + + for idx in range(target_boxes.shape[0]): + # Select scores for boxes matching the current query: + query_scores = scores[idx] + if not query_scores.nonzero().numel(): + continue + + # Apply threshold on scores before scaling + query_scores[query_scores < threshold] = 0.0 + + # Scale box alpha such that the best box for each query has alpha 1.0 and the worst box has alpha 0.1. + # All other boxes will either belong to a different query, or will not be shown. + max_score = torch.max(query_scores) + 1e-6 + query_alphas = (query_scores - (max_score * 0.1)) / (max_score * 0.9) + query_alphas = torch.clip(query_alphas, 0.0, 1.0) + alphas[idx] = query_alphas + + mask = alphas[idx] > 0 + box_scores = alphas[idx][mask] + boxes = target_boxes[idx][mask] + results.append({"scores": box_scores, "labels": None, "boxes": boxes}) + + return results + + +__all__ = ["Owlv2ImageProcessor"] diff --git a/docs/transformers/src/transformers/models/owlv2/modeling_owlv2.py b/docs/transformers/src/transformers/models/owlv2/modeling_owlv2.py new file mode 100644 index 0000000000000000000000000000000000000000..7d83ab0004d878e7488d636c0f0d55cb9ede6f79 --- /dev/null +++ b/docs/transformers/src/transformers/models/owlv2/modeling_owlv2.py @@ -0,0 +1,1845 @@ +# coding=utf-8 +# Copyright 2023 Google AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch OWLv2 model.""" + +from dataclasses import dataclass +from functools import lru_cache +from typing import Any, Dict, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import Tensor, nn + +from ...activations import ACT2FN +from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling +from ...modeling_utils import PreTrainedModel +from ...utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_vision_available, + logging, + replace_return_docstrings, + torch_int, +) +from .configuration_owlv2 import Owlv2Config, Owlv2TextConfig, Owlv2VisionConfig + + +if is_vision_available(): + from transformers.image_transforms import center_to_corners_format + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "google/owlv2-base-patch16-ensemble" + +# See all Owlv2 models at https://huggingface.co/models?filter=owlv2 + + +# Copied from transformers.models.clip.modeling_clip.contrastive_loss with clip->owlv2 +def contrastive_loss(logits: torch.Tensor) -> torch.Tensor: + return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device)) + + +# Copied from transformers.models.clip.modeling_clip.clip_loss with clip->owlv2 +def owlv2_loss(similarity: torch.Tensor) -> torch.Tensor: + caption_loss = contrastive_loss(similarity) + image_loss = contrastive_loss(similarity.t()) + return (caption_loss + image_loss) / 2.0 + + +@dataclass +class Owlv2Output(ModelOutput): + """ + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`): + Contrastive loss for image-text similarity. + logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`): + The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text + similarity scores. + logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`): + The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image + similarity scores. + text_embeds (`torch.FloatTensor` of shape `(batch_size * num_max_text_queries, output_dim`): + The text embeddings obtained by applying the projection layer to the pooled output of [`Owlv2TextModel`]. + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`): + The image embeddings obtained by applying the projection layer to the pooled output of + [`Owlv2VisionModel`]. + text_model_output (Tuple[`BaseModelOutputWithPooling`]): + The output of the [`Owlv2TextModel`]. + vision_model_output (`BaseModelOutputWithPooling`): + The output of the [`Owlv2VisionModel`]. + """ + + loss: Optional[torch.FloatTensor] = None + logits_per_image: Optional[torch.FloatTensor] = None + logits_per_text: Optional[torch.FloatTensor] = None + text_embeds: Optional[torch.FloatTensor] = None + image_embeds: Optional[torch.FloatTensor] = None + text_model_output: BaseModelOutputWithPooling = None + vision_model_output: BaseModelOutputWithPooling = None + + def to_tuple(self) -> Tuple[Any]: + return tuple( + self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() + for k in self.keys() + ) + + +# Copied from transformers.loss.loss_for_object_detection._upcast +def _upcast(t: Tensor) -> Tensor: + # Protects from numerical overflows in multiplications by upcasting to the equivalent higher type + if t.is_floating_point(): + return t if t.dtype in (torch.float32, torch.float64) else t.float() + else: + return t if t.dtype in (torch.int32, torch.int64) else t.int() + + +# Copied from transformers.loss.loss_for_object_detection.box_area +def box_area(boxes: Tensor) -> Tensor: + """ + Computes the area of a set of bounding boxes, which are specified by its (x1, y1, x2, y2) coordinates. + + Args: + boxes (`torch.FloatTensor` of shape `(number_of_boxes, 4)`): + Boxes for which the area will be computed. They are expected to be in (x1, y1, x2, y2) format with `0 <= x1 + < x2` and `0 <= y1 < y2`. + + Returns: + `torch.FloatTensor`: a tensor containing the area for each box. + """ + boxes = _upcast(boxes) + return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) + + +# Copied from transformers.loss.loss_for_object_detection.box_iou +def box_iou(boxes1, boxes2): + area1 = box_area(boxes1) + area2 = box_area(boxes2) + + left_top = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] + right_bottom = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] + + width_height = (right_bottom - left_top).clamp(min=0) # [N,M,2] + inter = width_height[:, :, 0] * width_height[:, :, 1] # [N,M] + + union = area1[:, None] + area2 - inter + + iou = inter / union + return iou, union + + +# Copied from transformers.loss.loss_for_object_detection.generalized_box_iou +def generalized_box_iou(boxes1, boxes2): + """ + Generalized IoU from https://giou.stanford.edu/. The boxes should be in [x0, y0, x1, y1] (corner) format. + + Returns: + `torch.FloatTensor`: a [N, M] pairwise matrix, where N = len(boxes1) and M = len(boxes2) + """ + # degenerate boxes gives inf / nan results + # so do an early check + if not (boxes1[:, 2:] >= boxes1[:, :2]).all(): + raise ValueError(f"boxes1 must be in [x0, y0, x1, y1] (corner) format, but got {boxes1}") + if not (boxes2[:, 2:] >= boxes2[:, :2]).all(): + raise ValueError(f"boxes2 must be in [x0, y0, x1, y1] (corner) format, but got {boxes2}") + iou, union = box_iou(boxes1, boxes2) + + top_left = torch.min(boxes1[:, None, :2], boxes2[:, :2]) + bottom_right = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) + + width_height = (bottom_right - top_left).clamp(min=0) # [N,M,2] + area = width_height[:, :, 0] * width_height[:, :, 1] + + return iou - (area - union) / area + + +@dataclass +class Owlv2ObjectDetectionOutput(ModelOutput): + """ + Output type of [`Owlv2ForObjectDetection`]. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)): + Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a + bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized + scale-invariant IoU loss. + loss_dict (`Dict`, *optional*): + A dictionary containing the individual losses. Useful for logging. + logits (`torch.FloatTensor` of shape `(batch_size, num_patches, num_queries)`): + Classification logits (including no-object) for all queries. + objectness_logits (`torch.FloatTensor` of shape `(batch_size, num_patches, 1)`): + The objectness logits of all image patches. OWL-ViT represents images as a set of image patches where the + total number of patches is (image_size / patch_size)**2. + pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`): + Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These + values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding + possible padding). You can use [`~Owlv2ImageProcessor.post_process_object_detection`] to retrieve the + unnormalized bounding boxes. + text_embeds (`torch.FloatTensor` of shape `(batch_size, num_max_text_queries, output_dim`): + The text embeddings obtained by applying the projection layer to the pooled output of [`Owlv2TextModel`]. + image_embeds (`torch.FloatTensor` of shape `(batch_size, patch_size, patch_size, output_dim`): + Pooled output of [`Owlv2VisionModel`]. OWLv2 represents images as a set of image patches and computes image + embeddings for each patch. + class_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`): + Class embeddings of all image patches. OWLv2 represents images as a set of image patches where the total + number of patches is (image_size / patch_size)**2. + text_model_output (Tuple[`BaseModelOutputWithPooling`]): + The output of the [`Owlv2TextModel`]. + vision_model_output (`BaseModelOutputWithPooling`): + The output of the [`Owlv2VisionModel`]. + """ + + loss: Optional[torch.FloatTensor] = None + loss_dict: Optional[Dict] = None + logits: Optional[torch.FloatTensor] = None + objectness_logits: Optional[torch.FloatTensor] = None + pred_boxes: Optional[torch.FloatTensor] = None + text_embeds: Optional[torch.FloatTensor] = None + image_embeds: Optional[torch.FloatTensor] = None + class_embeds: Optional[torch.FloatTensor] = None + text_model_output: BaseModelOutputWithPooling = None + vision_model_output: BaseModelOutputWithPooling = None + + def to_tuple(self) -> Tuple[Any]: + return tuple( + self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() + for k in self.keys() + ) + + +@dataclass +# Copied from transformers.models.owlvit.modeling_owlvit.OwlViTImageGuidedObjectDetectionOutput with OwlViT->Owlv2,OWL-ViT->OWLv2 +class Owlv2ImageGuidedObjectDetectionOutput(ModelOutput): + """ + Output type of [`Owlv2ForObjectDetection.image_guided_detection`]. + + Args: + logits (`torch.FloatTensor` of shape `(batch_size, num_patches, num_queries)`): + Classification logits (including no-object) for all queries. + target_pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`): + Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These + values are normalized in [0, 1], relative to the size of each individual target image in the batch + (disregarding possible padding). You can use [`~Owlv2ImageProcessor.post_process_object_detection`] to + retrieve the unnormalized bounding boxes. + query_pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`): + Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These + values are normalized in [0, 1], relative to the size of each individual query image in the batch + (disregarding possible padding). You can use [`~Owlv2ImageProcessor.post_process_object_detection`] to + retrieve the unnormalized bounding boxes. + image_embeds (`torch.FloatTensor` of shape `(batch_size, patch_size, patch_size, output_dim`): + Pooled output of [`Owlv2VisionModel`]. OWLv2 represents images as a set of image patches and computes + image embeddings for each patch. + query_image_embeds (`torch.FloatTensor` of shape `(batch_size, patch_size, patch_size, output_dim`): + Pooled output of [`Owlv2VisionModel`]. OWLv2 represents images as a set of image patches and computes + image embeddings for each patch. + class_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`): + Class embeddings of all image patches. OWLv2 represents images as a set of image patches where the total + number of patches is (image_size / patch_size)**2. + text_model_output (Tuple[`BaseModelOutputWithPooling`]): + The output of the [`Owlv2TextModel`]. + vision_model_output (`BaseModelOutputWithPooling`): + The output of the [`Owlv2VisionModel`]. + """ + + logits: Optional[torch.FloatTensor] = None + image_embeds: Optional[torch.FloatTensor] = None + query_image_embeds: Optional[torch.FloatTensor] = None + target_pred_boxes: Optional[torch.FloatTensor] = None + query_pred_boxes: Optional[torch.FloatTensor] = None + class_embeds: Optional[torch.FloatTensor] = None + text_model_output: BaseModelOutputWithPooling = None + vision_model_output: BaseModelOutputWithPooling = None + + def to_tuple(self) -> Tuple[Any]: + return tuple( + self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() + for k in self.keys() + ) + + +# Copied from transformers.models.owlvit.modeling_owlvit.OwlViTVisionEmbeddings with OwlViT->Owlv2 +class Owlv2VisionEmbeddings(nn.Module): + def __init__(self, config: Owlv2VisionConfig): + super().__init__() + self.patch_size = config.patch_size + self.config = config + self.embed_dim = config.hidden_size + self.class_embedding = nn.Parameter(torch.randn(config.hidden_size)) + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=config.patch_size, + stride=config.patch_size, + bias=False, + ) + + self.num_patches = (config.image_size // config.patch_size) ** 2 + self.num_positions = self.num_patches + 1 + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False) + + # Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings.interpolate_pos_encoding + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing. + + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 + """ + + num_patches = embeddings.shape[1] - 1 + position_embedding = self.position_embedding.weight.unsqueeze(0) + num_positions = position_embedding.shape[1] - 1 + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: + return self.position_embedding(self.position_ids) + + class_pos_embed = position_embedding[:, :1] + patch_pos_embed = position_embedding[:, 1:] + + dim = embeddings.shape[-1] + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(new_height, new_width), + mode="bicubic", + align_corners=False, + ) + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) + + def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False) -> torch.Tensor: + batch_size, _, height, width = pixel_values.shape + patch_embeds = self.patch_embedding(pixel_values) # shape = [batch_size, num_channels, height, width] + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + + class_embeds = self.class_embedding.expand(batch_size, 1, -1) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embedding(self.position_ids) + return embeddings + + +# Copied from transformers.models.owlvit.modeling_owlvit.OwlViTTextEmbeddings with OwlViT->Owlv2 +class Owlv2TextEmbeddings(nn.Module): + def __init__(self, config: Owlv2TextConfig): + super().__init__() + self.token_embedding = nn.Embedding(config.vocab_size, config.hidden_size) + self.position_embedding = nn.Embedding(config.max_position_embeddings, config.hidden_size) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ) -> torch.Tensor: + seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + if inputs_embeds is None: + inputs_embeds = self.token_embedding(input_ids) + + position_embeddings = self.position_embedding(position_ids) + embeddings = inputs_embeds + position_embeddings + + return embeddings + + +# Copied from transformers.models.owlvit.modeling_owlvit.OwlViTAttention with OwlViT->Owlv2 +class Owlv2Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + bsz, tgt_len, embed_dim = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scale + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + # apply the causal_attention_mask first + if causal_attention_mask is not None: + if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" + f" {causal_attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if output_attentions: + # this operation is a bit akward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + # For int8 compatibility, sometimes the `attn_probs` are in `fp32` + attn_probs = attn_probs.to(value_states.dtype) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped + + +# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Owlv2 +class Owlv2MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +# Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->Owlv2 +class Owlv2EncoderLayer(nn.Module): + def __init__(self, config: Owlv2Config): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = Owlv2Attention(config) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = Owlv2MLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + causal_attention_mask: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + `(config.encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +# Copied from transformers.models.owlvit.modeling_owlvit.OwlViTPreTrainedModel with OwlViT->Owlv2,owlvit->owlv2 +class Owlv2PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = Owlv2Config + base_model_prefix = "owlv2" + supports_gradient_checkpointing = True + _no_split_modules = ["Owlv2EncoderLayer"] + + def _init_weights(self, module): + """Initialize the weights""" + factor = self.config.initializer_factor + if isinstance(module, Owlv2TextEmbeddings): + module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02) + module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02) + elif isinstance(module, Owlv2VisionEmbeddings): + factor = self.config.initializer_factor + nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor) + nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor) + nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor) + elif isinstance(module, Owlv2Attention): + factor = self.config.initializer_factor + in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor + out_proj_std = (module.embed_dim**-0.5) * factor + nn.init.normal_(module.q_proj.weight, std=in_proj_std) + nn.init.normal_(module.k_proj.weight, std=in_proj_std) + nn.init.normal_(module.v_proj.weight, std=in_proj_std) + nn.init.normal_(module.out_proj.weight, std=out_proj_std) + elif isinstance(module, Owlv2MLP): + factor = self.config.initializer_factor + in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor + fc_std = (2 * module.config.hidden_size) ** -0.5 * factor + nn.init.normal_(module.fc1.weight, std=fc_std) + nn.init.normal_(module.fc2.weight, std=in_proj_std) + elif isinstance(module, Owlv2Model): + nn.init.normal_( + module.text_projection.weight, + std=module.text_embed_dim**-0.5 * self.config.initializer_factor, + ) + nn.init.normal_( + module.visual_projection.weight, + std=module.vision_embed_dim**-0.5 * self.config.initializer_factor, + ) + if isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +OWLV2_START_DOCSTRING = r""" + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Owvl2Config`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +OWLV2_TEXT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size * num_max_text_queries, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input + IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, num_max_text_queries, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +OWLV2_VISION_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + interpolate_pos_encoding (`bool`, *optional*, defaults `False`): + Whether to interpolate the pre-trained position encodings. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +OWLV2_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input + IDs?](../glossary#input-ids) + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + return_loss (`bool`, *optional*): + Whether or not to return the contrastive loss. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + interpolate_pos_encoding (`bool`, *optional*, defaults `False`): + Whether to interpolate the pre-trained position encodings. + return_base_image_embeds (`bool`, *optional*): + Whether or not to return the base image embeddings. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +OWLV2_OBJECT_DETECTION_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. + input_ids (`torch.LongTensor` of shape `(batch_size * num_max_text_queries, sequence_length)`, *optional*): + Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input + IDs?](../glossary#input-ids). + attention_mask (`torch.Tensor` of shape `(batch_size, num_max_text_queries, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + output_hidden_states (`bool`, *optional*): + Whether or not to return the last hidden state. See `text_model_last_hidden_state` and + `vision_model_last_hidden_state` under returned tensors for more detail. + interpolate_pos_encoding (`bool`, *optional*, defaults `False`): + Whether to interpolate the pre-trained position encodings. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +OWLV2_IMAGE_GUIDED_OBJECT_DETECTION_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. + query_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values of query image(s) to be detected. Pass in one query image per target image. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + interpolate_pos_encoding (`bool`, *optional*, defaults `False`): + Whether to interpolate the pre-trained position encodings. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +# Copied from transformers.models.owlvit.modeling_owlvit.OwlViTEncoder with OwlViT->Owlv2 +class Owlv2Encoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`Owlv2EncoderLayer`]. + + Args: + config: Owlv2Config + """ + + def __init__(self, config: Owlv2Config): + super().__init__() + self.layers = nn.ModuleList([Owlv2EncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`). + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Causal mask for the text model. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for encoder_layer in self.layers: + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + attention_mask, + causal_attention_mask, + output_attentions, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + causal_attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +# Copied from transformers.models.owlvit.modeling_owlvit.OwlViTTextTransformer with OWLVIT->OWLV2,OwlViT->Owlv2 +class Owlv2TextTransformer(nn.Module): + def __init__(self, config: Owlv2TextConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + self.embeddings = Owlv2TextEmbeddings(config) + self.encoder = Owlv2Encoder(config) + self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + @add_start_docstrings_to_model_forward(OWLV2_TEXT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=Owlv2TextConfig) + def forward( + self, + input_ids: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids) + + # num_samples, seq_len = input_shape where num_samples = batch_size * num_max_text_queries + # OWLV2's text model uses causal mask, prepare it here. + # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 + causal_attention_mask = _create_4d_causal_attention_mask( + input_shape, hidden_states.dtype, device=hidden_states.device + ) + # expand attention_mask + if attention_mask is not None: + # [num_samples, seq_len] -> [num_samples, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + last_hidden_state = self.final_layer_norm(last_hidden_state) + + # take features from the end of tokens embedding (end of token is the highest number in each sequence) + # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 + pooled_output = last_hidden_state[ + torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), + input_ids.to(torch.int).argmax(dim=-1).to(last_hidden_state.device), + ] + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +# Copied from transformers.models.owlvit.modeling_owlvit.OwlViTTextModel with google/owlvit-base-patch32->google/owlv2-base-patch16, OWLVIT->OWLV2,OwlViT->Owlv2 +class Owlv2TextModel(Owlv2PreTrainedModel): + config_class = Owlv2TextConfig + + def __init__(self, config: Owlv2TextConfig): + super().__init__(config) + self.text_model = Owlv2TextTransformer(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.text_model.embeddings.token_embedding + + def set_input_embeddings(self, value): + self.text_model.embeddings.token_embedding = value + + @add_start_docstrings_to_model_forward(OWLV2_TEXT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=Owlv2TextConfig) + def forward( + self, + input_ids: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + Examples: + ```python + >>> from transformers import AutoProcessor, Owlv2TextModel + + >>> model = Owlv2TextModel.from_pretrained("google/owlv2-base-patch16") + >>> processor = AutoProcessor.from_pretrained("google/owlv2-base-patch16") + >>> inputs = processor( + ... text=[["a photo of a cat", "a photo of a dog"], ["photo of a astranaut"]], return_tensors="pt" + ... ) + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled (EOS token) states + ```""" + + # Get embeddings for all text queries in all batch samples + return self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +# Copied from transformers.models.owlvit.modeling_owlvit.OwlViTVisionTransformer with OWLVIT->OWLV2,OwlViT->Owlv2 +class Owlv2VisionTransformer(nn.Module): + def __init__(self, config: Owlv2VisionConfig): + super().__init__() + self.config = config + + self.embeddings = Owlv2VisionEmbeddings(config) + self.pre_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.encoder = Owlv2Encoder(config) + self.post_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + @add_start_docstrings_to_model_forward(OWLV2_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=Owlv2VisionConfig) + def forward( + self, + pixel_values: torch.FloatTensor, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = False, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Cast the input to the expected `dtype` + expected_input_dtype = self.embeddings.patch_embedding.weight.dtype + pixel_values = pixel_values.to(expected_input_dtype) + + hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) + hidden_states = self.pre_layernorm(hidden_states) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + pooled_output = last_hidden_state[:, 0, :] + + pooled_output = self.post_layernorm(pooled_output) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +# Copied from transformers.models.owlvit.modeling_owlvit.OwlViTVisionModel with OWLVIT->OWLV2,OwlViT->Owlv2,google/owlvit-base-patch32->google/owlv2-base-patch16 +class Owlv2VisionModel(Owlv2PreTrainedModel): + config_class = Owlv2VisionConfig + main_input_name = "pixel_values" + + def __init__(self, config: Owlv2VisionConfig): + super().__init__(config) + self.vision_model = Owlv2VisionTransformer(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.vision_model.embeddings.patch_embedding + + @add_start_docstrings_to_model_forward(OWLV2_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=Owlv2VisionConfig) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + Examples: + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Owlv2VisionModel + + >>> model = Owlv2VisionModel.from_pretrained("google/owlv2-base-patch16") + >>> processor = AutoProcessor.from_pretrained("google/owlv2-base-patch16") + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled CLS states + ```""" + return self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + ) + + +@add_start_docstrings(OWLV2_START_DOCSTRING) +# Copied from transformers.models.owlvit.modeling_owlvit.OwlViTModel with google/owlvit-base-patch32->google/owlv2-base-patch16-ensemble, OWLVIT->OWLV2,OwlViT->Owlv2,owlvit->owlv2,OWL-ViT->OWLv2 +class Owlv2Model(Owlv2PreTrainedModel): + config_class = Owlv2Config + + def __init__(self, config: Owlv2Config): + super().__init__(config) + + if not isinstance(config.text_config, Owlv2TextConfig): + raise TypeError( + "config.text_config is expected to be of type Owlv2TextConfig but is of type" + f" {type(config.text_config)}." + ) + + if not isinstance(config.vision_config, Owlv2VisionConfig): + raise TypeError( + "config.vision_config is expected to be of type Owlv2VisionConfig but is of type" + f" {type(config.vision_config)}." + ) + + text_config = config.text_config + vision_config = config.vision_config + + self.projection_dim = config.projection_dim + self.text_embed_dim = text_config.hidden_size + self.vision_embed_dim = vision_config.hidden_size + + self.text_model = Owlv2TextTransformer(text_config) + self.vision_model = Owlv2VisionTransformer(vision_config) + + self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False) + self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False) + self.logit_scale = nn.Parameter(torch.tensor(config.logit_scale_init_value)) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(OWLV2_TEXT_INPUTS_DOCSTRING) + def get_text_features( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> torch.FloatTensor: + r""" + Returns: + text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by + applying the projection layer to the pooled output of [`Owlv2TextModel`]. + + Examples: + ```python + >>> from transformers import AutoProcessor, Owlv2Model + + >>> model = Owlv2Model.from_pretrained("google/owlv2-base-patch16-ensemble") + >>> processor = AutoProcessor.from_pretrained("google/owlv2-base-patch16-ensemble") + >>> inputs = processor( + ... text=[["a photo of a cat", "a photo of a dog"], ["photo of a astranaut"]], return_tensors="pt" + ... ) + >>> text_features = model.get_text_features(**inputs) + ```""" + # Use OWLv2 model's config for some fields (if specified) instead of those of vision & text components. + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Get embeddings for all text queries in all batch samples + text_output = self.text_model(input_ids=input_ids, attention_mask=attention_mask, return_dict=return_dict) + pooled_output = text_output[1] + text_features = self.text_projection(pooled_output) + + return text_features + + @add_start_docstrings_to_model_forward(OWLV2_VISION_INPUTS_DOCSTRING) + def get_image_features( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + return_dict: Optional[bool] = None, + ) -> torch.FloatTensor: + r""" + Returns: + image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by + applying the projection layer to the pooled output of [`Owlv2VisionModel`]. + + Examples: + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Owlv2Model + + >>> model = Owlv2Model.from_pretrained("google/owlv2-base-patch16-ensemble") + >>> processor = AutoProcessor.from_pretrained("google/owlv2-base-patch16-ensemble") + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> inputs = processor(images=image, return_tensors="pt") + >>> image_features = model.get_image_features(**inputs) + ```""" + # Use OWLv2 model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + ) + + pooled_output = vision_outputs[1] + image_features = self.visual_projection(pooled_output) + + return image_features + + @add_start_docstrings_to_model_forward(OWLV2_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Owlv2Output, config_class=Owlv2Config) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + return_loss: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + return_base_image_embeds: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, Owlv2Output]: + r""" + Returns: + + Examples: + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Owlv2Model + + >>> model = Owlv2Model.from_pretrained("google/owlv2-base-patch16-ensemble") + >>> processor = AutoProcessor.from_pretrained("google/owlv2-base-patch16-ensemble") + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> inputs = processor(text=[["a photo of a cat", "a photo of a dog"]], images=image, return_tensors="pt") + >>> outputs = model(**inputs) + >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score + >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities + ```""" + # Use OWLv2 model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + ) + + # Get embeddings for all text queries in all batch samples + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + text_embeds = text_outputs[1] + text_embeds = self.text_projection(text_embeds) + image_embeds = vision_outputs[1] + image_embeds = self.visual_projection(image_embeds) + + # normalized features + image_embeds = image_embeds / torch.linalg.norm(image_embeds, ord=2, dim=-1, keepdim=True) + text_embeds_norm = text_embeds / torch.linalg.norm(text_embeds, ord=2, dim=-1, keepdim=True) + + # cosine similarity as logits and set it on the correct device + logit_scale = self.logit_scale.exp().to(image_embeds.device) + + logits_per_text = torch.matmul(text_embeds_norm, image_embeds.t()) * logit_scale + logits_per_image = logits_per_text.t() + + loss = None + if return_loss: + loss = owlv2_loss(logits_per_text) + + text_embeds = text_embeds_norm + + if not return_dict: + output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs) + return ((loss,) + output) if loss is not None else output + + return Owlv2Output( + loss=loss, + logits_per_image=logits_per_image, + logits_per_text=logits_per_text, + text_embeds=text_embeds, + image_embeds=image_embeds, + text_model_output=text_outputs, + vision_model_output=vision_outputs, + ) + + +# Copied from transformers.models.owlvit.modeling_owlvit.OwlViTBoxPredictionHead with OwlViT->Owlv2 +class Owlv2BoxPredictionHead(nn.Module): + def __init__(self, config: Owlv2Config, out_dim: int = 4): + super().__init__() + + width = config.vision_config.hidden_size + self.dense0 = nn.Linear(width, width) + self.dense1 = nn.Linear(width, width) + self.gelu = nn.GELU() + self.dense2 = nn.Linear(width, out_dim) + + def forward(self, image_features: torch.Tensor) -> torch.FloatTensor: + output = self.dense0(image_features) + output = self.gelu(output) + output = self.dense1(output) + output = self.gelu(output) + output = self.dense2(output) + return output + + +# Copied from transformers.models.owlvit.modeling_owlvit.OwlViTClassPredictionHead with OwlViT->Owlv2 +class Owlv2ClassPredictionHead(nn.Module): + def __init__(self, config: Owlv2Config): + super().__init__() + + out_dim = config.text_config.hidden_size + self.query_dim = config.vision_config.hidden_size + + self.dense0 = nn.Linear(self.query_dim, out_dim) + self.logit_shift = nn.Linear(self.query_dim, 1) + self.logit_scale = nn.Linear(self.query_dim, 1) + self.elu = nn.ELU() + + def forward( + self, + image_embeds: torch.FloatTensor, + query_embeds: Optional[torch.FloatTensor], + query_mask: Optional[torch.Tensor], + ) -> Tuple[torch.FloatTensor]: + image_class_embeds = self.dense0(image_embeds) + if query_embeds is None: + device = image_class_embeds.device + batch_size, num_patches = image_class_embeds.shape[:2] + pred_logits = torch.zeros((batch_size, num_patches, self.query_dim)).to(device) + return (pred_logits, image_class_embeds) + + # Normalize image and text features + image_class_embeds = image_class_embeds / (torch.linalg.norm(image_class_embeds, dim=-1, keepdim=True) + 1e-6) + query_embeds = query_embeds / (torch.linalg.norm(query_embeds, dim=-1, keepdim=True) + 1e-6) + + # Get class predictions + pred_logits = torch.einsum("...pd,...qd->...pq", image_class_embeds, query_embeds) + + # Apply a learnable shift and scale to logits + logit_shift = self.logit_shift(image_embeds) + logit_scale = self.logit_scale(image_embeds) + logit_scale = self.elu(logit_scale) + 1 + pred_logits = (pred_logits + logit_shift) * logit_scale + + if query_mask is not None: + if query_mask.ndim > 1: + query_mask = torch.unsqueeze(query_mask, dim=-2) + + pred_logits = torch.where(query_mask == 0, torch.finfo(pred_logits.dtype).min, pred_logits) + pred_logits = pred_logits.to(torch.float32) + + return (pred_logits, image_class_embeds) + + +class Owlv2ForObjectDetection(Owlv2PreTrainedModel): + config_class = Owlv2Config + + def __init__(self, config: Owlv2Config): + super().__init__(config) + + self.owlv2 = Owlv2Model(config) + self.class_head = Owlv2ClassPredictionHead(config) + self.box_head = Owlv2BoxPredictionHead(config) + self.objectness_head = Owlv2BoxPredictionHead(config, out_dim=1) + + self.layer_norm = nn.LayerNorm(config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps) + self.sigmoid = nn.Sigmoid() + self.config = config + self.num_patches_height = self.config.vision_config.image_size // self.config.vision_config.patch_size + self.num_patches_width = self.config.vision_config.image_size // self.config.vision_config.patch_size + self.box_bias = self.compute_box_bias(self.num_patches_height, self.num_patches_width) + + @staticmethod + # Copied from transformers.models.owlvit.modeling_owlvit.OwlViTForObjectDetection.normalize_grid_corner_coordinates + def normalize_grid_corner_coordinates(num_patches_height: int, num_patches_width: int) -> torch.Tensor: + # Create grid coordinates using torch + x_coordinates = torch.arange(1, num_patches_width + 1, dtype=torch.float32) + y_coordinates = torch.arange(1, num_patches_height + 1, dtype=torch.float32) + xx, yy = torch.meshgrid(x_coordinates, y_coordinates, indexing="xy") + + # Stack the coordinates and divide by their respective patch counts + box_coordinates = torch.stack((xx, yy), dim=-1) + box_coordinates[..., 0] /= num_patches_width + box_coordinates[..., 1] /= num_patches_height + + # Flatten (h, w, 2) -> (h*w, 2) + box_coordinates = box_coordinates.view(-1, 2) + + return box_coordinates + + def objectness_predictor(self, image_features: torch.FloatTensor) -> torch.FloatTensor: + """Predicts the probability that each image feature token is an object. + + Args: + image_features (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_dim)`)): + Features extracted from the image. + Returns: + Objectness scores. + """ + image_features = image_features.detach() + objectness_logits = self.objectness_head(image_features) + objectness_logits = objectness_logits[..., 0] + return objectness_logits + + @lru_cache(maxsize=2) + # Copied from transformers.models.owlvit.modeling_owlvit.OwlViTForObjectDetection.compute_box_bias + def compute_box_bias( + self, num_patches_height: int, num_patches_width: int, feature_map: Optional[torch.FloatTensor] = None + ) -> torch.Tensor: + if feature_map is not None: + raise ValueError("feature_map has been deprecated as an input. Please pass in num_patches instead") + # The box center is biased to its position on the feature grid + box_coordinates = self.normalize_grid_corner_coordinates(num_patches_height, num_patches_width) + box_coordinates = torch.clip(box_coordinates, 0.0, 1.0) + + # Unnormalize xy + box_coord_bias = torch.log(box_coordinates + 1e-4) - torch.log1p(-box_coordinates + 1e-4) + + # The box size is biased to the patch size + box_size = torch.full_like(box_coord_bias, 1.0) + box_size[..., 0] /= num_patches_width + box_size[..., 1] /= num_patches_height + box_size_bias = torch.log(box_size + 1e-4) - torch.log1p(-box_size + 1e-4) + + # Compute box bias + box_bias = torch.cat([box_coord_bias, box_size_bias], dim=-1) + return box_bias + + # Copied from transformers.models.owlvit.modeling_owlvit.OwlViTForObjectDetection.box_predictor + def box_predictor( + self, + image_feats: torch.FloatTensor, + feature_map: torch.FloatTensor, + interpolate_pos_encoding: bool = False, + ) -> torch.FloatTensor: + """ + Args: + image_feats: + Features extracted from the image, returned by the `image_text_embedder` method. + feature_map: + A spatial re-arrangement of image_features, also returned by the `image_text_embedder` method. + interpolate_pos_encoding: + Whether to interpolate the pre-trained position encodings. + Returns: + pred_boxes: + List of predicted boxes (cxcywh normalized to 0, 1) nested within a dictionary. + """ + # Bounding box detection head [batch_size, num_boxes, 4]. + pred_boxes = self.box_head(image_feats) + + # Compute the location of each token on the grid and use it to compute a bias for the bbox prediction + if interpolate_pos_encoding: + _, num_patches_height, num_patches_width, _ = feature_map.shape + box_bias = self.compute_box_bias(num_patches_height, num_patches_width) + else: + box_bias = self.box_bias + + box_bias = box_bias.to(feature_map.device) + pred_boxes += box_bias + pred_boxes = self.sigmoid(pred_boxes) + return pred_boxes + + # Copied from transformers.models.owlvit.modeling_owlvit.OwlViTForObjectDetection.class_predictor + def class_predictor( + self, + image_feats: torch.FloatTensor, + query_embeds: Optional[torch.FloatTensor] = None, + query_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.FloatTensor]: + """ + Args: + image_feats: + Features extracted from the `image_text_embedder`. + query_embeds: + Text query embeddings. + query_mask: + Must be provided with query_embeddings. A mask indicating which query embeddings are valid. + """ + (pred_logits, image_class_embeds) = self.class_head(image_feats, query_embeds, query_mask) + + return (pred_logits, image_class_embeds) + + # Copied from transformers.models.owlvit.modeling_owlvit.OwlViTForObjectDetection.image_text_embedder with owlvit->owlv2 + def image_text_embedder( + self, + input_ids: torch.Tensor, + pixel_values: torch.FloatTensor, + attention_mask: torch.Tensor, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + ) -> Tuple[torch.FloatTensor]: + # Encode text and image + outputs = self.owlv2( + pixel_values=pixel_values, + input_ids=input_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=True, + ) + + if interpolate_pos_encoding: + _, _, height, width = pixel_values.shape + num_patches_height = height // self.config.vision_config.patch_size + num_patches_width = width // self.config.vision_config.patch_size + else: + num_patches_height = self.num_patches_height + num_patches_width = self.num_patches_width + + # Get image embeddings + last_hidden_state = outputs.vision_model_output[0] + image_embeds = self.owlv2.vision_model.post_layernorm(last_hidden_state) + + # Resize class token + class_token_out = torch.broadcast_to(image_embeds[:, :1, :], image_embeds[:, :-1].shape) + + # Merge image embedding with class tokens + image_embeds = image_embeds[:, 1:, :] * class_token_out + image_embeds = self.layer_norm(image_embeds) + + # Resize to [batch_size, num_patches_height, num_patches_width, hidden_size] + new_size = ( + image_embeds.shape[0], + num_patches_height, + num_patches_width, + image_embeds.shape[-1], + ) + image_embeds = image_embeds.reshape(new_size) + text_embeds = outputs[-4] + + return (text_embeds, image_embeds, outputs) + + # Copied from transformers.models.owlvit.modeling_owlvit.OwlViTForObjectDetection.image_embedder with owlvit->owlv2, OwlViTModel->Owlv2Model + def image_embedder( + self, + pixel_values: torch.FloatTensor, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + ) -> Tuple[torch.FloatTensor]: + # Get Owlv2Model vision embeddings (same as CLIP) + vision_outputs = self.owlv2.vision_model( + pixel_values=pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, return_dict=True + ) + + if interpolate_pos_encoding: + _, _, height, width = pixel_values.shape + num_patches_height = height // self.config.vision_config.patch_size + num_patches_width = width // self.config.vision_config.patch_size + else: + num_patches_height = self.num_patches_height + num_patches_width = self.num_patches_width + + # Apply post_layernorm to last_hidden_state, return non-projected output + last_hidden_state = vision_outputs[0] + image_embeds = self.owlv2.vision_model.post_layernorm(last_hidden_state) + + # Resize class token + class_token_out = torch.broadcast_to(image_embeds[:, :1, :], image_embeds[:, :-1].shape) + + # Merge image embedding with class tokens + image_embeds = image_embeds[:, 1:, :] * class_token_out + image_embeds = self.layer_norm(image_embeds) + + # Resize to [batch_size, num_patches_height, num_patches_width, hidden_size] + new_size = ( + image_embeds.shape[0], + num_patches_height, + num_patches_width, + image_embeds.shape[-1], + ) + image_embeds = image_embeds.reshape(new_size) + + return (image_embeds, vision_outputs) + + # Copied from transformers.models.owlvit.modeling_owlvit.OwlViTForObjectDetection.embed_image_query + def embed_image_query( + self, + query_image_features: torch.FloatTensor, + query_feature_map: torch.FloatTensor, + interpolate_pos_encoding: bool = False, + ) -> torch.FloatTensor: + _, class_embeds = self.class_predictor(query_image_features) + pred_boxes = self.box_predictor(query_image_features, query_feature_map, interpolate_pos_encoding) + pred_boxes_as_corners = center_to_corners_format(pred_boxes) + + # Loop over query images + best_class_embeds = [] + best_box_indices = [] + pred_boxes_device = pred_boxes_as_corners.device + + for i in range(query_image_features.shape[0]): + each_query_box = torch.tensor([[0, 0, 1, 1]], device=pred_boxes_device) + each_query_pred_boxes = pred_boxes_as_corners[i] + ious, _ = box_iou(each_query_box, each_query_pred_boxes) + + # If there are no overlapping boxes, fall back to generalized IoU + if torch.all(ious[0] == 0.0): + ious = generalized_box_iou(each_query_box, each_query_pred_boxes) + + # Use an adaptive threshold to include all boxes within 80% of the best IoU + iou_threshold = torch.max(ious) * 0.8 + + selected_inds = (ious[0] >= iou_threshold).nonzero() + if selected_inds.numel(): + selected_embeddings = class_embeds[i][selected_inds.squeeze(1)] + mean_embeds = torch.mean(class_embeds[i], axis=0) + mean_sim = torch.einsum("d,id->i", mean_embeds, selected_embeddings) + best_box_ind = selected_inds[torch.argmin(mean_sim)] + best_class_embeds.append(class_embeds[i][best_box_ind]) + best_box_indices.append(best_box_ind) + + if best_class_embeds: + query_embeds = torch.stack(best_class_embeds) + box_indices = torch.stack(best_box_indices) + else: + query_embeds, box_indices = None, None + + return query_embeds, box_indices, pred_boxes + + @add_start_docstrings_to_model_forward(OWLV2_IMAGE_GUIDED_OBJECT_DETECTION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Owlv2ImageGuidedObjectDetectionOutput, config_class=Owlv2Config) + def image_guided_detection( + self, + pixel_values: torch.FloatTensor, + query_pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + return_dict: Optional[bool] = None, + ) -> Owlv2ImageGuidedObjectDetectionOutput: + r""" + Returns: + + Examples: + ```python + >>> import requests + >>> from PIL import Image + >>> import torch + >>> from transformers import AutoProcessor, Owlv2ForObjectDetection + + >>> processor = AutoProcessor.from_pretrained("google/owlv2-base-patch16-ensemble") + >>> model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> query_url = "http://images.cocodataset.org/val2017/000000001675.jpg" + >>> query_image = Image.open(requests.get(query_url, stream=True).raw) + >>> inputs = processor(images=image, query_images=query_image, return_tensors="pt") + + >>> # forward pass + >>> with torch.no_grad(): + ... outputs = model.image_guided_detection(**inputs) + + >>> target_sizes = torch.Tensor([image.size[::-1]]) + + >>> # Convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax) + >>> results = processor.post_process_image_guided_detection( + ... outputs=outputs, threshold=0.9, nms_threshold=0.3, target_sizes=target_sizes + ... ) + >>> i = 0 # Retrieve predictions for the first image + >>> boxes, scores = results[i]["boxes"], results[i]["scores"] + >>> for box, score in zip(boxes, scores): + ... box = [round(i, 2) for i in box.tolist()] + ... print(f"Detected similar object with confidence {round(score.item(), 3)} at location {box}") + Detected similar object with confidence 0.938 at location [327.31, 54.94, 547.39, 268.06] + Detected similar object with confidence 0.959 at location [5.78, 360.65, 619.12, 366.39] + Detected similar object with confidence 0.902 at location [2.85, 360.01, 627.63, 380.8] + Detected similar object with confidence 0.985 at location [176.98, -29.45, 672.69, 182.83] + Detected similar object with confidence 1.0 at location [6.53, 14.35, 624.87, 470.82] + Detected similar object with confidence 0.998 at location [579.98, 29.14, 615.49, 489.05] + Detected similar object with confidence 0.985 at location [206.15, 10.53, 247.74, 466.01] + Detected similar object with confidence 0.947 at location [18.62, 429.72, 646.5, 457.72] + Detected similar object with confidence 0.996 at location [523.88, 20.69, 586.84, 483.18] + Detected similar object with confidence 0.998 at location [3.39, 360.59, 617.29, 499.21] + Detected similar object with confidence 0.969 at location [4.47, 449.05, 614.5, 474.76] + Detected similar object with confidence 0.966 at location [31.44, 463.65, 654.66, 471.07] + Detected similar object with confidence 0.924 at location [30.93, 468.07, 635.35, 475.39] + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + # Compute feature maps for the input and query images + query_feature_map = self.image_embedder( + pixel_values=query_pixel_values, interpolate_pos_encoding=interpolate_pos_encoding + )[0] + feature_map, vision_outputs = self.image_embedder( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + ) + + batch_size, num_patches_height, num_patches_width, hidden_dim = feature_map.shape + image_feats = torch.reshape(feature_map, (batch_size, num_patches_height * num_patches_width, hidden_dim)) + + batch_size, num_patches_height, num_patches_width, hidden_dim = query_feature_map.shape + query_image_feats = torch.reshape( + query_feature_map, (batch_size, num_patches_height * num_patches_width, hidden_dim) + ) + # Get top class embedding and best box index for each query image in batch + query_embeds, best_box_indices, query_pred_boxes = self.embed_image_query( + query_image_feats, query_feature_map, interpolate_pos_encoding + ) + + # Predict object classes [batch_size, num_patches, num_queries+1] + (pred_logits, class_embeds) = self.class_predictor(image_feats=image_feats, query_embeds=query_embeds) + + # Predict object boxes + target_pred_boxes = self.box_predictor(image_feats, feature_map, interpolate_pos_encoding) + + if not return_dict: + output = ( + feature_map, + query_feature_map, + target_pred_boxes, + query_pred_boxes, + pred_logits, + class_embeds, + vision_outputs.to_tuple(), + ) + output = tuple(x for x in output if x is not None) + return output + + return Owlv2ImageGuidedObjectDetectionOutput( + image_embeds=feature_map, + query_image_embeds=query_feature_map, + target_pred_boxes=target_pred_boxes, + query_pred_boxes=query_pred_boxes, + logits=pred_logits, + class_embeds=class_embeds, + text_model_output=None, + vision_model_output=vision_outputs, + ) + + @add_start_docstrings_to_model_forward(OWLV2_OBJECT_DETECTION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Owlv2ObjectDetectionOutput, config_class=Owlv2Config) + def forward( + self, + input_ids: torch.Tensor, + pixel_values: torch.FloatTensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + return_dict: Optional[bool] = None, + ) -> Owlv2ObjectDetectionOutput: + r""" + Returns: + + Examples: + ```python + >>> import requests + >>> from PIL import Image + >>> import torch + + >>> from transformers import Owlv2Processor, Owlv2ForObjectDetection + + >>> processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble") + >>> model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> text_labels = [["a photo of a cat", "a photo of a dog"]] + >>> inputs = processor(text=text_labels, images=image, return_tensors="pt") + >>> outputs = model(**inputs) + + >>> # Target image sizes (height, width) to rescale box predictions [batch_size, 2] + >>> target_sizes = torch.tensor([(image.height, image.width)]) + >>> # Convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax) + >>> results = processor.post_process_grounded_object_detection( + ... outputs=outputs, target_sizes=target_sizes, threshold=0.1, text_labels=text_labels + ... ) + >>> # Retrieve predictions for the first image for the corresponding text queries + >>> result = results[0] + >>> boxes, scores, text_labels = result["boxes"], result["scores"], result["text_labels"] + >>> for box, score, text_label in zip(boxes, scores, text_labels): + ... box = [round(i, 2) for i in box.tolist()] + ... print(f"Detected {text_label} with confidence {round(score.item(), 3)} at location {box}") + Detected a photo of a cat with confidence 0.614 at location [341.67, 23.39, 642.32, 371.35] + Detected a photo of a cat with confidence 0.665 at location [6.75, 51.96, 326.62, 473.13] + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + # Embed images and text queries + query_embeds, feature_map, outputs = self.image_text_embedder( + input_ids=input_ids, + pixel_values=pixel_values, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + ) + + # Text and vision model outputs + text_outputs = outputs.text_model_output + vision_outputs = outputs.vision_model_output + + batch_size, num_patches_height, num_patches_width, hidden_dim = feature_map.shape + image_feats = torch.reshape(feature_map, (batch_size, num_patches_height * num_patches_width, hidden_dim)) + + # Reshape from [batch_size * max_text_queries, hidden_dim] -> [batch_size, max_text_queries, hidden_dim] + max_text_queries = input_ids.shape[0] // batch_size + query_embeds = query_embeds.reshape(batch_size, max_text_queries, query_embeds.shape[-1]) + + # If first token is 0, then this is a padded query [batch_size, num_queries]. + input_ids = input_ids.reshape(batch_size, max_text_queries, input_ids.shape[-1]) + query_mask = input_ids[..., 0] > 0 + + # Predict object classes [batch_size, num_patches, num_queries+1] + (pred_logits, class_embeds) = self.class_predictor(image_feats, query_embeds, query_mask) + + # Predict objectness + objectness_logits = self.objectness_predictor(image_feats) + + # Predict object boxes + pred_boxes = self.box_predictor(image_feats, feature_map, interpolate_pos_encoding) + + if not return_dict: + output = ( + pred_logits, + objectness_logits, + pred_boxes, + query_embeds, + feature_map, + class_embeds, + text_outputs.to_tuple(), + vision_outputs.to_tuple(), + ) + output = tuple(x for x in output if x is not None) + return output + + return Owlv2ObjectDetectionOutput( + image_embeds=feature_map, + text_embeds=query_embeds, + pred_boxes=pred_boxes, + logits=pred_logits, + objectness_logits=objectness_logits, + class_embeds=class_embeds, + text_model_output=text_outputs, + vision_model_output=vision_outputs, + ) + + +__all__ = ["Owlv2Model", "Owlv2PreTrainedModel", "Owlv2TextModel", "Owlv2VisionModel", "Owlv2ForObjectDetection"] diff --git a/docs/transformers/src/transformers/models/owlv2/processing_owlv2.py b/docs/transformers/src/transformers/models/owlv2/processing_owlv2.py new file mode 100644 index 0000000000000000000000000000000000000000..4996cae7ab9694e2114cdc046ca95837a870158f --- /dev/null +++ b/docs/transformers/src/transformers/models/owlv2/processing_owlv2.py @@ -0,0 +1,320 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Image/Text processor class for OWLv2 +""" + +import warnings +from typing import TYPE_CHECKING, List, Optional, Tuple, Union + +import numpy as np + +from ...image_processing_utils import BatchFeature +from ...image_utils import ImageInput +from ...processing_utils import ( + ImagesKwargs, + ProcessingKwargs, + ProcessorMixin, + Unpack, + _validate_images_text_input_order, +) +from ...tokenization_utils_base import PreTokenizedInput, TextInput +from ...utils import TensorType, is_flax_available, is_tf_available, is_torch_available + + +if TYPE_CHECKING: + from .modeling_owlv2 import Owlv2ImageGuidedObjectDetectionOutput, Owlv2ObjectDetectionOutput + + +class Owlv2ImagesKwargs(ImagesKwargs, total=False): + query_images: Optional[ImageInput] + + +class Owlv2ProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs: Owlv2ImagesKwargs + _defaults = { + "text_kwargs": { + "padding": "max_length", + }, + "images_kwargs": {}, + "common_kwargs": { + "return_tensors": "np", + }, + } + + +class Owlv2Processor(ProcessorMixin): + r""" + Constructs an Owlv2 processor which wraps [`Owlv2ImageProcessor`] and [`CLIPTokenizer`]/[`CLIPTokenizerFast`] into + a single processor that interits both the image processor and tokenizer functionalities. See the + [`~OwlViTProcessor.__call__`] and [`~OwlViTProcessor.decode`] for more information. + + Args: + image_processor ([`Owlv2ImageProcessor`]): + The image processor is a required input. + tokenizer ([`CLIPTokenizer`, `CLIPTokenizerFast`]): + The tokenizer is a required input. + """ + + attributes = ["image_processor", "tokenizer"] + image_processor_class = "Owlv2ImageProcessor" + tokenizer_class = ("CLIPTokenizer", "CLIPTokenizerFast") + # For backward compatibility. See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details. + optional_call_args = ["query_images"] + + def __init__(self, image_processor, tokenizer, **kwargs): + super().__init__(image_processor, tokenizer) + + # Copied from transformers.models.owlvit.processing_owlvit.OwlViTProcessor.__call__ with OwlViT->Owlv2 + def __call__( + self, + images: Optional[ImageInput] = None, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + # The following is to capture `query_images` argument that may be passed as a positional argument. + # See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details, + # or this conversation for more context: https://github.com/huggingface/transformers/pull/32544#discussion_r1720208116 + # This behavior is only needed for backward compatibility and will be removed in future versions. + # + *args, + audio=None, + videos=None, + **kwargs: Unpack[Owlv2ProcessorKwargs], + ) -> BatchFeature: + """ + Main method to prepare for the model one or several text(s) and image(s). This method forwards the `text` and + `kwargs` arguments to CLIPTokenizerFast's [`~CLIPTokenizerFast.__call__`] if `text` is not `None` to encode: + the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to + CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the docstring + of the above two methods for more information. + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, + `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + query_images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The query image to be prepared, one query image is expected per target image to be queried. Each image + can be a PIL image, NumPy array or PyTorch tensor. In case of a NumPy array/PyTorch tensor, each image + should be of shape (C, H, W), where C is a number of channels, H and W are image height and width. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + - **query_pixel_values** -- Pixel values of the query images to be fed to a model. Returned when `query_images` is not `None`. + """ + output_kwargs = self._merge_kwargs( + Owlv2ProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + **self.prepare_and_validate_optional_call_args(*args), + ) + query_images = output_kwargs["images_kwargs"].pop("query_images", None) + return_tensors = output_kwargs["common_kwargs"]["return_tensors"] + + if text is None and query_images is None and images is None: + raise ValueError( + "You have to specify at least one text or query image or image. All three cannot be none." + ) + # check if images and text inputs are reversed for BC + images, text = _validate_images_text_input_order(images, text) + + data = {} + if text is not None: + if isinstance(text, str) or (isinstance(text, List) and not isinstance(text[0], List)): + encodings = [self.tokenizer(text, **output_kwargs["text_kwargs"])] + + elif isinstance(text, List) and isinstance(text[0], List): + encodings = [] + + # Maximum number of queries across batch + max_num_queries = max([len(text_single) for text_single in text]) + + # Pad all batch samples to max number of text queries + for text_single in text: + if len(text_single) != max_num_queries: + text_single = text_single + [" "] * (max_num_queries - len(text_single)) + + encoding = self.tokenizer(text_single, **output_kwargs["text_kwargs"]) + encodings.append(encoding) + else: + raise TypeError("Input text should be a string, a list of strings or a nested list of strings") + + if return_tensors == "np": + input_ids = np.concatenate([encoding["input_ids"] for encoding in encodings], axis=0) + attention_mask = np.concatenate([encoding["attention_mask"] for encoding in encodings], axis=0) + + elif return_tensors == "jax" and is_flax_available(): + import jax.numpy as jnp + + input_ids = jnp.concatenate([encoding["input_ids"] for encoding in encodings], axis=0) + attention_mask = jnp.concatenate([encoding["attention_mask"] for encoding in encodings], axis=0) + + elif return_tensors == "pt" and is_torch_available(): + import torch + + input_ids = torch.cat([encoding["input_ids"] for encoding in encodings], dim=0) + attention_mask = torch.cat([encoding["attention_mask"] for encoding in encodings], dim=0) + + elif return_tensors == "tf" and is_tf_available(): + import tensorflow as tf + + input_ids = tf.stack([encoding["input_ids"] for encoding in encodings], axis=0) + attention_mask = tf.stack([encoding["attention_mask"] for encoding in encodings], axis=0) + + else: + raise ValueError("Target return tensor type could not be returned") + + data["input_ids"] = input_ids + data["attention_mask"] = attention_mask + + if query_images is not None: + query_pixel_values = self.image_processor(query_images, **output_kwargs["images_kwargs"]).pixel_values + # Query images always override the text prompt + data = {"query_pixel_values": query_pixel_values} + + if images is not None: + image_features = self.image_processor(images, **output_kwargs["images_kwargs"]) + data["pixel_values"] = image_features.pixel_values + + return BatchFeature(data=data, tensor_type=return_tensors) + + # Copied from transformers.models.owlvit.processing_owlvit.OwlViTProcessor.post_process_object_detection with OwlViT->Owlv2 + def post_process_object_detection(self, *args, **kwargs): + """ + This method forwards all its arguments to [`Owlv2ImageProcessor.post_process_object_detection`]. Please refer + to the docstring of this method for more information. + """ + warnings.warn( + "`post_process_object_detection` method is deprecated for OwlVitProcessor and will be removed in v5. " + "Use `post_process_grounded_object_detection` instead.", + FutureWarning, + ) + return self.image_processor.post_process_object_detection(*args, **kwargs) + + # Copied from transformers.models.owlvit.processing_owlvit.OwlViTProcessor.post_process_grounded_object_detection with OwlViT->Owlv2 + def post_process_grounded_object_detection( + self, + outputs: "Owlv2ObjectDetectionOutput", + threshold: float = 0.1, + target_sizes: Optional[Union[TensorType, List[Tuple]]] = None, + text_labels: Optional[List[List[str]]] = None, + ): + """ + Converts the raw output of [`Owlv2ForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y, + bottom_right_x, bottom_right_y) format. + + Args: + outputs ([`Owlv2ObjectDetectionOutput`]): + Raw outputs of the model. + threshold (`float`, *optional*, defaults to 0.1): + Score threshold to keep object detection predictions. + target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*): + Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size + `(height, width)` of each image in the batch. If unset, predictions will not be resized. + text_labels (`List[List[str]]`, *optional*): + List of lists of text labels for each image in the batch. If unset, "text_labels" in output will be + set to `None`. + + Returns: + `List[Dict]`: A list of dictionaries, each dictionary containing the following keys: + - "scores": The confidence scores for each predicted box on the image. + - "labels": Indexes of the classes predicted by the model on the image. + - "boxes": Image bounding boxes in (top_left_x, top_left_y, bottom_right_x, bottom_right_y) format. + - "text_labels": The text labels for each predicted bounding box on the image. + """ + output = self.image_processor.post_process_object_detection( + outputs=outputs, threshold=threshold, target_sizes=target_sizes + ) + + if text_labels is not None and len(text_labels) != len(output): + raise ValueError("Make sure that you pass in as many lists of text labels as images") + + # adding text labels to the output + if text_labels is not None: + for image_output, image_text_labels in zip(output, text_labels): + object_text_labels = [image_text_labels[i] for i in image_output["labels"]] + image_output["text_labels"] = object_text_labels + else: + for image_output in output: + image_output["text_labels"] = None + + return output + + # Copied from transformers.models.owlvit.processing_owlvit.OwlViTProcessor.post_process_image_guided_detection with OwlViT->Owlv2 + def post_process_image_guided_detection( + self, + outputs: "Owlv2ImageGuidedObjectDetectionOutput", + threshold: float = 0.0, + nms_threshold: float = 0.3, + target_sizes: Optional[Union[TensorType, List[Tuple]]] = None, + ): + """ + Converts the output of [`Owlv2ForObjectDetection.image_guided_detection`] into the format expected by the COCO + api. + + Args: + outputs ([`Owlv2ImageGuidedObjectDetectionOutput`]): + Raw outputs of the model. + threshold (`float`, *optional*, defaults to 0.0): + Minimum confidence threshold to use to filter out predicted boxes. + nms_threshold (`float`, *optional*, defaults to 0.3): + IoU threshold for non-maximum suppression of overlapping boxes. + target_sizes (`torch.Tensor`, *optional*): + Tensor of shape (batch_size, 2) where each entry is the (height, width) of the corresponding image in + the batch. If set, predicted normalized bounding boxes are rescaled to the target sizes. If left to + None, predictions will not be unnormalized. + + Returns: + `List[Dict]`: A list of dictionaries, each dictionary containing the following keys: + - "scores": The confidence scores for each predicted box on the image. + - "boxes": Image bounding boxes in (top_left_x, top_left_y, bottom_right_x, bottom_right_y) format. + - "labels": Set to `None`. + """ + return self.image_processor.post_process_image_guided_detection( + outputs=outputs, threshold=threshold, nms_threshold=nms_threshold, target_sizes=target_sizes + ) + + # Copied from transformers.models.owlvit.processing_owlvit.OwlViTProcessor.batch_decode + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to CLIPTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + # Copied from transformers.models.owlvit.processing_owlvit.OwlViTProcessor.decode + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to CLIPTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + +__all__ = ["Owlv2Processor"] diff --git a/docs/transformers/src/transformers/models/owlvit/__init__.py b/docs/transformers/src/transformers/models/owlvit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b3fd80c2bcc6b268cd47a7afe321364ce3ed1aa5 --- /dev/null +++ b/docs/transformers/src/transformers/models/owlvit/__init__.py @@ -0,0 +1,31 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_owlvit import * + from .feature_extraction_owlvit import * + from .image_processing_owlvit import * + from .image_processing_owlvit_fast import * + from .modeling_owlvit import * + from .processing_owlvit import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/owlvit/configuration_owlvit.py b/docs/transformers/src/transformers/models/owlvit/configuration_owlvit.py new file mode 100644 index 0000000000000000000000000000000000000000..d80ac73d57d4431189c2444f9fb253a984d276b1 --- /dev/null +++ b/docs/transformers/src/transformers/models/owlvit/configuration_owlvit.py @@ -0,0 +1,335 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""OWL-ViT model configuration""" + +from collections import OrderedDict +from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional + + +if TYPE_CHECKING: + from ...processing_utils import ProcessorMixin + from ...utils import TensorType + +from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class OwlViTTextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of an [`OwlViTTextModel`]. It is used to instantiate an + OwlViT text encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the OwlViT + [google/owlvit-base-patch32](https://huggingface.co/google/owlvit-base-patch32) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 49408): + Vocabulary size of the OWL-ViT text model. Defines the number of different tokens that can be represented + by the `inputs_ids` passed when calling [`OwlViTTextModel`]. + hidden_size (`int`, *optional*, defaults to 512): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 2048): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer encoder. + max_position_embeddings (`int`, *optional*, defaults to 16): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + initializer_factor (`float`, *optional*, defaults to 1.0): + A factor for initializing all weight matrices (should be kept to 1, used internally for initialization + testing). + pad_token_id (`int`, *optional*, defaults to 0): + The id of the padding token in the input sequences. + bos_token_id (`int`, *optional*, defaults to 49406): + The id of the beginning-of-sequence token in the input sequences. + eos_token_id (`int`, *optional*, defaults to 49407): + The id of the end-of-sequence token in the input sequences. + + Example: + + ```python + >>> from transformers import OwlViTTextConfig, OwlViTTextModel + + >>> # Initializing a OwlViTTextModel with google/owlvit-base-patch32 style configuration + >>> configuration = OwlViTTextConfig() + + >>> # Initializing a OwlViTTextConfig from the google/owlvit-base-patch32 style configuration + >>> model = OwlViTTextModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "owlvit_text_model" + base_config_key = "text_config" + + def __init__( + self, + vocab_size=49408, + hidden_size=512, + intermediate_size=2048, + num_hidden_layers=12, + num_attention_heads=8, + max_position_embeddings=16, + hidden_act="quick_gelu", + layer_norm_eps=1e-5, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + pad_token_id=0, + bos_token_id=49406, + eos_token_id=49407, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.max_position_embeddings = max_position_embeddings + self.hidden_act = hidden_act + self.layer_norm_eps = layer_norm_eps + self.attention_dropout = attention_dropout + self.initializer_range = initializer_range + self.initializer_factor = initializer_factor + + +class OwlViTVisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of an [`OwlViTVisionModel`]. It is used to instantiate + an OWL-ViT image encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the OWL-ViT + [google/owlvit-base-patch32](https://huggingface.co/google/owlvit-base-patch32) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + num_channels (`int`, *optional*, defaults to 3): + Number of channels in the input images. + image_size (`int`, *optional*, defaults to 768): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 32): + The size (resolution) of each patch. + hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + initializer_factor (`float`, *optional*, defaults to 1.0): + A factor for initializing all weight matrices (should be kept to 1, used internally for initialization + testing). + + Example: + + ```python + >>> from transformers import OwlViTVisionConfig, OwlViTVisionModel + + >>> # Initializing a OwlViTVisionModel with google/owlvit-base-patch32 style configuration + >>> configuration = OwlViTVisionConfig() + + >>> # Initializing a OwlViTVisionModel model from the google/owlvit-base-patch32 style configuration + >>> model = OwlViTVisionModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "owlvit_vision_model" + base_config_key = "vision_config" + + def __init__( + self, + hidden_size=768, + intermediate_size=3072, + num_hidden_layers=12, + num_attention_heads=12, + num_channels=3, + image_size=768, + patch_size=32, + hidden_act="quick_gelu", + layer_norm_eps=1e-5, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.image_size = image_size + self.patch_size = patch_size + self.hidden_act = hidden_act + self.layer_norm_eps = layer_norm_eps + self.attention_dropout = attention_dropout + self.initializer_range = initializer_range + self.initializer_factor = initializer_factor + + +class OwlViTConfig(PretrainedConfig): + r""" + [`OwlViTConfig`] is the configuration class to store the configuration of an [`OwlViTModel`]. It is used to + instantiate an OWL-ViT model according to the specified arguments, defining the text model and vision model + configs. Instantiating a configuration with the defaults will yield a similar configuration to that of the OWL-ViT + [google/owlvit-base-patch32](https://huggingface.co/google/owlvit-base-patch32) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + text_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`OwlViTTextConfig`]. + vision_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`OwlViTVisionConfig`]. + projection_dim (`int`, *optional*, defaults to 512): + Dimensionality of text and vision projection layers. + logit_scale_init_value (`float`, *optional*, defaults to 2.6592): + The initial value of the *logit_scale* parameter. Default is used as per the original OWL-ViT + implementation. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not the model should return a dictionary. If `False`, returns a tuple. + kwargs (*optional*): + Dictionary of keyword arguments. + """ + + model_type = "owlvit" + sub_configs = {"text_config": OwlViTTextConfig, "vision_config": OwlViTVisionConfig} + + def __init__( + self, + text_config=None, + vision_config=None, + projection_dim=512, + logit_scale_init_value=2.6592, + return_dict=True, + **kwargs, + ): + super().__init__(**kwargs) + + if text_config is None: + text_config = {} + logger.info("text_config is None. Initializing the OwlViTTextConfig with default values.") + + if vision_config is None: + vision_config = {} + logger.info("vision_config is None. initializing the OwlViTVisionConfig with default values.") + + self.text_config = OwlViTTextConfig(**text_config) + self.vision_config = OwlViTVisionConfig(**vision_config) + + self.projection_dim = projection_dim + self.logit_scale_init_value = logit_scale_init_value + self.return_dict = return_dict + self.initializer_factor = 1.0 + + @classmethod + def from_text_vision_configs(cls, text_config: Dict, vision_config: Dict, **kwargs): + r""" + Instantiate a [`OwlViTConfig`] (or a derived class) from owlvit text model configuration and owlvit vision + model configuration. + + Returns: + [`OwlViTConfig`]: An instance of a configuration object + """ + config_dict = {} + config_dict["text_config"] = text_config + config_dict["vision_config"] = vision_config + + return cls.from_dict(config_dict, **kwargs) + + +class OwlViTOnnxConfig(OnnxConfig): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict( + [ + ("input_ids", {0: "batch", 1: "sequence"}), + ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}), + ("attention_mask", {0: "batch", 1: "sequence"}), + ] + ) + + @property + def outputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict( + [ + ("logits_per_image", {0: "batch"}), + ("logits_per_text", {0: "batch"}), + ("text_embeds", {0: "batch"}), + ("image_embeds", {0: "batch"}), + ] + ) + + @property + def atol_for_validation(self) -> float: + return 1e-4 + + def generate_dummy_inputs( + self, + processor: "ProcessorMixin", + batch_size: int = -1, + seq_length: int = -1, + framework: Optional["TensorType"] = None, + ) -> Mapping[str, Any]: + text_input_dict = super().generate_dummy_inputs( + processor.tokenizer, batch_size=batch_size, seq_length=seq_length, framework=framework + ) + image_input_dict = super().generate_dummy_inputs( + processor.image_processor, batch_size=batch_size, framework=framework + ) + return {**text_input_dict, **image_input_dict} + + @property + def default_onnx_opset(self) -> int: + return 14 + + +__all__ = ["OwlViTConfig", "OwlViTOnnxConfig", "OwlViTTextConfig", "OwlViTVisionConfig"] diff --git a/docs/transformers/src/transformers/models/owlvit/convert_owlvit_original_flax_to_hf.py b/docs/transformers/src/transformers/models/owlvit/convert_owlvit_original_flax_to_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..1e9fbb950467b124b44fcf0d686a3f2af04b3bae --- /dev/null +++ b/docs/transformers/src/transformers/models/owlvit/convert_owlvit_original_flax_to_hf.py @@ -0,0 +1,406 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert OWL-ViT checkpoints from the original repository. URL: +https://github.com/google-research/scenic/tree/main/scenic/projects/owl_vit""" + +import argparse +import collections + +import jax +import jax.numpy as jnp +import torch +import torch.nn as nn +from clip.model import CLIP +from flax.training import checkpoints +from huggingface_hub import Repository + +from transformers import ( + CLIPTokenizer, + OwlViTConfig, + OwlViTForObjectDetection, + OwlViTImageProcessor, + OwlViTModel, + OwlViTProcessor, +) + + +CONFIGS = { + "vit_b32": { + "embed_dim": 512, + "image_resolution": 768, + "context_length": 16, + "vocab_size": 49408, + "vision_layers": 12, + "vision_width": 768, + "vision_patch_size": 32, + "transformer_width": 512, + "transformer_heads": 8, + "transformer_layers": 12, + }, + "vit_b16": { + "embed_dim": 512, + "image_resolution": 768, + "context_length": 16, + "vocab_size": 49408, + "vision_layers": 12, + "vision_width": 768, + "vision_patch_size": 16, + "transformer_width": 512, + "transformer_heads": 8, + "transformer_layers": 12, + }, + "vit_l14": { + "embed_dim": 768, + "image_resolution": 840, + "context_length": 16, + "vocab_size": 49408, + "vision_layers": 24, + "vision_width": 1024, + "vision_patch_size": 14, + "transformer_width": 768, + "transformer_heads": 12, + "transformer_layers": 12, + }, +} + + +def flatten_nested_dict(params, parent_key="", sep="/"): + items = [] + + for k, v in params.items(): + new_key = parent_key + sep + k if parent_key else k + + if isinstance(v, collections.MutableMapping): + items.extend(flatten_nested_dict(v, new_key, sep=sep).items()) + else: + items.append((new_key, v)) + return dict(items) + + +def to_f32(params): + return jax.tree_util.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, params) + + +def copy_attn_layer(hf_attn_layer, pt_attn_layer): + q_proj, k_proj, v_proj = pt_attn_layer.in_proj_weight.chunk(3, dim=0) + q_proj_bias, k_proj_bias, v_proj_bias = pt_attn_layer.in_proj_bias.chunk(3, dim=0) + + out_proj_weights = pt_attn_layer.out_proj.weight + out_proj_bias = pt_attn_layer.out_proj.bias + + hf_attn_layer.q_proj.weight.data = q_proj + hf_attn_layer.q_proj.bias.data = q_proj_bias + + hf_attn_layer.k_proj.weight.data = k_proj + hf_attn_layer.k_proj.bias.data = k_proj_bias + + hf_attn_layer.v_proj.weight.data = v_proj + hf_attn_layer.v_proj.bias.data = v_proj_bias + + hf_attn_layer.out_proj.weight = out_proj_weights + hf_attn_layer.out_proj.bias = out_proj_bias + + +def copy_mlp(hf_mlp, pt_mlp): + copy_linear(hf_mlp.fc1, pt_mlp.c_fc) + copy_linear(hf_mlp.fc2, pt_mlp.c_proj) + + +def copy_linear(hf_linear, pt_linear): + hf_linear.weight = pt_linear.weight + hf_linear.bias = pt_linear.bias + + +def copy_layer(hf_layer, pt_layer): + # copy layer norms + copy_linear(hf_layer.layer_norm1, pt_layer.ln_1) + copy_linear(hf_layer.layer_norm2, pt_layer.ln_2) + + # copy MLP + copy_mlp(hf_layer.mlp, pt_layer.mlp) + + # copy attn + copy_attn_layer(hf_layer.self_attn, pt_layer.attn) + + +def copy_layers(hf_layers, pt_layers): + for hf_layer, pt_layer in zip(hf_layers, pt_layers): + copy_layer(hf_layer, pt_layer) + + +def copy_encoder(hf_encoder, pt_model): + # copy embeds + hf_encoder.embeddings.token_embedding.weight = pt_model.token_embedding.weight + hf_encoder.embeddings.position_embedding.weight.data = pt_model.positional_embedding + + # copy layer norm + copy_linear(hf_encoder.final_layer_norm, pt_model.ln_final) + + # copy hidden layers + copy_layers(hf_encoder.encoder.layers, pt_model.transformer.resblocks) + + +def copy_text_model_and_projection(hf_model, pt_model): + # copy projection + hf_model.text_projection.weight.data = pt_model.text_projection.data.T + + # copy text encoder + copy_encoder(hf_model.text_model, pt_model) + + +def copy_vision_model_and_projection(hf_model, pt_model): + # copy projection + hf_model.visual_projection.weight.data = pt_model.visual.proj.data.T + + # copy layer norms + copy_linear(hf_model.vision_model.pre_layernorm, pt_model.visual.ln_pre) + copy_linear(hf_model.vision_model.post_layernorm, pt_model.visual.ln_post) + + # copy embeds + hf_model.vision_model.embeddings.patch_embedding.weight.data = pt_model.visual.conv1.weight.data + hf_model.vision_model.embeddings.class_embedding = pt_model.visual.class_embedding + hf_model.vision_model.embeddings.position_embedding.weight.data = pt_model.visual.positional_embedding.data + + # copy encoder + copy_layers(hf_model.vision_model.encoder.layers, pt_model.visual.transformer.resblocks) + + +def copy_class_merge_token(hf_model, flax_params): + flax_class_token_params = flatten_nested_dict(flax_params["backbone"]["merged_class_token"]) + + weight = torch.from_numpy(flax_class_token_params["scale"]) + bias = torch.from_numpy(flax_class_token_params["bias"]) + hf_model.layer_norm.weight = nn.Parameter(weight) + hf_model.layer_norm.bias = nn.Parameter(bias) + + +def copy_class_box_heads(hf_model, flax_params): + pt_params = hf_model.state_dict() + new_params = {} + + # Rename class prediction head flax params to pytorch HF + flax_class_params = flatten_nested_dict(flax_params["class_head"]) + + for flax_key, v in flax_class_params.items(): + torch_key = flax_key.replace("/", ".") + torch_key = torch_key.replace(".kernel", ".weight") + torch_key = torch_key.replace("Dense_0", "dense0") + torch_key = "class_head." + torch_key + + if "weight" in torch_key and v.ndim == 2: + v = v.T + + new_params[torch_key] = nn.Parameter(torch.from_numpy(v)) + + # Rename box prediction box flax params to pytorch HF + flax_box_params = flatten_nested_dict(flax_params["obj_box_head"]) + + for flax_key, v in flax_box_params.items(): + torch_key = flax_key.replace("/", ".") + torch_key = torch_key.replace(".kernel", ".weight") + torch_key = torch_key.replace("_", "").lower() + torch_key = "box_head." + torch_key + + if "weight" in torch_key and v.ndim == 2: + v = v.T + + new_params[torch_key] = nn.Parameter(torch.from_numpy(v)) + + # Copy flax params to PyTorch params + for name, param in new_params.items(): + if name in pt_params.keys(): + pt_params[name].copy_(param) + + +def copy_flax_attn_params(hf_backbone, flax_attn_params): + for k, v in flax_attn_params.items(): + if k.startswith("transformer"): + torch_key = k.replace("transformer.resblocks", "text_model.encoder.layers") + else: + torch_key = k.replace("visual.transformer.resblocks", "vision_model.encoder.layers") + + torch_key = torch_key.replace("attn", "self_attn") + torch_key = torch_key.replace("key", "k_proj") + torch_key = torch_key.replace("value", "v_proj") + torch_key = torch_key.replace("query", "q_proj") + torch_key = torch_key.replace("out", "out_proj") + + if "bias" in torch_key and v.ndim == 2: + shape = v.shape[0] * v.shape[1] + v = v.reshape(shape) + + if "weight" in torch_key and "out" in torch_key: + shape = (v.shape[0] * v.shape[1], v.shape[2]) + v = v.reshape(shape).T + + if "weight" in torch_key and "out" not in torch_key: + shape = (v.shape[0], v.shape[1] * v.shape[2]) + v = v.reshape(shape).T + + # Copy flax CLIP attn params to HF PyTorch params + v = torch.from_numpy(v) + hf_backbone.state_dict()[torch_key].copy_(v) + + +def _convert_attn_layers(params): + new_params = {} + processed_attn_layers = [] + + for k, v in params.items(): + if "attn." in k: + base = k[: k.rindex("attn.") + 5] + if base in processed_attn_layers: + continue + + processed_attn_layers.append(base) + dim = params[base + "out.weight"].shape[-1] + new_params[base + "out_proj.weight"] = params[base + "out.weight"].reshape(dim, dim).T + new_params[base + "out_proj.bias"] = params[base + "out.bias"] + else: + new_params[k] = v + return new_params + + +def convert_clip_backbone(flax_params, torch_config): + torch_model = CLIP(**torch_config) + torch_model.eval() + torch_clip_params = torch_model.state_dict() + + flax_clip_params = flatten_nested_dict(flax_params["backbone"]["clip"]) + new_torch_params = {} + + for flax_key, v in flax_clip_params.items(): + torch_key = flax_key.replace("/", ".") + torch_key = torch_key.replace("text.token_embedding.embedding", "token_embedding.kernel") + + if ( + torch_key.startswith("text.transformer") + or torch_key.startswith("text.text_projection") + or torch_key.startswith("text.ln_final") + or torch_key.startswith("text.positional_embedding") + ): + torch_key = torch_key[5:] + + torch_key = torch_key.replace("text_projection.kernel", "text_projection") + torch_key = torch_key.replace("visual.proj.kernel", "visual.proj") + torch_key = torch_key.replace(".scale", ".weight") + torch_key = torch_key.replace(".kernel", ".weight") + + if "conv" in torch_key or "downsample.0.weight" in torch_key: + v = v.transpose(3, 2, 0, 1) + + elif "weight" in torch_key and v.ndim == 2 and "embedding" not in torch_key: + # Fully connected layers are transposed, embeddings are not + v = v.T + + new_torch_params[torch_key] = v + + attn_params = _convert_attn_layers(new_torch_params) + new_torch_params.update(attn_params) + attn_params = {} + + # Copy flax CLIP backbone params to PyTorch params + for name, param in new_torch_params.items(): + if name in torch_clip_params.keys(): + new_param = torch.from_numpy(new_torch_params[name]) + torch_clip_params[name].copy_(new_param) + else: + attn_params[name] = param + + return torch_clip_params, torch_model, attn_params + + +@torch.no_grad() +def convert_owlvit_checkpoint(pt_backbone, flax_params, attn_params, pytorch_dump_folder_path, config_path=None): + """ + Copy/paste/tweak model's weights to transformers design. + """ + repo = Repository(pytorch_dump_folder_path, clone_from=f"google/{pytorch_dump_folder_path}") + repo.git_pull() + + if config_path is not None: + config = OwlViTConfig.from_pretrained(config_path) + else: + config = OwlViTConfig() + + hf_backbone = OwlViTModel(config).eval() + hf_model = OwlViTForObjectDetection(config).eval() + + copy_text_model_and_projection(hf_backbone, pt_backbone) + copy_vision_model_and_projection(hf_backbone, pt_backbone) + hf_backbone.logit_scale = pt_backbone.logit_scale + copy_flax_attn_params(hf_backbone, attn_params) + + hf_model.owlvit = hf_backbone + copy_class_merge_token(hf_model, flax_params) + copy_class_box_heads(hf_model, flax_params) + + # Save HF model + hf_model.save_pretrained(repo.local_dir) + + # Initialize image processor + image_processor = OwlViTImageProcessor( + size=config.vision_config.image_size, crop_size=config.vision_config.image_size + ) + # Initialize tokenizer + tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32", pad_token="!", model_max_length=16) + + # Initialize processor + processor = OwlViTProcessor(image_processor=image_processor, tokenizer=tokenizer) + image_processor.save_pretrained(repo.local_dir) + processor.save_pretrained(repo.local_dir) + + repo.git_add() + repo.git_commit("Upload model and processor") + repo.git_push() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--owlvit_version", + default=None, + type=str, + required=True, + help="OWL-ViT model name [clip_b16, clip_b32, clip_l14].", + ) + parser.add_argument( + "--owlvit_checkpoint", default=None, type=str, required=True, help="Path to flax model checkpoint." + ) + parser.add_argument("--hf_config", default=None, type=str, required=True, help="Path to HF model config.") + parser.add_argument( + "--pytorch_dump_folder_path", default="hf_model", type=str, help="Path to the output PyTorch model." + ) + args = parser.parse_args() + + # Initialize PyToch clip model + model_name = args.owlvit_version + if model_name == "clip_b16": + torch_config = CONFIGS["vit_b16"] + elif model_name == "clip_b32": + torch_config = CONFIGS["vit_b32"] + elif model_name == "clip_l14": + torch_config = CONFIGS["vit_l14"] + + # Load from checkpoint and convert params to float-32 + variables = checkpoints.restore_checkpoint(args.owlvit_checkpoint, target=None)["optimizer"]["target"] + flax_params = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, variables) + del variables + + # Convert CLIP backbone + pt_backbone_params, clip_pt, attn_params = convert_clip_backbone(flax_params, torch_config) + + convert_owlvit_checkpoint(clip_pt, flax_params, attn_params, args.pytorch_dump_folder_path, args.hf_config) diff --git a/docs/transformers/src/transformers/models/owlvit/feature_extraction_owlvit.py b/docs/transformers/src/transformers/models/owlvit/feature_extraction_owlvit.py new file mode 100644 index 0000000000000000000000000000000000000000..ee3a8d0b145cc54e8ccebf7acc84538236fac644 --- /dev/null +++ b/docs/transformers/src/transformers/models/owlvit/feature_extraction_owlvit.py @@ -0,0 +1,38 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Feature extractor class for OwlViT.""" + +import warnings + +from ...utils import logging +from ...utils.import_utils import requires +from .image_processing_owlvit import OwlViTImageProcessor + + +logger = logging.get_logger(__name__) + + +@requires(backends=("vision",)) +class OwlViTFeatureExtractor(OwlViTImageProcessor): + def __init__(self, *args, **kwargs) -> None: + warnings.warn( + "The class OwlViTFeatureExtractor is deprecated and will be removed in version 5 of Transformers. Please" + " use OwlViTImageProcessor instead.", + FutureWarning, + ) + super().__init__(*args, **kwargs) + + +__all__ = ["OwlViTFeatureExtractor"] diff --git a/docs/transformers/src/transformers/models/owlvit/image_processing_owlvit.py b/docs/transformers/src/transformers/models/owlvit/image_processing_owlvit.py new file mode 100644 index 0000000000000000000000000000000000000000..dfd5007f990c504c3c7fb3a8fd00f2fb48065937 --- /dev/null +++ b/docs/transformers/src/transformers/models/owlvit/image_processing_owlvit.py @@ -0,0 +1,625 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for OwlViT""" + +import warnings +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import ( + center_crop, + center_to_corners_format, + rescale, + resize, + to_channel_dimension_format, +) +from ...image_utils import ( + OPENAI_CLIP_MEAN, + OPENAI_CLIP_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + infer_channel_dimension_format, + is_scaled_image, + make_list_of_images, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...utils import TensorType, filter_out_non_signature_kwargs, is_torch_available, logging +from ...utils.import_utils import requires + + +if TYPE_CHECKING: + from .modeling_owlvit import OwlViTObjectDetectionOutput + +if is_torch_available(): + import torch + + +logger = logging.get_logger(__name__) + + +def _upcast(t): + # Protects from numerical overflows in multiplications by upcasting to the equivalent higher type + if t.is_floating_point(): + return t if t.dtype in (torch.float32, torch.float64) else t.float() + else: + return t if t.dtype in (torch.int32, torch.int64) else t.int() + + +def _scale_boxes(boxes, target_sizes): + """ + Scale batch of bounding boxes to the target sizes. + + Args: + boxes (`torch.Tensor` of shape `(batch_size, num_boxes, 4)`): + Bounding boxes to scale. Each box is expected to be in (x1, y1, x2, y2) format. + target_sizes (`List[Tuple[int, int]]` or `torch.Tensor` of shape `(batch_size, 2)`): + Target sizes to scale the boxes to. Each target size is expected to be in (height, width) format. + + Returns: + `torch.Tensor` of shape `(batch_size, num_boxes, 4)`: Scaled bounding boxes. + """ + + if isinstance(target_sizes, (list, tuple)): + image_height = torch.tensor([i[0] for i in target_sizes]) + image_width = torch.tensor([i[1] for i in target_sizes]) + elif isinstance(target_sizes, torch.Tensor): + image_height, image_width = target_sizes.unbind(1) + else: + raise ValueError("`target_sizes` must be a list, tuple or torch.Tensor") + + scale_factor = torch.stack([image_width, image_height, image_width, image_height], dim=1) + scale_factor = scale_factor.unsqueeze(1).to(boxes.device) + boxes = boxes * scale_factor + return boxes + + +def box_area(boxes): + """ + Computes the area of a set of bounding boxes, which are specified by its (x1, y1, x2, y2) coordinates. + + Args: + boxes (`torch.FloatTensor` of shape `(number_of_boxes, 4)`): + Boxes for which the area will be computed. They are expected to be in (x1, y1, x2, y2) format with `0 <= x1 + < x2` and `0 <= y1 < y2`. + Returns: + `torch.FloatTensor`: a tensor containing the area for each box. + """ + boxes = _upcast(boxes) + return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) + + +def box_iou(boxes1, boxes2): + area1 = box_area(boxes1) + area2 = box_area(boxes2) + + left_top = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] + right_bottom = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] + + width_height = (right_bottom - left_top).clamp(min=0) # [N,M,2] + inter = width_height[:, :, 0] * width_height[:, :, 1] # [N,M] + + union = area1[:, None] + area2 - inter + + iou = inter / union + return iou, union + + +@requires(backends=("vision",)) +class OwlViTImageProcessor(BaseImageProcessor): + r""" + Constructs an OWL-ViT image processor. + + This image processor inherits from [`ImageProcessingMixin`] which contains most of the main methods. Users should + refer to this superclass for more information regarding those methods. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the shorter edge of the input to a certain `size`. + size (`Dict[str, int]`, *optional*, defaults to {"height": 768, "width": 768}): + The size to use for resizing the image. Only has an effect if `do_resize` is set to `True`. If `size` is a + sequence like (h, w), output size will be matched to this. If `size` is an int, then image will be resized + to (size, size). + resample (`int`, *optional*, defaults to `Resampling.BICUBIC`): + An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`, + `PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`, + `PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set + to `True`. + do_center_crop (`bool`, *optional*, defaults to `False`): + Whether to crop the input at the center. If the input size is smaller than `crop_size` along any edge, the + image is padded with 0's and then center cropped. + crop_size (`int`, *optional*, defaults to {"height": 768, "width": 768}): + The size to use for center cropping the image. Only has an effect if `do_center_crop` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the input by a certain factor. + rescale_factor (`float`, *optional*, defaults to `1/255`): + The factor to use for rescaling the image. Only has an effect if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether or not to normalize the input with `image_mean` and `image_std`. Desired output size when applying + center-cropping. Only has an effect if `do_center_crop` is set to `True`. + image_mean (`List[int]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`): + The sequence of means for each channel, to be used when normalizing images. + image_std (`List[int]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`): + The sequence of standard deviations for each channel, to be used when normalizing images. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize=True, + size=None, + resample=PILImageResampling.BICUBIC, + do_center_crop=False, + crop_size=None, + do_rescale=True, + rescale_factor=1 / 255, + do_normalize=True, + image_mean=None, + image_std=None, + **kwargs, + ): + size = size if size is not None else {"height": 768, "width": 768} + size = get_size_dict(size, default_to_square=True) + + crop_size = crop_size if crop_size is not None else {"height": 768, "width": 768} + crop_size = get_size_dict(crop_size, default_to_square=True) + + # Early versions of the OWL-ViT config on the hub had "rescale" as a flag. This clashes with the + # vision image processor method `rescale` as it would be set as an attribute during the super().__init__ + # call. This is for backwards compatibility. + if "rescale" in kwargs: + rescale_val = kwargs.pop("rescale") + kwargs["do_rescale"] = rescale_val + + super().__init__(**kwargs) + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_center_crop = do_center_crop + self.crop_size = crop_size + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN + self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD + + def resize( + self, + image: np.ndarray, + size: Dict[str, int], + resample: PILImageResampling.BICUBIC, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize an image to a certain size. + + Args: + image (`np.ndarray`): + Image to resize. + size (`Dict[str, int]`): + The size to resize the image to. Must contain height and width keys. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + The resampling filter to use when resizing the input. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the output image. If unset, the channel dimension format of the input + image is used. + input_data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + size = get_size_dict(size, default_to_square=True) + if "height" not in size or "width" not in size: + raise ValueError("size dictionary must contain height and width keys") + + return resize( + image, + (size["height"], size["width"]), + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + def center_crop( + self, + image: np.ndarray, + crop_size: Dict[str, int], + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Center crop an image to a certain size. + + Args: + image (`np.ndarray`): + Image to center crop. + crop_size (`Dict[str, int]`): + The size to center crop the image to. Must contain height and width keys. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the output image. If unset, the channel dimension format of the input + image is used. + input_data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + crop_size = get_size_dict(crop_size, default_to_square=True) + if "height" not in crop_size or "width" not in crop_size: + raise ValueError("crop_size dictionary must contain height and width keys") + + return center_crop( + image, + (crop_size["height"], crop_size["width"]), + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.rescale + def rescale( + self, + image: np.ndarray, + rescale_factor: float, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> np.ndarray: + """ + Rescale the image by the given factor. image = image * rescale_factor. + + Args: + image (`np.ndarray`): + Image to rescale. + rescale_factor (`float`): + The value to use for rescaling. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the output image. If unset, the channel dimension format of the input + image is used. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + input_data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the input image. If unset, is inferred from the input image. Can be + one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + """ + return rescale(image, rescale_factor, data_format=data_format, input_data_format=input_data_format) + + @filter_out_non_signature_kwargs() + def preprocess( + self, + images: ImageInput, + do_resize: Optional[bool] = None, + size: Optional[Dict[str, int]] = None, + resample: PILImageResampling = None, + do_center_crop: Optional[bool] = None, + crop_size: Optional[Dict[str, int]] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + return_tensors: Optional[Union[TensorType, str]] = None, + data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> BatchFeature: + """ + Prepares an image or batch of images for the model. + + Args: + images (`ImageInput`): + The image or batch of images to be prepared. Expects a single or batch of images with pixel values + ranging from 0 to 255. If passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether or not to resize the input. If `True`, will resize the input to the size specified by `size`. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + The size to resize the input to. Only has an effect if `do_resize` is set to `True`. + resample (`PILImageResampling`, *optional*, defaults to `self.resample`): + The resampling filter to use when resizing the input. Only has an effect if `do_resize` is set to + `True`. + do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`): + Whether or not to center crop the input. If `True`, will center crop the input to the size specified by + `crop_size`. + crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`): + The size to center crop the input to. Only has an effect if `do_center_crop` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether or not to rescale the input. If `True`, will rescale the input by dividing it by + `rescale_factor`. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + The factor to rescale the input by. Only has an effect if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether or not to normalize the input. If `True`, will normalize the input by subtracting `image_mean` + and dividing by `image_std`. + image_mean (`Union[float, List[float]]`, *optional*, defaults to `self.image_mean`): + The mean to subtract from the input when normalizing. Only has an effect if `do_normalize` is set to + `True`. + image_std (`Union[float, List[float]]`, *optional*, defaults to `self.image_std`): + The standard deviation to divide the input by when normalizing. Only has an effect if `do_normalize` is + set to `True`. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: defaults to the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + do_resize = do_resize if do_resize is not None else self.do_resize + size = size if size is not None else self.size + resample = resample if resample is not None else self.resample + do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop + crop_size = crop_size if crop_size is not None else self.crop_size + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + + images = make_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_resize=do_resize, + size=size, + resample=resample, + ) + + # All transformations expect numpy arrays + images = [to_numpy_array(image) for image in images] + + if do_rescale and is_scaled_image(images[0]): + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + if do_resize: + images = [ + self.resize(image, size=size, resample=resample, input_data_format=input_data_format) + for image in images + ] + + if do_center_crop: + images = [ + self.center_crop(image, crop_size=crop_size, input_data_format=input_data_format) for image in images + ] + + if do_rescale: + images = [ + self.rescale(image, rescale_factor=rescale_factor, input_data_format=input_data_format) + for image in images + ] + + if do_normalize: + images = [ + self.normalize(image, mean=image_mean, std=image_std, input_data_format=input_data_format) + for image in images + ] + + images = [ + to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images + ] + encoded_inputs = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors) + return encoded_inputs + + def post_process(self, outputs, target_sizes): + """ + Converts the raw output of [`OwlViTForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y, + bottom_right_x, bottom_right_y) format. + + Args: + outputs ([`OwlViTObjectDetectionOutput`]): + Raw outputs of the model. + target_sizes (`torch.Tensor` of shape `(batch_size, 2)`): + Tensor containing the size (h, w) of each image of the batch. For evaluation, this must be the original + image size (before any data augmentation). For visualization, this should be the image size after data + augment, but before padding. + Returns: + `List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image + in the batch as predicted by the model. + """ + # TODO: (amy) add support for other frameworks + warnings.warn( + "`post_process` is deprecated and will be removed in v5 of Transformers, please use" + " `post_process_object_detection` instead, with `threshold=0.` for equivalent results.", + FutureWarning, + ) + + logits, boxes = outputs.logits, outputs.pred_boxes + + if len(logits) != len(target_sizes): + raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits") + if target_sizes.shape[1] != 2: + raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch") + + probs = torch.max(logits, dim=-1) + scores = torch.sigmoid(probs.values) + labels = probs.indices + + # Convert to [x0, y0, x1, y1] format + boxes = center_to_corners_format(boxes) + + # Convert from relative [0, 1] to absolute [0, height] coordinates + img_h, img_w = target_sizes.unbind(1) + scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device) + boxes = boxes * scale_fct[:, None, :] + + results = [{"scores": s, "labels": l, "boxes": b} for s, l, b in zip(scores, labels, boxes)] + + return results + + def post_process_object_detection( + self, + outputs: "OwlViTObjectDetectionOutput", + threshold: float = 0.1, + target_sizes: Optional[Union[TensorType, List[Tuple]]] = None, + ): + """ + Converts the raw output of [`OwlViTForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y, + bottom_right_x, bottom_right_y) format. + + Args: + outputs ([`OwlViTObjectDetectionOutput`]): + Raw outputs of the model. + threshold (`float`, *optional*, defaults to 0.1): + Score threshold to keep object detection predictions. + target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*): + Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size + `(height, width)` of each image in the batch. If unset, predictions will not be resized. + + Returns: + `List[Dict]`: A list of dictionaries, each dictionary containing the following keys: + - "scores": The confidence scores for each predicted box on the image. + - "labels": Indexes of the classes predicted by the model on the image. + - "boxes": Image bounding boxes in (top_left_x, top_left_y, bottom_right_x, bottom_right_y) format. + """ + batch_logits, batch_boxes = outputs.logits, outputs.pred_boxes + batch_size = len(batch_logits) + + if target_sizes is not None and len(target_sizes) != batch_size: + raise ValueError("Make sure that you pass in as many target sizes as images") + + # batch_logits of shape (batch_size, num_queries, num_classes) + batch_class_logits = torch.max(batch_logits, dim=-1) + batch_scores = torch.sigmoid(batch_class_logits.values) + batch_labels = batch_class_logits.indices + + # Convert to [x0, y0, x1, y1] format + batch_boxes = center_to_corners_format(batch_boxes) + + # Convert from relative [0, 1] to absolute [0, height] coordinates + if target_sizes is not None: + batch_boxes = _scale_boxes(batch_boxes, target_sizes) + + results = [] + for scores, labels, boxes in zip(batch_scores, batch_labels, batch_boxes): + keep = scores > threshold + scores = scores[keep] + labels = labels[keep] + boxes = boxes[keep] + results.append({"scores": scores, "labels": labels, "boxes": boxes}) + + return results + + def post_process_image_guided_detection(self, outputs, threshold=0.0, nms_threshold=0.3, target_sizes=None): + """ + Converts the output of [`OwlViTForObjectDetection.image_guided_detection`] into the format expected by the COCO + api. + + Args: + outputs ([`OwlViTImageGuidedObjectDetectionOutput`]): + Raw outputs of the model. + threshold (`float`, *optional*, defaults to 0.0): + Minimum confidence threshold to use to filter out predicted boxes. + nms_threshold (`float`, *optional*, defaults to 0.3): + IoU threshold for non-maximum suppression of overlapping boxes. + target_sizes (`torch.Tensor`, *optional*): + Tensor of shape (batch_size, 2) where each entry is the (height, width) of the corresponding image in + the batch. If set, predicted normalized bounding boxes are rescaled to the target sizes. If left to + None, predictions will not be unnormalized. + + Returns: + `List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image + in the batch as predicted by the model. All labels are set to None as + `OwlViTForObjectDetection.image_guided_detection` perform one-shot object detection. + """ + logits, target_boxes = outputs.logits, outputs.target_pred_boxes + + if target_sizes is not None and len(logits) != len(target_sizes): + raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits") + if target_sizes is not None and target_sizes.shape[1] != 2: + raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch") + + probs = torch.max(logits, dim=-1) + scores = torch.sigmoid(probs.values) + + # Convert to [x0, y0, x1, y1] format + target_boxes = center_to_corners_format(target_boxes) + + # Apply non-maximum suppression (NMS) + if nms_threshold < 1.0: + for idx in range(target_boxes.shape[0]): + for i in torch.argsort(-scores[idx]): + if not scores[idx][i]: + continue + + ious = box_iou(target_boxes[idx][i, :].unsqueeze(0), target_boxes[idx])[0][0] + ious[i] = -1.0 # Mask self-IoU. + scores[idx][ious > nms_threshold] = 0.0 + + # Convert from relative [0, 1] to absolute [0, height] coordinates + if target_sizes is not None: + target_boxes = _scale_boxes(target_boxes, target_sizes) + + # Compute box display alphas based on prediction scores + results = [] + alphas = torch.zeros_like(scores) + + for idx in range(target_boxes.shape[0]): + # Select scores for boxes matching the current query: + query_scores = scores[idx] + if not query_scores.nonzero().numel(): + continue + + # Apply threshold on scores before scaling + query_scores[query_scores < threshold] = 0.0 + + # Scale box alpha such that the best box for each query has alpha 1.0 and the worst box has alpha 0.1. + # All other boxes will either belong to a different query, or will not be shown. + max_score = torch.max(query_scores) + 1e-6 + query_alphas = (query_scores - (max_score * 0.1)) / (max_score * 0.9) + query_alphas = torch.clip(query_alphas, 0.0, 1.0) + alphas[idx] = query_alphas + + mask = alphas[idx] > 0 + box_scores = alphas[idx][mask] + boxes = target_boxes[idx][mask] + results.append({"scores": box_scores, "labels": None, "boxes": boxes}) + + return results + + +__all__ = ["OwlViTImageProcessor"] diff --git a/docs/transformers/src/transformers/models/owlvit/image_processing_owlvit_fast.py b/docs/transformers/src/transformers/models/owlvit/image_processing_owlvit_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..f048f1a0da51e45d821faa2823b7cb074a1992eb --- /dev/null +++ b/docs/transformers/src/transformers/models/owlvit/image_processing_owlvit_fast.py @@ -0,0 +1,240 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for OwlViT""" + +import warnings +from typing import TYPE_CHECKING, List, Optional, Tuple, Union + +from ...image_processing_utils_fast import ( + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, + BaseImageProcessorFast, +) +from ...image_transforms import center_to_corners_format +from ...image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, PILImageResampling +from ...utils import TensorType, add_start_docstrings, is_torch_available, logging + + +if TYPE_CHECKING: + from .modeling_owlvit import OwlViTObjectDetectionOutput + + +if is_torch_available(): + import torch + + from .image_processing_owlvit import _scale_boxes, box_iou + + +logger = logging.get_logger(__name__) + + +@add_start_docstrings( + "Constructs a fast OwlViT image processor.", + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, +) +class OwlViTImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.BICUBIC + image_mean = OPENAI_CLIP_MEAN + image_std = OPENAI_CLIP_STD + size = {"height": 768, "width": 768} + default_to_square = True + crop_size = {"height": 768, "width": 768} + do_resize = True + do_center_crop = False + do_rescale = True + do_normalize = None + do_convert_rgb = None + model_input_names = ["pixel_values"] + + # Copied from transformers.models.owlvit.image_processing_owlvit.OwlViTImageProcessor.post_process + def post_process(self, outputs, target_sizes): + """ + Converts the raw output of [`OwlViTForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y, + bottom_right_x, bottom_right_y) format. + + Args: + outputs ([`OwlViTObjectDetectionOutput`]): + Raw outputs of the model. + target_sizes (`torch.Tensor` of shape `(batch_size, 2)`): + Tensor containing the size (h, w) of each image of the batch. For evaluation, this must be the original + image size (before any data augmentation). For visualization, this should be the image size after data + augment, but before padding. + Returns: + `List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image + in the batch as predicted by the model. + """ + # TODO: (amy) add support for other frameworks + warnings.warn( + "`post_process` is deprecated and will be removed in v5 of Transformers, please use" + " `post_process_object_detection` instead, with `threshold=0.` for equivalent results.", + FutureWarning, + ) + + logits, boxes = outputs.logits, outputs.pred_boxes + + if len(logits) != len(target_sizes): + raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits") + if target_sizes.shape[1] != 2: + raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch") + + probs = torch.max(logits, dim=-1) + scores = torch.sigmoid(probs.values) + labels = probs.indices + + # Convert to [x0, y0, x1, y1] format + boxes = center_to_corners_format(boxes) + + # Convert from relative [0, 1] to absolute [0, height] coordinates + img_h, img_w = target_sizes.unbind(1) + scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device) + boxes = boxes * scale_fct[:, None, :] + + results = [{"scores": s, "labels": l, "boxes": b} for s, l, b in zip(scores, labels, boxes)] + + return results + + # Copied from transformers.models.owlvit.image_processing_owlvit.OwlViTImageProcessor.post_process_object_detection + def post_process_object_detection( + self, + outputs: "OwlViTObjectDetectionOutput", + threshold: float = 0.1, + target_sizes: Optional[Union[TensorType, List[Tuple]]] = None, + ): + """ + Converts the raw output of [`OwlViTForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y, + bottom_right_x, bottom_right_y) format. + + Args: + outputs ([`OwlViTObjectDetectionOutput`]): + Raw outputs of the model. + threshold (`float`, *optional*, defaults to 0.1): + Score threshold to keep object detection predictions. + target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*): + Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size + `(height, width)` of each image in the batch. If unset, predictions will not be resized. + + Returns: + `List[Dict]`: A list of dictionaries, each dictionary containing the following keys: + - "scores": The confidence scores for each predicted box on the image. + - "labels": Indexes of the classes predicted by the model on the image. + - "boxes": Image bounding boxes in (top_left_x, top_left_y, bottom_right_x, bottom_right_y) format. + """ + batch_logits, batch_boxes = outputs.logits, outputs.pred_boxes + batch_size = len(batch_logits) + + if target_sizes is not None and len(target_sizes) != batch_size: + raise ValueError("Make sure that you pass in as many target sizes as images") + + # batch_logits of shape (batch_size, num_queries, num_classes) + batch_class_logits = torch.max(batch_logits, dim=-1) + batch_scores = torch.sigmoid(batch_class_logits.values) + batch_labels = batch_class_logits.indices + + # Convert to [x0, y0, x1, y1] format + batch_boxes = center_to_corners_format(batch_boxes) + + # Convert from relative [0, 1] to absolute [0, height] coordinates + if target_sizes is not None: + batch_boxes = _scale_boxes(batch_boxes, target_sizes) + + results = [] + for scores, labels, boxes in zip(batch_scores, batch_labels, batch_boxes): + keep = scores > threshold + scores = scores[keep] + labels = labels[keep] + boxes = boxes[keep] + results.append({"scores": scores, "labels": labels, "boxes": boxes}) + + return results + + # Copied from transformers.models.owlvit.image_processing_owlvit.OwlViTImageProcessor.post_process_image_guided_detection + def post_process_image_guided_detection(self, outputs, threshold=0.0, nms_threshold=0.3, target_sizes=None): + """ + Converts the output of [`OwlViTForObjectDetection.image_guided_detection`] into the format expected by the COCO + api. + + Args: + outputs ([`OwlViTImageGuidedObjectDetectionOutput`]): + Raw outputs of the model. + threshold (`float`, *optional*, defaults to 0.0): + Minimum confidence threshold to use to filter out predicted boxes. + nms_threshold (`float`, *optional*, defaults to 0.3): + IoU threshold for non-maximum suppression of overlapping boxes. + target_sizes (`torch.Tensor`, *optional*): + Tensor of shape (batch_size, 2) where each entry is the (height, width) of the corresponding image in + the batch. If set, predicted normalized bounding boxes are rescaled to the target sizes. If left to + None, predictions will not be unnormalized. + + Returns: + `List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image + in the batch as predicted by the model. All labels are set to None as + `OwlViTForObjectDetection.image_guided_detection` perform one-shot object detection. + """ + logits, target_boxes = outputs.logits, outputs.target_pred_boxes + + if target_sizes is not None and len(logits) != len(target_sizes): + raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits") + if target_sizes is not None and target_sizes.shape[1] != 2: + raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch") + + probs = torch.max(logits, dim=-1) + scores = torch.sigmoid(probs.values) + + # Convert to [x0, y0, x1, y1] format + target_boxes = center_to_corners_format(target_boxes) + + # Apply non-maximum suppression (NMS) + if nms_threshold < 1.0: + for idx in range(target_boxes.shape[0]): + for i in torch.argsort(-scores[idx]): + if not scores[idx][i]: + continue + + ious = box_iou(target_boxes[idx][i, :].unsqueeze(0), target_boxes[idx])[0][0] + ious[i] = -1.0 # Mask self-IoU. + scores[idx][ious > nms_threshold] = 0.0 + + # Convert from relative [0, 1] to absolute [0, height] coordinates + if target_sizes is not None: + target_boxes = _scale_boxes(target_boxes, target_sizes) + + # Compute box display alphas based on prediction scores + results = [] + alphas = torch.zeros_like(scores) + + for idx in range(target_boxes.shape[0]): + # Select scores for boxes matching the current query: + query_scores = scores[idx] + if not query_scores.nonzero().numel(): + continue + + # Apply threshold on scores before scaling + query_scores[query_scores < threshold] = 0.0 + + # Scale box alpha such that the best box for each query has alpha 1.0 and the worst box has alpha 0.1. + # All other boxes will either belong to a different query, or will not be shown. + max_score = torch.max(query_scores) + 1e-6 + query_alphas = (query_scores - (max_score * 0.1)) / (max_score * 0.9) + query_alphas = torch.clip(query_alphas, 0.0, 1.0) + alphas[idx] = query_alphas + + mask = alphas[idx] > 0 + box_scores = alphas[idx][mask] + boxes = target_boxes[idx][mask] + results.append({"scores": box_scores, "labels": None, "boxes": boxes}) + + return results + + +__all__ = ["OwlViTImageProcessorFast"] diff --git a/docs/transformers/src/transformers/models/owlvit/modeling_owlvit.py b/docs/transformers/src/transformers/models/owlvit/modeling_owlvit.py new file mode 100644 index 0000000000000000000000000000000000000000..2eb0114cf11f9cc1cd9d0caa7302cf392baa87a1 --- /dev/null +++ b/docs/transformers/src/transformers/models/owlvit/modeling_owlvit.py @@ -0,0 +1,1780 @@ +# coding=utf-8 +# Copyright 2022 Google AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch OWL-ViT model.""" + +from dataclasses import dataclass +from functools import lru_cache +from typing import Any, Dict, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import Tensor, nn + +from ...activations import ACT2FN +from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling +from ...modeling_utils import PreTrainedModel +from ...utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_vision_available, + logging, + replace_return_docstrings, + torch_int, +) +from .configuration_owlvit import OwlViTConfig, OwlViTTextConfig, OwlViTVisionConfig + + +if is_vision_available(): + from transformers.image_transforms import center_to_corners_format + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "google/owlvit-base-patch32" + +# See all OwlViT models at https://huggingface.co/models?filter=owlvit + + +# Copied from transformers.models.clip.modeling_clip.contrastive_loss with clip->owlvit +def contrastive_loss(logits: torch.Tensor) -> torch.Tensor: + return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device)) + + +# Copied from transformers.models.clip.modeling_clip.clip_loss with clip->owlvit +def owlvit_loss(similarity: torch.Tensor) -> torch.Tensor: + caption_loss = contrastive_loss(similarity) + image_loss = contrastive_loss(similarity.t()) + return (caption_loss + image_loss) / 2.0 + + +@dataclass +class OwlViTOutput(ModelOutput): + """ + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`): + Contrastive loss for image-text similarity. + logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`): + The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text + similarity scores. + logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`): + The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image + similarity scores. + text_embeds (`torch.FloatTensor` of shape `(batch_size * num_max_text_queries, output_dim`): + The text embeddings obtained by applying the projection layer to the pooled output of [`OwlViTTextModel`]. + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`): + The image embeddings obtained by applying the projection layer to the pooled output of + [`OwlViTVisionModel`]. + text_model_output (Tuple[`BaseModelOutputWithPooling`]): + The output of the [`OwlViTTextModel`]. + vision_model_output (`BaseModelOutputWithPooling`): + The output of the [`OwlViTVisionModel`]. + """ + + loss: Optional[torch.FloatTensor] = None + logits_per_image: Optional[torch.FloatTensor] = None + logits_per_text: Optional[torch.FloatTensor] = None + text_embeds: Optional[torch.FloatTensor] = None + image_embeds: Optional[torch.FloatTensor] = None + text_model_output: BaseModelOutputWithPooling = None + vision_model_output: BaseModelOutputWithPooling = None + + def to_tuple(self) -> Tuple[Any]: + return tuple( + self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() + for k in self.keys() + ) + + +# Copied from transformers.loss.loss_for_object_detection._upcast +def _upcast(t: Tensor) -> Tensor: + # Protects from numerical overflows in multiplications by upcasting to the equivalent higher type + if t.is_floating_point(): + return t if t.dtype in (torch.float32, torch.float64) else t.float() + else: + return t if t.dtype in (torch.int32, torch.int64) else t.int() + + +# Copied from transformers.loss.loss_for_object_detection.box_area +def box_area(boxes: Tensor) -> Tensor: + """ + Computes the area of a set of bounding boxes, which are specified by its (x1, y1, x2, y2) coordinates. + + Args: + boxes (`torch.FloatTensor` of shape `(number_of_boxes, 4)`): + Boxes for which the area will be computed. They are expected to be in (x1, y1, x2, y2) format with `0 <= x1 + < x2` and `0 <= y1 < y2`. + + Returns: + `torch.FloatTensor`: a tensor containing the area for each box. + """ + boxes = _upcast(boxes) + return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) + + +# Copied from transformers.loss.loss_for_object_detection.box_iou +def box_iou(boxes1, boxes2): + area1 = box_area(boxes1) + area2 = box_area(boxes2) + + left_top = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] + right_bottom = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] + + width_height = (right_bottom - left_top).clamp(min=0) # [N,M,2] + inter = width_height[:, :, 0] * width_height[:, :, 1] # [N,M] + + union = area1[:, None] + area2 - inter + + iou = inter / union + return iou, union + + +# Copied from transformers.loss.loss_for_object_detection.generalized_box_iou +def generalized_box_iou(boxes1, boxes2): + """ + Generalized IoU from https://giou.stanford.edu/. The boxes should be in [x0, y0, x1, y1] (corner) format. + + Returns: + `torch.FloatTensor`: a [N, M] pairwise matrix, where N = len(boxes1) and M = len(boxes2) + """ + # degenerate boxes gives inf / nan results + # so do an early check + if not (boxes1[:, 2:] >= boxes1[:, :2]).all(): + raise ValueError(f"boxes1 must be in [x0, y0, x1, y1] (corner) format, but got {boxes1}") + if not (boxes2[:, 2:] >= boxes2[:, :2]).all(): + raise ValueError(f"boxes2 must be in [x0, y0, x1, y1] (corner) format, but got {boxes2}") + iou, union = box_iou(boxes1, boxes2) + + top_left = torch.min(boxes1[:, None, :2], boxes2[:, :2]) + bottom_right = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) + + width_height = (bottom_right - top_left).clamp(min=0) # [N,M,2] + area = width_height[:, :, 0] * width_height[:, :, 1] + + return iou - (area - union) / area + + +@dataclass +class OwlViTObjectDetectionOutput(ModelOutput): + """ + Output type of [`OwlViTForObjectDetection`]. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)): + Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a + bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized + scale-invariant IoU loss. + loss_dict (`Dict`, *optional*): + A dictionary containing the individual losses. Useful for logging. + logits (`torch.FloatTensor` of shape `(batch_size, num_patches, num_queries)`): + Classification logits (including no-object) for all queries. + pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`): + Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These + values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding + possible padding). You can use [`~OwlViTImageProcessor.post_process_object_detection`] to retrieve the + unnormalized bounding boxes. + text_embeds (`torch.FloatTensor` of shape `(batch_size, num_max_text_queries, output_dim`): + The text embeddings obtained by applying the projection layer to the pooled output of [`OwlViTTextModel`]. + image_embeds (`torch.FloatTensor` of shape `(batch_size, patch_size, patch_size, output_dim`): + Pooled output of [`OwlViTVisionModel`]. OWL-ViT represents images as a set of image patches and computes + image embeddings for each patch. + class_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`): + Class embeddings of all image patches. OWL-ViT represents images as a set of image patches where the total + number of patches is (image_size / patch_size)**2. + text_model_output (Tuple[`BaseModelOutputWithPooling`]): + The output of the [`OwlViTTextModel`]. + vision_model_output (`BaseModelOutputWithPooling`): + The output of the [`OwlViTVisionModel`]. + """ + + loss: Optional[torch.FloatTensor] = None + loss_dict: Optional[Dict] = None + logits: Optional[torch.FloatTensor] = None + pred_boxes: Optional[torch.FloatTensor] = None + text_embeds: Optional[torch.FloatTensor] = None + image_embeds: Optional[torch.FloatTensor] = None + class_embeds: Optional[torch.FloatTensor] = None + text_model_output: BaseModelOutputWithPooling = None + vision_model_output: BaseModelOutputWithPooling = None + + def to_tuple(self) -> Tuple[Any]: + return tuple( + self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() + for k in self.keys() + ) + + +@dataclass +class OwlViTImageGuidedObjectDetectionOutput(ModelOutput): + """ + Output type of [`OwlViTForObjectDetection.image_guided_detection`]. + + Args: + logits (`torch.FloatTensor` of shape `(batch_size, num_patches, num_queries)`): + Classification logits (including no-object) for all queries. + target_pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`): + Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These + values are normalized in [0, 1], relative to the size of each individual target image in the batch + (disregarding possible padding). You can use [`~OwlViTImageProcessor.post_process_object_detection`] to + retrieve the unnormalized bounding boxes. + query_pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`): + Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These + values are normalized in [0, 1], relative to the size of each individual query image in the batch + (disregarding possible padding). You can use [`~OwlViTImageProcessor.post_process_object_detection`] to + retrieve the unnormalized bounding boxes. + image_embeds (`torch.FloatTensor` of shape `(batch_size, patch_size, patch_size, output_dim`): + Pooled output of [`OwlViTVisionModel`]. OWL-ViT represents images as a set of image patches and computes + image embeddings for each patch. + query_image_embeds (`torch.FloatTensor` of shape `(batch_size, patch_size, patch_size, output_dim`): + Pooled output of [`OwlViTVisionModel`]. OWL-ViT represents images as a set of image patches and computes + image embeddings for each patch. + class_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`): + Class embeddings of all image patches. OWL-ViT represents images as a set of image patches where the total + number of patches is (image_size / patch_size)**2. + text_model_output (Tuple[`BaseModelOutputWithPooling`]): + The output of the [`OwlViTTextModel`]. + vision_model_output (`BaseModelOutputWithPooling`): + The output of the [`OwlViTVisionModel`]. + """ + + logits: Optional[torch.FloatTensor] = None + image_embeds: Optional[torch.FloatTensor] = None + query_image_embeds: Optional[torch.FloatTensor] = None + target_pred_boxes: Optional[torch.FloatTensor] = None + query_pred_boxes: Optional[torch.FloatTensor] = None + class_embeds: Optional[torch.FloatTensor] = None + text_model_output: BaseModelOutputWithPooling = None + vision_model_output: BaseModelOutputWithPooling = None + + def to_tuple(self) -> Tuple[Any]: + return tuple( + self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() + for k in self.keys() + ) + + +class OwlViTVisionEmbeddings(nn.Module): + def __init__(self, config: OwlViTVisionConfig): + super().__init__() + self.patch_size = config.patch_size + self.config = config + self.embed_dim = config.hidden_size + self.class_embedding = nn.Parameter(torch.randn(config.hidden_size)) + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=config.patch_size, + stride=config.patch_size, + bias=False, + ) + + self.num_patches = (config.image_size // config.patch_size) ** 2 + self.num_positions = self.num_patches + 1 + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False) + + # Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings.interpolate_pos_encoding + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing. + + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 + """ + + num_patches = embeddings.shape[1] - 1 + position_embedding = self.position_embedding.weight.unsqueeze(0) + num_positions = position_embedding.shape[1] - 1 + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: + return self.position_embedding(self.position_ids) + + class_pos_embed = position_embedding[:, :1] + patch_pos_embed = position_embedding[:, 1:] + + dim = embeddings.shape[-1] + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(new_height, new_width), + mode="bicubic", + align_corners=False, + ) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) + + def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False) -> torch.Tensor: + batch_size, _, height, width = pixel_values.shape + patch_embeds = self.patch_embedding(pixel_values) # shape = [batch_size, num_channels, height, width] + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + class_embeds = self.class_embedding.expand(batch_size, 1, -1) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embedding(self.position_ids) + return embeddings + + +class OwlViTTextEmbeddings(nn.Module): + def __init__(self, config: OwlViTTextConfig): + super().__init__() + self.token_embedding = nn.Embedding(config.vocab_size, config.hidden_size) + self.position_embedding = nn.Embedding(config.max_position_embeddings, config.hidden_size) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ) -> torch.Tensor: + seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + if inputs_embeds is None: + inputs_embeds = self.token_embedding(input_ids) + + position_embeddings = self.position_embedding(position_ids) + embeddings = inputs_embeds + position_embeddings + + return embeddings + + +class OwlViTAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + bsz, tgt_len, embed_dim = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scale + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + # apply the causal_attention_mask first + if causal_attention_mask is not None: + if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" + f" {causal_attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if output_attentions: + # this operation is a bit akward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + # For int8 compatibility, sometimes the `attn_probs` are in `fp32` + attn_probs = attn_probs.to(value_states.dtype) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped + + +# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->OwlViT +class OwlViTMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +# Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->OwlViT +class OwlViTEncoderLayer(nn.Module): + def __init__(self, config: OwlViTConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = OwlViTAttention(config) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = OwlViTMLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + causal_attention_mask: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + `(config.encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class OwlViTPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = OwlViTConfig + base_model_prefix = "owlvit" + supports_gradient_checkpointing = True + _no_split_modules = ["OwlViTEncoderLayer"] + + def _init_weights(self, module): + """Initialize the weights""" + factor = self.config.initializer_factor + if isinstance(module, OwlViTTextEmbeddings): + module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02) + module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02) + elif isinstance(module, OwlViTVisionEmbeddings): + factor = self.config.initializer_factor + nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor) + nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor) + nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor) + elif isinstance(module, OwlViTAttention): + factor = self.config.initializer_factor + in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor + out_proj_std = (module.embed_dim**-0.5) * factor + nn.init.normal_(module.q_proj.weight, std=in_proj_std) + nn.init.normal_(module.k_proj.weight, std=in_proj_std) + nn.init.normal_(module.v_proj.weight, std=in_proj_std) + nn.init.normal_(module.out_proj.weight, std=out_proj_std) + elif isinstance(module, OwlViTMLP): + factor = self.config.initializer_factor + in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor + fc_std = (2 * module.config.hidden_size) ** -0.5 * factor + nn.init.normal_(module.fc1.weight, std=fc_std) + nn.init.normal_(module.fc2.weight, std=in_proj_std) + elif isinstance(module, OwlViTModel): + nn.init.normal_( + module.text_projection.weight, + std=module.text_embed_dim**-0.5 * self.config.initializer_factor, + ) + nn.init.normal_( + module.visual_projection.weight, + std=module.vision_embed_dim**-0.5 * self.config.initializer_factor, + ) + if isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +OWLVIT_START_DOCSTRING = r""" + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`OwlViTConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +OWLVIT_TEXT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size * num_max_text_queries, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input + IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, num_max_text_queries, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +OWLVIT_VISION_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + interpolate_pos_encoding (`bool`, *optional*, defaults `False`): + Whether to interpolate the pre-trained position encodings. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +OWLVIT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input + IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. + return_loss (`bool`, *optional*): + Whether or not to return the contrastive loss. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + interpolate_pos_encoding (`bool`, *optional*, defaults `False`): + Whether to interpolate the pre-trained position encodings. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +OWLVIT_OBJECT_DETECTION_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. + input_ids (`torch.LongTensor` of shape `(batch_size * num_max_text_queries, sequence_length)`, *optional*): + Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input + IDs?](../glossary#input-ids). + attention_mask (`torch.Tensor` of shape `(batch_size, num_max_text_queries, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + output_hidden_states (`bool`, *optional*): + Whether or not to return the last hidden state. See `text_model_last_hidden_state` and + `vision_model_last_hidden_state` under returned tensors for more detail. + interpolate_pos_encoding (`bool`, *optional*, defaults `False`): + Whether to interpolate the pre-trained position encodings. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +OWLVIT_IMAGE_GUIDED_OBJECT_DETECTION_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. + query_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values of query image(s) to be detected. Pass in one query image per target image. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + interpolate_pos_encoding (`bool`, *optional*, defaults `False`): + Whether to interpolate the pre-trained position encodings. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class OwlViTEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`OwlViTEncoderLayer`]. + + Args: + config: OwlViTConfig + """ + + def __init__(self, config: OwlViTConfig): + super().__init__() + self.layers = nn.ModuleList([OwlViTEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`). + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Causal mask for the text model. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for encoder_layer in self.layers: + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + attention_mask, + causal_attention_mask, + output_attentions, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + causal_attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +class OwlViTTextTransformer(nn.Module): + def __init__(self, config: OwlViTTextConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + self.embeddings = OwlViTTextEmbeddings(config) + self.encoder = OwlViTEncoder(config) + self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + @add_start_docstrings_to_model_forward(OWLVIT_TEXT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=OwlViTTextConfig) + def forward( + self, + input_ids: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids) + + # num_samples, seq_len = input_shape where num_samples = batch_size * num_max_text_queries + # OWLVIT's text model uses causal mask, prepare it here. + # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 + causal_attention_mask = _create_4d_causal_attention_mask( + input_shape, hidden_states.dtype, device=hidden_states.device + ) + # expand attention_mask + if attention_mask is not None: + # [num_samples, seq_len] -> [num_samples, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + last_hidden_state = self.final_layer_norm(last_hidden_state) + + # take features from the end of tokens embedding (end of token is the highest number in each sequence) + # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 + pooled_output = last_hidden_state[ + torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), + input_ids.to(torch.int).argmax(dim=-1).to(last_hidden_state.device), + ] + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class OwlViTTextModel(OwlViTPreTrainedModel): + config_class = OwlViTTextConfig + + def __init__(self, config: OwlViTTextConfig): + super().__init__(config) + self.text_model = OwlViTTextTransformer(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.text_model.embeddings.token_embedding + + def set_input_embeddings(self, value): + self.text_model.embeddings.token_embedding = value + + @add_start_docstrings_to_model_forward(OWLVIT_TEXT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=OwlViTTextConfig) + def forward( + self, + input_ids: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + Examples: + ```python + >>> from transformers import AutoProcessor, OwlViTTextModel + + >>> model = OwlViTTextModel.from_pretrained("google/owlvit-base-patch32") + >>> processor = AutoProcessor.from_pretrained("google/owlvit-base-patch32") + >>> inputs = processor( + ... text=[["a photo of a cat", "a photo of a dog"], ["photo of a astranaut"]], return_tensors="pt" + ... ) + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled (EOS token) states + ```""" + + # Get embeddings for all text queries in all batch samples + return self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +class OwlViTVisionTransformer(nn.Module): + def __init__(self, config: OwlViTVisionConfig): + super().__init__() + self.config = config + + self.embeddings = OwlViTVisionEmbeddings(config) + self.pre_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.encoder = OwlViTEncoder(config) + self.post_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + @add_start_docstrings_to_model_forward(OWLVIT_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=OwlViTVisionConfig) + def forward( + self, + pixel_values: torch.FloatTensor, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = False, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Cast the input to the expected `dtype` + expected_input_dtype = self.embeddings.patch_embedding.weight.dtype + pixel_values = pixel_values.to(expected_input_dtype) + + hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) + hidden_states = self.pre_layernorm(hidden_states) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + pooled_output = last_hidden_state[:, 0, :] + + pooled_output = self.post_layernorm(pooled_output) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class OwlViTVisionModel(OwlViTPreTrainedModel): + config_class = OwlViTVisionConfig + main_input_name = "pixel_values" + + def __init__(self, config: OwlViTVisionConfig): + super().__init__(config) + self.vision_model = OwlViTVisionTransformer(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.vision_model.embeddings.patch_embedding + + @add_start_docstrings_to_model_forward(OWLVIT_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=OwlViTVisionConfig) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + Examples: + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, OwlViTVisionModel + + >>> model = OwlViTVisionModel.from_pretrained("google/owlvit-base-patch32") + >>> processor = AutoProcessor.from_pretrained("google/owlvit-base-patch32") + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled CLS states + ```""" + return self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + ) + + +@add_start_docstrings(OWLVIT_START_DOCSTRING) +class OwlViTModel(OwlViTPreTrainedModel): + config_class = OwlViTConfig + + def __init__(self, config: OwlViTConfig): + super().__init__(config) + + if not isinstance(config.text_config, OwlViTTextConfig): + raise TypeError( + "config.text_config is expected to be of type OwlViTTextConfig but is of type" + f" {type(config.text_config)}." + ) + + if not isinstance(config.vision_config, OwlViTVisionConfig): + raise TypeError( + "config.vision_config is expected to be of type OwlViTVisionConfig but is of type" + f" {type(config.vision_config)}." + ) + + text_config = config.text_config + vision_config = config.vision_config + + self.projection_dim = config.projection_dim + self.text_embed_dim = text_config.hidden_size + self.vision_embed_dim = vision_config.hidden_size + + self.text_model = OwlViTTextTransformer(text_config) + self.vision_model = OwlViTVisionTransformer(vision_config) + + self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False) + self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False) + self.logit_scale = nn.Parameter(torch.tensor(config.logit_scale_init_value)) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(OWLVIT_TEXT_INPUTS_DOCSTRING) + def get_text_features( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> torch.FloatTensor: + r""" + Returns: + text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by + applying the projection layer to the pooled output of [`OwlViTTextModel`]. + + Examples: + ```python + >>> from transformers import AutoProcessor, OwlViTModel + + >>> model = OwlViTModel.from_pretrained("google/owlvit-base-patch32") + >>> processor = AutoProcessor.from_pretrained("google/owlvit-base-patch32") + >>> inputs = processor( + ... text=[["a photo of a cat", "a photo of a dog"], ["photo of a astranaut"]], return_tensors="pt" + ... ) + >>> text_features = model.get_text_features(**inputs) + ```""" + # Use OWL-ViT model's config for some fields (if specified) instead of those of vision & text components. + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Get embeddings for all text queries in all batch samples + text_output = self.text_model(input_ids=input_ids, attention_mask=attention_mask, return_dict=return_dict) + pooled_output = text_output[1] + text_features = self.text_projection(pooled_output) + + return text_features + + @add_start_docstrings_to_model_forward(OWLVIT_VISION_INPUTS_DOCSTRING) + def get_image_features( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + return_dict: Optional[bool] = None, + ) -> torch.FloatTensor: + r""" + Returns: + image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by + applying the projection layer to the pooled output of [`OwlViTVisionModel`]. + + Examples: + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, OwlViTModel + + >>> model = OwlViTModel.from_pretrained("google/owlvit-base-patch32") + >>> processor = AutoProcessor.from_pretrained("google/owlvit-base-patch32") + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> inputs = processor(images=image, return_tensors="pt") + >>> image_features = model.get_image_features(**inputs) + ```""" + # Use OWL-ViT model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + ) + + pooled_output = vision_outputs[1] + image_features = self.visual_projection(pooled_output) + + return image_features + + @add_start_docstrings_to_model_forward(OWLVIT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=OwlViTOutput, config_class=OwlViTConfig) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + return_loss: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + return_base_image_embeds: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, OwlViTOutput]: + r""" + Returns: + + Examples: + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, OwlViTModel + + >>> model = OwlViTModel.from_pretrained("google/owlvit-base-patch32") + >>> processor = AutoProcessor.from_pretrained("google/owlvit-base-patch32") + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> inputs = processor(text=[["a photo of a cat", "a photo of a dog"]], images=image, return_tensors="pt") + >>> outputs = model(**inputs) + >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score + >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities + ```""" + # Use OWL-ViT model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + ) + + # Get embeddings for all text queries in all batch samples + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + text_embeds = text_outputs[1] + text_embeds = self.text_projection(text_embeds) + image_embeds = vision_outputs[1] + image_embeds = self.visual_projection(image_embeds) + + # normalized features + image_embeds = image_embeds / torch.linalg.norm(image_embeds, ord=2, dim=-1, keepdim=True) + text_embeds_norm = text_embeds / torch.linalg.norm(text_embeds, ord=2, dim=-1, keepdim=True) + + # cosine similarity as logits and set it on the correct device + logit_scale = self.logit_scale.exp().to(image_embeds.device) + + logits_per_text = torch.matmul(text_embeds_norm, image_embeds.t()) * logit_scale + logits_per_image = logits_per_text.t() + + loss = None + if return_loss: + loss = owlvit_loss(logits_per_text) + + text_embeds = text_embeds_norm + + if not return_dict: + output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs) + return ((loss,) + output) if loss is not None else output + + return OwlViTOutput( + loss=loss, + logits_per_image=logits_per_image, + logits_per_text=logits_per_text, + text_embeds=text_embeds, + image_embeds=image_embeds, + text_model_output=text_outputs, + vision_model_output=vision_outputs, + ) + + +class OwlViTBoxPredictionHead(nn.Module): + def __init__(self, config: OwlViTConfig, out_dim: int = 4): + super().__init__() + + width = config.vision_config.hidden_size + self.dense0 = nn.Linear(width, width) + self.dense1 = nn.Linear(width, width) + self.gelu = nn.GELU() + self.dense2 = nn.Linear(width, out_dim) + + def forward(self, image_features: torch.Tensor) -> torch.FloatTensor: + output = self.dense0(image_features) + output = self.gelu(output) + output = self.dense1(output) + output = self.gelu(output) + output = self.dense2(output) + return output + + +class OwlViTClassPredictionHead(nn.Module): + def __init__(self, config: OwlViTConfig): + super().__init__() + + out_dim = config.text_config.hidden_size + self.query_dim = config.vision_config.hidden_size + + self.dense0 = nn.Linear(self.query_dim, out_dim) + self.logit_shift = nn.Linear(self.query_dim, 1) + self.logit_scale = nn.Linear(self.query_dim, 1) + self.elu = nn.ELU() + + def forward( + self, + image_embeds: torch.FloatTensor, + query_embeds: Optional[torch.FloatTensor], + query_mask: Optional[torch.Tensor], + ) -> Tuple[torch.FloatTensor]: + image_class_embeds = self.dense0(image_embeds) + if query_embeds is None: + device = image_class_embeds.device + batch_size, num_patches = image_class_embeds.shape[:2] + pred_logits = torch.zeros((batch_size, num_patches, self.query_dim)).to(device) + return (pred_logits, image_class_embeds) + + # Normalize image and text features + image_class_embeds = image_class_embeds / (torch.linalg.norm(image_class_embeds, dim=-1, keepdim=True) + 1e-6) + query_embeds = query_embeds / (torch.linalg.norm(query_embeds, dim=-1, keepdim=True) + 1e-6) + + # Get class predictions + pred_logits = torch.einsum("...pd,...qd->...pq", image_class_embeds, query_embeds) + + # Apply a learnable shift and scale to logits + logit_shift = self.logit_shift(image_embeds) + logit_scale = self.logit_scale(image_embeds) + logit_scale = self.elu(logit_scale) + 1 + pred_logits = (pred_logits + logit_shift) * logit_scale + + if query_mask is not None: + if query_mask.ndim > 1: + query_mask = torch.unsqueeze(query_mask, dim=-2) + + pred_logits = torch.where(query_mask == 0, torch.finfo(pred_logits.dtype).min, pred_logits) + pred_logits = pred_logits.to(torch.float32) + + return (pred_logits, image_class_embeds) + + +class OwlViTForObjectDetection(OwlViTPreTrainedModel): + config_class = OwlViTConfig + + def __init__(self, config: OwlViTConfig): + super().__init__(config) + + self.owlvit = OwlViTModel(config) + self.class_head = OwlViTClassPredictionHead(config) + self.box_head = OwlViTBoxPredictionHead(config) + + self.layer_norm = nn.LayerNorm(config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps) + self.sigmoid = nn.Sigmoid() + self.config = config + self.num_patches_height = self.config.vision_config.image_size // self.config.vision_config.patch_size + self.num_patches_width = self.config.vision_config.image_size // self.config.vision_config.patch_size + self.box_bias = self.compute_box_bias(self.num_patches_height, self.num_patches_width) + + @staticmethod + def normalize_grid_corner_coordinates(num_patches_height: int, num_patches_width: int) -> torch.Tensor: + # Create grid coordinates using torch + x_coordinates = torch.arange(1, num_patches_width + 1, dtype=torch.float32) + y_coordinates = torch.arange(1, num_patches_height + 1, dtype=torch.float32) + xx, yy = torch.meshgrid(x_coordinates, y_coordinates, indexing="xy") + + # Stack the coordinates and divide by their respective patch counts + box_coordinates = torch.stack((xx, yy), dim=-1) + box_coordinates[..., 0] /= num_patches_width + box_coordinates[..., 1] /= num_patches_height + + # Flatten (h, w, 2) -> (h*w, 2) + box_coordinates = box_coordinates.view(-1, 2) + + return box_coordinates + + @lru_cache(maxsize=2) + def compute_box_bias( + self, num_patches_height: int, num_patches_width: int, feature_map: Optional[torch.FloatTensor] = None + ) -> torch.Tensor: + if feature_map is not None: + raise ValueError("feature_map has been deprecated as an input. Please pass in num_patches instead") + # The box center is biased to its position on the feature grid + box_coordinates = self.normalize_grid_corner_coordinates(num_patches_height, num_patches_width) + box_coordinates = torch.clip(box_coordinates, 0.0, 1.0) + + # Unnormalize xy + box_coord_bias = torch.log(box_coordinates + 1e-4) - torch.log1p(-box_coordinates + 1e-4) + + # The box size is biased to the patch size + box_size = torch.full_like(box_coord_bias, 1.0) + box_size[..., 0] /= num_patches_width + box_size[..., 1] /= num_patches_height + box_size_bias = torch.log(box_size + 1e-4) - torch.log1p(-box_size + 1e-4) + + # Compute box bias + box_bias = torch.cat([box_coord_bias, box_size_bias], dim=-1) + return box_bias + + def box_predictor( + self, + image_feats: torch.FloatTensor, + feature_map: torch.FloatTensor, + interpolate_pos_encoding: bool = False, + ) -> torch.FloatTensor: + """ + Args: + image_feats: + Features extracted from the image, returned by the `image_text_embedder` method. + feature_map: + A spatial re-arrangement of image_features, also returned by the `image_text_embedder` method. + interpolate_pos_encoding: + Whether to interpolate the pre-trained position encodings. + Returns: + pred_boxes: + List of predicted boxes (cxcywh normalized to 0, 1) nested within a dictionary. + """ + # Bounding box detection head [batch_size, num_boxes, 4]. + pred_boxes = self.box_head(image_feats) + + # Compute the location of each token on the grid and use it to compute a bias for the bbox prediction + if interpolate_pos_encoding: + _, num_patches_height, num_patches_width, _ = feature_map.shape + box_bias = self.compute_box_bias(num_patches_height, num_patches_width) + else: + box_bias = self.box_bias + + box_bias = box_bias.to(feature_map.device) + pred_boxes += box_bias + pred_boxes = self.sigmoid(pred_boxes) + return pred_boxes + + def class_predictor( + self, + image_feats: torch.FloatTensor, + query_embeds: Optional[torch.FloatTensor] = None, + query_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.FloatTensor]: + """ + Args: + image_feats: + Features extracted from the `image_text_embedder`. + query_embeds: + Text query embeddings. + query_mask: + Must be provided with query_embeddings. A mask indicating which query embeddings are valid. + """ + (pred_logits, image_class_embeds) = self.class_head(image_feats, query_embeds, query_mask) + + return (pred_logits, image_class_embeds) + + def image_text_embedder( + self, + input_ids: torch.Tensor, + pixel_values: torch.FloatTensor, + attention_mask: torch.Tensor, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + ) -> Tuple[torch.FloatTensor]: + # Encode text and image + outputs = self.owlvit( + pixel_values=pixel_values, + input_ids=input_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=True, + ) + + if interpolate_pos_encoding: + _, _, height, width = pixel_values.shape + num_patches_height = height // self.config.vision_config.patch_size + num_patches_width = width // self.config.vision_config.patch_size + else: + num_patches_height = self.num_patches_height + num_patches_width = self.num_patches_width + + # Get image embeddings + last_hidden_state = outputs.vision_model_output[0] + image_embeds = self.owlvit.vision_model.post_layernorm(last_hidden_state) + + # Resize class token + class_token_out = torch.broadcast_to(image_embeds[:, :1, :], image_embeds[:, :-1].shape) + + # Merge image embedding with class tokens + image_embeds = image_embeds[:, 1:, :] * class_token_out + image_embeds = self.layer_norm(image_embeds) + + # Resize to [batch_size, num_patches_height, num_patches_width, hidden_size] + new_size = ( + image_embeds.shape[0], + num_patches_height, + num_patches_width, + image_embeds.shape[-1], + ) + image_embeds = image_embeds.reshape(new_size) + text_embeds = outputs[-4] + + return (text_embeds, image_embeds, outputs) + + def image_embedder( + self, + pixel_values: torch.FloatTensor, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + ) -> Tuple[torch.FloatTensor]: + # Get OwlViTModel vision embeddings (same as CLIP) + vision_outputs = self.owlvit.vision_model( + pixel_values=pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, return_dict=True + ) + + if interpolate_pos_encoding: + _, _, height, width = pixel_values.shape + num_patches_height = height // self.config.vision_config.patch_size + num_patches_width = width // self.config.vision_config.patch_size + else: + num_patches_height = self.num_patches_height + num_patches_width = self.num_patches_width + + # Apply post_layernorm to last_hidden_state, return non-projected output + last_hidden_state = vision_outputs[0] + image_embeds = self.owlvit.vision_model.post_layernorm(last_hidden_state) + + # Resize class token + class_token_out = torch.broadcast_to(image_embeds[:, :1, :], image_embeds[:, :-1].shape) + + # Merge image embedding with class tokens + image_embeds = image_embeds[:, 1:, :] * class_token_out + image_embeds = self.layer_norm(image_embeds) + + # Resize to [batch_size, num_patches_height, num_patches_width, hidden_size] + new_size = ( + image_embeds.shape[0], + num_patches_height, + num_patches_width, + image_embeds.shape[-1], + ) + image_embeds = image_embeds.reshape(new_size) + + return (image_embeds, vision_outputs) + + def embed_image_query( + self, + query_image_features: torch.FloatTensor, + query_feature_map: torch.FloatTensor, + interpolate_pos_encoding: bool = False, + ) -> torch.FloatTensor: + _, class_embeds = self.class_predictor(query_image_features) + pred_boxes = self.box_predictor(query_image_features, query_feature_map, interpolate_pos_encoding) + pred_boxes_as_corners = center_to_corners_format(pred_boxes) + + # Loop over query images + best_class_embeds = [] + best_box_indices = [] + pred_boxes_device = pred_boxes_as_corners.device + + for i in range(query_image_features.shape[0]): + each_query_box = torch.tensor([[0, 0, 1, 1]], device=pred_boxes_device) + each_query_pred_boxes = pred_boxes_as_corners[i] + ious, _ = box_iou(each_query_box, each_query_pred_boxes) + + # If there are no overlapping boxes, fall back to generalized IoU + if torch.all(ious[0] == 0.0): + ious = generalized_box_iou(each_query_box, each_query_pred_boxes) + + # Use an adaptive threshold to include all boxes within 80% of the best IoU + iou_threshold = torch.max(ious) * 0.8 + + selected_inds = (ious[0] >= iou_threshold).nonzero() + if selected_inds.numel(): + selected_embeddings = class_embeds[i][selected_inds.squeeze(1)] + mean_embeds = torch.mean(class_embeds[i], axis=0) + mean_sim = torch.einsum("d,id->i", mean_embeds, selected_embeddings) + best_box_ind = selected_inds[torch.argmin(mean_sim)] + best_class_embeds.append(class_embeds[i][best_box_ind]) + best_box_indices.append(best_box_ind) + + if best_class_embeds: + query_embeds = torch.stack(best_class_embeds) + box_indices = torch.stack(best_box_indices) + else: + query_embeds, box_indices = None, None + + return query_embeds, box_indices, pred_boxes + + @add_start_docstrings_to_model_forward(OWLVIT_IMAGE_GUIDED_OBJECT_DETECTION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=OwlViTImageGuidedObjectDetectionOutput, config_class=OwlViTConfig) + def image_guided_detection( + self, + pixel_values: torch.FloatTensor, + query_pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + return_dict: Optional[bool] = None, + ) -> OwlViTImageGuidedObjectDetectionOutput: + r""" + Returns: + + Examples: + ```python + >>> import requests + >>> from PIL import Image + >>> import torch + >>> from transformers import AutoProcessor, OwlViTForObjectDetection + + >>> processor = AutoProcessor.from_pretrained("google/owlvit-base-patch16") + >>> model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch16") + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> query_url = "http://images.cocodataset.org/val2017/000000001675.jpg" + >>> query_image = Image.open(requests.get(query_url, stream=True).raw) + >>> inputs = processor(images=image, query_images=query_image, return_tensors="pt") + >>> with torch.no_grad(): + ... outputs = model.image_guided_detection(**inputs) + >>> # Target image sizes (height, width) to rescale box predictions [batch_size, 2] + >>> target_sizes = torch.Tensor([image.size[::-1]]) + >>> # Convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax) + >>> results = processor.post_process_image_guided_detection( + ... outputs=outputs, threshold=0.6, nms_threshold=0.3, target_sizes=target_sizes + ... ) + >>> i = 0 # Retrieve predictions for the first image + >>> boxes, scores = results[i]["boxes"], results[i]["scores"] + >>> for box, score in zip(boxes, scores): + ... box = [round(i, 2) for i in box.tolist()] + ... print(f"Detected similar object with confidence {round(score.item(), 3)} at location {box}") + Detected similar object with confidence 0.856 at location [10.94, 50.4, 315.8, 471.39] + Detected similar object with confidence 1.0 at location [334.84, 25.33, 636.16, 374.71] + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + # Compute feature maps for the input and query images + query_feature_map = self.image_embedder( + pixel_values=query_pixel_values, interpolate_pos_encoding=interpolate_pos_encoding + )[0] + feature_map, vision_outputs = self.image_embedder( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + ) + + batch_size, num_patches_height, num_patches_width, hidden_dim = feature_map.shape + image_feats = torch.reshape(feature_map, (batch_size, num_patches_height * num_patches_width, hidden_dim)) + + batch_size, num_patches_height, num_patches_width, hidden_dim = query_feature_map.shape + query_image_feats = torch.reshape( + query_feature_map, (batch_size, num_patches_height * num_patches_width, hidden_dim) + ) + # Get top class embedding and best box index for each query image in batch + query_embeds, best_box_indices, query_pred_boxes = self.embed_image_query( + query_image_feats, query_feature_map, interpolate_pos_encoding + ) + + # Predict object classes [batch_size, num_patches, num_queries+1] + (pred_logits, class_embeds) = self.class_predictor(image_feats=image_feats, query_embeds=query_embeds) + + # Predict object boxes + target_pred_boxes = self.box_predictor(image_feats, feature_map, interpolate_pos_encoding) + + if not return_dict: + output = ( + feature_map, + query_feature_map, + target_pred_boxes, + query_pred_boxes, + pred_logits, + class_embeds, + vision_outputs.to_tuple(), + ) + output = tuple(x for x in output if x is not None) + return output + + return OwlViTImageGuidedObjectDetectionOutput( + image_embeds=feature_map, + query_image_embeds=query_feature_map, + target_pred_boxes=target_pred_boxes, + query_pred_boxes=query_pred_boxes, + logits=pred_logits, + class_embeds=class_embeds, + text_model_output=None, + vision_model_output=vision_outputs, + ) + + @add_start_docstrings_to_model_forward(OWLVIT_OBJECT_DETECTION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=OwlViTObjectDetectionOutput, config_class=OwlViTConfig) + def forward( + self, + input_ids: torch.Tensor, + pixel_values: torch.FloatTensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + return_dict: Optional[bool] = None, + ) -> OwlViTObjectDetectionOutput: + r""" + Returns: + + Examples: + ```python + >>> import requests + >>> from PIL import Image + >>> import torch + + >>> from transformers import OwlViTProcessor, OwlViTForObjectDetection + + >>> processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32") + >>> model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> text_labels = [["a photo of a cat", "a photo of a dog"]] + >>> inputs = processor(text=text_labels, images=image, return_tensors="pt") + >>> outputs = model(**inputs) + + >>> # Target image sizes (height, width) to rescale box predictions [batch_size, 2] + >>> target_sizes = torch.tensor([(image.height, image.width)]) + >>> # Convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax) + >>> results = processor.post_process_grounded_object_detection( + ... outputs=outputs, target_sizes=target_sizes, threshold=0.1, text_labels=text_labels + ... ) + >>> # Retrieve predictions for the first image for the corresponding text queries + >>> result = results[0] + >>> boxes, scores, text_labels = result["boxes"], result["scores"], result["text_labels"] + >>> for box, score, text_label in zip(boxes, scores, text_labels): + ... box = [round(i, 2) for i in box.tolist()] + ... print(f"Detected {text_label} with confidence {round(score.item(), 3)} at location {box}") + Detected a photo of a cat with confidence 0.707 at location [324.97, 20.44, 640.58, 373.29] + Detected a photo of a cat with confidence 0.717 at location [1.46, 55.26, 315.55, 472.17] + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + # Embed images and text queries + query_embeds, feature_map, outputs = self.image_text_embedder( + input_ids=input_ids, + pixel_values=pixel_values, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + ) + + # Text and vision model outputs + text_outputs = outputs.text_model_output + vision_outputs = outputs.vision_model_output + + batch_size, num_patches_height, num_patches_width, hidden_dim = feature_map.shape + image_feats = torch.reshape(feature_map, (batch_size, num_patches_height * num_patches_width, hidden_dim)) + + # Reshape from [batch_size * max_text_queries, hidden_dim] -> [batch_size, max_text_queries, hidden_dim] + max_text_queries = input_ids.shape[0] // batch_size + query_embeds = query_embeds.reshape(batch_size, max_text_queries, query_embeds.shape[-1]) + + # If first token is 0, then this is a padded query [batch_size, num_queries]. + input_ids = input_ids.reshape(batch_size, max_text_queries, input_ids.shape[-1]) + query_mask = input_ids[..., 0] > 0 + + # Predict object classes [batch_size, num_patches, num_queries+1] + (pred_logits, class_embeds) = self.class_predictor(image_feats, query_embeds, query_mask) + + # Predict object boxes + pred_boxes = self.box_predictor(image_feats, feature_map, interpolate_pos_encoding) + + if not return_dict: + output = ( + pred_logits, + pred_boxes, + query_embeds, + feature_map, + class_embeds, + text_outputs.to_tuple(), + vision_outputs.to_tuple(), + ) + output = tuple(x for x in output if x is not None) + return output + + return OwlViTObjectDetectionOutput( + image_embeds=feature_map, + text_embeds=query_embeds, + pred_boxes=pred_boxes, + logits=pred_logits, + class_embeds=class_embeds, + text_model_output=text_outputs, + vision_model_output=vision_outputs, + ) + + +__all__ = ["OwlViTModel", "OwlViTPreTrainedModel", "OwlViTTextModel", "OwlViTVisionModel", "OwlViTForObjectDetection"] diff --git a/docs/transformers/src/transformers/models/owlvit/processing_owlvit.py b/docs/transformers/src/transformers/models/owlvit/processing_owlvit.py new file mode 100644 index 0000000000000000000000000000000000000000..859e28bfcc8b335a61f7dd78aa486b368f2f7f8e --- /dev/null +++ b/docs/transformers/src/transformers/models/owlvit/processing_owlvit.py @@ -0,0 +1,352 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Image/Text processor class for OWL-ViT +""" + +import warnings +from typing import TYPE_CHECKING, List, Optional, Tuple, Union + +import numpy as np + +from ...image_processing_utils import BatchFeature +from ...image_utils import ImageInput +from ...processing_utils import ( + ImagesKwargs, + ProcessingKwargs, + ProcessorMixin, + Unpack, + _validate_images_text_input_order, +) +from ...tokenization_utils_base import PreTokenizedInput, TextInput +from ...utils import TensorType, is_flax_available, is_tf_available, is_torch_available + + +if TYPE_CHECKING: + from .modeling_owlvit import OwlViTImageGuidedObjectDetectionOutput, OwlViTObjectDetectionOutput + + +class OwlViTImagesKwargs(ImagesKwargs, total=False): + query_images: Optional[ImageInput] + + +class OwlViTProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs: OwlViTImagesKwargs + _defaults = { + "text_kwargs": { + "padding": "max_length", + }, + "images_kwargs": {}, + "common_kwargs": { + "return_tensors": "np", + }, + } + + +class OwlViTProcessor(ProcessorMixin): + r""" + Constructs an OWL-ViT processor which wraps [`OwlViTImageProcessor`] and [`CLIPTokenizer`]/[`CLIPTokenizerFast`] + into a single processor that interits both the image processor and tokenizer functionalities. See the + [`~OwlViTProcessor.__call__`] and [`~OwlViTProcessor.decode`] for more information. + + Args: + image_processor ([`OwlViTImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`CLIPTokenizer`, `CLIPTokenizerFast`], *optional*): + The tokenizer is a required input. + """ + + attributes = ["image_processor", "tokenizer"] + image_processor_class = "OwlViTImageProcessor" + tokenizer_class = ("CLIPTokenizer", "CLIPTokenizerFast") + # For backward compatibility. See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details. + optional_call_args = ["query_images"] + + def __init__(self, image_processor=None, tokenizer=None, **kwargs): + feature_extractor = None + if "feature_extractor" in kwargs: + warnings.warn( + "The `feature_extractor` argument is deprecated and will be removed in v5, use `image_processor`" + " instead.", + FutureWarning, + ) + feature_extractor = kwargs.pop("feature_extractor") + + image_processor = image_processor if image_processor is not None else feature_extractor + if image_processor is None: + raise ValueError("You need to specify an `image_processor`.") + if tokenizer is None: + raise ValueError("You need to specify a `tokenizer`.") + + super().__init__(image_processor, tokenizer) + + def __call__( + self, + images: Optional[ImageInput] = None, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + # The following is to capture `query_images` argument that may be passed as a positional argument. + # See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details, + # or this conversation for more context: https://github.com/huggingface/transformers/pull/32544#discussion_r1720208116 + # This behavior is only needed for backward compatibility and will be removed in future versions. + # + *args, + audio=None, + videos=None, + **kwargs: Unpack[OwlViTProcessorKwargs], + ) -> BatchFeature: + """ + Main method to prepare for the model one or several text(s) and image(s). This method forwards the `text` and + `kwargs` arguments to CLIPTokenizerFast's [`~CLIPTokenizerFast.__call__`] if `text` is not `None` to encode: + the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to + CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the docstring + of the above two methods for more information. + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, + `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + query_images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The query image to be prepared, one query image is expected per target image to be queried. Each image + can be a PIL image, NumPy array or PyTorch tensor. In case of a NumPy array/PyTorch tensor, each image + should be of shape (C, H, W), where C is a number of channels, H and W are image height and width. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + - **query_pixel_values** -- Pixel values of the query images to be fed to a model. Returned when `query_images` is not `None`. + """ + output_kwargs = self._merge_kwargs( + OwlViTProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + **self.prepare_and_validate_optional_call_args(*args), + ) + query_images = output_kwargs["images_kwargs"].pop("query_images", None) + return_tensors = output_kwargs["common_kwargs"]["return_tensors"] + + if text is None and query_images is None and images is None: + raise ValueError( + "You have to specify at least one text or query image or image. All three cannot be none." + ) + # check if images and text inputs are reversed for BC + images, text = _validate_images_text_input_order(images, text) + + data = {} + if text is not None: + if isinstance(text, str) or (isinstance(text, List) and not isinstance(text[0], List)): + encodings = [self.tokenizer(text, **output_kwargs["text_kwargs"])] + + elif isinstance(text, List) and isinstance(text[0], List): + encodings = [] + + # Maximum number of queries across batch + max_num_queries = max([len(text_single) for text_single in text]) + + # Pad all batch samples to max number of text queries + for text_single in text: + if len(text_single) != max_num_queries: + text_single = text_single + [" "] * (max_num_queries - len(text_single)) + + encoding = self.tokenizer(text_single, **output_kwargs["text_kwargs"]) + encodings.append(encoding) + else: + raise TypeError("Input text should be a string, a list of strings or a nested list of strings") + + if return_tensors == "np": + input_ids = np.concatenate([encoding["input_ids"] for encoding in encodings], axis=0) + attention_mask = np.concatenate([encoding["attention_mask"] for encoding in encodings], axis=0) + + elif return_tensors == "jax" and is_flax_available(): + import jax.numpy as jnp + + input_ids = jnp.concatenate([encoding["input_ids"] for encoding in encodings], axis=0) + attention_mask = jnp.concatenate([encoding["attention_mask"] for encoding in encodings], axis=0) + + elif return_tensors == "pt" and is_torch_available(): + import torch + + input_ids = torch.cat([encoding["input_ids"] for encoding in encodings], dim=0) + attention_mask = torch.cat([encoding["attention_mask"] for encoding in encodings], dim=0) + + elif return_tensors == "tf" and is_tf_available(): + import tensorflow as tf + + input_ids = tf.stack([encoding["input_ids"] for encoding in encodings], axis=0) + attention_mask = tf.stack([encoding["attention_mask"] for encoding in encodings], axis=0) + + else: + raise ValueError("Target return tensor type could not be returned") + + data["input_ids"] = input_ids + data["attention_mask"] = attention_mask + + if query_images is not None: + query_pixel_values = self.image_processor(query_images, **output_kwargs["images_kwargs"]).pixel_values + # Query images always override the text prompt + data = {"query_pixel_values": query_pixel_values} + + if images is not None: + image_features = self.image_processor(images, **output_kwargs["images_kwargs"]) + data["pixel_values"] = image_features.pixel_values + + return BatchFeature(data=data, tensor_type=return_tensors) + + def post_process(self, *args, **kwargs): + """ + This method forwards all its arguments to [`OwlViTImageProcessor.post_process`]. Please refer to the docstring + of this method for more information. + """ + return self.image_processor.post_process(*args, **kwargs) + + def post_process_object_detection(self, *args, **kwargs): + """ + This method forwards all its arguments to [`OwlViTImageProcessor.post_process_object_detection`]. Please refer + to the docstring of this method for more information. + """ + warnings.warn( + "`post_process_object_detection` method is deprecated for OwlVitProcessor and will be removed in v5. " + "Use `post_process_grounded_object_detection` instead.", + FutureWarning, + ) + return self.image_processor.post_process_object_detection(*args, **kwargs) + + def post_process_grounded_object_detection( + self, + outputs: "OwlViTObjectDetectionOutput", + threshold: float = 0.1, + target_sizes: Optional[Union[TensorType, List[Tuple]]] = None, + text_labels: Optional[List[List[str]]] = None, + ): + """ + Converts the raw output of [`OwlViTForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y, + bottom_right_x, bottom_right_y) format. + + Args: + outputs ([`OwlViTObjectDetectionOutput`]): + Raw outputs of the model. + threshold (`float`, *optional*, defaults to 0.1): + Score threshold to keep object detection predictions. + target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*): + Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size + `(height, width)` of each image in the batch. If unset, predictions will not be resized. + text_labels (`List[List[str]]`, *optional*): + List of lists of text labels for each image in the batch. If unset, "text_labels" in output will be + set to `None`. + + Returns: + `List[Dict]`: A list of dictionaries, each dictionary containing the following keys: + - "scores": The confidence scores for each predicted box on the image. + - "labels": Indexes of the classes predicted by the model on the image. + - "boxes": Image bounding boxes in (top_left_x, top_left_y, bottom_right_x, bottom_right_y) format. + - "text_labels": The text labels for each predicted bounding box on the image. + """ + output = self.image_processor.post_process_object_detection( + outputs=outputs, threshold=threshold, target_sizes=target_sizes + ) + + if text_labels is not None and len(text_labels) != len(output): + raise ValueError("Make sure that you pass in as many lists of text labels as images") + + # adding text labels to the output + if text_labels is not None: + for image_output, image_text_labels in zip(output, text_labels): + object_text_labels = [image_text_labels[i] for i in image_output["labels"]] + image_output["text_labels"] = object_text_labels + else: + for image_output in output: + image_output["text_labels"] = None + + return output + + def post_process_image_guided_detection( + self, + outputs: "OwlViTImageGuidedObjectDetectionOutput", + threshold: float = 0.0, + nms_threshold: float = 0.3, + target_sizes: Optional[Union[TensorType, List[Tuple]]] = None, + ): + """ + Converts the output of [`OwlViTForObjectDetection.image_guided_detection`] into the format expected by the COCO + api. + + Args: + outputs ([`OwlViTImageGuidedObjectDetectionOutput`]): + Raw outputs of the model. + threshold (`float`, *optional*, defaults to 0.0): + Minimum confidence threshold to use to filter out predicted boxes. + nms_threshold (`float`, *optional*, defaults to 0.3): + IoU threshold for non-maximum suppression of overlapping boxes. + target_sizes (`torch.Tensor`, *optional*): + Tensor of shape (batch_size, 2) where each entry is the (height, width) of the corresponding image in + the batch. If set, predicted normalized bounding boxes are rescaled to the target sizes. If left to + None, predictions will not be unnormalized. + + Returns: + `List[Dict]`: A list of dictionaries, each dictionary containing the following keys: + - "scores": The confidence scores for each predicted box on the image. + - "boxes": Image bounding boxes in (top_left_x, top_left_y, bottom_right_x, bottom_right_y) format. + - "labels": Set to `None`. + """ + return self.image_processor.post_process_image_guided_detection( + outputs=outputs, threshold=threshold, nms_threshold=nms_threshold, target_sizes=target_sizes + ) + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to CLIPTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to CLIPTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + @property + def feature_extractor_class(self): + warnings.warn( + "`feature_extractor_class` is deprecated and will be removed in v5. Use `image_processor_class` instead.", + FutureWarning, + ) + return self.image_processor_class + + @property + def feature_extractor(self): + warnings.warn( + "`feature_extractor` is deprecated and will be removed in v5. Use `image_processor` instead.", + FutureWarning, + ) + return self.image_processor + + +__all__ = ["OwlViTProcessor"] diff --git a/docs/transformers/src/transformers/models/paligemma/__init__.py b/docs/transformers/src/transformers/models/paligemma/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9048afe6adbdc0ad36007e02f60e899cae677c55 --- /dev/null +++ b/docs/transformers/src/transformers/models/paligemma/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_paligemma import * + from .modeling_paligemma import * + from .processing_paligemma import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/paligemma/configuration_paligemma.py b/docs/transformers/src/transformers/models/paligemma/configuration_paligemma.py new file mode 100644 index 0000000000000000000000000000000000000000..4551b85bcd54811b0bc1ee44cfb675fb27553bfa --- /dev/null +++ b/docs/transformers/src/transformers/models/paligemma/configuration_paligemma.py @@ -0,0 +1,154 @@ +# coding=utf-8 +# Copyright 2024 Microsoft Research & University of Wisconsin-Madison and the HuggingFace Inc. team. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PaliGemmamodel configuration""" + +import warnings + +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ..auto import CONFIG_MAPPING, AutoConfig + + +logger = logging.get_logger(__name__) + + +class PaliGemmaConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`PaliGemmaForConditionalGeneration`]. It is used to instantiate an + PaliGemmamodel according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the PaliGemma-2B. + + e.g. [paligemma-hf/paligemma-2b](https://huggingface.co/paligemma-hf/paligemma-2b) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vision_config (`PaliGemmaVisionConfig`, *optional*): + Custom vision config or dict + text_config (`Union[AutoConfig, dict]`, *optional*): + The config object of the text backbone. Can be any of `LlamaConfig` or `MistralConfig`. + ignore_index (`int`, *optional*, defaults to -100): + The ignore index for the loss function. + image_token_index (`int`, *optional*, defaults to 256000): + The image token index to encode the image prompt. + vocab_size (`int`, *optional*, defaults to 257152): + Vocabulary size of the PaliGemmamodel. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`~PaliGemmaForConditionalGeneration`] + projection_dim (`int`, *optional*, defaults to 2048): + Dimension of the multimodal projection space. + hidden_size (`int`, *optional*, defaults to 2048): + Dimension of the hidden layer of the Language model. + + Example: + + ```python + >>> from transformers import PaliGemmaForConditionalGeneration, PaliGemmaConfig, SiglipVisionConfig, GemmaConfig + + >>> # Initializing a Siglip-like vision config + >>> vision_config = SiglipVisionConfig() + + >>> # Initializing a PaliGemma config + >>> text_config = GemmaConfig() + + >>> # Initializing a PaliGemma paligemma-3b-224 style configuration + >>> configuration = PaliGemmaConfig(vision_config, text_config) + + >>> # Initializing a model from the paligemma-3b-224 style configuration + >>> model = PaliGemmaForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "paligemma" + attribute_map = { + "image_token_id": "image_token_index", + } + sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig} + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vision_config=None, + text_config=None, + ignore_index=-100, + image_token_index=256000, + vocab_size=257152, + projection_dim=2048, + hidden_size=2048, + **kwargs, + ): + self._ignore_index = ignore_index + self.image_token_index = image_token_index + self._vocab_size = vocab_size + self.projection_dim = projection_dim + self.hidden_size = hidden_size + self.vision_config = vision_config + self.is_encoder_decoder = False + + if isinstance(self.vision_config, dict): + vision_config["model_type"] = ( + vision_config["model_type"] if "model_type" in vision_config else "siglip_vision_model" + ) + self.vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config) + elif vision_config is None: + self.vision_config = CONFIG_MAPPING["siglip_vision_model"]( + intermediate_size=4096, + hidden_size=1152, + patch_size=14, + image_size=224, + num_hidden_layers=27, + num_attention_heads=16, + vocab_size=257152, + vision_use_head=False, + ) + + self.text_config = text_config + if isinstance(self.text_config, dict): + text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "gemma" + self.text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) + elif text_config is None: + self.text_config = CONFIG_MAPPING["gemma"]( + hidden_size=2048, + num_hidden_layers=18, + intermediate_size=16384, + num_attention_heads=8, + num_key_value_heads=1, + is_encoder_decoder=False, + vocab_size=vocab_size, + ) + self.text_config.num_image_tokens = (self.vision_config.image_size // self.vision_config.patch_size) ** 2 + self.vision_config.projection_dim = projection_dim + super().__init__(**kwargs) + + @property + def ignore_index(self): + warnings.warn( + "The `ignore_index` attribute is deprecated and will be removed in v4.47.", + FutureWarning, + ) + return self._ignore_index + + @ignore_index.setter + def ignore_index(self, value): + self._ignore_index = value + + def to_dict(self): + output = super().to_dict() + output.pop("_ignore_index", None) + return output + + +__all__ = ["PaliGemmaConfig"] diff --git a/docs/transformers/src/transformers/models/paligemma/convert_paligemma2_weights_to_hf.py b/docs/transformers/src/transformers/models/paligemma/convert_paligemma2_weights_to_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..3334e6f28fc32b4dd1d54dff35f4091e24a01602 --- /dev/null +++ b/docs/transformers/src/transformers/models/paligemma/convert_paligemma2_weights_to_hf.py @@ -0,0 +1,415 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert PaliGemma2 checkpoints from the original repository.""" + +import argparse +import collections + +import jax.numpy as jnp +import ml_dtypes +import numpy as np +import torch + +from transformers import ( + AutoTokenizer, + Gemma2Config, + PaliGemmaConfig, + PaliGemmaForConditionalGeneration, + PaliGemmaProcessor, + SiglipImageProcessor, +) +from transformers.tokenization_utils_base import AddedToken +from transformers.utils import logging + + +device = "cpu" + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + +# TODO add sequence length variations here + +PALIGEMMA2_VARIANTS = ["2b-224", "2b-448", "2b-896", "9b-224", "9b-448", "9b-896", "27b-224", "27b-448", "27b-896"] +VARIANT_CONFIGS = { + "2b": { + "num_positions": 256, + "hidden_size": 2304, + "num_hidden_layers": 26, + "intermediate_size": 9216, + "num_key_value_heads": 4, + "num_attention_heads": 8, + "head_dim": 256, + "query_pre_attn_scalar": 256, + }, + "9b": { + "num_positions": 1024, + "hidden_size": 3584, + "num_hidden_layers": 42, + "intermediate_size": 14336, + "num_key_value_heads": 8, + "num_attention_heads": 16, + "head_dim": 256, + "query_pre_attn_scalar": 256, + }, + "27b": { + "num_positions": 4096, + "hidden_size": 4608, + "num_hidden_layers": 46, + "intermediate_size": 36864, + "num_key_value_heads": 16, + "num_attention_heads": 32, + "head_dim": 128, + "query_pre_attn_scalar": 4608 // 32, # scaling is different for the 28b + }, +} + +DTYPES = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16} + + +def get_paligemma2_config(variant: str, precision: str): + config = { + "image_token_id": None, + "pad_token_id": 0, + "bos_token_id": 2, + "eos_token_id": 1, + } + base_variant = variant.split("-")[0] + + if variant in PALIGEMMA2_VARIANTS: + image_size = int(variant.split("-")[1]) + variant_config = VARIANT_CONFIGS[base_variant] + patch_size = 14 + num_image_tokens = (image_size**2) // (patch_size**2) + config["projection_dim"] = variant_config["hidden_size"] + config["image_token_id"] = 257152 + config["num_hidden_layers"] = variant_config["num_hidden_layers"] # For generate + text_config = Gemma2Config.from_pretrained("google/gemma-2-2b-it").to_dict() + sup_text_config = { + "model_type": "gemma2", + "vocab_size": 257152, + "num_hidden_layers": variant_config["num_hidden_layers"], + "num_key_value_heads": variant_config["num_key_value_heads"], + "head_dim": variant_config["head_dim"], + "torch_dtype": precision, + "hidden_size": variant_config["hidden_size"], + "hidden_activation": "gelu_pytorch_tanh", + "num_attention_heads": variant_config["num_attention_heads"], + "intermediate_size": variant_config["intermediate_size"], + "is_encoder_decoder": False, + "query_pre_attn_scalar": variant_config["query_pre_attn_scalar"], + } + text_config.update(sup_text_config) + + vision_config = { + "num_positions": variant_config["num_positions"], # not useful, to remove + "torch_dtype": precision, + "image_size": image_size, + "patch_size": patch_size, + "num_image_tokens": num_image_tokens, + "hidden_size": 1152, + "intermediate_size": 4304, + "num_hidden_layers": 27, + "num_attention_heads": 16, + "projection_dim": variant_config["hidden_size"], + "hidden_act": "gelu_pytorch_tanh", + "vision_use_head": False, + } + final_config = PaliGemmaConfig(text_config=text_config, vision_config=vision_config, **config) + else: + raise ValueError(f"Identifier {variant} not supported. Available: {PALIGEMMA2_VARIANTS}") + return final_config + + +def slice_state_dict(state_dict, config): + # fmt: off + # patch embeddings + state_dict["vision_tower.vision_model.embeddings.patch_embedding.weight"] = state_dict.pop("img/embedding/kernel").transpose( + 3, 2, 0, 1 + ) + state_dict["vision_tower.vision_model.embeddings.patch_embedding.bias"] = state_dict.pop("img/embedding/bias") + # positional embeddings + state_dict["vision_tower.vision_model.embeddings.position_embedding.weight"] = state_dict.pop("img/pos_embedding").reshape( + -1, config.vision_config.hidden_size + ) + + # extract vision layers to be sliced at index 0. There are 27 layers in the base model. + encoderblock_layernorm0_scale = state_dict.pop("img/Transformer/encoderblock/LayerNorm_0/scale") + encoderblock_layernorm0_bias = state_dict.pop("img/Transformer/encoderblock/LayerNorm_0/bias") + encoderblock_layernorm1_scale = state_dict.pop("img/Transformer/encoderblock/LayerNorm_1/scale") + encoderblock_layernorm1_bias = state_dict.pop("img/Transformer/encoderblock/LayerNorm_1/bias") + + encoderblock_mlp_dense0_kernel= state_dict.pop("img/Transformer/encoderblock/MlpBlock_0/Dense_0/kernel") + encoderblock_mlp_dense0_bias= state_dict.pop("img/Transformer/encoderblock/MlpBlock_0/Dense_0/bias") + encoderblock_mlp_dense1_kernel= state_dict.pop("img/Transformer/encoderblock/MlpBlock_0/Dense_1/kernel") + encoderblock_mlp_dense1_bias= state_dict.pop("img/Transformer/encoderblock/MlpBlock_0/Dense_1/bias") + + encoderblock_attention_0_key_kernel = state_dict.pop("img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/kernel") + encoderblock_attention_0_key_bias = state_dict.pop("img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/bias") + encoderblock_attention_0_value_kernel = state_dict.pop("img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/kernel") + encoderblock_attention_0_value_bias = state_dict.pop("img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/bias") + encoderblock_attention_0_query_kernel = state_dict.pop("img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/kernel") + encoderblock_attention_0_query_bias = state_dict.pop("img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/bias") + encoderblock_attention_0_out_kernel = state_dict.pop("img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/kernel") + encoderblock_attention_0_out_bias = state_dict.pop("img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/bias") + + for i in range(config.vision_config.num_hidden_layers): + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.layer_norm1.weight"] = encoderblock_layernorm0_scale[i].transpose() + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.layer_norm1.bias"] = encoderblock_layernorm0_bias[i] + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.layer_norm2.weight"] = encoderblock_layernorm1_scale[i].transpose() + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.layer_norm2.bias"] = encoderblock_layernorm1_bias[i] + + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.weight"] = encoderblock_mlp_dense0_kernel[i].transpose() + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.bias"] = encoderblock_mlp_dense0_bias[i] + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.weight"] = encoderblock_mlp_dense1_kernel[i].transpose() + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.bias"] = encoderblock_mlp_dense1_bias[i] + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.weight"] = encoderblock_attention_0_key_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.bias"] = encoderblock_attention_0_key_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.weight"] = encoderblock_attention_0_value_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.bias"] = encoderblock_attention_0_value_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.weight"] = encoderblock_attention_0_query_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.bias"] = encoderblock_attention_0_query_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.weight"] = encoderblock_attention_0_out_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.bias"] = encoderblock_attention_0_out_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) + + state_dict["vision_tower.vision_model.post_layernorm.weight"] = state_dict.pop("img/Transformer/encoder_norm/scale").transpose() + state_dict["vision_tower.vision_model.post_layernorm.bias"] = state_dict.pop("img/Transformer/encoder_norm/bias") + + # multimodal projector + + state_dict['multi_modal_projector.linear.weight'] = state_dict.pop("img/head/kernel").transpose() + state_dict['multi_modal_projector.linear.bias'] = state_dict.pop("img/head/bias") + + # text decoder (gemma) + + embedding_vector = state_dict.pop("llm/embedder/input_embedding") + state_dict["language_model.model.embed_tokens.weight"] = embedding_vector + + # pop the einsum attention + mlp representations. There are 26 layers in gemma2-2b. + + llm_attention_attn_vec_einsum = state_dict.pop("llm/layers/attn/attn_vec_einsum/w") + # (26, 2, 4, 2304, 256) for 2b-224, 4 kv heads and 26 layers + llm_attention_kv_einsum = state_dict.pop("llm/layers/attn/kv_einsum/w") + llm_attention_q_einsum = state_dict.pop("llm/layers/attn/q_einsum/w") + llm_mlp_gating_einsum = state_dict.pop("llm/layers/mlp/gating_einsum") + llm_mlp_linear = state_dict.pop("llm/layers/mlp/linear") + # TODO verify correctness of layer norm loading + llm_input_layernorm = state_dict.pop("llm/layers/pre_attention_norm/scale") + llm_pre_feedforward_layernorm = state_dict.pop("llm/layers/pre_ffw_norm/scale") + + llm_post_attention_layernorm = state_dict.pop("llm/layers/post_attention_norm/scale") + llm_post_feedforward_layernorm = state_dict.pop("llm/layers/post_ffw_norm/scale") + + for i in range(config.text_config.num_hidden_layers): + # llm_attention_q_einsum[i].shape = (8, 2048, 256) + # q_proj_weight_reshaped = llm_attention_q_einsum[i].transpose(0, 2, 1).reshape(config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size) + + """ + q shape (8, 2304, 256) + k shape (4, 2304, 256) + v shape (4, 2304, 256) + o shape (8, 256, 2304) + + """ + q_transpose = (0, 2, 1) + k_transpose = (0, 2, 1) + v_transpose = (0, 2, 1) + o_transpose = (2, 0, 1) + + q_weight_matrices = llm_attention_q_einsum[i].transpose(*q_transpose) + q_proj_weight_reshaped = q_weight_matrices + q_proj_weight_reshaped = q_proj_weight_reshaped.reshape(config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size) + state_dict[f"language_model.model.layers.{i}.self_attn.q_proj.weight"] = q_proj_weight_reshaped + # Shape: (4, 2304, 256) + k_weight_matrices = llm_attention_kv_einsum[i, 0].transpose(*k_transpose) + k_proj_weight_reshaped = k_weight_matrices.reshape( + config.text_config.num_key_value_heads * config.text_config.head_dim, + config.text_config.hidden_size + ) + state_dict[f"language_model.model.layers.{i}.self_attn.k_proj.weight"] = k_proj_weight_reshaped + # llm_attention_kv_einsum[i, 1].shape = (num_key_value_heads, hidden_size, head_dim) + v_weight_matrices = llm_attention_kv_einsum[i, 1].transpose(*v_transpose) # Shape: (4, 2304, 256) + v_proj_weight_reshaped = v_weight_matrices.reshape( + config.text_config.num_key_value_heads * config.text_config.head_dim, + config.text_config.hidden_size + ) + state_dict[f"language_model.model.layers.{i}.self_attn.v_proj.weight"] = v_proj_weight_reshaped + + # output projection. + + # llm_attention_attn_vec_einsum[i].shape = (8, 256, 2304) + o_proj_weight_reshaped = llm_attention_attn_vec_einsum[i].transpose(*o_transpose).reshape(config.text_config.hidden_size, config.text_config.num_attention_heads * config.text_config.head_dim) + state_dict[f"language_model.model.layers.{i}.self_attn.o_proj.weight"] = o_proj_weight_reshaped + # mlp layers + gate_proj_weight = llm_mlp_gating_einsum[i, 0] + state_dict[f"language_model.model.layers.{i}.mlp.gate_proj.weight"] = gate_proj_weight.transpose() + up_proj_weight = llm_mlp_gating_einsum[i, 1] + state_dict[f"language_model.model.layers.{i}.mlp.up_proj.weight"] = up_proj_weight.transpose() + state_dict[f"language_model.model.layers.{i}.mlp.down_proj.weight"] = llm_mlp_linear[i].transpose() + state_dict[f"language_model.model.layers.{i}.input_layernorm.weight"] = llm_input_layernorm[i] + state_dict[f"language_model.model.layers.{i}.post_attention_layernorm.weight"] = llm_post_attention_layernorm[i] + state_dict[f"language_model.model.layers.{i}.pre_feedforward_layernorm.weight"] = llm_pre_feedforward_layernorm[i] + state_dict[f"language_model.model.layers.{i}.post_feedforward_layernorm.weight"] = llm_post_feedforward_layernorm[i] + state_dict["language_model.model.norm.weight"] = state_dict.pop("llm/final_norm/scale") + state_dict["language_model.lm_head.weight"] = embedding_vector # weights are tied. + [k for k in state_dict.keys() if not k.startswith('vision') and not k.startswith('language')] + # fmt: on + for key, value in state_dict.items(): + if not isinstance(value, torch.Tensor): + try: + if value.dtype == jnp.bfloat16: + value = jnp.array(value).astype(jnp.float32) + value = np.array(value) + state_dict[key] = torch.from_numpy(value).to(torch.bfloat16) + else: + state_dict[key] = torch.from_numpy(value) + except Exception as initial_exception: + raise ValueError(f"Conversion failed from jax weights with {initial_exception}. Check your inputs.") + return state_dict + + +def flatten_nested_dict(params, parent_key="", sep="/", precision: int = "float32"): + items = [] + + for k, v in params.items(): + k = k.removeprefix("params/") + new_key = parent_key + sep + k if parent_key else k + + if isinstance(v, collections.abc.MutableMapping): + items.extend(flatten_nested_dict(v, parent_key=new_key, sep=sep, precision=precision).items()) + else: + if precision == "bfloat16": + try: + v = v.view(ml_dtypes.bfloat16) + except Exception as initial_exception: + raise ValueError(f"Conversion failed from bfloat16 with {initial_exception}, check your inputs.") + items.append((new_key, v)) + return dict(items) + + +@torch.no_grad() +def convert_paligemma2_checkpoint( + checkpoint_path, + pytorch_dump_folder_path, + variant: str, + precision: str, + do_convert_weights=False, +): + """ + Read checkpoints from flax npz files, rename/reshape, send result to state dict and verify logits if needed. + """ + config = get_paligemma2_config(variant, precision=precision) + if do_convert_weights: + tokenizer_id = "google/paligemma-3b-pt-224" # same tokenizer as paligemma 1 + tokenizer = AutoTokenizer.from_pretrained(tokenizer_id) + image_token = AddedToken("", normalized=False, special=True) + tokens_to_add = {"additional_special_tokens": [image_token]} + tokenizer.add_special_tokens(tokens_to_add) + + # tokenizer.padding_side = 'right' # uncomment for testing purposes only. + + image_processor = SiglipImageProcessor.from_pretrained("google/paligemma-3b-pt-224") + image_processor.size = {"width": config.vision_config.image_size, "height": config.vision_config.image_size} + image_processor.image_seq_length = config.vision_config.num_image_tokens + + processor = PaliGemmaProcessor(image_processor=image_processor, tokenizer=tokenizer) + data = jnp.load(checkpoint_path) + state_dict = flatten_nested_dict(data, precision=precision) + del data + state_dict_transformers = slice_state_dict(state_dict, config) + del state_dict + del config.hidden_size # this key is unused + model = PaliGemmaForConditionalGeneration(config).to(device).eval() + model.load_state_dict(state_dict_transformers) + del state_dict_transformers + model.config.text_config._attn_implementation = "sdpa" + + # model expansion to get random embeds of image tokens + pad_shape = 64 # for performance reasons + pre_expansion_embeddings = model.language_model.model.embed_tokens.weight.data + mu = torch.mean(pre_expansion_embeddings, dim=0).float() + n = pre_expansion_embeddings.size()[0] + sigma = ((pre_expansion_embeddings - mu).T @ (pre_expansion_embeddings - mu)) / n + dist = torch.distributions.multivariate_normal.MultivariateNormal(mu, covariance_matrix=1e-5 * sigma) + + # We add an image token so we resize the model + model.resize_token_embeddings(config.text_config.vocab_size + 2, pad_shape) + model.language_model.model.embed_tokens.weight.data[257152:] = torch.stack( + tuple( + (dist.sample() for _ in range(model.language_model.model.embed_tokens.weight.data[257152:].shape[0])) + ), + dim=0, + ) + model.language_model.lm_head.weight.data[257152:] = torch.stack( + tuple((dist.sample() for _ in range(model.language_model.lm_head.weight.data[257152:].shape[0]))), + dim=0, + ) + # convert to needed precision + + model.to(DTYPES[precision]) + model.save_pretrained(pytorch_dump_folder_path, safe_serialization=True) + processor.save_pretrained(pytorch_dump_folder_path) + + else: + processor = PaliGemmaProcessor.from_pretrained(pytorch_dump_folder_path, do_rescale=False) + model = ( + PaliGemmaForConditionalGeneration.from_pretrained(pytorch_dump_folder_path, attn_implementation="sdpa") + .to(device) + .eval() + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--checkpoint_path", + required=True, + type=str, + help="Path to the .npz checkpoint", + ) + + parser.add_argument( + "--pytorch_dump_folder_path", + required=True, + type=str, + help="Path to the output directory where model and processor will be saved.", + ) + + parser.add_argument( + "--precision", + choices=["float32", "bfloat16", "float16"], + type=str, + help="Precision identifier for model conversion - should match the base checkpoint precision.", + ) + + parser.add_argument( + "--variant", + default="2b-224", + choices=PALIGEMMA2_VARIANTS, + type=str, + help="String identifier of the paligemma2 variant to convert.", + ) + + parser.add_argument( + "--do_convert_weights", action="store_true", help="Whether or not to reload and convert the weights." + ) + + args = parser.parse_args() + convert_paligemma2_checkpoint( + checkpoint_path=args.checkpoint_path, + pytorch_dump_folder_path=args.pytorch_dump_folder_path, + variant=args.variant, + precision=args.precision, + do_convert_weights=args.do_convert_weights, + ) diff --git a/docs/transformers/src/transformers/models/paligemma/convert_paligemma_weights_to_hf.py b/docs/transformers/src/transformers/models/paligemma/convert_paligemma_weights_to_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..054872a799aca08befa064fecaa1034048af023b --- /dev/null +++ b/docs/transformers/src/transformers/models/paligemma/convert_paligemma_weights_to_hf.py @@ -0,0 +1,347 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert PaliGemma checkpoints from the original repository.""" + +import argparse +import collections + +import torch +from numpy import load + +from transformers import ( + AutoTokenizer, + GemmaTokenizer, + GemmaTokenizerFast, + PaliGemmaConfig, + PaliGemmaForConditionalGeneration, + PaliGemmaProcessor, + SiglipImageProcessor, +) +from transformers.tokenization_utils_base import AddedToken +from transformers.utils import logging + + +device = "cuda" # "cpu" + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + +# TODO add sequence length variations here + +PALIGEMMA_VARIANTS = ["2b-test", "3b-224px", "3b-448px", "3b-896px"] + + +def get_paligemma_config(variant: str, precision: str): + config = { + "image_token_id": None, + "pad_token_id": 0, + "bos_token_id": 2, + "eos_token_id": 1, + } + + image_sizes = {"2b-test": 224, "3b-224px": 224, "3b-448px": 448, "3b-896px": 896} + + if variant in PALIGEMMA_VARIANTS: + image_size = image_sizes[variant] + patch_size = 14 + num_image_tokens = (image_size**2) // (patch_size**2) + + config["image_token_id"] = 257152 if variant != "2b-test" else 256000 + text_config = { + "vocab_size": 257152, + "num_hidden_layers": 18, + "num_key_value_heads": 1, + "head_dim": 256, + "torch_dtype": precision, + "hidden_size": 2048, + "hidden_activation": "gelu_pytorch_tanh", + "num_attention_heads": 8, + "intermediate_size": 16384, + "is_encoder_decoder": False, + } + vision_config = { + "torch_dtype": precision, + "image_size": image_size, + "patch_size": patch_size, + "num_image_tokens": num_image_tokens, + "hidden_size": 1152, + "intermediate_size": 4304, + "num_hidden_layers": 27, + "num_attention_heads": 16, + "projector_hidden_act": "gelu_fast", + "vision_use_head": False, + } + final_config = PaliGemmaConfig(text_config=text_config, vision_config=vision_config, **config) + else: + raise ValueError(f"Identifier {variant} not supported. Available: {PALIGEMMA_VARIANTS}") + return final_config + + +def slice_state_dict(state_dict, config): + # fmt: off + # patch embeddings + state_dict["vision_tower.vision_model.embeddings.patch_embedding.weight"] = state_dict.pop("img/embedding/kernel").transpose( + 3, 2, 0, 1 + ) + state_dict["vision_tower.vision_model.embeddings.patch_embedding.bias"] = state_dict.pop("img/embedding/bias") + # positional embeddings + state_dict["vision_tower.vision_model.embeddings.position_embedding.weight"] = state_dict.pop("img/pos_embedding").reshape( + -1, config.vision_config.hidden_size + ) + + # extract vision layers to be sliced at index 0. There are 27 layers in the base model. + encoderblock_layernorm0_scale = state_dict.pop("img/Transformer/encoderblock/LayerNorm_0/scale") + encoderblock_layernorm0_bias = state_dict.pop("img/Transformer/encoderblock/LayerNorm_0/bias") + encoderblock_layernorm1_scale = state_dict.pop("img/Transformer/encoderblock/LayerNorm_1/scale") + encoderblock_layernorm1_bias = state_dict.pop("img/Transformer/encoderblock/LayerNorm_1/bias") + + encoderblock_mlp_dense0_kernel= state_dict.pop("img/Transformer/encoderblock/MlpBlock_0/Dense_0/kernel") + encoderblock_mlp_dense0_bias= state_dict.pop("img/Transformer/encoderblock/MlpBlock_0/Dense_0/bias") + encoderblock_mlp_dense1_kernel= state_dict.pop("img/Transformer/encoderblock/MlpBlock_0/Dense_1/kernel") + encoderblock_mlp_dense1_bias= state_dict.pop("img/Transformer/encoderblock/MlpBlock_0/Dense_1/bias") + + encoderblock_attention_0_key_kernel = state_dict.pop("img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/kernel") + encoderblock_attention_0_key_bias = state_dict.pop("img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/bias") + encoderblock_attention_0_value_kernel = state_dict.pop("img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/kernel") + encoderblock_attention_0_value_bias = state_dict.pop("img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/bias") + encoderblock_attention_0_query_kernel = state_dict.pop("img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/kernel") + encoderblock_attention_0_query_bias = state_dict.pop("img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/bias") + encoderblock_attention_0_out_kernel = state_dict.pop("img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/kernel") + encoderblock_attention_0_out_bias = state_dict.pop("img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/bias") + + for i in range(config.vision_config.num_hidden_layers): + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.layer_norm1.weight"] = encoderblock_layernorm0_scale[i].transpose() + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.layer_norm1.bias"] = encoderblock_layernorm0_bias[i] + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.layer_norm2.weight"] = encoderblock_layernorm1_scale[i].transpose() + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.layer_norm2.bias"] = encoderblock_layernorm1_bias[i] + + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.weight"] = encoderblock_mlp_dense0_kernel[i].transpose() + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.bias"] = encoderblock_mlp_dense0_bias[i] + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.weight"] = encoderblock_mlp_dense1_kernel[i].transpose() + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.bias"] = encoderblock_mlp_dense1_bias[i] + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.weight"] = encoderblock_attention_0_key_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.bias"] = encoderblock_attention_0_key_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.weight"] = encoderblock_attention_0_value_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.bias"] = encoderblock_attention_0_value_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.weight"] = encoderblock_attention_0_query_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.bias"] = encoderblock_attention_0_query_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.weight"] = encoderblock_attention_0_out_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.bias"] = encoderblock_attention_0_out_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) + + state_dict["vision_tower.vision_model.post_layernorm.weight"] = state_dict.pop("img/Transformer/encoder_norm/scale").transpose() + state_dict["vision_tower.vision_model.post_layernorm.bias"] = state_dict.pop("img/Transformer/encoder_norm/bias") + + # multimodal projector + + state_dict['multi_modal_projector.linear.weight'] = state_dict.pop("img/head/kernel").transpose() + state_dict['multi_modal_projector.linear.bias'] = state_dict.pop("img/head/bias") + + # text decoder (gemma) + + embedding_vector = state_dict.pop("llm/embedder/input_embedding") + state_dict["language_model.model.embed_tokens.weight"] = embedding_vector + + # pop the einsum attention + mlp representations. There are 18 layers in gemma-2b. + + llm_attention_attn_vec_einsum = state_dict.pop("llm/layers/attn/attn_vec_einsum/w") + llm_attention_kv_einsum = state_dict.pop("llm/layers/attn/kv_einsum/w") + llm_attention_q_einsum = state_dict.pop("llm/layers/attn/q_einsum/w") + + llm_mlp_gating_einsum = state_dict.pop("llm/layers/mlp/gating_einsum") + llm_mlp_linear = state_dict.pop("llm/layers/mlp/linear") + # TODO verify correctness of layer norm loading + + llm_input_layernorm = state_dict.pop("llm/layers/pre_attention_norm/scale") + llm_post_attention_layernorm = state_dict.pop("llm/layers/pre_ffw_norm/scale") + + for i in range(config.text_config.num_hidden_layers): + # llm_attention_q_einsum[i].shape = (8, 2048, 256) + q_proj_weight_reshaped = llm_attention_q_einsum[i].transpose(0, 2, 1).reshape(config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size) + + state_dict[f"language_model.model.layers.{i}.self_attn.q_proj.weight"] = q_proj_weight_reshaped + + # llm_attention_kv_einsum[i, 0, 0].shape = (2048, 256) + k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose() + state_dict[f"language_model.model.layers.{i}.self_attn.k_proj.weight"] = k_proj_weight_reshaped + # llm_attention_kv_einsum[i, 1, 0].shape = (2048, 256) + v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose() + state_dict[f"language_model.model.layers.{i}.self_attn.v_proj.weight"] = v_proj_weight_reshaped + + # output projection. + + # llm_attention_attn_vec_einsum[i].shape = (8, 256, 2048) + o_proj_weight_reshaped = llm_attention_attn_vec_einsum[i].transpose(2, 0, 1).reshape(config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size) + + state_dict[f"language_model.model.layers.{i}.self_attn.o_proj.weight"] = o_proj_weight_reshaped + # mlp layers + gate_proj_weight = llm_mlp_gating_einsum[i, 0] + state_dict[f"language_model.model.layers.{i}.mlp.gate_proj.weight"] = gate_proj_weight.transpose() + up_proj_weight = llm_mlp_gating_einsum[i, 1] + state_dict[f"language_model.model.layers.{i}.mlp.up_proj.weight"] = up_proj_weight.transpose() + state_dict[f"language_model.model.layers.{i}.mlp.down_proj.weight"] = llm_mlp_linear[i].transpose() + state_dict[f"language_model.model.layers.{i}.input_layernorm.weight"] = llm_input_layernorm[i] + state_dict[f"language_model.model.layers.{i}.post_attention_layernorm.weight"] = llm_post_attention_layernorm[i] + + state_dict["language_model.model.norm.weight"] = state_dict.pop("llm/final_norm/scale") + state_dict["language_model.lm_head.weight"] = embedding_vector # weights are tied. + + # fmt: on + for key, value in state_dict.items(): + state_dict[key] = torch.from_numpy(value) + return state_dict + + +def flatten_nested_dict(params, parent_key="", sep="/"): + items = [] + + for k, v in params.items(): + k = k.removeprefix("params/") + new_key = parent_key + sep + k if parent_key else k + + if isinstance(v, collections.abc.MutableMapping): + items.extend(flatten_nested_dict(v, parent_key=new_key, sep=sep).items()) + else: + items.append((new_key, v)) + return dict(items) + + +@torch.no_grad() +def convert_paligemma_checkpoint( + checkpoint_path, + tokenizer_model_file, + pytorch_dump_folder_path, + variant: str, + precision: str, + do_convert_weights=False, +): + """ + Read checkpoints from flax npz files, rename/reshape, send result to state dict and verify logits if needed. + """ + config = get_paligemma_config(variant, precision=precision) + if do_convert_weights: + if variant == "2b-test": + # for the test model, the vocabulary was smaller + tokenizer_id = "google/gemma-2b" + tokenizer = AutoTokenizer.from_pretrained(tokenizer_id) + else: + tokenizer_class = GemmaTokenizer if GemmaTokenizerFast is None else GemmaTokenizerFast + tokenizer = tokenizer_class(tokenizer_model_file) + image_token = AddedToken("", normalized=False, special=True) + tokens_to_add = {"additional_special_tokens": [image_token]} + tokenizer.add_special_tokens(tokens_to_add) + + # tokenizer.padding_side = 'right' # uncomment for testing purposes only. + + image_processor = SiglipImageProcessor.from_pretrained("google/siglip-so400m-patch14-384") + image_processor.size = {"width": config.vision_config.image_size, "height": config.vision_config.image_size} + image_processor.image_seq_length = config.vision_config.num_image_tokens + + processor = PaliGemmaProcessor(image_processor=image_processor, tokenizer=tokenizer) + data = load(checkpoint_path) + state_dict = flatten_nested_dict(data) + del data + state_dict_transformers = slice_state_dict(state_dict, config) + del state_dict + + model = PaliGemmaForConditionalGeneration(config).to(device).eval() + model.load_state_dict(state_dict_transformers) + del state_dict_transformers + + else: + processor = PaliGemmaProcessor.from_pretrained(pytorch_dump_folder_path) + model = ( + PaliGemmaForConditionalGeneration.from_pretrained(pytorch_dump_folder_path, attn_implementation="sdpa") + .to(device) + .eval() + ) + model.config.text_config._attn_implementation = "sdpa" + + # model expansion to get random embeds of image tokens + pad_shape = 64 # for performance reasons + pre_expansion_embeddings = model.language_model.model.embed_tokens.weight.data + mu = torch.mean(pre_expansion_embeddings, dim=0).float() + n = pre_expansion_embeddings.size()[0] + sigma = ((pre_expansion_embeddings - mu).T @ (pre_expansion_embeddings - mu)) / n + dist = torch.distributions.multivariate_normal.MultivariateNormal(mu, covariance_matrix=1e-5 * sigma) + + # We add an image token so we resize the model + model.resize_token_embeddings(config.text_config.vocab_size + 2, pad_shape) + model.language_model.model.embed_tokens.weight.data[257152:] = torch.stack( + tuple((dist.sample() for _ in range(model.language_model.model.embed_tokens.weight.data[257152:].shape[0]))), + dim=0, + ) + model.language_model.lm_head.weight.data[257152:] = torch.stack( + tuple((dist.sample() for _ in range(model.language_model.lm_head.weight.data[257152:].shape[0]))), + dim=0, + ) + + model.save_pretrained(pytorch_dump_folder_path, max_shard_size="2GB", safe_serialization=True) + processor.save_pretrained(pytorch_dump_folder_path) + + +# + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--checkpoint_path", + required=True, + type=str, + help="Path to the .npz checkpoint", + ) + + parser.add_argument( + "--tokenizer_model_file", + required=True, + type=str, + help="Path to the sentencepiece tokenizer.model file", + ) + + parser.add_argument( + "--pytorch_dump_folder_path", + required=True, + type=str, + help="Path to the output directory where model and processor will be saved.", + ) + + parser.add_argument( + "--precision", + choices=["float32", "bfloat16", "float16"], + type=str, + help="Precision identifier for model conversion - should match the base checkpoint precision.", + ) + + parser.add_argument( + "--variant", + default="2b-test", + choices=PALIGEMMA_VARIANTS, + type=str, + help="String identifier of the paligemma variant to convert.", + ) + + parser.add_argument( + "--do_convert_weights", action="store_true", help="Whether or not to reload and convert the weights." + ) + + args = parser.parse_args() + convert_paligemma_checkpoint( + checkpoint_path=args.checkpoint_path, + tokenizer_model_file=args.tokenizer_model_file, + pytorch_dump_folder_path=args.pytorch_dump_folder_path, + variant=args.variant, + precision=args.precision, + do_convert_weights=args.do_convert_weights, + ) diff --git a/docs/transformers/src/transformers/models/paligemma/modeling_paligemma.py b/docs/transformers/src/transformers/models/paligemma/modeling_paligemma.py new file mode 100644 index 0000000000000000000000000000000000000000..ade69444f462814c001f862dc280ccaa55be41bf --- /dev/null +++ b/docs/transformers/src/transformers/models/paligemma/modeling_paligemma.py @@ -0,0 +1,622 @@ +# coding=utf-8 +# Copyright 2024 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch PaliGemmamodel.""" + +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn + +from ...cache_utils import Cache, HybridCache, StaticCache +from ...generation import GenerationMixin +from ...modeling_outputs import CausalLMOutputWithPast +from ...modeling_utils import PreTrainedModel +from ...utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_torchdynamo_compiling, + logging, + replace_return_docstrings, +) +from ..auto import AutoModel, AutoModelForCausalLM +from .configuration_paligemma import PaliGemmaConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "PaliGemmaConfig" + + +# Adapted from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position +# But Paligemma has no causal mask on prefix +def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + min_dtype: float, + cache_position: torch.Tensor, + batch_size: int, + is_training: bool = False, + token_type_ids: Optional[torch.Tensor] = None, + **kwargs, +): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to place the 4D attention mask on. + min_dtype (`float`): + The minimum value representable with the dtype `dtype`. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + is_training (`bool`): + Whether the model is in training mode or in inference. The condition is checked by presence/absence of `token_type_ids/labels` + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) + # Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below + if sequence_length != 1: + if is_training: + causal_mask = torch.triu(causal_mask, diagonal=1) + else: + causal_mask[:, :sequence_length] = 0.0 + + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + # we are training thus we need to create a full mask on the image + prefix but causal on suffix + if is_training: + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0 + ) + return causal_mask + + +@dataclass +class PaliGemmaCausalLMOutputWithPast(ModelOutput): + """ + Base class for PaliGemmacausal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder after projecting last hidden state. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + image_hidden_states: Optional[torch.FloatTensor] = None + + +class PaliGemmaMultiModalProjector(nn.Module): + def __init__(self, config: PaliGemmaConfig): + super().__init__() + self.linear = nn.Linear(config.vision_config.hidden_size, config.vision_config.projection_dim, bias=True) + + def forward(self, image_features): + hidden_states = self.linear(image_features) + + return hidden_states + + +PALIGEMMA_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`PaliGemmaConfig`] or [`PaliGemmaVisionConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + PALIGEMMA_START_DOCSTRING, +) +class PaliGemmaPreTrainedModel(PreTrainedModel): + config_class = PaliGemmaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["PaliGemmaMultiModalProjector"] + _skip_keys_device_placement = "past_key_values" + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + _supports_flash_attn_2 = True + _supports_sdpa = True + + def _init_weights(self, module): + # important: this ported version of PaliGemmaisn't meant for training from scratch - only + # inference and fine-tuning + std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) + + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + + +PALIGEMMA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)): + The tensors corresponding to the input images. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`SiglipImageProcessor.__call__`] for details ([]`PaliGemmaProcessor`] uses + [`SiglipImageProcessor`] for processing images). + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + """The PALIGEMMA model which consists of a vision backbone and a language model.""", + PALIGEMMA_START_DOCSTRING, +) +class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixin): + def __init__(self, config: PaliGemmaConfig): + super().__init__(config) + self.vision_tower = AutoModel.from_config(config=config.vision_config) + self.multi_modal_projector = PaliGemmaMultiModalProjector(config) + self.vocab_size = config.text_config.vocab_size + + language_model = AutoModelForCausalLM.from_config(config=config.text_config) + + if language_model._tied_weights_keys is not None: + self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys] + self.language_model = language_model + + self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 + self.post_init() + + # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_input_embeddings with Llava->PaliGemma + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_input_embeddings with Llava->PaliGemma + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_output_embeddings with Llava->PaliGemma + def get_output_embeddings(self): + return self.language_model.get_output_embeddings() + + # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_output_embeddings with Llava->PaliGemma + def set_output_embeddings(self, new_embeddings): + self.language_model.set_output_embeddings(new_embeddings) + + # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_decoder with Llava->PaliGemma + def set_decoder(self, decoder): + self.language_model.set_decoder(decoder) + + # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_decoder with Llava->PaliGemma + def get_decoder(self): + return self.language_model.get_decoder() + + def _update_causal_mask( + self, + attention_mask, + token_type_ids=None, + past_key_values=None, + cache_position=None, + input_tensor=None, + is_training: Optional[bool] = None, + ): + if self.config.text_config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + is_training = is_training if is_training is not None else self.training + using_static_cache = isinstance(past_key_values, StaticCache) + min_dtype = torch.finfo(self.dtype).min + if input_tensor is None: + input_tensor = attention_mask + + inputs_lead_dim, sequence_length = input_tensor.shape[:2] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + elif isinstance(past_key_values, HybridCache): + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else cache_position[0] + sequence_length + 1 + ) + + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + return attention_mask + + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=self.dtype, device=cache_position.device + ) + # Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below + if sequence_length != 1: + if is_training: + causal_mask = torch.triu(causal_mask, diagonal=1) + else: + causal_mask[:, :sequence_length] = 0.0 + + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + + # First unmask prefix tokens during training + if is_training: + if token_type_ids is None: + raise ValueError("Token type ids must be provided during training") + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0 + ) + + # Then apply padding mask (will mask pad tokens) + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + def get_image_features(self, pixel_values: torch.FloatTensor): + """ + Obtains image last hidden states from the vision tower and apply multimodal projection. + + Args: + pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`) + The tensors corresponding to the input images. + Returns: + image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). + """ + image_outputs = self.vision_tower(pixel_values) + selected_image_feature = image_outputs.last_hidden_state + image_features = self.multi_modal_projector(selected_image_feature) + image_features = image_features / (self.config.text_config.hidden_size**0.5) + return image_features + + @add_start_docstrings_to_model_forward(PALIGEMMA_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=PaliGemmaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None, + token_type_ids: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **lm_kwargs, + ) -> Union[Tuple, PaliGemmaCausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, PaliGemmaForConditionalGeneration + + >>> model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma2-3b-mix-224") + >>> processor = AutoProcessor.from_pretrained("google/paligemma2-3b-mix-224") + + >>> prompt = "Where is the cat standing?" + >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, text=prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(**inputs,) + >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Where is the cat standing?\nsnow" + ```""" + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + is_training = token_type_ids is not None and labels is not None + + # Replace image id woth PAD if the image token if OOV, to avoid index-errors + if input_ids is not None and self.config.image_token_id >= self.vocab_size: + special_image_mask = input_ids == self.config.image_token_id + llm_input_ids = input_ids.clone() + llm_input_ids[special_image_mask] = 0 + else: + llm_input_ids = input_ids + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(llm_input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + 1 # Paligemma positions are 1-indexed + + # Merge text and images + if pixel_values is not None: + image_features = self.get_image_features(pixel_values) + + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + else: + special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): + image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0] + raise ValueError( + f"Number of images does not match number of special image tokens in the input text. " + f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} " + "tokens from image embeddings." + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + # mask out pad-token-ids in labels for BC + if labels is not None and self.pad_token_id in labels: + logger.warning_once( + "`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. " + "You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.", + ) + labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels) + + causal_mask = self._update_causal_mask( + attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training + ) + outputs: CausalLMOutputWithPast = self.language_model( + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + logits_to_keep=logits_to_keep, + **lm_kwargs, + ) + + logits = outputs[0] + loss = None + if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + shift_logits = logits[..., :-1, :] + shift_labels = labels[..., 1:] + if attention_mask is not None: + # we use the input attention mask to shift the logits and labels, because it is 2D. + # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft + shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device) + shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous() + shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous() + else: + shift_logits = shift_logits.contiguous() + shift_labels = shift_labels.contiguous() + # Flatten the tokens + loss_fct = nn.CrossEntropyLoss() + + flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size) + flat_labels = shift_labels.view(-1).to(shift_logits.device) + loss = loss_fct(flat_logits, flat_labels) + + output = PaliGemmaCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + ) + return output if return_dict else output.to_tuple() + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + pixel_values=None, + attention_mask=None, + token_type_ids=None, + use_cache=True, + logits_to_keep=None, + labels=None, + **kwargs, + ): + # Overwritten -- custom `position_ids` and `pixel_values` handling + model_inputs = self.language_model.prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + cache_position=cache_position, + use_cache=use_cache, + logits_to_keep=logits_to_keep, + token_type_ids=token_type_ids, + **kwargs, + ) + + # position_ids in Paligemma are 1-indexed + if model_inputs.get("position_ids") is not None: + model_inputs["position_ids"] += 1 + # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore + # Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always + if cache_position[0] == 0: + model_inputs["pixel_values"] = pixel_values + is_training = token_type_ids is not None and labels is not None + if cache_position[0] == 0 and isinstance(past_key_values, HybridCache): + input_tensor = inputs_embeds if inputs_embeds is not None else input_ids + causal_mask = self._update_causal_mask( + attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training + ) + model_inputs["attention_mask"] = causal_mask + + return model_inputs + + +__all__ = ["PaliGemmaForConditionalGeneration", "PaliGemmaPreTrainedModel"] diff --git a/docs/transformers/src/transformers/models/paligemma/processing_paligemma.py b/docs/transformers/src/transformers/models/paligemma/processing_paligemma.py new file mode 100644 index 0000000000000000000000000000000000000000..5048f0c3eef860509af06eb4184013c9faa20e7f --- /dev/null +++ b/docs/transformers/src/transformers/models/paligemma/processing_paligemma.py @@ -0,0 +1,341 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Processor class for PaliGemma. +""" + +from typing import List, Optional, Union + +from ...feature_extraction_utils import BatchFeature +from ...image_utils import ImageInput, is_valid_image, make_flat_list_of_images +from ...processing_utils import ( + ImagesKwargs, + ProcessingKwargs, + ProcessorMixin, + TextKwargs, + Unpack, + _validate_images_text_input_order, +) +from ...tokenization_utils_base import ( + AddedToken, + PreTokenizedInput, + TextInput, +) +from ...utils import logging + + +logger = logging.get_logger(__name__) + +IMAGE_TOKEN = "" +EXTRA_TOKENS = [f"4}>" for i in range(1024)] + [f"3}>" for i in range(128)] + + +class PaliGemmaTextKwargs(TextKwargs): + suffix: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] + + +class PaliGemmaImagesKwargs(ImagesKwargs): + do_convert_rgb: Optional[bool] + + +class PaliGemmaProcessorKwargs(ProcessingKwargs, total=False): + text_kwargs: PaliGemmaTextKwargs + images_kwargs: PaliGemmaImagesKwargs + _defaults = { + "text_kwargs": { + "padding": False, + }, + "images_kwargs": { + "data_format": "channels_first", + }, + } + + +# Copied from transformers.models.idefics2.processing_idefics2.is_url +def is_url(val) -> bool: + return isinstance(val, str) and val.startswith("http") + + +# Copied from transformers.models.idefics2.processing_idefics2.is_image_or_image_url +def is_image_or_image_url(elem): + return is_url(elem) or is_valid_image(elem) + + +def _is_str_or_image(elem): + return isinstance(elem, (str)) or is_image_or_image_url(elem) + + +def build_string_from_input(prompt, bos_token, image_seq_len, image_token, num_images): + """ + Builds a string from the input prompt and image tokens. + For example, for the call: + build_string_from_input( + prompt="Prefix str" + bos_token="", + image_seq_len=3, + image_token="", + ) + The output will be: + "Initial str" + Args: + prompt (`List[Union[str, ImageInput]]`): The input prompt. + bos_token (`str`): The beginning of sentence token. + image_seq_len (`int`): The length of the image sequence. + image_token (`str`): The image token. + num_images (`int`): Number of images in the prompt. + """ + return f"{image_token * image_seq_len * num_images}{bos_token}{prompt}\n" + + +class PaliGemmaProcessor(ProcessorMixin): + r""" + Constructs a PaliGemma processor which wraps a PaliGemma image processor and a PaliGemma tokenizer into a single processor. + + [`PaliGemmaProcessor`] offers all the functionalities of [`SiglipImageProcessor`] and [`GemmaTokenizerFast`]. See the + [`~PaliGemmaProcessor.__call__`] and [`~PaliGemmaProcessor.decode`] for more information. + + Args: + image_processor ([`SiglipImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`GemmaTokenizerFast`], *optional*): + The tokenizer is a required input. + chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages + in a chat into a tokenizable string. + """ + + attributes = ["image_processor", "tokenizer"] + valid_kwargs = ["chat_template"] + image_processor_class = ("SiglipImageProcessor", "SiglipImageProcessorFast") + tokenizer_class = ("GemmaTokenizer", "GemmaTokenizerFast") + + def __init__( + self, + image_processor=None, + tokenizer=None, + chat_template=None, + **kwargs, + ): + if image_processor is None: + raise ValueError("You need to specify an `image_processor`.") + if tokenizer is None: + raise ValueError("You need to specify a `tokenizer`.") + if not hasattr(image_processor, "image_seq_length"): + raise ValueError("Image processor is missing an `image_seq_length` attribute.") + + self.image_seq_length = image_processor.image_seq_length + + if not hasattr(tokenizer, "image_token"): + image_token = AddedToken(IMAGE_TOKEN, normalized=False, special=True) + tokens_to_add = {"additional_special_tokens": [image_token]} + tokenizer.add_special_tokens(tokens_to_add) + self.image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) + self.image_token = IMAGE_TOKEN + else: + self.image_token_id = tokenizer.image_token_id + self.image_token = tokenizer.image_token + + tokenizer.add_tokens(EXTRA_TOKENS) + tokenizer.add_bos_token = False + tokenizer.add_eos_token = False + + super().__init__(image_processor, tokenizer, chat_template=chat_template) + + def __call__( + self, + images: ImageInput = None, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + audio=None, + videos=None, + **kwargs: Unpack[PaliGemmaProcessorKwargs], + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` + and `kwargs` arguments to GemmaTokenizerFast's [`~GemmaTokenizerFast.__call__`] if `text` is not `None` to encode + the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to + SiglipImageProcessor's [`~SiglipImageProcessor.__call__`] if `images` is not `None`. Please refer to the docstring + of the above two methods for more information. + + The usage for PaliGemma fine-tuning preparation is slightly different than usual. suffix passed are suffixes to + the prompt in `text`, and will be placed after the prompt. This is because attention is handled differently for + the prefix and the suffix. For instance, + ```python + image = PIL_cow_image + prompt = "answer en Where is the cow standing?" + suffix = "on the beach" + inputs = processor(text=prompt, images=image, suffix=suffix) + ``` + Here `inputs` will contain the `input_ids` and `token_type_ids` that follow + ```python + inputs["input_ids"][:, 256:] + # tensor([[ 2, 6006, 603, 573, 13910, 9980, 235336, 108, 477, 573, 8318]]) + inputs["token_type_ids"][:, 256:] + tensor([[0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1]]) + ``` + Meaning the last three tokens are of "label" ("suffix") type while the other ones are of "prefix" type. + + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a + number of channels, H and W are image height and width. + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + suffix (`str`, `List[str]`, `List[List[str]]`): + The suffixes or batch of suffixes to be encoded. Only necessary for finetuning. See https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/README.md + for more information. If your prompt is " What is on the image", the suffix corresponds to the expected prediction "a cow sitting on a bench". + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. If `suffix` + is provided, the `input_ids` will also contain the suffix input ids. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + - **labels** -- Labels compatible with training if `suffix` is not None + """ + # check if images and text inputs are reversed for BC + images, text = _validate_images_text_input_order(images, text) + + output_kwargs = self._merge_kwargs( + PaliGemmaProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + suffix = output_kwargs["text_kwargs"].pop("suffix", None) + + return_token_type_ids = True if suffix is not None else False + + if images is None: + raise ValueError("`images` are expected as arguments to a `PaliGemmaProcessor` instance.") + if text is None: + logger.warning_once( + "You are using PaliGemma without a text prefix. It will perform as a picture-captioning model." + ) + text = "" + + if _is_str_or_image(text): + text = [text] + elif isinstance(text, list) and _is_str_or_image(text[0]): + pass + + if text is not None and images is not None: + if not any(IMAGE_TOKEN in sample for sample in text): + logger.warning( + "You are passing both `text` and `images` to `PaliGemmaProcessor`. The processor expects special " + "image tokens in the text, as many tokens as there are images per each text. It is recommended to " + "add `` tokens in the very beginning of your text. For this call, we will infer how many images " + "each text has and add special tokens." + ) + + if isinstance(text, List) and isinstance(images, List): + if len(images) != len(text): + raise ValueError( + f"Received {len(images)} images for {len(text)} prompts. Each prompt should be associated with an image or list of images." + ) + + # make a nested list of lists to be able to iterate over the images and text below + if is_valid_image(images): + images = [[images]] + elif isinstance(images, (list, tuple)) and is_valid_image(images[0]): + images = [[image] for image in images] + elif not ( + isinstance(images, (list, tuple)) + and isinstance(images[0], (list, tuple)) + and is_valid_image(images[0][0]) + ): + raise ValueError("images must be an image, list of images or list of list of images") + + input_strings = [ + build_string_from_input( + prompt=prompt, + bos_token=self.tokenizer.bos_token, + image_seq_len=self.image_seq_length, + image_token=IMAGE_TOKEN, + num_images=len(image_list) if isinstance(image_list, list) else 1, + ) + for prompt, image_list in zip(text, images) + ] + images = make_flat_list_of_images(images) + else: + expanded_samples = [] + for sample in text: + expanded_sample = sample.replace(IMAGE_TOKEN, IMAGE_TOKEN * self.image_seq_length) + bos_rfind_index = expanded_sample.rfind(IMAGE_TOKEN) + bos_index = bos_rfind_index + len(IMAGE_TOKEN) if bos_rfind_index != -1 else 0 + expanded_sample = ( + expanded_sample[:bos_index] + self.tokenizer.bos_token + expanded_sample[bos_index:] + ) + expanded_samples.append(expanded_sample) + input_strings = [f"{sample}\n" for sample in expanded_samples] + + if suffix is not None and _is_str_or_image(suffix): + suffix = [suffix] + if suffix is not None: + suffix = [sfx + self.tokenizer.eos_token for sfx in suffix] + pixel_values = self.image_processor(images, **output_kwargs["images_kwargs"])["pixel_values"] + + return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) + inputs = self.tokenizer( + input_strings, + text_pair=suffix, + return_token_type_ids=return_token_type_ids, + **output_kwargs["text_kwargs"], + ) + self._check_special_mm_tokens(input_strings, inputs, modalities=["image"]) + + return_data = {**inputs, "pixel_values": pixel_values} + + if return_token_type_ids: + labels = inputs["input_ids"].masked_fill(inputs["token_type_ids"] == 0, -100) + return_data.update({"labels": labels}) + return BatchFeature(data=return_data, tensor_type=return_tensors) + + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Gemma + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Gemma + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + @property + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names with CLIP->PaliGemma + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) + + +__all__ = ["PaliGemmaProcessor"] diff --git a/docs/transformers/src/transformers/models/patchtsmixer/__init__.py b/docs/transformers/src/transformers/models/patchtsmixer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..285c1970308a47827806fca349d130703f40a2c8 --- /dev/null +++ b/docs/transformers/src/transformers/models/patchtsmixer/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_patchtsmixer import * + from .modeling_patchtsmixer import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/patchtsmixer/configuration_patchtsmixer.py b/docs/transformers/src/transformers/models/patchtsmixer/configuration_patchtsmixer.py new file mode 100644 index 0000000000000000000000000000000000000000..83f374651a7cfc998349a734d8350fc399ea7f07 --- /dev/null +++ b/docs/transformers/src/transformers/models/patchtsmixer/configuration_patchtsmixer.py @@ -0,0 +1,235 @@ +# coding=utf-8 +# Copyright 2023 IBM and HuggingFace Inc. team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PatchTSMixer model configuration""" + +from typing import List, Optional, Union + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class PatchTSMixerConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`PatchTSMixerModel`]. It is used to instantiate a + PatchTSMixer model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the PatchTSMixer + [ibm/patchtsmixer-etth1-pretrain](https://huggingface.co/ibm/patchtsmixer-etth1-pretrain) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + context_length (`int`, *optional*, defaults to 32): + The context/history length for the input sequence. + patch_length (`int`, *optional*, defaults to 8): + The patch length for the input sequence. + num_input_channels (`int`, *optional*, defaults to 1): + Number of input variates. For Univariate, set it to 1. + patch_stride (`int`, *optional*, defaults to 8): + Determines the overlap between two consecutive patches. Set it to patch_length (or greater), if we want + non-overlapping patches. + num_parallel_samples (`int`, *optional*, defaults to 100): + The number of samples to generate in parallel for probabilistic forecast. + d_model (`int`, *optional*, defaults to 8): + Hidden dimension of the model. Recommended to set it as a multiple of patch_length (i.e. 2-5X of + patch_length). Larger value indicates more complex model. + expansion_factor (`int`, *optional*, defaults to 2): + Expansion factor to use inside MLP. Recommended range is 2-5. Larger value indicates more complex model. + num_layers (`int`, *optional*, defaults to 3): + Number of layers to use. Recommended range is 3-15. Larger value indicates more complex model. + dropout (`float`, *optional*, defaults to 0.2): + The dropout probability the `PatchTSMixer` backbone. Recommended range is 0.2-0.7 + mode (`str`, *optional*, defaults to `"common_channel"`): + Mixer Mode. Determines how to process the channels. Allowed values: "common_channel", "mix_channel". In + "common_channel" mode, we follow Channel-independent modelling with no explicit channel-mixing. Channel + mixing happens in an implicit manner via shared weights across channels. (preferred first approach) In + "mix_channel" mode, we follow explicit channel-mixing in addition to patch and feature mixer. (preferred + approach when channel correlations are very important to model) + gated_attn (`bool`, *optional*, defaults to `True`): + Enable Gated Attention. + norm_mlp (`str`, *optional*, defaults to `"LayerNorm"`): + Normalization layer (BatchNorm or LayerNorm). + self_attn (`bool`, *optional*, defaults to `False`): + Enable Tiny self attention across patches. This can be enabled when the output of Vanilla PatchTSMixer with + gated attention is not satisfactory. Enabling this leads to explicit pair-wise attention and modelling + across patches. + self_attn_heads (`int`, *optional*, defaults to 1): + Number of self-attention heads. Works only when `self_attn` is set to `True`. + use_positional_encoding (`bool`, *optional*, defaults to `False`): + Enable the use of positional embedding for the tiny self-attention layers. Works only when `self_attn` is + set to `True`. + positional_encoding_type (`str`, *optional*, defaults to `"sincos"`): + Positional encodings. Options `"random"` and `"sincos"` are supported. Works only when + `use_positional_encoding` is set to `True` + scaling (`string` or `bool`, *optional*, defaults to `"std"`): + Whether to scale the input targets via "mean" scaler, "std" scaler or no scaler if `None`. If `True`, the + scaler is set to "mean". + loss (`string`, *optional*, defaults to `"mse"`): + The loss function for the model corresponding to the `distribution_output` head. For parametric + distributions it is the negative log likelihood ("nll") and for point estimates it is the mean squared + error "mse". + init_std (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated normal weight initialization distribution. + post_init (`bool`, *optional*, defaults to `False`): + Whether to use custom weight initialization from `transformers` library, or the default initialization in + `PyTorch`. Setting it to `False` performs `PyTorch` weight initialization. + norm_eps (`float`, *optional*, defaults to 1e-05): + A value added to the denominator for numerical stability of normalization. + mask_type (`str`, *optional*, defaults to `"random"`): + Type of masking to use for Masked Pretraining mode. Allowed values are "random", "forecast". In Random + masking, points are masked randomly. In Forecast masking, points are masked towards the end. + random_mask_ratio (`float`, *optional*, defaults to 0.5): + Masking ratio to use when `mask_type` is `random`. Higher value indicates more masking. + num_forecast_mask_patches (`int` or `list`, *optional*, defaults to `[2]`): + Number of patches to be masked at the end of each batch sample. If it is an integer, all the samples in the + batch will have the same number of masked patches. If it is a list, samples in the batch will be randomly + masked by numbers defined in the list. This argument is only used for forecast pretraining. + mask_value (`float`, *optional*, defaults to `0.0`): + Mask value to use. + masked_loss (`bool`, *optional*, defaults to `True`): + Whether to compute pretraining loss only at the masked portions, or on the entire output. + channel_consistent_masking (`bool`, *optional*, defaults to `True`): + When true, masking will be same across all channels of a timeseries. Otherwise, masking positions will vary + across channels. + unmasked_channel_indices (`list`, *optional*): + Channels that are not masked during pretraining. + head_dropout (`float`, *optional*, defaults to 0.2): + The dropout probability the `PatchTSMixer` head. + distribution_output (`string`, *optional*, defaults to `"student_t"`): + The distribution emission head for the model when loss is "nll". Could be either "student_t", "normal" or + "negative_binomial". + prediction_length (`int`, *optional*, defaults to 16): + Number of time steps to forecast for a forecasting task. Also known as the Forecast Horizon. + prediction_channel_indices (`list`, *optional*): + List of channel indices to forecast. If None, forecast all channels. Target data is expected to have all + channels and we explicitly filter the channels in prediction and target before loss computation. + num_targets (`int`, *optional*, defaults to 3): + Number of targets (dimensionality of the regressed variable) for a regression task. + output_range (`list`, *optional*): + Output range to restrict for the regression task. Defaults to None. + head_aggregation (`str`, *optional*, defaults to `"max_pool"`): + Aggregation mode to enable for classification or regression task. Allowed values are `None`, "use_last", + "max_pool", "avg_pool". + + Example: + + ```python + >>> from transformers import PatchTSMixerConfig, PatchTSMixerModel + + >>> # Initializing a default PatchTSMixer configuration + >>> configuration = PatchTSMixerConfig() + + >>> # Randomly initializing a model (with random weights) from the configuration + >>> model = PatchTSMixerModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "patchtsmixer" + attribute_map = { + "hidden_size": "d_model", + "num_hidden_layers": "num_layers", + } + + def __init__( + self, + # Time series specific configuration + context_length: int = 32, + patch_length: int = 8, + num_input_channels: int = 1, + patch_stride: int = 8, + num_parallel_samples: int = 100, + # General model configuration + d_model: int = 8, + expansion_factor: int = 2, + num_layers: int = 3, + dropout: float = 0.2, + mode: str = "common_channel", + gated_attn: bool = True, + norm_mlp: str = "LayerNorm", + self_attn: bool = False, + self_attn_heads: int = 1, + use_positional_encoding: bool = False, + positional_encoding_type: str = "sincos", + scaling: Optional[Union[str, bool]] = "std", + loss: str = "mse", + init_std: float = 0.02, + post_init: bool = False, + norm_eps: float = 1e-5, + # Pretrain model configuration + mask_type: str = "random", + random_mask_ratio: float = 0.5, + num_forecast_mask_patches: Optional[Union[List[int], int]] = [2], + mask_value: int = 0, + masked_loss: bool = True, + channel_consistent_masking: bool = True, + unmasked_channel_indices: Optional[List[int]] = None, + # General head configuration + head_dropout: float = 0.2, + distribution_output: str = "student_t", + # Prediction head configuration + prediction_length: int = 16, + prediction_channel_indices: list = None, + # Classification/Regression configuration + num_targets: int = 3, + output_range: list = None, + head_aggregation: str = "max_pool", + **kwargs, + ): + self.num_input_channels = num_input_channels + self.context_length = context_length + self.patch_length = patch_length + self.patch_stride = patch_stride + self.d_model = d_model + self.expansion_factor = expansion_factor + self.num_layers = num_layers + self.dropout = dropout + self.mode = mode + self.gated_attn = gated_attn + self.norm_mlp = norm_mlp + self.scaling = scaling + self.head_dropout = head_dropout + self.num_patches = (max(context_length, patch_length) - patch_length) // patch_stride + 1 + self.mask_type = mask_type + self.random_mask_ratio = random_mask_ratio + self.num_forecast_mask_patches = num_forecast_mask_patches + self.mask_value = mask_value + self.channel_consistent_masking = channel_consistent_masking + self.masked_loss = masked_loss + self.patch_last = True + self.use_positional_encoding = use_positional_encoding + self.positional_encoding_type = positional_encoding_type + self.prediction_length = prediction_length + self.prediction_channel_indices = prediction_channel_indices + self.num_targets = num_targets + self.output_range = output_range + self.head_aggregation = head_aggregation + self.self_attn = self_attn + self.self_attn_heads = self_attn_heads + self.init_std = init_std + self.post_init = post_init + self.distribution_output = distribution_output + self.loss = loss + self.num_parallel_samples = num_parallel_samples + self.unmasked_channel_indices = unmasked_channel_indices + self.norm_eps = norm_eps + super().__init__(**kwargs) + + +__all__ = ["PatchTSMixerConfig"] diff --git a/docs/transformers/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py b/docs/transformers/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py new file mode 100644 index 0000000000000000000000000000000000000000..ca88a84a399ef11729898ed22c5fc884e9c23e70 --- /dev/null +++ b/docs/transformers/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py @@ -0,0 +1,2182 @@ +# coding=utf-8 +# Copyright 2023 IBM and HuggingFace Inc. team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch PatchTSMixer model.""" + +import math +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn + +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ModelOutput + +from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from .configuration_patchtsmixer import PatchTSMixerConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "PatchTSMixerConfig" + + +PATCHTSMIXER_START_DOCSTRING = r""" + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`PatchTSMixerConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. + mask_input (`bool`, *optional*, defaults to `False`): + If True, Masking will be enabled. False otherwise. +""" + +PATCHTSMIXER_INPUTS_DOCSTRING = r""" + Args: + past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`): + Context values of the time series. For a pretraining task, this denotes the input time series to predict + the masked portion. For a forecasting task, this denotes the history/past time series values. Similarly, + for classification or regression tasks, it denotes the appropriate context values of the time series. + + For univariate time series, `num_input_channels` dimension should be 1. For multivariate time series, it is + greater than 1. + + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. + + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class PatchTSMixerGatedAttention(nn.Module): + """ + Module that applies gated attention to input data. + + Args: + in_size (`int`): The input size. + out_size (`int`): The output size. + """ + + def __init__(self, in_size: int, out_size: int): + super().__init__() + self.attn_layer = nn.Linear(in_size, out_size) + self.attn_softmax = nn.Softmax(dim=-1) + + def forward(self, inputs): + attn_weight = self.attn_softmax(self.attn_layer(inputs)) + inputs = inputs * attn_weight + return inputs + + +# Copied from transformers.models.patchtst.modeling_patchtst.PatchTSTBatchNorm with PatchTST->PatchTSMixer +class PatchTSMixerBatchNorm(nn.Module): + """ + Compute batch normalization over the sequence length (time) dimension. + """ + + def __init__(self, config: PatchTSMixerConfig): + super().__init__() + self.batchnorm = nn.BatchNorm1d(config.d_model, eps=config.norm_eps) + + def forward(self, inputs: torch.Tensor): + """ + Parameters: + inputs (`torch.Tensor` of shape `(batch_size, sequence_length, d_model)`): + input for Batch norm calculation + Returns: + `torch.Tensor` of shape `(batch_size, sequence_length, d_model)` + """ + output = inputs.transpose(1, 2) # output: (batch_size, d_model, sequence_length) + output = self.batchnorm(output) + return output.transpose(1, 2) + + +class PatchTSMixerPositionalEncoding(nn.Module): + """ + Class for positional encoding + """ + + def __init__(self, config: PatchTSMixerConfig): + super().__init__() + # positional encoding: [num_patches x d_model] + if config.use_positional_encoding: + self.position_enc = self._init_pe(config) + else: + self.position_enc = nn.Parameter(torch.zeros(config.num_patches, config.d_model)) + + @staticmethod + def _init_pe(config: PatchTSMixerConfig) -> nn.Parameter: + # Positional encoding + if config.positional_encoding_type == "random": + position_enc = nn.Parameter(torch.randn(config.num_patches, config.d_model), requires_grad=True) + elif config.positional_encoding_type == "sincos": + position_enc = torch.zeros(config.num_patches, config.d_model) + position = torch.arange(0, config.num_patches).unsqueeze(1) + div_term = torch.exp(torch.arange(0, config.d_model, 2) * -(math.log(10000.0) / config.d_model)) + position_enc[:, 0::2] = torch.sin(position * div_term) + position_enc[:, 1::2] = torch.cos(position * div_term) + position_enc = position_enc - position_enc.mean() + position_enc = position_enc / (position_enc.std() * 10) + position_enc = nn.Parameter(position_enc, requires_grad=False) + else: + raise ValueError( + f"{config.positional_encoding_type} is not a valid positional encoder. Available types are 'random' and 'sincos'." + ) + return position_enc + + def forward(self, patch_input: torch.Tensor): + # hidden_state: [bs x num_channels x num_patches x d_model] + hidden_state = patch_input + self.position_enc + return hidden_state + + +class PatchTSMixerNormLayer(nn.Module): + """Normalization block + + Args: + config (`PatchTSMixerConfig`): + Configuration. + """ + + def __init__(self, config: PatchTSMixerConfig): + super().__init__() + + self.norm_mlp = config.norm_mlp + + if "batch" in config.norm_mlp.lower(): + self.norm = PatchTSMixerBatchNorm(config) + else: + self.norm = nn.LayerNorm(config.d_model, eps=config.norm_eps) + + def forward(self, inputs: torch.Tensor): + """ + Args: + inputs (`torch.Tensor` of shape `((batch_size, num_channels, num_patches, d_model))`): + Input to the normalization layer. + Returns: + `torch.Tensor` of shape `((batch_size, num_channels, num_patches, d_model))` + """ + if "batch" in self.norm_mlp.lower(): + # reshape the data + inputs_reshaped = torch.reshape( + inputs, + ( + inputs.shape[0] * inputs.shape[1], + inputs.shape[2], + inputs.shape[3], + ), + ) # inputs_reshaped: [batch_size*num_channels, num_patches, d_model] + + # inputs_reshaped: [batch_size*num_channels, num_patches, d_model] + inputs_reshaped = self.norm(inputs_reshaped) + + # put back data to the original shape + inputs = torch.reshape(inputs_reshaped, inputs.shape) + + else: + inputs = self.norm(inputs) + + return inputs + + +class PatchTSMixerMLP(nn.Module): + def __init__(self, in_features, out_features, config): + super().__init__() + num_hidden = in_features * config.expansion_factor + self.fc1 = nn.Linear(in_features, num_hidden) + self.dropout1 = nn.Dropout(config.dropout) + self.fc2 = nn.Linear(num_hidden, out_features) + self.dropout2 = nn.Dropout(config.dropout) + + def forward(self, inputs: torch.Tensor): + """ + Args: + inputs (`torch.Tensor` of shape `((batch_size, num_channels, num_patches, d_model))`): + Input to the MLP layer. + Returns: + `torch.Tensor` of the same shape as `inputs` + """ + inputs = self.dropout1(nn.functional.gelu(self.fc1(inputs))) + inputs = self.fc2(inputs) + inputs = self.dropout2(inputs) + return inputs + + +class PatchTSMixerChannelFeatureMixerBlock(nn.Module): + """This module mixes the features in the channel dimension. + + Args: + config (`PatchTSMixerConfig`): + Configuration. + """ + + def __init__(self, config: PatchTSMixerConfig): + super().__init__() + + self.norm = PatchTSMixerNormLayer(config) + self.gated_attn = config.gated_attn + self.mlp = PatchTSMixerMLP( + in_features=config.num_input_channels, + out_features=config.num_input_channels, + config=config, + ) + + if config.gated_attn: + self.gating_block = PatchTSMixerGatedAttention( + in_size=config.num_input_channels, out_size=config.num_input_channels + ) + + def forward(self, inputs: torch.Tensor): + """ + Args: + inputs (`torch.Tensor` of shape `((batch_size, num_channels, num_patches, d_model))`): + input to the MLP layer + Returns: + `torch.Tensor` of the same shape as `inputs` + """ + residual = inputs + inputs = self.norm(inputs) + + inputs = inputs.permute(0, 3, 2, 1) + + if self.gated_attn: + inputs = self.gating_block(inputs) + + inputs = self.mlp(inputs) + + inputs = inputs.permute(0, 3, 2, 1) + + out = inputs + residual + return out + + +# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->PatchTSMixer +class PatchTSMixerAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + is_causal: bool = False, + config: Optional[PatchTSMixerConfig] = None, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + self.config = config + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + self.is_causal = is_causal + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.reshape(*proj_shape) + value_states = value_states.reshape(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value + + +class PatchMixerBlock(nn.Module): + """This module mixes the patch dimension. + + Args: + config (`PatchTSMixerConfig`): + Configuration. + """ + + def __init__(self, config: PatchTSMixerConfig): + super().__init__() + + self.norm = PatchTSMixerNormLayer(config) + + self.self_attn = config.self_attn + self.gated_attn = config.gated_attn + + self.mlp = PatchTSMixerMLP( + in_features=config.num_patches, + out_features=config.num_patches, + config=config, + ) + + if config.gated_attn: + self.gating_block = PatchTSMixerGatedAttention(in_size=config.num_patches, out_size=config.num_patches) + + if config.self_attn: + self.self_attn_layer = PatchTSMixerAttention( + embed_dim=config.d_model, + num_heads=config.self_attn_heads, + dropout=config.dropout, + ) + self.norm_attn = PatchTSMixerNormLayer(config) + + def forward(self, hidden_state): + """ + Args: + hidden_state (`torch.Tensor`): Input tensor. + + Returns: + `torch.Tensor`: Transformed tensor. + """ + residual = hidden_state + + hidden_state = self.norm(hidden_state) + + if self.self_attn: + batch_size, n_vars, num_patches, d_model = hidden_state.shape + hidden_state_reshaped = hidden_state.reshape(batch_size * n_vars, num_patches, d_model) + + x_attn, _, _ = self.self_attn_layer(hidden_state_reshaped, output_attentions=False) + x_attn = x_attn.reshape(batch_size, n_vars, num_patches, d_model) + + # Transpose so that num_patches is the last dimension + hidden_state = hidden_state.transpose(2, 3) + hidden_state = self.mlp(hidden_state) + + if self.gated_attn: + hidden_state = self.gating_block(hidden_state) + + # Transpose back + hidden_state = hidden_state.transpose(2, 3) + + if self.self_attn: + hidden_state = self.norm_attn(hidden_state + x_attn) + + out = hidden_state + residual + return out + + +class FeatureMixerBlock(nn.Module): + """This module mixes the hidden feature dimension. + + Args: + config (`PatchTSMixerConfig`): + Configuration. + + """ + + def __init__(self, config: PatchTSMixerConfig): + super().__init__() + + self.norm = PatchTSMixerNormLayer(config) + + self.gated_attn = config.gated_attn + + self.mlp = PatchTSMixerMLP( + in_features=config.d_model, + out_features=config.d_model, + config=config, + ) + + if config.gated_attn: + self.gating_block = PatchTSMixerGatedAttention(in_size=config.d_model, out_size=config.d_model) + + def forward(self, hidden: torch.Tensor): + """ + Args: + hidden (`torch.Tensor` of shape `(batch_size, num_patches, d_model)`): + Input tensor to the layer. + + Returns: + `torch.Tensor`: Transformed tensor. + """ + residual = hidden + hidden = self.norm(hidden) + hidden = self.mlp(hidden) + + if self.gated_attn: + hidden = self.gating_block(hidden) + + out = hidden + residual + return out + + +class PatchTSMixerLayer(nn.Module): + """ + The `PatchTSMixer` layer that does all three kinds of mixing. + + Args: + config (`PatchTSMixerConfig`): + Configuration. + + """ + + def __init__(self, config: PatchTSMixerConfig): + super().__init__() + + self.patch_mixer = PatchMixerBlock(config=config) + self.feature_mixer = FeatureMixerBlock(config=config) + + self.mode = config.mode + + if config.mode == "mix_channel": + self.channel_feature_mixer = PatchTSMixerChannelFeatureMixerBlock(config=config) + + def forward(self, hidden: torch.Tensor): + """ + Args: + hidden (`torch.Tensor` of shape `(batch_size, num_patches, d_model)`): + Input tensor to the layer. + + Returns: + `torch.Tensor`: Transformed tensor. + """ + if self.mode == "mix_channel": + hidden = self.channel_feature_mixer(hidden) + + hidden = self.patch_mixer(hidden) + hidden = self.feature_mixer(hidden) # hidden: (batch_size x num_patches x d_model) + return hidden + + +class PatchTSMixerBlock(nn.Module): + """The main computing framework of the `PatchTSMixer` model. + + Args: + config (`PatchTSMixerConfig`): + Configuration. + """ + + def __init__(self, config: PatchTSMixerConfig): + super().__init__() + + num_layers = config.num_layers + + self.mixers = nn.ModuleList([PatchTSMixerLayer(config=config) for _ in range(num_layers)]) + + def forward(self, hidden_state, output_hidden_states: bool = False): + """ + Args: + hidden_state (`torch.Tensor`): The input tensor. + output_hidden_states (`bool`, *optional*, defaults to False.): + Whether to output the hidden states as well. + + Returns: + `torch.Tensor`: The embedding. `list`: List of all hidden states if `output_hidden_states` is set to + `True`. + """ + all_hidden_states = [] + + embedding = hidden_state + + for mod in self.mixers: + embedding = mod(embedding) + if output_hidden_states: + all_hidden_states.append(embedding) + + if output_hidden_states: + return embedding, all_hidden_states + else: + return embedding, None + + +class PatchTSMixerForPredictionHead(nn.Module): + """Prediction Head for Forecasting + + Args: + config (`PatchTSMixerConfig`): + Configuration. + """ + + def __init__(self, config: PatchTSMixerConfig, distribution_output=None): + super().__init__() + + self.prediction_channel_indices = config.prediction_channel_indices + + if self.prediction_channel_indices is not None: + self.prediction_channel_indices.sort() + + self.dropout_layer = nn.Dropout(config.head_dropout) + if distribution_output is None: + self.base_forecast_block = nn.Linear((config.num_patches * config.d_model), config.prediction_length) + else: + self.base_forecast_block = distribution_output.get_parameter_projection( + config.num_patches * config.d_model + ) + + self.flatten = nn.Flatten(start_dim=-2) + + def forward(self, hidden_features): + """ + + Args: + hidden_features (`torch.Tensor` of shape `(batch_size, num_patch, d_model)` in `flatten` mode + or `(batch_size, n_vars, num_patch, d_model)` in `common_channel`/`mix_channel` mode.): Input hidden + features. + + Returns: + `torch.Tensor` of shape `(batch_size, prediction_length, nvars)`. + + """ + + hidden_features = self.flatten(hidden_features) # [batch_size x n_vars x num_patch * d_model] + hidden_features = self.dropout_layer(hidden_features) # [batch_size x n_vars x num_patch * d_model] + forecast = self.base_forecast_block(hidden_features) # [batch_size x n_vars x prediction_length] + if isinstance(forecast, tuple): + forecast = tuple(z.transpose(-1, -2) for z in forecast) + else: + forecast = forecast.transpose(-1, -2) # [batch_size x prediction_length x n_vars] + + if self.prediction_channel_indices is not None: + if isinstance(forecast, tuple): + forecast = tuple(z[..., self.prediction_channel_indices] for z in forecast) + else: + forecast = forecast[..., self.prediction_channel_indices] # [batch_size x prediction_length x n_vars] + + return forecast + + +class PatchTSMixerLinearHead(nn.Module): + """Linear head for Classification and Regression. + + Args: + config (`PatchTSMixerConfig`): + Configuration. + """ + + def __init__(self, config: PatchTSMixerConfig, distribution_output=None): + super().__init__() + + self.head_aggregation = config.head_aggregation + self.output_range = config.output_range + + if config.head_aggregation is None: + mul_factor = config.num_patches + else: + mul_factor = 1 + self.distribution_output = distribution_output + if distribution_output is None: + self.projection = nn.Linear( + config.d_model * config.num_input_channels * mul_factor, + config.num_targets, + ) + else: + self.projection = distribution_output.get_parameter_projection( + config.d_model * config.num_input_channels * mul_factor + ) + + if config.head_aggregation is None: + self.flatten = nn.Flatten(start_dim=-3) + else: + self.flatten = nn.Flatten(start_dim=-2) + + self.dropout = nn.Dropout(config.head_dropout) + + def forward(self, hidden_features): + """ + Args: + hidden_features (`torch.Tensor` of shape `(batch_size x num_patch x d_model)` in `flatten` mode + or `(batch_size x n_vars x num_patch x d_model)` in `common_channel`/`mix_channel` mode.): Input hidden + features. + + Returns: + `torch.Tensor` of shape `(batch_size x num_targets)`. + """ + + # batch_size x d_model x num_patch or batch_size x n_vars x d_model x num_patch + hidden_features = hidden_features.transpose(-1, -2) + if self.head_aggregation == "use_last": + # batch_size x d_model (flatten) or # batch_size x n_vars x d_model (common_channel) + hidden_features = hidden_features[..., -1] + elif self.head_aggregation == "max_pool": + # batch_size x n_vars x d_model or batch_size x d_model + hidden_features = hidden_features.max(dim=-1).values + elif self.head_aggregation == "avg_pool": + # batch_size x n_vars x d_model or batch_size x d_model + hidden_features = hidden_features.mean(dim=-1) + + if self.flatten: + hidden_features = self.flatten(hidden_features) + hidden_features = self.dropout(hidden_features) + hidden_features = self.projection(hidden_features) # batch_size x num_targets + + if (self.distribution_output is None) and (self.output_range is not None): + hidden_features = ( + torch.sigmoid(hidden_features) * (self.output_range[1] - self.output_range[0]) + self.output_range[0] + ) + return hidden_features + + +class PatchTSMixerPreTrainedModel(PreTrainedModel): + # Weight initialization + config_class = PatchTSMixerConfig + base_model_prefix = "model" + main_input_name = "past_values" + supports_gradient_checkpointing = False + + def _init_weights(self, module): + """Initialize weights""" + if isinstance(module, PatchTSMixerPositionalEncoding): + # initialize positional encoding + if self.config.positional_encoding_type == "random": + nn.init.normal_(module.position_enc, mean=0.0, std=0.1) + elif isinstance(module, (nn.LayerNorm, nn.BatchNorm1d)): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, PatchTSMixerBatchNorm): + module.batchnorm.bias.data.zero_() + module.batchnorm.weight.data.fill_(1.0) + elif isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=self.config.init_std) + if module.bias is not None: + module.bias.data.zero_() + + +class PatchTSMixerPretrainHead(nn.Module): + """Pretraining head. + + Args: + config (`PatchTSMixerConfig`): + Configuration. + """ + + def __init__(self, config: PatchTSMixerConfig): + super().__init__() + + self.dropout_layer = nn.Dropout(config.head_dropout) + self.base_pt_block = nn.Linear(config.d_model, config.patch_length) + + def forward(self, hidden_features): + """ + Args: + hidden_features (`torch.Tensor` of shape `(batch_size x num_patch x d_model)` in `flatten` mode + or `(batch_size x n_vars x num_patch x d_model)` in `common_channel`/`mix_channel` mode.): Input hidden + features. + + Returns: + `torch.Tensor` of shape `(batch_size x n_vars x num_patch x patch_length)`. + """ + + hidden_features = self.dropout_layer(hidden_features) + forecast = self.base_pt_block(hidden_features) # [batch_size x n_vars x num_patch x patch_length] + return forecast + + +# Copied from transformers.models.patchtst.modeling_patchtst.random_masking +def random_masking( + inputs: torch.Tensor, + mask_ratio: float, + unmasked_channel_indices: list = None, + channel_consistent_masking: bool = False, + mask_value: int = 0, +): + """random_masking: Mask the input considering the control variables. + + Args: + inputs (`torch.Tensor` of shape `(batch_size, num_channels, sequence_length, num_features)`): + The input tensor to mask. + mask_ratio (`float`): + Masking ratio applied to mask the input data during random pretraining. It is the number between 0 and 1. + unmasked_channel_indices (list, *optional*): + Indices of channels that will not be masked. + channel_consistent_masking (bool, *optional*, defaults to `False`): + When true, masking will be same across all channels of a timeseries. Otherwise, masking positions will vary + across channels. + mask_value (int, *optional*, defaults to 0): + Define the value of masked patches for pretraining. + + Returns: + `tuple(torch.Tensor)`: inputs_mask, masked input, same shape as input Tensor and mask tensor of shape [bs x c x + n] + """ + if mask_ratio < 0 or mask_ratio >= 1: + raise ValueError(f"Mask ratio {mask_ratio} has to be between 0 and 1.") + + batch_size, num_channels, sequence_length, num_features = inputs.shape + device = inputs.device + + len_keep = int(sequence_length * (1 - mask_ratio)) + + if channel_consistent_masking: + noise = torch.rand(batch_size, 1, sequence_length, device=device) # noise in [0, 1], bs x 1 x L + noise = noise.repeat(1, num_channels, 1) # bs x num_channels x time + else: + # noise in [0, 1], bs x num_channels x L + noise = torch.rand(batch_size, num_channels, sequence_length, device=device) + + # mask: [bs x num_channels x num_patch] + mask = torch.ones(batch_size, num_channels, sequence_length, device=device) + mask[:, :, :len_keep] = 0 + + # sort noise for each sample + ids_shuffle = torch.argsort(noise, dim=-1) # ascend: small is keep, large is remove + ids_restore = torch.argsort(ids_shuffle, dim=-1) # ids_restore: [bs x num_channels x L] + + mask = torch.gather(mask, dim=-1, index=ids_restore) + mask = mask.unsqueeze(-1).repeat(1, 1, 1, num_features) # mask: [bs x num_channels x num_patches x patch_length] + if unmasked_channel_indices is not None: + mask[:, unmasked_channel_indices, :, :] = 0 + + inputs_mask = inputs.masked_fill(mask.bool(), mask_value) + return inputs_mask, mask[..., 0] + + +# Copied from transformers.models.patchtst.modeling_patchtst.forecast_masking +def forecast_masking( + inputs: torch.Tensor, + num_forecast_mask_patches: Union[list, int], + unmasked_channel_indices: list = None, + mask_value: int = 0, +): + """Forecast masking that masks the last K patches where K is from the num_forecast_mask_patches. + If num_forecast_mask_patches is a list, samples in the batch will be randomly masked by numbers defined in the list. + + Parameters: + inputs (`torch.Tensor`): + Input of shape `(bs, num_channels, num_patch, patch_length)` + num_forecast_mask_patches (`list`): + Number of patches to be masked at the end of each batch sample. e.g. 4 or [3, 5]. + unmasked_channel_indices (`list`, *optional*): + Indices of channels that are not masked. + mask_value (`int`, *optional*, defaults to 0): + Values in the masked patches will be filled by `mask_value`. + + Returns: + `tuple(torch.Tensor)`: inputs_mask, masked input, same shape as inputs Tensor and Mask tensor of shape `(bs, + num_channels , num_patch)` or `(bs, tsg1, tsg2, num_channels, num_patch)` + """ + + if isinstance(num_forecast_mask_patches, int): + num_forecast_mask_patches = [num_forecast_mask_patches] + forecast_mask_ratios = [1 for _ in num_forecast_mask_patches] + + batch_size, num_channels, sequence_length, num_features = inputs.shape + mask = torch.zeros(batch_size, num_channels, sequence_length, device=inputs.device) + + t_list = [] + total_length = 0 + total_ratio = sum(forecast_mask_ratios) + + for patch_length, ratio in zip(num_forecast_mask_patches, forecast_mask_ratios): + if patch_length <= 0 or patch_length >= sequence_length: + raise ValueError( + f"num_forecast_mask_patches {patch_length} should be greater than 0 and less than total patches." + ) + temp_len = int(batch_size * ratio / total_ratio) + t_list.append([patch_length, ratio, temp_len]) + total_length += temp_len + + t_list = sorted(t_list, key=lambda x: x[2]) + + if total_length < batch_size: + t_list[0][2] = t_list[0][2] + (batch_size - total_length) + elif total_length > batch_size: + t_list[-1][2] = t_list[-1][2] + (total_length - batch_size) + + batch1 = 0 + for patch_len, _, temp_len in t_list: + batch2 = batch1 + temp_len + mask[batch1:batch2, :, -patch_len:] = 1 + batch1 = batch2 + + perm = torch.randperm(mask.shape[0]) + mask = mask[perm] + + mask = mask.unsqueeze(-1).repeat(1, 1, 1, num_features) # mask: [bs x num_channels x num_patch x patch_len] + if unmasked_channel_indices is not None: + mask[:, unmasked_channel_indices, :, :] = 0 + + inputs_mask = inputs.masked_fill(mask.bool(), mask_value) + return inputs_mask, mask[..., 0] + + +# Copied from transformers.models.patchtst.modeling_patchtst.PatchTSTPatchify with PatchTST->PatchTSMixer +class PatchTSMixerPatchify(nn.Module): + """ + A class to patchify the time series sequence into different patches + + Returns: + `torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)` + """ + + def __init__(self, config: PatchTSMixerConfig): + super().__init__() + + self.sequence_length = config.context_length + self.patch_length = config.patch_length + self.patch_stride = config.patch_stride + + if self.sequence_length <= self.patch_length: + raise ValueError( + f"Sequence length ({self.sequence_length}) has to be greater than the patch length ({self.patch_length})" + ) + + # get the number of patches + self.num_patches = (max(self.sequence_length, self.patch_length) - self.patch_length) // self.patch_stride + 1 + new_sequence_length = self.patch_length + self.patch_stride * (self.num_patches - 1) + self.sequence_start = self.sequence_length - new_sequence_length + + def forward(self, past_values: torch.Tensor): + """ + Parameters: + past_values (`torch.Tensor` of shape `(batch_size, sequence_length, num_channels)`, *required*): + Input for patchification + + Returns: + `torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)` + """ + sequence_length = past_values.shape[-2] + if sequence_length != self.sequence_length: + raise ValueError( + f"Input sequence length ({sequence_length}) doesn't match model configuration ({self.sequence_length})." + ) + # output: [bs x new_sequence_length x num_channels] + output = past_values[:, self.sequence_start :, :] + # output: [bs x num_patches x num_input_channels x patch_length] + output = output.unfold(dimension=-2, size=self.patch_length, step=self.patch_stride) + # output: [bs x num_input_channels x num_patches x patch_length] + output = output.transpose(-2, -3).contiguous() + return output + + +# Copied from transformers.models.patchtst.modeling_patchtst.PatchTSTMasking with PatchTST->PatchTSMixer +class PatchTSMixerMasking(nn.Module): + """ + Class to perform random or forecast masking. + + Parameters: + config (`PatchTSMixerConfig`): model config + Returns: + x_mask (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`) + Masked patched input + mask (`torch.Tensor` of shape `(batch_size, num_channels, num_patches)`) + Bool tensor indicating True on masked points + """ + + def __init__(self, config: PatchTSMixerConfig): + super().__init__() + self.random_mask_ratio = config.random_mask_ratio + self.channel_consistent_masking = config.channel_consistent_masking + self.mask_type = config.mask_type + self.num_forecast_mask_patches = config.num_forecast_mask_patches + self.unmasked_channel_indices = config.unmasked_channel_indices + self.mask_value = config.mask_value + if self.unmasked_channel_indices is not None: + self.unmasked_channel_indices = sorted(self.unmasked_channel_indices) + + def forward(self, patch_input: torch.Tensor): + """ + Parameters: + patch_input (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`, *required*): + Patch input + + Return: + masked_input (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`) + Masked patched input + mask (`torch.Tensor` of shape `(batch_size, num_channels, num_patches)`) + Bool tensor indicating True on masked points + + """ + if self.mask_type == "random": + masked_input, mask = random_masking( + inputs=patch_input, + mask_ratio=self.random_mask_ratio, + unmasked_channel_indices=self.unmasked_channel_indices, + channel_consistent_masking=self.channel_consistent_masking, + mask_value=self.mask_value, + ) + elif self.mask_type == "forecast": + masked_input, mask = forecast_masking( + inputs=patch_input, + num_forecast_mask_patches=self.num_forecast_mask_patches, + unmasked_channel_indices=self.unmasked_channel_indices, + mask_value=self.mask_value, + ) + else: + raise ValueError(f"Invalid mask type {self.mask_type}.") + + # mask: [bs x num_input_channels x num_patch] + mask = mask.bool() + return masked_input, mask + + +# Copied from transformers.models.patchtst.modeling_patchtst.PatchTSTStdScaler with PatchTST->PatchTSMixer +class PatchTSMixerStdScaler(nn.Module): + """ + Standardize features by calculating the mean and scaling along the first dimension, and then normalizes it by + subtracting from the mean and dividing by the standard deviation. + """ + + def __init__(self, config: PatchTSMixerConfig): + super().__init__() + self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1 + self.keepdim = config.keepdim if hasattr(config, "keepdim") else True + self.minimum_scale = config.minimum_scale if hasattr(config, "minimum_scale") else 1e-5 + + def forward( + self, data: torch.Tensor, observed_indicator: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Parameters: + data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`): + input for Batch norm calculation + observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`): + Calculating the scale on the observed indicator. + Returns: + tuple of `torch.Tensor` of shapes + (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`, + `(batch_size, 1, num_input_channels)`) + """ + denominator = observed_indicator.sum(self.dim, keepdim=self.keepdim) + denominator = denominator.clamp_min(1.0) + loc = (data * observed_indicator).sum(self.dim, keepdim=self.keepdim) / denominator + + variance = (((data - loc) * observed_indicator) ** 2).sum(self.dim, keepdim=self.keepdim) / denominator + scale = torch.sqrt(variance + self.minimum_scale) + return (data - loc) / scale, loc, scale + + +# Copied from transformers.models.patchtst.modeling_patchtst.PatchTSTMeanScaler with PatchTST->PatchTSMixer +class PatchTSMixerMeanScaler(nn.Module): + """ + Computes a scaling factor as the weighted average absolute value along the first dimension, and scales the data + accordingly. + """ + + def __init__(self, config: PatchTSMixerConfig): + super().__init__() + self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1 + self.keepdim = config.keepdim if hasattr(config, "keepdim") else True + self.minimum_scale = config.minimum_scale if hasattr(config, "minimum_scale") else 1e-10 + self.default_scale = config.default_scale if hasattr(config, "default_scale") else None + + def forward( + self, data: torch.Tensor, observed_indicator: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Parameters: + data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`): + input for Batch norm calculation + observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`): + Calculating the scale on the observed indicator. + Returns: + tuple of `torch.Tensor` of shapes + (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`, + `(batch_size, 1, num_input_channels)`) + """ + ts_sum = (data * observed_indicator).abs().sum(self.dim, keepdim=True) + num_observed = observed_indicator.sum(self.dim, keepdim=True) + + scale = ts_sum / torch.clamp(num_observed, min=1) + + # If `default_scale` is provided, we use it, otherwise we use the scale + # of the batch. + if self.default_scale is None: + batch_sum = ts_sum.sum(dim=0) + batch_observations = torch.clamp(num_observed.sum(0), min=1) + default_scale = torch.squeeze(batch_sum / batch_observations) + else: + default_scale = self.default_scale * torch.ones_like(scale) + + # apply default scale where there are no observations + scale = torch.where(num_observed > 0, scale, default_scale) + + # ensure the scale is at least `self.minimum_scale` + scale = torch.clamp(scale, min=self.minimum_scale) + scaled_data = data / scale + + if not self.keepdim: + scale = scale.squeeze(dim=self.dim) + + return scaled_data, torch.zeros_like(scale), scale + + +# Copied from transformers.models.patchtst.modeling_patchtst.PatchTSTNOPScaler with PatchTST->PatchTSMixer +class PatchTSMixerNOPScaler(nn.Module): + """ + Assigns a scaling factor equal to 1 along the first dimension, and therefore applies no scaling to the input data. + """ + + def __init__(self, config: PatchTSMixerConfig): + super().__init__() + self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1 + self.keepdim = config.keepdim if hasattr(config, "keepdim") else True + + def forward( + self, data: torch.Tensor, observed_indicator: Optional[torch.Tensor] = None + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Parameters: + data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`): + input for Batch norm calculation + Returns: + tuple of `torch.Tensor` of shapes + (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`, + `(batch_size, 1, num_input_channels)`) + """ + scale = torch.ones_like(data, requires_grad=False).mean(dim=self.dim, keepdim=self.keepdim) + loc = torch.zeros_like(data, requires_grad=False).mean(dim=self.dim, keepdim=self.keepdim) + return data, loc, scale + + +@dataclass +class PatchTSMixerEncoderOutput(ModelOutput): + """ + Base class for `PatchTSMixerEncoderOutput`, with potential hidden states. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches, d_model)`): + Hidden-state at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*): + Hidden-states of the model at the output of each layer. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + + +class PatchTSMixerEncoder(PatchTSMixerPreTrainedModel): + """ + Encoder for PatchTSMixer which inputs patched time-series and outputs patched embeddings. + + Args: + config (`PatchTSMixerConfig`): + Configuration. + """ + + def __init__(self, config: PatchTSMixerConfig): + super().__init__(config) + + self.use_return_dict = config.use_return_dict + + self.patcher = nn.Linear(config.patch_length, config.d_model) + if config.use_positional_encoding: + self.positional_encoder = PatchTSMixerPositionalEncoding(config=config) + else: + self.positional_encoder = None + self.mlp_mixer_encoder = PatchTSMixerBlock(config=config) + + # Initialize weights and apply final processing + if config.post_init: + self.post_init() + + @replace_return_docstrings(output_type=PatchTSMixerEncoderOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + past_values: torch.Tensor, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, PatchTSMixerEncoderOutput]: + r""" + Args: + past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`): + Context values of the time series. For a pretraining task, this denotes the input time series to + predict the masked portion. For a forecasting task, this denotes the history/past time series values. + Similarly, for classification or regression tasks, it denotes the appropriate context values of the + time series. + + For univariate time series, `num_input_channels` dimension should be 1. For multivariate time series, + it is greater than 1. + + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. + + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + + Returns: + `torch.FloatTensor` of shape `(batch_size, n_vars, num_patches, d_model)` + """ + + return_dict = return_dict if return_dict is not None else self.use_return_dict + + # flatten [bs x num_patch x d_model]. common_channel/mix_channel: [bs x n_vars x num_patch x d_model] + patches = self.patcher(past_values) + + # add positional encoder + if self.positional_encoder is not None: + patches = self.positional_encoder(patches) + + last_hidden_state, hidden_states = self.mlp_mixer_encoder(patches, output_hidden_states=output_hidden_states) + + if not return_dict: + return tuple( + v + for v in [ + last_hidden_state, + hidden_states, + ] + ) + + return PatchTSMixerEncoderOutput(last_hidden_state=last_hidden_state, hidden_states=hidden_states) + + +@dataclass +class PatchTSMixerModelOutput(ModelOutput): + """ + Base class for model's outputs, with potential hidden states. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches, d_model)`): + Hidden-state at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*): + Hidden-states of the model at the output of each layer. + patch_input (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches, patch_length)`): + Patched input data to the model. + mask: (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches)`,*optional*): + Bool Tensor indicating True in masked patches and False otherwise. + loc: (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`,*optional*): + Gives the mean of the context window per channel. Used for revin denorm outside the model, if revin + enabled. + scale: (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`,*optional*): + Gives the std dev of the context window per channel. Used for revin denorm outside the model, if revin + enabled. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + patch_input: Optional[torch.FloatTensor] = None + mask: Optional[torch.FloatTensor] = None + loc: Optional[torch.FloatTensor] = None + scale: Optional[torch.FloatTensor] = None + + +@add_start_docstrings( + "The PatchTSMixer Model for time-series forecasting.", + PATCHTSMIXER_START_DOCSTRING, +) +class PatchTSMixerModel(PatchTSMixerPreTrainedModel): + def __init__(self, config: PatchTSMixerConfig, mask_input: bool = False): + super().__init__(config) + + self.use_return_dict = config.use_return_dict + self.encoder = PatchTSMixerEncoder(config) + self.patching = PatchTSMixerPatchify(config) + + if mask_input is True: + self.masking = PatchTSMixerMasking(config) + else: + self.masking = None + + if config.scaling == "mean": + self.scaler = PatchTSMixerMeanScaler(config) + elif config.scaling == "std" or config.scaling is True: + self.scaler = PatchTSMixerStdScaler(config) + else: + self.scaler = PatchTSMixerNOPScaler(config) + + # Initialize weights and apply final processing + if config.post_init: + self.post_init() + + @add_start_docstrings_to_model_forward(PATCHTSMIXER_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=PatchTSMixerModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + past_values: torch.Tensor, + observed_mask: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = None, + ) -> PatchTSMixerModelOutput: + r""" + observed_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*): + Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected + in `[0, 1]`: + - 1 for values that are **observed**, + - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros). + + Returns: + + """ + return_dict = return_dict if return_dict is not None else self.use_return_dict + + mask = None + if observed_mask is None: + observed_mask = torch.ones_like(past_values) + scaled_past_values, loc, scale = self.scaler(past_values, observed_mask) + + patched_x = self.patching(scaled_past_values) # [batch_size x num_input_channels x num_patch x patch_length + + enc_input = patched_x + if self.masking is not None: + enc_input, mask = self.masking(patched_x) + # enc_input: [batch_size x num_input_channels x num_patch x patch_length] + # mask: [batch_size x num_input_channels x num_patch] + + encoder_output = self.encoder( + enc_input, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if isinstance(encoder_output, tuple): + encoder_output = PatchTSMixerEncoderOutput(*encoder_output) + + if not return_dict: + return tuple( + v + for v in [ + encoder_output.last_hidden_state, + encoder_output.hidden_states, + patched_x, + mask, + loc, + scale, + ] + ) + + return PatchTSMixerModelOutput( + last_hidden_state=encoder_output.last_hidden_state, + hidden_states=encoder_output.hidden_states, + patch_input=patched_x, + mask=mask, + loc=loc, + scale=scale, + ) + + +@dataclass +class PatchTSMixerForPreTrainingOutput(ModelOutput): + """ + Output type of [`PatchTSMixerForPreTrainingOutput`]. + + Args: + prediction_outputs (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, patch_length)`): + Prediction output from the pretrain head. + hidden_states (`tuple(torch.FloatTensor)`, *optional*): + Hidden-states of the model at the output of each layer. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, d_model)`): + Backbone embeddings before passing through the head. + loss (*optional*, returned when `y` is provided, `torch.FloatTensor` of shape `()`): + Total loss + """ + + loss: Optional[torch.FloatTensor] = None + prediction_outputs: Optional[torch.FloatTensor] = None + last_hidden_state: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + + +class PatchTSMixerForPretraining(PatchTSMixerPreTrainedModel): + r""" + `PatchTSMixer` for mask pretraining. + + Args: + config (`PatchTSMixerConfig`): + Configuration. + + Returns: + `None`. + """ + + def __init__(self, config: PatchTSMixerConfig): + super().__init__(config) + self.model = PatchTSMixerModel(config, mask_input=True) + self.head = PatchTSMixerPretrainHead(config=config) + self.masked_loss = config.masked_loss + self.use_return_dict = config.use_return_dict + + # Initialize weights and apply final processing + if config.post_init: + self.post_init() + + @add_start_docstrings_to_model_forward(PATCHTSMIXER_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=PatchTSMixerForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + past_values: torch.Tensor, + observed_mask: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = False, + return_loss: bool = True, + return_dict: Optional[bool] = None, + ) -> PatchTSMixerForPreTrainingOutput: + r""" + observed_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*): + Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected + in `[0, 1]`: + - 1 for values that are **observed**, + - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros). + return_loss (`bool`, *optional*): + Whether to return the loss in the `forward` call. + + Returns: + + """ + return_dict = return_dict if return_dict is not None else self.use_return_dict + + if self.masked_loss is True: + loss = torch.nn.MSELoss(reduction="none") + else: + loss = torch.nn.MSELoss(reduction="mean") + + # past_values: tensor [batch_size x context_length x num_input_channels] + model_output = self.model( + past_values, + observed_mask=observed_mask, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) # x.last_hidden_state: [batch_size x nvars x num_patch x d_model] + if isinstance(model_output, tuple): + model_output = PatchTSMixerModelOutput(*model_output) + + x_hat = self.head(model_output.last_hidden_state) # tensor [batch_size x nvars x num_patch x patch_length] + + if return_loss is True: + loss_val = loss(x_hat, model_output.patch_input) + else: + loss_val = None + + # calculate masked_loss + if self.masked_loss is True and loss_val is not None: + loss_val = (loss_val.mean(dim=-1) * model_output.mask).sum() / (model_output.mask.sum() + 1e-10) + + if not return_dict: + return tuple( + v + for v in [ + loss_val, + x_hat, + model_output.last_hidden_state, + model_output.hidden_states, + ] + ) + + return PatchTSMixerForPreTrainingOutput( + loss=loss_val, + prediction_outputs=x_hat, # tensor [batch_size x nvars x num_patch x patch_length] + last_hidden_state=model_output.last_hidden_state, # x: [batch_size x nvars x num_patch x d_model] + hidden_states=model_output.hidden_states, + ) + + +@dataclass +class PatchTSMixerForPredictionOutput(ModelOutput): + """ + Output type of [`PatchTSMixerForPredictionOutput`]. + + Args: + prediction_outputs (`torch.FloatTensor` of shape `(batch_size, prediction_length, num_input_channels)`): + Prediction output from the forecast head. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, d_model)`): + Backbone embeddings before passing through the head. + hidden_states (`tuple(torch.FloatTensor)`, *optional*): + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + loss (*optional*, returned when `y` is provided, `torch.FloatTensor` of shape `()`): + Total loss. + loc (`torch.FloatTensor`, *optional* of shape `(batch_size, 1, num_input_channels)`): + Input mean + scale (`torch.FloatTensor`, *optional* of shape `(batch_size, 1, num_input_channels)`): + Input std dev + + """ + + loss: Optional[torch.FloatTensor] = None + prediction_outputs: Optional[torch.FloatTensor] = None + last_hidden_state: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + loc: Optional[torch.FloatTensor] = None + scale: Optional[torch.FloatTensor] = None + + +@dataclass +class SamplePatchTSMixerPredictionOutput(ModelOutput): + """ + Base class for time series model's predictions outputs that contains the sampled values from the chosen + distribution. + + Args: + sequences (`torch.FloatTensor` of shape `(batch_size, num_samples, prediction_length, number_channels)`): + Sampled values from the chosen distribution. + """ + + sequences: Optional[torch.FloatTensor] = None + + +@dataclass +class SamplePatchTSMixerRegressionOutput(ModelOutput): + """ + Base class for time series model's predictions outputs that contains the sampled values from the chosen + distribution. + + Args: + sequences (`torch.FloatTensor` of shape `(batch_size, num_samples, num_targets)` + Sampled values from the chosen distribution. + """ + + sequences: Optional[torch.FloatTensor] = None + + +# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.nll +def nll(input: torch.distributions.Distribution, target: torch.Tensor) -> torch.Tensor: + """ + Computes the negative log likelihood loss from input distribution with respect to target. + """ + return -input.log_prob(target) + + +# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.weighted_average +def weighted_average(input_tensor: torch.Tensor, weights: Optional[torch.Tensor] = None, dim=None) -> torch.Tensor: + """ + Computes the weighted average of a given tensor across a given `dim`, masking values associated with weight zero, + meaning instead of `nan * 0 = nan` you will get `0 * 0 = 0`. + + Args: + input_tensor (`torch.FloatTensor`): + Input tensor, of which the average must be computed. + weights (`torch.FloatTensor`, *optional*): + Weights tensor, of the same shape as `input_tensor`. + dim (`int`, *optional*): + The dim along which to average `input_tensor`. + + Returns: + `torch.FloatTensor`: The tensor with values averaged along the specified `dim`. + """ + if weights is not None: + weighted_tensor = torch.where(weights != 0, input_tensor * weights, torch.zeros_like(input_tensor)) + sum_weights = torch.clamp(weights.sum(dim=dim) if dim else weights.sum(), min=1.0) + return (weighted_tensor.sum(dim=dim) if dim else weighted_tensor.sum()) / sum_weights + else: + return input_tensor.mean(dim=dim) + + +class PatchTSMixerForPrediction(PatchTSMixerPreTrainedModel): + r""" + `PatchTSMixer` for forecasting application. + + Args: + config (`PatchTSMixerConfig`): + Configuration. + + Returns: + `None`. + """ + + def __init__(self, config: PatchTSMixerConfig): + super().__init__(config) + self.loss = config.loss + self.use_return_dict = config.use_return_dict + self.prediction_channel_indices = config.prediction_channel_indices + self.num_parallel_samples = config.num_parallel_samples + + if config.loss == "mse": + self.distribution_output = None + else: + dim = config.prediction_length + distribution_output_map = { + "student_t": StudentTOutput, + "normal": NormalOutput, + "negative_binomial": NegativeBinomialOutput, + } + output_class = distribution_output_map.get(config.distribution_output, None) + if output_class is not None: + self.distribution_output = output_class(dim=dim) + else: + raise ValueError(f"Unknown distribution output {config.distribution_output}") + + self.model = PatchTSMixerModel(config) + self.head = PatchTSMixerForPredictionHead( + config=config, + distribution_output=self.distribution_output, + ) + + # Initialize weights and apply final processing + if config.post_init: + self.post_init() + + @add_start_docstrings_to_model_forward(PATCHTSMIXER_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=PatchTSMixerForPredictionOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + past_values: torch.Tensor, + observed_mask: Optional[torch.Tensor] = None, + future_values: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = False, + return_loss: bool = True, + return_dict: Optional[bool] = None, + ) -> PatchTSMixerForPredictionOutput: + r""" + observed_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*): + Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected + in `[0, 1]`: + - 1 for values that are **observed**, + - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros). + future_values (`torch.FloatTensor` of shape `(batch_size, target_len, num_input_channels)` for forecasting,: + `(batch_size, num_targets)` for regression, or `(batch_size,)` for classification, *optional*): Target + values of the time series, that serve as labels for the model. The `future_values` is what the + Transformer needs during training to learn to output, given the `past_values`. Note that, this is NOT + required for a pretraining task. + + For a forecasting task, the shape is be `(batch_size, target_len, num_input_channels)`. Even if we want + to forecast only specific channels by setting the indices in `prediction_channel_indices` parameter, + pass the target data with all channels, as channel Filtering for both prediction and target will be + manually applied before the loss computation. + return_loss (`bool`, *optional*): + Whether to return the loss in the `forward` call. + + Returns: + + """ + if self.loss == "mse": + loss = nn.MSELoss(reduction="mean") + elif self.loss == "nll": + loss = nll + else: + raise ValueError("Invalid loss function: Allowed values: mse and nll") + + return_dict = return_dict if return_dict is not None else self.use_return_dict + + # past_values: tensor [batch_size x context_length x num_input_channels] + model_output = self.model( + past_values, + observed_mask=observed_mask, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) # model_output: [batch_size x nvars x num_patch x d_model] + if isinstance(model_output, tuple): + model_output = PatchTSMixerModelOutput(*model_output) + + # tensor [batch_size x prediction_length x num_input_channels] + y_hat = self.head(model_output.last_hidden_state) + + loss_val = None + if self.prediction_channel_indices is not None: + if self.distribution_output: + distribution = self.distribution_output.distribution( + y_hat, + loc=model_output.loc[..., self.prediction_channel_indices], + scale=model_output.scale[..., self.prediction_channel_indices], + ) + if future_values is not None and return_loss is True: + loss_val = loss( + distribution, + future_values[..., self.prediction_channel_indices], + ) + # take average of the loss + loss_val = weighted_average(loss_val) + else: + y_hat = ( + y_hat * model_output.scale[..., self.prediction_channel_indices] + + model_output.loc[..., self.prediction_channel_indices] + ) + if future_values is not None and return_loss is True: + loss_val = loss(y_hat, future_values[..., self.prediction_channel_indices]) + else: + if self.distribution_output: + distribution = self.distribution_output.distribution( + y_hat, loc=model_output.loc, scale=model_output.scale + ) + if future_values is not None and return_loss is True: + loss_val = loss(distribution, future_values) + loss_val = weighted_average(loss_val) + else: + y_hat = y_hat * model_output.scale + model_output.loc + if future_values is not None and return_loss is True: + loss_val = loss(y_hat, future_values) + + if self.prediction_channel_indices is not None: + loc = model_output.loc[..., self.prediction_channel_indices] + scale = model_output.scale[..., self.prediction_channel_indices] + else: + loc = model_output.loc + scale = model_output.scale + + if not return_dict: + return tuple( + v + for v in [ + loss_val, + y_hat, + model_output.last_hidden_state, + model_output.hidden_states, + loc, + scale, + ] + ) + + return PatchTSMixerForPredictionOutput( + loss=loss_val, + prediction_outputs=y_hat, # tensor [batch_size x prediction_length x num_input_channels] + last_hidden_state=model_output.last_hidden_state, # x: [batch_size x nvars x num_patch x d_model] + hidden_states=model_output.hidden_states, + loc=loc, + scale=scale, + ) + + def generate( + self, + past_values: torch.Tensor, + observed_mask: Optional[torch.Tensor] = None, + ) -> SamplePatchTSMixerPredictionOutput: + """ + Generate sequences of sample predictions from a model with a probability distribution head. + + Args: + past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`): + Past values of the time series that serves as context in order to predict the future. + + observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*): + Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected + in `[0, 1]`: + + - 1 for values that are **observed**, + - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros). + + Return: + [`SamplePatchTSMixerPredictionOutput`] where the outputs `sequences` tensor will have shape `(batch_size, + number of samples, prediction_length, num_input_channels)`. + """ + # get number of samples + num_parallel_samples = self.num_parallel_samples + + # get model output + outputs = self( + past_values=past_values, + future_values=None, + observed_mask=observed_mask, + output_hidden_states=False, + ) + + # get distribution + + distribution = self.distribution_output.distribution( + outputs.prediction_outputs, loc=outputs.loc, scale=outputs.scale + ) + + # get samples: list of [batch_size x prediction_length x num_channels] + samples = [distribution.sample() for _ in range(num_parallel_samples)] + + # stack tensors + samples = torch.stack(samples, dim=1) # [batch_size x num_samples x prediction_length x num_channels] + return SamplePatchTSMixerPredictionOutput(sequences=samples) + + +@dataclass +class PatchTSMixerForTimeSeriesClassificationOutput(ModelOutput): + """ + Output type of [`PatchTSMixerForTimeSeriesClassificationOutput`]. + + Args: + prediction_outputs (`torch.FloatTensor` of shape `(batch_size, num_labels)`): + Prediction output from the classfication head. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, d_model)`): + Backbone embeddings before passing through the head. + hidden_states (`tuple(torch.FloatTensor)`, *optional*): + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + loss (*optional*, returned when `y` is provided, `torch.FloatTensor` of shape `()`): + Total loss. + """ + + loss: Optional[torch.FloatTensor] = None + prediction_outputs: Optional[torch.FloatTensor] = None + last_hidden_state: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + + +class PatchTSMixerForTimeSeriesClassification(PatchTSMixerPreTrainedModel): + r""" + `PatchTSMixer` for classification application. + + Args: + config (`PatchTSMixerConfig`): + Configuration. + + Returns: + `None`. + """ + + def __init__(self, config: PatchTSMixerConfig): + super().__init__(config) + + self.model = PatchTSMixerModel(config) + self.head = PatchTSMixerLinearHead( + config=config, + ) + self.use_return_dict = config.use_return_dict + if config.scaling in ["std", "mean", True]: + self.inject_scale = InjectScalerStatistics4D(d_model=config.d_model, num_patches=config.num_patches) + else: + self.inject_scale = None + + # Initialize weights and apply final processing + if config.post_init: + self.post_init() + + @add_start_docstrings_to_model_forward(PATCHTSMIXER_INPUTS_DOCSTRING) + @replace_return_docstrings( + output_type=PatchTSMixerForTimeSeriesClassificationOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + past_values: torch.Tensor, + target_values: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = False, + return_loss: bool = True, + return_dict: Optional[bool] = None, + ) -> PatchTSMixerForTimeSeriesClassificationOutput: + r""" + target_values (`torch.FloatTensor` of shape `(batch_size, target_len, num_input_channels)` for forecasting, + `(batch_size, num_targets)` for regression, or `(batch_size,)` for classification, *optional*): Target + values of the time series, that serve as labels for the model. The `target_values` is what the + Transformer needs during training to learn to output, given the `past_values`. Note that, this is NOT + required for a pretraining task. + + For a forecasting task, the shape is be `(batch_size, target_len, num_input_channels)`. Even if we want + to forecast only specific channels by setting the indices in `prediction_channel_indices` parameter, + pass the target data with all channels, as channel Filtering for both prediction and target will be + manually applied before the loss computation. + + For a classification task, it has a shape of `(batch_size,)`. + + For a regression task, it has a shape of `(batch_size, num_targets)`. + return_loss (`bool`, *optional*): + Whether to return the loss in the `forward` call. + + Returns: + + """ + + loss = torch.nn.CrossEntropyLoss() + + return_dict = return_dict if return_dict is not None else self.use_return_dict + + model_output = self.model( + past_values, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) # x: [batch_size x nvars x num_patch x d_model] + if isinstance(model_output, tuple): + model_output = PatchTSMixerModelOutput(*model_output) + + if self.inject_scale is not None: + model_output.last_hidden_state = self.inject_scale( + model_output.last_hidden_state, + loc=model_output.loc, + scale=model_output.scale, + ) # x: [batch_size x nvars x num_patch x d_model] + + y_hat = self.head(model_output.last_hidden_state) # tensor [batch_size x n_labels] + + if target_values is not None and return_loss is True: + loss_val = loss(y_hat, target_values) + else: + loss_val = None + + if not return_dict: + return tuple( + v + for v in [ + loss_val, + y_hat, + model_output.last_hidden_state, + model_output.hidden_states, + ] + ) + + return PatchTSMixerForTimeSeriesClassificationOutput( + loss=loss_val, + prediction_outputs=y_hat, # tensor [batch_size x n_labels] + last_hidden_state=model_output.last_hidden_state, # x: [batch_size x nvars x num_patch x d_model] + hidden_states=model_output.hidden_states, + ) + + +@dataclass +class PatchTSMixerForRegressionOutput(ModelOutput): + """ + Output type of [`PatchTSMixerForRegressionOutput`]. + + Args: + regression_outputs (`torch.FloatTensor` of shape `(batch_size, num_targets)`): + Prediction output from the regression head. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, d_model)`): + Backbone embeddings before passing through the head. + hidden_states (`tuple(torch.FloatTensor)`, *optional*): + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + loss (*optional*, returned when `y` is provided, `torch.FloatTensor` of shape `()`): + Total loss. + """ + + loss: Optional[torch.FloatTensor] = None + regression_outputs: Optional[torch.FloatTensor] = None + last_hidden_state: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + + +class InjectScalerStatistics4D(nn.Module): + def __init__(self, d_model: int, num_patches: int, expansion: int = 2): + super().__init__() + + self.inverse_trans_expansion = nn.Linear(d_model + 2, expansion * d_model) + self.inverse_trans_compression = nn.Linear(expansion * d_model, d_model) + self.map_scale_expansion = nn.Linear(2, 2 * expansion) + self.map_scale_compression = nn.Linear(2 * expansion, 2) + self.num_patches = num_patches + + def forward(self, inputs: torch.Tensor, loc: torch.Tensor, scale: torch.Tensor): + """ + Args: + inputs (`torch.Tensor` of shape `(batch_size, num_input_channels, num_patch, d_model)`) + loc (`torch.Tensor` of shape `(batch_size, 1, num_input_channels)`) + scale (`torch.Tensor` of shape `(batch_size, 1, num_input_channels)`) + Returns: + `torch.Tensor` of shape `(batch_size, num_input_channels, num_patch, d_model)` + """ + + mean = loc.transpose(-1, -2) # [batch_size x n_channels x 1 ] + mean = mean.unsqueeze(-2) # [batch_size x n_channels x 1 x 1] + mean = mean.repeat(1, 1, self.num_patches, 1) # [batch_size x n_channels x num_patch x 1] + + stdev = scale.transpose(-1, -2) # [batch_size x n_channels x 1 ] + stdev = stdev.unsqueeze(-2) # [batch_size x n_channels x 1 x 1] + stdev = stdev.repeat(1, 1, self.num_patches, 1) # [batch_size x n_channels x num_patch x 1] + + concat_stats = torch.cat([mean, stdev], dim=-1) # [batch_size x n_channels x num_patch x 2] + + concat_stats = self.map_scale_expansion(concat_stats) # [batch_size x n_channels x num_patch x (2*expansion)] + concat_stats = self.map_scale_compression(concat_stats) # [batch_size x n_channels x num_patch x 2] + + inputs = torch.cat([inputs, concat_stats], dim=-1) # [batch_size x channels x num_patch x d_model+2] + inputs = self.inverse_trans_expansion(inputs) # [batch_size x channels x num_patch x (expansion*d_model)] + inputs = self.inverse_trans_compression(inputs) # [batch_size x channels x num_patch x d_model] + + return inputs + + +class PatchTSMixerForRegression(PatchTSMixerPreTrainedModel): + r""" + `PatchTSMixer` for regression application. + + Args: + config (`PatchTSMixerConfig`): + Configuration. + + Returns: + `None`. + """ + + def __init__(self, config: PatchTSMixerConfig): + super().__init__(config) + + self.model = PatchTSMixerModel(config) + + self.loss = config.loss + self.distribution_output = config.distribution_output + + self.use_return_dict = config.use_return_dict + self.num_parallel_samples = config.num_parallel_samples + + if config.loss == "mse": + self.distribution_output = None + else: + distribution_output_map = { + "student_t": StudentTOutput, + "normal": NormalOutput, + "negative_binomial": NegativeBinomialOutput, + } + output_class = distribution_output_map.get(config.distribution_output) + if output_class is not None: + self.distribution_output = output_class(dim=config.num_targets) + else: + raise ValueError(f"Unknown distribution output {config.distribution_output}") + + if config.scaling in ["std", "mean", True]: + self.inject_scale = InjectScalerStatistics4D(d_model=config.d_model, num_patches=config.num_patches) + else: + self.inject_scale = None + + self.head = PatchTSMixerLinearHead( + config=config, + distribution_output=self.distribution_output, + ) + + # Initialize weights and apply final processing + if config.post_init: + self.post_init() + + @add_start_docstrings_to_model_forward(PATCHTSMIXER_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=PatchTSMixerForRegressionOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + past_values: torch.Tensor, + target_values: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = False, + return_loss: bool = True, + return_dict: Optional[bool] = None, + ) -> PatchTSMixerForRegressionOutput: + r""" + target_values (`torch.FloatTensor` of shape `(batch_size, target_len, num_input_channels)` for forecasting, + `(batch_size, num_targets)` for regression, or `(batch_size,)` for classification, *optional*): Target + values of the time series, that serve as labels for the model. The `target_values` is what the + Transformer needs during training to learn to output, given the `past_values`. Note that, this is NOT + required for a pretraining task. + + For a forecasting task, the shape is be `(batch_size, target_len, num_input_channels)`. Even if we want + to forecast only specific channels by setting the indices in `prediction_channel_indices` parameter, + pass the target data with all channels, as channel Filtering for both prediction and target will be + manually applied before the loss computation. + + For a classification task, it has a shape of `(batch_size,)`. + + For a regression task, it has a shape of `(batch_size, num_targets)`. + return_loss (`bool`, *optional*): + Whether to return the loss in the `forward` call. + + Returns: + + """ + + if self.loss == "mse": + loss = nn.MSELoss(reduction="mean") + elif self.loss == "nll": + loss = nll + else: + raise ValueError("Invalid loss function: Allowed values: mse and nll") + + return_dict = return_dict if return_dict is not None else self.use_return_dict + model_output = self.model( + past_values, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) # model_output: [batch_size x nvars x num_patch x d_model] + if isinstance(model_output, tuple): + model_output = PatchTSMixerModelOutput(*model_output) + + if self.inject_scale is not None: + model_output.last_hidden_state = self.inject_scale( + model_output.last_hidden_state, + loc=model_output.loc, + scale=model_output.scale, + ) # x: [batch_size x nvars x num_patch x d_model] + + y_hat = self.head(model_output.last_hidden_state) # [batch_size x num_targets] + + if target_values is not None and return_loss is True: + if self.distribution_output: + if self.distribution_output == "negative_binomial" and torch.any(target_values < 0): + raise Exception("target_values cannot be negative for negative_binomial distribution.") + distribution = self.distribution_output.distribution(y_hat) + # y_hat should be a 2-tuple, each with dimension [bs, num_targets] + y_hat = tuple([item.view(-1, self.config.num_targets) for item in y_hat]) + loss_val = loss(distribution, target_values) + # take average of the loss + loss_val = weighted_average(loss_val) + else: + loss_val = loss(y_hat, target_values) + else: + loss_val = None + + if not return_dict: + return tuple( + v + for v in [ + loss_val, + y_hat, + model_output.last_hidden_state, + model_output.hidden_states, + ] + ) + + return PatchTSMixerForRegressionOutput( + loss=loss_val, + regression_outputs=y_hat, # tensor [batch_size x num_targets] + last_hidden_state=model_output.last_hidden_state, # [batch_size x nvars x num_patch x d_model] + hidden_states=model_output.hidden_states, + ) + + def generate( + self, + past_values: torch.Tensor, + ) -> SamplePatchTSMixerRegressionOutput: + """ + Generate sequences of sample predictions from a model with a probability distribution head. + + Args: + past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`): + Past values of the time series that serves as context in order to predict the target values. + + Return: + [`SamplePatchTSMixerRegressionOutput`] where the outputs `sequences` tensor will have shape `(batch_size, + number of samples, num_targets)`. + """ + # get number of samples + num_parallel_samples = self.num_parallel_samples + + # get model output + outputs = self( + past_values=past_values, + target_values=None, + output_hidden_states=False, + ) + + # get distribution + distribution = self.distribution_output.distribution(outputs.regression_outputs) + + # get samples + samples = [ + distribution.sample() for _ in range(num_parallel_samples) + ] # samples: list of [batch_size x num_targets] + # stack tensors + # [batch_size x num_samples x num_targets] + samples = torch.stack(samples, dim=1).view(-1, num_parallel_samples, self.config.num_targets) + return SamplePatchTSMixerRegressionOutput(sequences=samples) + + +__all__ = [ + "PatchTSMixerPreTrainedModel", + "PatchTSMixerModel", + "PatchTSMixerForPretraining", + "PatchTSMixerForPrediction", + "PatchTSMixerForTimeSeriesClassification", + "PatchTSMixerForRegression", +] diff --git a/docs/transformers/src/transformers/models/patchtst/__init__.py b/docs/transformers/src/transformers/models/patchtst/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f1a045bfb1820424f77c9b798a3c027569eed3eb --- /dev/null +++ b/docs/transformers/src/transformers/models/patchtst/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_patchtst import * + from .modeling_patchtst import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/patchtst/configuration_patchtst.py b/docs/transformers/src/transformers/models/patchtst/configuration_patchtst.py new file mode 100644 index 0000000000000000000000000000000000000000..be19f2383acd529e493ef3542b4187b8e50c85d0 --- /dev/null +++ b/docs/transformers/src/transformers/models/patchtst/configuration_patchtst.py @@ -0,0 +1,256 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PatchTST model configuration""" + +from typing import List, Optional, Union + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + + +class PatchTSTConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of an [`PatchTSTModel`]. It is used to instantiate an + PatchTST model according to the specified arguments, defining the model architecture. + [ibm/patchtst](https://huggingface.co/ibm/patchtst) architecture. + + Configuration objects inherit from [`PretrainedConfig`] can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + num_input_channels (`int`, *optional*, defaults to 1): + The size of the target variable which by default is 1 for univariate targets. Would be > 1 in case of + multivariate targets. + context_length (`int`, *optional*, defaults to 32): + The context length of the input sequence. + distribution_output (`str`, *optional*, defaults to `"student_t"`): + The distribution emission head for the model when loss is "nll". Could be either "student_t", "normal" or + "negative_binomial". + loss (`str`, *optional*, defaults to `"mse"`): + The loss function for the model corresponding to the `distribution_output` head. For parametric + distributions it is the negative log likelihood ("nll") and for point estimates it is the mean squared + error "mse". + patch_length (`int`, *optional*, defaults to 1): + Define the patch length of the patchification process. + patch_stride (`int`, *optional*, defaults to 1): + Define the stride of the patchification process. + num_hidden_layers (`int`, *optional*, defaults to 3): + Number of hidden layers. + d_model (`int`, *optional*, defaults to 128): + Dimensionality of the transformer layers. + num_attention_heads (`int`, *optional*, defaults to 4): + Number of attention heads for each attention layer in the Transformer encoder. + share_embedding (`bool`, *optional*, defaults to `True`): + Sharing the input embedding across all channels. + channel_attention (`bool`, *optional*, defaults to `False`): + Activate channel attention block in the Transformer to allow channels to attend each other. + ffn_dim (`int`, *optional*, defaults to 512): + Dimension of the "intermediate" (often named feed-forward) layer in the Transformer encoder. + norm_type (`str` , *optional*, defaults to `"batchnorm"`): + Normalization at each Transformer layer. Can be `"batchnorm"` or `"layernorm"`. + norm_eps (`float`, *optional*, defaults to 1e-05): + A value added to the denominator for numerical stability of normalization. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout probability for the attention probabilities. + positional_dropout (`float`, *optional*, defaults to 0.0): + The dropout probability in the positional embedding layer. + path_dropout (`float`, *optional*, defaults to 0.0): + The dropout path in the residual block. + ff_dropout (`float`, *optional*, defaults to 0.0): + The dropout probability used between the two layers of the feed-forward networks. + bias (`bool`, *optional*, defaults to `True`): + Whether to add bias in the feed-forward networks. + activation_function (`str`, *optional*, defaults to `"gelu"`): + The non-linear activation function (string) in the Transformer.`"gelu"` and `"relu"` are supported. + pre_norm (`bool`, *optional*, defaults to `True`): + Normalization is applied before self-attention if pre_norm is set to `True`. Otherwise, normalization is + applied after residual block. + positional_encoding_type (`str`, *optional*, defaults to `"sincos"`): + Positional encodings. Options `"random"` and `"sincos"` are supported. + use_cls_token (`bool`, *optional*, defaults to `False`): + Whether cls token is used. + init_std (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated normal weight initialization distribution. + share_projection (`bool`, *optional*, defaults to `True`): + Sharing the projection layer across different channels in the forecast head. + scaling (`Union`, *optional*, defaults to `"std"`): + Whether to scale the input targets via "mean" scaler, "std" scaler or no scaler if `None`. If `True`, the + scaler is set to "mean". + do_mask_input (`bool`, *optional*): + Apply masking during the pretraining. + mask_type (`str`, *optional*, defaults to `"random"`): + Masking type. Only `"random"` and `"forecast"` are currently supported. + random_mask_ratio (`float`, *optional*, defaults to 0.5): + Masking ratio applied to mask the input data during random pretraining. + num_forecast_mask_patches (`int` or `list`, *optional*, defaults to `[2]`): + Number of patches to be masked at the end of each batch sample. If it is an integer, + all the samples in the batch will have the same number of masked patches. If it is a list, + samples in the batch will be randomly masked by numbers defined in the list. This argument is only used + for forecast pretraining. + channel_consistent_masking (`bool`, *optional*, defaults to `False`): + If channel consistent masking is True, all the channels will have the same masking pattern. + unmasked_channel_indices (`list`, *optional*): + Indices of channels that are not masked during pretraining. Values in the list are number between 1 and + `num_input_channels` + mask_value (`int`, *optional*, defaults to 0): + Values in the masked patches will be filled by `mask_value`. + pooling_type (`str`, *optional*, defaults to `"mean"`): + Pooling of the embedding. `"mean"`, `"max"` and `None` are supported. + head_dropout (`float`, *optional*, defaults to 0.0): + The dropout probability for head. + prediction_length (`int`, *optional*, defaults to 24): + The prediction horizon that the model will output. + num_targets (`int`, *optional*, defaults to 1): + Number of targets for regression and classification tasks. For classification, it is the number of + classes. + output_range (`list`, *optional*): + Output range for regression task. The range of output values can be set to enforce the model to produce + values within a range. + num_parallel_samples (`int`, *optional*, defaults to 100): + The number of samples is generated in parallel for probabilistic prediction. + + + ```python + >>> from transformers import PatchTSTConfig, PatchTSTModel + + >>> # Initializing an PatchTST configuration with 12 time steps for prediction + >>> configuration = PatchTSTConfig(prediction_length=12) + + >>> # Randomly initializing a model (with random weights) from the configuration + >>> model = PatchTSTModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "patchtst" + attribute_map = { + "hidden_size": "d_model", + "num_attention_heads": "num_attention_heads", + "num_hidden_layers": "num_hidden_layers", + } + + def __init__( + self, + # time series specific configuration + num_input_channels: int = 1, + context_length: int = 32, + distribution_output: str = "student_t", + loss: str = "mse", + # PatchTST arguments + patch_length: int = 1, + patch_stride: int = 1, + # Transformer architecture configuration + num_hidden_layers: int = 3, + d_model: int = 128, + num_attention_heads: int = 4, + share_embedding: bool = True, + channel_attention: bool = False, + ffn_dim: int = 512, + norm_type: str = "batchnorm", + norm_eps: float = 1e-05, + attention_dropout: float = 0.0, + positional_dropout: float = 0.0, + path_dropout: float = 0.0, + ff_dropout: float = 0.0, + bias: bool = True, + activation_function: str = "gelu", + pre_norm: bool = True, + positional_encoding_type: str = "sincos", + use_cls_token: bool = False, + init_std: float = 0.02, + share_projection: bool = True, + scaling: Optional[Union[str, bool]] = "std", + # mask pretraining + do_mask_input: Optional[bool] = None, + mask_type: str = "random", + random_mask_ratio: float = 0.5, + num_forecast_mask_patches: Optional[Union[List[int], int]] = [2], + channel_consistent_masking: Optional[bool] = False, + unmasked_channel_indices: Optional[List[int]] = None, + mask_value: int = 0, + # head + pooling_type: str = "mean", + head_dropout: float = 0.0, + prediction_length: int = 24, + num_targets: int = 1, + output_range: Optional[List] = None, + # distribution head + num_parallel_samples: int = 100, + **kwargs, + ): + # time series specific configuration + self.context_length = context_length + self.num_input_channels = num_input_channels # n_vars + self.loss = loss + self.distribution_output = distribution_output + self.num_parallel_samples = num_parallel_samples + + # Transformer architecture configuration + self.d_model = d_model + self.num_attention_heads = num_attention_heads + self.ffn_dim = ffn_dim + self.num_hidden_layers = num_hidden_layers + self.attention_dropout = attention_dropout + self.share_embedding = share_embedding + self.channel_attention = channel_attention + self.norm_type = norm_type + self.norm_eps = norm_eps + self.positional_dropout = positional_dropout + self.path_dropout = path_dropout + self.ff_dropout = ff_dropout + self.bias = bias + self.activation_function = activation_function + self.pre_norm = pre_norm + self.positional_encoding_type = positional_encoding_type + self.use_cls_token = use_cls_token + self.init_std = init_std + self.scaling = scaling + + # PatchTST parameters + self.patch_length = patch_length + self.patch_stride = patch_stride + + # Mask pretraining + self.do_mask_input = do_mask_input + self.mask_type = mask_type + self.random_mask_ratio = random_mask_ratio # for random masking + self.num_forecast_mask_patches = num_forecast_mask_patches # for forecast masking + self.channel_consistent_masking = channel_consistent_masking + self.unmasked_channel_indices = unmasked_channel_indices + self.mask_value = mask_value + + # general head params + self.pooling_type = pooling_type + self.head_dropout = head_dropout + + # For prediction head + self.share_projection = share_projection + self.prediction_length = prediction_length + + # For prediction and regression head + self.num_parallel_samples = num_parallel_samples + + # Regression + self.num_targets = num_targets + self.output_range = output_range + + super().__init__(**kwargs) + + +__all__ = ["PatchTSTConfig"] diff --git a/docs/transformers/src/transformers/models/patchtst/modeling_patchtst.py b/docs/transformers/src/transformers/models/patchtst/modeling_patchtst.py new file mode 100644 index 0000000000000000000000000000000000000000..ae09d410ace06cd8ea0ebb95e0cf87ecab0b2a77 --- /dev/null +++ b/docs/transformers/src/transformers/models/patchtst/modeling_patchtst.py @@ -0,0 +1,2042 @@ +# coding=utf-8 +# Copyright 2023 IBM & Hugging Face. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch PatchTST model.""" + +import math +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +from torch import nn + +from ...activations import ACT2CLS +from ...modeling_outputs import BaseModelOutput +from ...modeling_utils import PreTrainedModel +from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput +from ...utils import ModelOutput, add_start_docstrings, logging +from .configuration_patchtst import PatchTSTConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "PatchTSTConfig" + + +# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->PatchTST +class PatchTSTAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + is_causal: bool = False, + config: Optional[PatchTSTConfig] = None, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + self.config = config + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + self.is_causal = is_causal + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.reshape(*proj_shape) + value_states = value_states.reshape(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value + + +class PatchTSTBatchNorm(nn.Module): + """ + Compute batch normalization over the sequence length (time) dimension. + """ + + def __init__(self, config: PatchTSTConfig): + super().__init__() + self.batchnorm = nn.BatchNorm1d(config.d_model, eps=config.norm_eps) + + def forward(self, inputs: torch.Tensor): + """ + Parameters: + inputs (`torch.Tensor` of shape `(batch_size, sequence_length, d_model)`): + input for Batch norm calculation + Returns: + `torch.Tensor` of shape `(batch_size, sequence_length, d_model)` + """ + output = inputs.transpose(1, 2) # output: (batch_size, d_model, sequence_length) + output = self.batchnorm(output) + return output.transpose(1, 2) + + +def random_masking( + inputs: torch.Tensor, + mask_ratio: float, + unmasked_channel_indices: list = None, + channel_consistent_masking: bool = False, + mask_value: int = 0, +): + """random_masking: Mask the input considering the control variables. + + Args: + inputs (`torch.Tensor` of shape `(batch_size, num_channels, sequence_length, num_features)`): + The input tensor to mask. + mask_ratio (`float`): + Masking ratio applied to mask the input data during random pretraining. It is the number between 0 and 1. + unmasked_channel_indices (list, *optional*): + Indices of channels that will not be masked. + channel_consistent_masking (bool, *optional*, defaults to `False`): + When true, masking will be same across all channels of a timeseries. Otherwise, masking positions will vary + across channels. + mask_value (int, *optional*, defaults to 0): + Define the value of masked patches for pretraining. + + Returns: + `tuple(torch.Tensor)`: inputs_mask, masked input, same shape as input Tensor and mask tensor of shape [bs x c x + n] + """ + if mask_ratio < 0 or mask_ratio >= 1: + raise ValueError(f"Mask ratio {mask_ratio} has to be between 0 and 1.") + + batch_size, num_channels, sequence_length, num_features = inputs.shape + device = inputs.device + + len_keep = int(sequence_length * (1 - mask_ratio)) + + if channel_consistent_masking: + noise = torch.rand(batch_size, 1, sequence_length, device=device) # noise in [0, 1], bs x 1 x L + noise = noise.repeat(1, num_channels, 1) # bs x num_channels x time + else: + # noise in [0, 1], bs x num_channels x L + noise = torch.rand(batch_size, num_channels, sequence_length, device=device) + + # mask: [bs x num_channels x num_patch] + mask = torch.ones(batch_size, num_channels, sequence_length, device=device) + mask[:, :, :len_keep] = 0 + + # sort noise for each sample + ids_shuffle = torch.argsort(noise, dim=-1) # ascend: small is keep, large is remove + ids_restore = torch.argsort(ids_shuffle, dim=-1) # ids_restore: [bs x num_channels x L] + + mask = torch.gather(mask, dim=-1, index=ids_restore) + mask = mask.unsqueeze(-1).repeat(1, 1, 1, num_features) # mask: [bs x num_channels x num_patches x patch_length] + if unmasked_channel_indices is not None: + mask[:, unmasked_channel_indices, :, :] = 0 + + inputs_mask = inputs.masked_fill(mask.bool(), mask_value) + return inputs_mask, mask[..., 0] + + +def forecast_masking( + inputs: torch.Tensor, + num_forecast_mask_patches: Union[list, int], + unmasked_channel_indices: list = None, + mask_value: int = 0, +): + """Forecast masking that masks the last K patches where K is from the num_forecast_mask_patches. + If num_forecast_mask_patches is a list, samples in the batch will be randomly masked by numbers defined in the list. + + Parameters: + inputs (`torch.Tensor`): + Input of shape `(bs, num_channels, num_patch, patch_length)` + num_forecast_mask_patches (`list`): + Number of patches to be masked at the end of each batch sample. e.g. 4 or [3, 5]. + unmasked_channel_indices (`list`, *optional*): + Indices of channels that are not masked. + mask_value (`int`, *optional*, defaults to 0): + Values in the masked patches will be filled by `mask_value`. + + Returns: + `tuple(torch.Tensor)`: inputs_mask, masked input, same shape as inputs Tensor and Mask tensor of shape `(bs, + num_channels , num_patch)` or `(bs, tsg1, tsg2, num_channels, num_patch)` + """ + + if isinstance(num_forecast_mask_patches, int): + num_forecast_mask_patches = [num_forecast_mask_patches] + forecast_mask_ratios = [1 for _ in num_forecast_mask_patches] + + batch_size, num_channels, sequence_length, num_features = inputs.shape + mask = torch.zeros(batch_size, num_channels, sequence_length, device=inputs.device) + + t_list = [] + total_length = 0 + total_ratio = sum(forecast_mask_ratios) + + for patch_length, ratio in zip(num_forecast_mask_patches, forecast_mask_ratios): + if patch_length <= 0 or patch_length >= sequence_length: + raise ValueError( + f"num_forecast_mask_patches {patch_length} should be greater than 0 and less than total patches." + ) + temp_len = int(batch_size * ratio / total_ratio) + t_list.append([patch_length, ratio, temp_len]) + total_length += temp_len + + t_list = sorted(t_list, key=lambda x: x[2]) + + if total_length < batch_size: + t_list[0][2] = t_list[0][2] + (batch_size - total_length) + elif total_length > batch_size: + t_list[-1][2] = t_list[-1][2] + (total_length - batch_size) + + batch1 = 0 + for patch_len, _, temp_len in t_list: + batch2 = batch1 + temp_len + mask[batch1:batch2, :, -patch_len:] = 1 + batch1 = batch2 + + perm = torch.randperm(mask.shape[0]) + mask = mask[perm] + + mask = mask.unsqueeze(-1).repeat(1, 1, 1, num_features) # mask: [bs x num_channels x num_patch x patch_len] + if unmasked_channel_indices is not None: + mask[:, unmasked_channel_indices, :, :] = 0 + + inputs_mask = inputs.masked_fill(mask.bool(), mask_value) + return inputs_mask, mask[..., 0] + + +class PatchTSTPatchify(nn.Module): + """ + A class to patchify the time series sequence into different patches + + Returns: + `torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)` + """ + + def __init__(self, config: PatchTSTConfig): + super().__init__() + + self.sequence_length = config.context_length + self.patch_length = config.patch_length + self.patch_stride = config.patch_stride + + if self.sequence_length <= self.patch_length: + raise ValueError( + f"Sequence length ({self.sequence_length}) has to be greater than the patch length ({self.patch_length})" + ) + + # get the number of patches + self.num_patches = (max(self.sequence_length, self.patch_length) - self.patch_length) // self.patch_stride + 1 + new_sequence_length = self.patch_length + self.patch_stride * (self.num_patches - 1) + self.sequence_start = self.sequence_length - new_sequence_length + + def forward(self, past_values: torch.Tensor): + """ + Parameters: + past_values (`torch.Tensor` of shape `(batch_size, sequence_length, num_channels)`, *required*): + Input for patchification + + Returns: + `torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)` + """ + sequence_length = past_values.shape[-2] + if sequence_length != self.sequence_length: + raise ValueError( + f"Input sequence length ({sequence_length}) doesn't match model configuration ({self.sequence_length})." + ) + # output: [bs x new_sequence_length x num_channels] + output = past_values[:, self.sequence_start :, :] + # output: [bs x num_patches x num_input_channels x patch_length] + output = output.unfold(dimension=-2, size=self.patch_length, step=self.patch_stride) + # output: [bs x num_input_channels x num_patches x patch_length] + output = output.transpose(-2, -3).contiguous() + return output + + +class PatchTSTMasking(nn.Module): + """ + Class to perform random or forecast masking. + + Parameters: + config (`PatchTSTConfig`): model config + Returns: + x_mask (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`) + Masked patched input + mask (`torch.Tensor` of shape `(batch_size, num_channels, num_patches)`) + Bool tensor indicating True on masked points + """ + + def __init__(self, config: PatchTSTConfig): + super().__init__() + self.random_mask_ratio = config.random_mask_ratio + self.channel_consistent_masking = config.channel_consistent_masking + self.mask_type = config.mask_type + self.num_forecast_mask_patches = config.num_forecast_mask_patches + self.unmasked_channel_indices = config.unmasked_channel_indices + self.mask_value = config.mask_value + if self.unmasked_channel_indices is not None: + self.unmasked_channel_indices = sorted(self.unmasked_channel_indices) + + def forward(self, patch_input: torch.Tensor): + """ + Parameters: + patch_input (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`, *required*): + Patch input + + Return: + masked_input (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`) + Masked patched input + mask (`torch.Tensor` of shape `(batch_size, num_channels, num_patches)`) + Bool tensor indicating True on masked points + + """ + if self.mask_type == "random": + masked_input, mask = random_masking( + inputs=patch_input, + mask_ratio=self.random_mask_ratio, + unmasked_channel_indices=self.unmasked_channel_indices, + channel_consistent_masking=self.channel_consistent_masking, + mask_value=self.mask_value, + ) + elif self.mask_type == "forecast": + masked_input, mask = forecast_masking( + inputs=patch_input, + num_forecast_mask_patches=self.num_forecast_mask_patches, + unmasked_channel_indices=self.unmasked_channel_indices, + mask_value=self.mask_value, + ) + else: + raise ValueError(f"Invalid mask type {self.mask_type}.") + + # mask: [bs x num_input_channels x num_patch] + mask = mask.bool() + return masked_input, mask + + +class PatchTSTEncoderLayer(nn.Module): + """ + PatchTST encoder layer + """ + + def __init__(self, config: PatchTSTConfig): + super().__init__() + + self.channel_attention = config.channel_attention + # Multi-Head attention + self.self_attn = PatchTSTAttention( + embed_dim=config.d_model, + num_heads=config.num_attention_heads, + dropout=config.attention_dropout, + ) + + # Add & Norm of the sublayer 1 + self.dropout_path1 = nn.Dropout(config.path_dropout) if config.path_dropout > 0 else nn.Identity() + if config.norm_type == "batchnorm": + self.norm_sublayer1 = PatchTSTBatchNorm(config) + elif config.norm_type == "layernorm": + self.norm_sublayer1 = nn.LayerNorm(config.d_model, eps=config.norm_eps) + else: + raise ValueError(f"{config.norm_type} is not a supported norm layer type.") + + # Add & Norm of the sublayer 2 + if self.channel_attention: + self.dropout_path2 = nn.Dropout(config.path_dropout) if config.path_dropout > 0 else nn.Identity() + if config.norm_type == "batchnorm": + self.norm_sublayer2 = PatchTSTBatchNorm(config) + elif config.norm_type == "layernorm": + self.norm_sublayer2 = nn.LayerNorm(config.d_model, eps=config.norm_eps) + else: + raise ValueError(f"{config.norm_type} is not a supported norm layer type.") + + # Position-wise Feed-Forward + self.ff = nn.Sequential( + nn.Linear(config.d_model, config.ffn_dim, bias=config.bias), + ACT2CLS[config.activation_function](), + nn.Dropout(config.ff_dropout) if config.ff_dropout > 0 else nn.Identity(), + nn.Linear(config.ffn_dim, config.d_model, bias=config.bias), + ) + + # Add & Norm of sublayer 3 + self.dropout_path3 = nn.Dropout(config.path_dropout) if config.path_dropout > 0 else nn.Identity() + if config.norm_type == "batchnorm": + self.norm_sublayer3 = PatchTSTBatchNorm(config) + elif config.norm_type == "layernorm": + self.norm_sublayer3 = nn.LayerNorm(config.d_model, eps=config.norm_eps) + else: + raise ValueError(f"{config.norm_type} is not a supported norm layer type.") + + self.pre_norm = config.pre_norm + + def forward(self, hidden_state: torch.Tensor, output_attentions: Optional[bool] = None): + """ + Parameters: + hidden_state (`torch.Tensor` of shape `(batch_size, num_channels, sequence_length, d_model)`, *required*): + Past values of the time series + output_attentions (`bool`, *optional*): + Whether or not to return the output attention of all layers + Return: + `torch.Tensor` of shape `(batch_size, num_channels, sequence_length, d_model)` + + """ + batch_size, num_input_channels, sequence_length, d_model = hidden_state.shape + + # First sublayer: attention across time + # hidden_states: [(bs*num_channels) x sequence_length x d_model] + hidden_state = hidden_state.view(batch_size * num_input_channels, sequence_length, d_model) + + if self.pre_norm: + ## Norm and Multi-Head attention and Add residual connection + attn_output, attn_weights, _ = self.self_attn( + hidden_states=self.norm_sublayer1(hidden_state), output_attentions=output_attentions + ) + # Add: residual connection with residual dropout + hidden_state = hidden_state + self.dropout_path1(attn_output) + else: + ## Multi-Head attention and Add residual connection and Norm - Standard Transformer from BERT + attn_output, attn_weights, _ = self.self_attn( + hidden_states=hidden_state, output_attentions=output_attentions + ) + # hidden_states: [(bs*num_channels) x sequence_length x d_model] + hidden_state = self.norm_sublayer1(hidden_state + self.dropout_path1(attn_output)) + + # hidden_state: [bs x num_channels x sequence_length x d_model] + hidden_state = hidden_state.reshape(batch_size, num_input_channels, sequence_length, d_model) + + # second sublayer: attention across variable at any given time + if self.channel_attention: + # hidden_state: [bs x sequence_length x num_channels x d_model] + hidden_state = hidden_state.transpose(2, 1).contiguous() + # hidden_state: [(bs*sequence_length) x num_channels x d_model] + hidden_state = hidden_state.view(batch_size * sequence_length, num_input_channels, d_model) + if self.pre_norm: + ## Norm and Multi-Head attention and Add residual connection + attn_output, channel_attn_weights, _ = self.self_attn( + hidden_states=self.norm_sublayer2(hidden_state), output_attentions=output_attentions + ) + # Add: residual connection with residual dropout + hidden_state = hidden_state + self.dropout_path2(attn_output) + else: + ## Multi-Head attention and Add residual connection and Norm + attn_output, channel_attn_weights, _ = self.self_attn( + hidden_states=hidden_state, output_attentions=output_attentions + ) + # hidden_states: [(bs*sequence_length) x num_channels x d_model] + hidden_state = self.norm_sublayer2(hidden_state + self.dropout_path2(attn_output)) + + # Reshape hidden state + # hidden_state: [bs x sequence_length x num_channels x d_model] + hidden_state = hidden_state.reshape(batch_size, sequence_length, num_input_channels, d_model) + # hidden_state: [bs x num_channels x sequence_length x d_model] + hidden_state = hidden_state.transpose(1, 2).contiguous() + + # Third sublayer: mixing across hidden + # hidden_state: [(batch_size*num_channels) x sequence_length x d_model] + hidden_state = hidden_state.view(batch_size * num_input_channels, sequence_length, d_model) + if self.pre_norm: + ## Norm and Position-wise Feed-Forward and Add residual connection + # Add: residual connection with residual dropout + hidden_state = hidden_state + self.dropout_path3(self.ff(self.norm_sublayer3(hidden_state))) + else: + ## Position-wise Feed-Forward and Add residual connection and Norm + # Add: residual connection with residual dropout + hidden_state = self.norm_sublayer3(hidden_state + self.dropout_path3(self.ff(hidden_state))) + + # [bs x num_channels x sequence_length x d_model] + hidden_state = hidden_state.reshape(batch_size, num_input_channels, sequence_length, d_model) + + outputs = (hidden_state,) + if output_attentions: + outputs += (attn_weights, channel_attn_weights) if self.channel_attention else (attn_weights,) + + return outputs + + +class PatchTSTPreTrainedModel(PreTrainedModel): + config_class = PatchTSTConfig + base_model_prefix = "model" + main_input_name = "past_values" + supports_gradient_checkpointing = False + + def _init_weights(self, module): + """ + Initialize weights + """ + if isinstance(module, PatchTSTPositionalEncoding): + # initialize cls_token + if self.config.use_cls_token: + nn.init.normal_(module.cls_token, std=0.02) + # initialize positional encoding + if self.config.positional_encoding_type == "random": + nn.init.normal_(module.position_enc, mean=0.0, std=0.1) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, PatchTSTBatchNorm): + module.batchnorm.bias.data.zero_() + module.batchnorm.weight.data.fill_(1.0) + elif isinstance(module, (nn.Linear, nn.Conv1d)): + module.weight.data.normal_(mean=0.0, std=self.config.init_std) + if module.bias is not None: + module.bias.data.zero_() + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (PatchTSTEncoder)): + module.gradient_checkpointing = value + + +class PatchTSTEmbedding(nn.Module): + def __init__(self, config: PatchTSTConfig): + super().__init__() + self.num_input_channels = config.num_input_channels + self.share_embedding = config.share_embedding + # Input encoding: projection of feature vectors onto a d-dim vector space + if self.share_embedding: + self.input_embedding = nn.Linear(config.patch_length, config.d_model) + else: + self.input_embedding = nn.ModuleList() + for _ in range(config.num_input_channels): + self.input_embedding.append(nn.Linear(config.patch_length, config.d_model)) + + def forward(self, patch_input: torch.Tensor): + """ + Parameters: + patch_input (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`, *required*): + Patch input for embedding + return: + `torch.Tensor` of shape `(batch_size, num_channels, num_patches, d_model)` + """ + # Input encoding + num_input_channels = patch_input.shape[1] + if num_input_channels != self.num_input_channels: + raise ValueError( + f"The defined number of input channels ({self.num_input_channels}) in the config " + f"has to be the same as the number of channels in the batch input ({num_input_channels})" + ) + if self.share_embedding: + embeddings = self.input_embedding(patch_input) # x: [bs x num_channels x num_patches x d_model] + else: + embeddings = [self.input_embedding[i](patch_input[:, i, :, :]) for i in range(num_input_channels)] + embeddings = torch.stack(embeddings, dim=1) + return embeddings + + +class PatchTSTPositionalEncoding(nn.Module): + """ + Class for positional encoding + """ + + def __init__(self, config: PatchTSTConfig, num_patches: int): + super().__init__() + self.use_cls_token = config.use_cls_token + self.num_input_channels = config.num_input_channels + if config.use_cls_token: + # cls_token: [1 x num_input_channels x 1 x d_model] + self.cls_token = nn.Parameter(torch.zeros(1, 1, 1, config.d_model)) + num_patches += 1 + # postional encoding: [num_patches x d_model] + self.position_enc = self._init_pe(config, num_patches) + # Positional dropout + self.positional_dropout = ( + nn.Dropout(config.positional_dropout) if config.positional_dropout > 0 else nn.Identity() + ) + + @staticmethod + def _init_pe(config: PatchTSTConfig, num_patches: int) -> nn.Parameter: + # Positional encoding + if config.positional_encoding_type == "random": + position_enc = nn.Parameter(torch.randn(num_patches, config.d_model), requires_grad=True) + elif config.positional_encoding_type == "sincos": + position_enc = torch.zeros(num_patches, config.d_model) + position = torch.arange(0, num_patches).unsqueeze(1) + div_term = torch.exp(torch.arange(0, config.d_model, 2) * -(math.log(10000.0) / config.d_model)) + position_enc[:, 0::2] = torch.sin(position * div_term) + position_enc[:, 1::2] = torch.cos(position * div_term) + position_enc = position_enc - position_enc.mean() + position_enc = position_enc / (position_enc.std() * 10) + position_enc = nn.Parameter(position_enc, requires_grad=False) + else: + raise ValueError( + f"{config.positional_encoding_type} is not a valid positional encoder. Available types are 'random' and 'sincos'." + ) + return position_enc + + def forward(self, patch_input: torch.Tensor): + if self.use_cls_token: + # patch_input: [bs x num_channels x num_patches x d_model] + patch_input = self.positional_dropout(patch_input + self.position_enc[1:, :]) + # append cls token where cls_token: [1 x num_channels x 1 x d_model] + cls_token = self.cls_token + self.position_enc[:1, :] + # get the same copy of cls_token for all the samples in batch: [bs x num_channels x 1 x d_model] + cls_tokens = cls_token.expand(patch_input.shape[0], self.num_input_channels, -1, -1) + # hidden_state: [bs x num_channels x (num_patches+1) x d_model] + hidden_state = torch.cat((cls_tokens, patch_input), dim=2) + else: + # hidden_state: [bs x num_channels x num_patches x d_model] + hidden_state = self.positional_dropout(patch_input + self.position_enc) + return hidden_state + + +class PatchTSTEncoder(PatchTSTPreTrainedModel): + """ + PatchTST Encoder + """ + + def __init__(self, config: PatchTSTConfig, num_patches: int): + super().__init__(config) + self.gradient_checkpointing = False + + # Input embedding: projection of feature vectors onto a d-dim vector space + self.embedder = PatchTSTEmbedding(config) + # Positional encoding + self.positional_encoder = PatchTSTPositionalEncoding(config, num_patches) + # Encoder + self.layers = nn.ModuleList([PatchTSTEncoderLayer(config) for i in range(config.num_hidden_layers)]) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + patch_input: torch.Tensor, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + ) -> BaseModelOutput: + """ + Parameters: + patch_input (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`, *required*): + Past values of the time series + output_hidden_states (bool, optional): Indicates if hidden states should be outputted. + output_attentions (bool, optional): Indicates if attentions should be outputted. + + return: + `BaseModelOutput` + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + # Input embedding + patch_input = self.embedder(patch_input) + # Positional encoding + hidden_state = self.positional_encoder(patch_input) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + for encoder_layer in self.layers: + if output_hidden_states: + encoder_states = encoder_states + (hidden_state,) + + layer_outputs = encoder_layer(hidden_state=hidden_state, output_attentions=output_attentions) + # get hidden state. hidden_state shape is [bs x num_channels x num_patches x d_model] + # or [bs x num_channels x (num_patches+1) x d_model] if use cls_token + hidden_state = layer_outputs[0] + # append attention matrix at each layer + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + # return past_values, hidden_states + return BaseModelOutput(last_hidden_state=hidden_state, hidden_states=encoder_states, attentions=all_attentions) + + +PATCHTST_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`PatchTSTConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@dataclass +class PatchTSTModelOutput(ModelOutput): + """ + Base class for model's outputs, with potential hidden states. + + Parameters: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches, patch_length)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, num_channels, height, width)`. Hidden-states of + the model at the output of each layer plus the optional initial embedding outputs. + mask: (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches)`, *optional*) + Bool masked tensor indicating which patches are masked + loc: (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`, *optional*) + Mean of the input data (batch_size, sequence_length, num_channels) over the sequence_length + scale: (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`, *optional*) + Std of the input data (batch_size, sequence_length, num_channels) over the sequence_length + patch_input (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches, patch_length)`): + Patched input to the Transformer + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + mask: Optional[torch.FloatTensor] = None + loc: Optional[torch.FloatTensor] = None + scale: Optional[torch.FloatTensor] = None + patch_input: Optional[torch.FloatTensor] = None + + +@dataclass +class PatchTSTForPretrainingOutput(ModelOutput): + """ + Output type of [`PatchTSTForPretraining`]. + + Parameters: + loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`): + MSE loss. + prediction_outputs (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction outputs of the time series modeling heads. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + prediction_output: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class PatchTSTForRegressionOutput(ModelOutput): + """ + Output type of [`PatchTSTForRegression`]. + + Parameters: + loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`): + MSE loss. + regression_outputs (`torch.FloatTensor` of shape `(batch_size, num_targets)`): + Regression outputs of the time series modeling heads. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + regression_outputs: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class PatchTSTForPredictionOutput(ModelOutput): + """ + Output type of [`PatchTSTForPrediction`]. + + Parameters: + loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`): + MSE loss. + prediction_outputs (`torch.FloatTensor` of shape `(batch_size, prediction_length, -1)`): + Prediction outputs of the time series modeling heads. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + loc: (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`, *optional*) + Mean of the input data (batch_size, sequence_length, num_channels) over the sequence_length + scale: (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`, *optional*) + Std of the input data (batch_size, sequence_length, num_channels) over the sequence_length + """ + + loss: Optional[torch.FloatTensor] = None + prediction_outputs: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + loc: Optional[torch.FloatTensor] = None + scale: Optional[torch.FloatTensor] = None + + +@dataclass +class PatchTSTForClassificationOutput(ModelOutput): + """ + Output type of [`PatchTSTForClassification`]. + + Parameters: + loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`): + Total loss as the sum of the masked language modeling loss and the next sequence prediction + (classification) loss. + prediction_logits (`torch.FloatTensor` of shape `(batch_size, num_targets)`): + Prediction scores of the PatchTST modeling head (scores before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + prediction_logits: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class SamplePatchTSTOutput(ModelOutput): + """ + Base class for time series model's predictions outputs that contains the sampled values from the chosen + distribution. + + Parameters: + sequences `(batch_size, num_samples, prediction_length, num_targets)`): + Sampled values from the chosen distribution. + """ + + sequences: Optional[torch.FloatTensor] = None + + +# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.nll +def nll(input: torch.distributions.Distribution, target: torch.Tensor) -> torch.Tensor: + """ + Computes the negative log likelihood loss from input distribution with respect to target. + """ + return -input.log_prob(target) + + +# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.weighted_average +def weighted_average(input_tensor: torch.Tensor, weights: Optional[torch.Tensor] = None, dim=None) -> torch.Tensor: + """ + Computes the weighted average of a given tensor across a given `dim`, masking values associated with weight zero, + meaning instead of `nan * 0 = nan` you will get `0 * 0 = 0`. + + Args: + input_tensor (`torch.FloatTensor`): + Input tensor, of which the average must be computed. + weights (`torch.FloatTensor`, *optional*): + Weights tensor, of the same shape as `input_tensor`. + dim (`int`, *optional*): + The dim along which to average `input_tensor`. + + Returns: + `torch.FloatTensor`: The tensor with values averaged along the specified `dim`. + """ + if weights is not None: + weighted_tensor = torch.where(weights != 0, input_tensor * weights, torch.zeros_like(input_tensor)) + sum_weights = torch.clamp(weights.sum(dim=dim) if dim else weights.sum(), min=1.0) + return (weighted_tensor.sum(dim=dim) if dim else weighted_tensor.sum()) / sum_weights + else: + return input_tensor.mean(dim=dim) + + +# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesStdScaler with TimeSeriesTransformer->PatchTST,TimeSeries->PatchTST +class PatchTSTStdScaler(nn.Module): + """ + Standardize features by calculating the mean and scaling along the first dimension, and then normalizes it by + subtracting from the mean and dividing by the standard deviation. + """ + + def __init__(self, config: PatchTSTConfig): + super().__init__() + self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1 + self.keepdim = config.keepdim if hasattr(config, "keepdim") else True + self.minimum_scale = config.minimum_scale if hasattr(config, "minimum_scale") else 1e-5 + + def forward( + self, data: torch.Tensor, observed_indicator: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Parameters: + data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`): + input for Batch norm calculation + observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`): + Calculating the scale on the observed indicator. + Returns: + tuple of `torch.Tensor` of shapes + (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`, + `(batch_size, 1, num_input_channels)`) + """ + denominator = observed_indicator.sum(self.dim, keepdim=self.keepdim) + denominator = denominator.clamp_min(1.0) + loc = (data * observed_indicator).sum(self.dim, keepdim=self.keepdim) / denominator + + variance = (((data - loc) * observed_indicator) ** 2).sum(self.dim, keepdim=self.keepdim) / denominator + scale = torch.sqrt(variance + self.minimum_scale) + return (data - loc) / scale, loc, scale + + +# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesMeanScaler with TimeSeriesTransformer->PatchTST,TimeSeries->PatchTST +class PatchTSTMeanScaler(nn.Module): + """ + Computes a scaling factor as the weighted average absolute value along the first dimension, and scales the data + accordingly. + """ + + def __init__(self, config: PatchTSTConfig): + super().__init__() + self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1 + self.keepdim = config.keepdim if hasattr(config, "keepdim") else True + self.minimum_scale = config.minimum_scale if hasattr(config, "minimum_scale") else 1e-10 + self.default_scale = config.default_scale if hasattr(config, "default_scale") else None + + def forward( + self, data: torch.Tensor, observed_indicator: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Parameters: + data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`): + input for Batch norm calculation + observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`): + Calculating the scale on the observed indicator. + Returns: + tuple of `torch.Tensor` of shapes + (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`, + `(batch_size, 1, num_input_channels)`) + """ + ts_sum = (data * observed_indicator).abs().sum(self.dim, keepdim=True) + num_observed = observed_indicator.sum(self.dim, keepdim=True) + + scale = ts_sum / torch.clamp(num_observed, min=1) + + # If `default_scale` is provided, we use it, otherwise we use the scale + # of the batch. + if self.default_scale is None: + batch_sum = ts_sum.sum(dim=0) + batch_observations = torch.clamp(num_observed.sum(0), min=1) + default_scale = torch.squeeze(batch_sum / batch_observations) + else: + default_scale = self.default_scale * torch.ones_like(scale) + + # apply default scale where there are no observations + scale = torch.where(num_observed > 0, scale, default_scale) + + # ensure the scale is at least `self.minimum_scale` + scale = torch.clamp(scale, min=self.minimum_scale) + scaled_data = data / scale + + if not self.keepdim: + scale = scale.squeeze(dim=self.dim) + + return scaled_data, torch.zeros_like(scale), scale + + +# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesNOPScaler with TimeSeriesTransformer->PatchTST,TimeSeries->PatchTST +class PatchTSTNOPScaler(nn.Module): + """ + Assigns a scaling factor equal to 1 along the first dimension, and therefore applies no scaling to the input data. + """ + + def __init__(self, config: PatchTSTConfig): + super().__init__() + self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1 + self.keepdim = config.keepdim if hasattr(config, "keepdim") else True + + def forward( + self, data: torch.Tensor, observed_indicator: Optional[torch.Tensor] = None + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Parameters: + data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`): + input for Batch norm calculation + Returns: + tuple of `torch.Tensor` of shapes + (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`, + `(batch_size, 1, num_input_channels)`) + """ + scale = torch.ones_like(data, requires_grad=False).mean(dim=self.dim, keepdim=self.keepdim) + loc = torch.zeros_like(data, requires_grad=False).mean(dim=self.dim, keepdim=self.keepdim) + return data, loc, scale + + +class PatchTSTScaler(nn.Module): + def __init__(self, config: PatchTSTConfig): + super().__init__() + if config.scaling == "mean" or config.scaling is True: + self.scaler = PatchTSTMeanScaler(config) + elif config.scaling == "std": + self.scaler = PatchTSTStdScaler(config) + else: + self.scaler = PatchTSTNOPScaler(config) + + def forward( + self, data: torch.Tensor, observed_indicator: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Parameters: + data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`): + Input for scaler calculation + observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`): + Calculating the scale on the observed indicator. + Returns: + tuple of `torch.Tensor` of shapes + (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`, + `(batch_size, 1, um_input_channels)`) + """ + data, loc, scale = self.scaler(data, observed_indicator) + return data, loc, scale + + +@add_start_docstrings( + "The bare PatchTST Model outputting raw hidden-states without any specific head.", + PATCHTST_START_DOCSTRING, +) +class PatchTSTModel(PatchTSTPreTrainedModel): + def __init__(self, config: PatchTSTConfig): + super().__init__(config) + + self.scaler = PatchTSTScaler(config) + self.patchifier = PatchTSTPatchify(config) + self.do_mask_input = config.do_mask_input + # get num_patches information from PatchTSTPatchify + num_patches = self.patchifier.num_patches + + if self.do_mask_input: + self.masking = PatchTSTMasking(config) + else: + self.masking = nn.Identity() + self.encoder = PatchTSTEncoder(config, num_patches=num_patches) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + past_values: torch.Tensor, + past_observed_mask: Optional[torch.Tensor] = None, + future_values: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, PatchTSTModelOutput]: + r""" + Parameters: + past_values (`torch.Tensor` of shape `(bs, sequence_length, num_input_channels)`, *required*): + Input sequence to the model + past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*): + Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected + in `[0, 1]`: + + - 1 for values that are **observed**, + - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros). + future_values (`torch.BoolTensor` of shape `(batch_size, prediction_length, num_input_channels)`, *optional*): + Future target values associated with the `past_values` + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers + output_attentions (`bool`, *optional*): + Whether or not to return the output attention of all layers + return_dict (`bool`, *optional*): + Whether or not to return a `ModelOutput` instead of a plain tuple. + + Returns: + `PatchTSTModelOutput` or tuple of `torch.Tensor` (if `return_dict`=False or `config.return_dict`=False) + + Examples: + + ```python + >>> from huggingface_hub import hf_hub_download + >>> import torch + >>> from transformers import PatchTSTModel + + >>> file = hf_hub_download( + ... repo_id="hf-internal-testing/etth1-hourly-batch", filename="train-batch.pt", repo_type="dataset" + ... ) + >>> batch = torch.load(file) + + >>> model = PatchTSTModel.from_pretrained("namctin/patchtst_etth1_pretrain") + + >>> # during training, one provides both past and future values + >>> outputs = model( + ... past_values=batch["past_values"], + ... future_values=batch["future_values"], + ... ) + + >>> last_hidden_state = outputs.last_hidden_state + ```""" + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + if past_observed_mask is None: + past_observed_mask = torch.ones_like(past_values) + + # x: tensor [bs x sequence_length x num_input_channels] + scaled_past_values, loc, scale = self.scaler(past_values, past_observed_mask) + + # patched_values: [bs x num_input_channels x num_patches x patch_length] for pretrain + patched_values = self.patchifier(scaled_past_values) + if self.do_mask_input: + masked_values, mask = self.masking(patched_values) + else: + masked_values, mask = self.masking(patched_values), None + + encoder_output = self.encoder( + patch_input=masked_values, output_hidden_states=output_hidden_states, output_attentions=output_attentions + ) + + if not return_dict: + outputs = (encoder_output.last_hidden_state, encoder_output.hidden_states, encoder_output.attentions) + outputs = outputs + (mask, loc, scale, patched_values) + return tuple(v for v in outputs if v is not None) + + return PatchTSTModelOutput( + last_hidden_state=encoder_output.last_hidden_state, + hidden_states=encoder_output.hidden_states, + attentions=encoder_output.attentions, + mask=mask, + loc=loc, + scale=scale, + patch_input=patched_values, + ) + + +class PatchTSTMaskPretrainHead(nn.Module): + """ + Pretraining head for mask modelling + """ + + def __init__(self, config: PatchTSTConfig): + super().__init__() + self.dropout = nn.Dropout(config.head_dropout) if config.head_dropout > 0 else nn.Identity() + self.linear = nn.Linear(config.d_model, config.patch_length) + self.use_cls_token = config.use_cls_token + + def forward(self, embedding: torch.Tensor) -> torch.Tensor: + """ + Parameters: + embedding (`torch.Tensor` of shape `(bs, num_channels, num_patches, d_model)` or + `(bs, num_channels, num_patches+1, d_model)` if `cls_token` is set to True, *required*): + Embedding from the model + Returns: + `torch.Tensor` of shape `(bs, num_channels, num_patches, d_model)` or + `(bs, num_channels, num_patches+1, d_model)` if `cls_token` is set to True + + """ + embedding = self.linear(self.dropout(embedding)) # [bs x num_channels x num_patches x patch_length] + if self.use_cls_token: + embedding = embedding[:, :, 1:, :] # remove the first cls token + return embedding + + +@add_start_docstrings( + "The PatchTST for pretrain model.", + PATCHTST_START_DOCSTRING, +) +class PatchTSTForPretraining(PatchTSTPreTrainedModel): + def __init__(self, config: PatchTSTConfig): + super().__init__(config) + + config.do_mask_input = True + self.model = PatchTSTModel(config=config) + self.head = PatchTSTMaskPretrainHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + past_values: torch.Tensor, + past_observed_mask: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, PatchTSTForPretrainingOutput]: + r""" + Parameters: + past_values (`torch.Tensor` of shape `(bs, sequence_length, num_input_channels)`, *required*): + Input sequence to the model + past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*): + Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected + in `[0, 1]`: + + - 1 for values that are **observed**, + - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros). + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers + output_attentions (`bool`, *optional*): + Whether or not to return the output attention of all layers + return_dict (`bool`, *optional*): Whether or not to return a `ModelOutput` instead of a plain tuple. + + Returns: + `PatchTSTForPretrainingOutput` or tuple of `torch.Tensor` (if `return_dict`=False or + `config.return_dict`=False) + + Examples: + + ```python + >>> from huggingface_hub import hf_hub_download + >>> import torch + >>> from transformers import PatchTSTConfig, PatchTSTForPretraining + + >>> file = hf_hub_download( + ... repo_id="hf-internal-testing/etth1-hourly-batch", filename="train-batch.pt", repo_type="dataset" + ... ) + >>> batch = torch.load(file) + + >>> # Config for random mask pretraining + >>> config = PatchTSTConfig( + ... num_input_channels=7, + ... context_length=512, + ... patch_length=12, + ... stride=12, + ... mask_type='random', + ... random_mask_ratio=0.4, + ... use_cls_token=True, + ... ) + >>> # Config for forecast mask pretraining + >>> config = PatchTSTConfig( + ... num_input_channels=7, + ... context_length=512, + ... patch_length=12, + ... stride=12, + ... mask_type='forecast', + ... num_forecast_mask_patches=5, + ... use_cls_token=True, + ... ) + >>> model = PatchTSTForPretraining(config) + + >>> # during training, one provides both past and future values + >>> outputs = model(past_values=batch["past_values"]) + + >>> loss = outputs.loss + >>> loss.backward() + ```""" + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # past_values: [bs x num_channels x num_patches x d_model] or + # [bs x num_channels x (num_patches+1) x d_model] if use cls_token + model_output = self.model( + past_values=past_values, + past_observed_mask=past_observed_mask, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + return_dict=True, + ) + + # last_hidden_state: [bs x num_channels x num_patches x patch_length] or + # [bs x num_channels x (num_patches+1) x patch_length] if use cls_token + x_hat = self.head(model_output.last_hidden_state) + + # calculate masked_loss + loss = nn.MSELoss(reduction="none") + loss_val = loss(x_hat, model_output.patch_input) + masked_loss = (loss_val.mean(dim=-1) * model_output.mask).sum() / (model_output.mask.sum() + 1e-10) + + encoder_states = model_output.hidden_states + if not return_dict: + outputs = (x_hat,) + model_output[1:-4] + outputs = (masked_loss,) + outputs if masked_loss is not None else outputs + return outputs + return PatchTSTForPretrainingOutput( + loss=masked_loss, prediction_output=x_hat, hidden_states=encoder_states, attentions=model_output.attentions + ) + + +class PatchTSTClassificationHead(nn.Module): + def __init__(self, config: PatchTSTConfig): + super().__init__() + self.use_cls_token = config.use_cls_token + self.pooling_type = config.pooling_type + self.flatten = nn.Flatten(start_dim=1) + self.dropout = nn.Dropout(config.head_dropout) if config.head_dropout > 0 else nn.Identity() + self.linear = nn.Linear(config.num_input_channels * config.d_model, config.num_targets) + + def forward(self, embedding: torch.Tensor): + """ + Parameters: + embedding (`torch.Tensor` of shape `(bs, num_channels, num_patches, d_model)` or + `(bs, num_channels, num_patches+1, d_model)` if `cls_token` is set to True, *required*): + Embedding from the model + Returns: + `torch.Tensor` of shape `(bs, num_targets)` + + """ + if self.use_cls_token: + # use the first output token, pooled_embedding: bs x num_channels x d_model + pooled_embedding = embedding[:, :, 0, :] + elif self.pooling_type == "mean": + # pooled_embedding: [bs x num_channels x d_model] + pooled_embedding = embedding.mean(dim=2) + elif self.pooling_type == "max": + # pooled_embedding: [bs x num_channels x d_model] + pooled_embedding = embedding.max(dim=2).values + else: + raise ValueError(f"pooling operator {self.pooling_type} is not implemented yet") + # pooled_embedding: bs x num_channels * d_model + pooled_embedding = self.flatten(pooled_embedding) + # output: bs x n_classes + output = self.linear(self.dropout(pooled_embedding)) + return output + + +@add_start_docstrings( + "The PatchTST for classification model.", + PATCHTST_START_DOCSTRING, +) +class PatchTSTForClassification(PatchTSTPreTrainedModel): + def __init__(self, config: PatchTSTConfig): + super().__init__(config) + + # Turn off masking + if config.do_mask_input: + logger.warning("Setting `do_mask_input` parameter to False.") + config.do_mask_input = False + + self.model = PatchTSTModel(config) + self.head = PatchTSTClassificationHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + past_values: torch.Tensor, + target_values: Optional[torch.Tensor] = None, + past_observed_mask: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, PatchTSTForClassificationOutput]: + r""" + Parameters: + past_values (`torch.Tensor` of shape `(bs, sequence_length, num_input_channels)`, *required*): + Input sequence to the model + target_values (`torch.Tensor`, *optional*): + Labels associates with the `past_values` + past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*): + Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected + in `[0, 1]`: + + - 1 for values that are **observed**, + - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros). + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers + output_attentions (`bool`, *optional*): + Whether or not to return the output attention of all layers + return_dict (`bool`, *optional*): + Whether or not to return a `ModelOutput` instead of a plain tuple. + + Returns: + `PatchTSTForClassificationOutput` or tuple of `torch.Tensor` (if `return_dict`=False or + `config.return_dict`=False) + + Examples: + + ```python + >>> from transformers import PatchTSTConfig, PatchTSTForClassification + + >>> # classification task with two input channel2 and 3 classes + >>> config = PatchTSTConfig( + ... num_input_channels=2, + ... num_targets=3, + ... context_length=512, + ... patch_length=12, + ... stride=12, + ... use_cls_token=True, + ... ) + >>> model = PatchTSTForClassification(config=config) + + >>> # during inference, one only provides past values + >>> past_values = torch.randn(20, 512, 2) + >>> outputs = model(past_values=past_values) + >>> labels = outputs.prediction_logits + ```""" + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + model_output = self.model( + past_values=past_values, + past_observed_mask=past_observed_mask, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + return_dict=True, + ) + y_hat = self.head(model_output.last_hidden_state) + + loss_val = None + if target_values is not None: + loss = nn.CrossEntropyLoss() + loss_val = loss(y_hat, target_values) + + if not return_dict: + outputs = (y_hat,) + model_output[1:-3] + outputs = (loss_val,) + outputs if loss_val is not None else outputs + return outputs + return PatchTSTForClassificationOutput( + loss=loss_val, + prediction_logits=y_hat, + hidden_states=model_output.hidden_states, + attentions=model_output.attentions, + ) + + +@add_start_docstrings( + "The PatchTST for regression Model.", + PATCHTST_START_DOCSTRING, +) +class PatchTSTPredictionHead(nn.Module): + def __init__(self, config: PatchTSTConfig, num_patches, distribution_output=None): + super().__init__() + + self.share_projection = config.share_projection + self.num_input_channels = config.num_input_channels + self.use_cls_token = config.use_cls_token + self.pooling_type = config.pooling_type + if self.pooling_type or self.use_cls_token: + head_dim = config.d_model + else: + head_dim = config.d_model * num_patches + + if not self.share_projection: + # if each channel has its own head + self.projections = nn.ModuleList() + self.dropouts = nn.ModuleList() + self.flattens = nn.ModuleList() + for i in range(self.num_input_channels): + self.flattens.append(nn.Flatten(start_dim=2)) + if distribution_output is None: + # use linear head + self.projections.append(nn.Linear(head_dim, config.prediction_length)) + else: + # use distribution head + self.projections.append(distribution_output.get_parameter_projection(head_dim)) + self.dropouts.append(nn.Dropout(config.head_dropout) if config.head_dropout > 0 else nn.Identity()) + else: + # all the channels share the same head + self.flatten = nn.Flatten(start_dim=2) + if distribution_output is None: + # use linear head + self.projection = nn.Linear(head_dim, config.prediction_length) + else: + # use distribution head + self.projection = distribution_output.get_parameter_projection(head_dim) + self.dropout = nn.Dropout(config.head_dropout) if config.head_dropout > 0 else nn.Identity() + + def forward(self, embedding: torch.Tensor): + """ + Parameters: + embedding (`torch.Tensor` of shape `(bs, num_channels, num_patches, d_model)` or + `(bs, num_channels, num_patches+1, d_model)` if `cls_token` is set to True, *required*): + Embedding from the model + Returns: + `torch.Tensor` of shape `(bs, forecast_len, num_channels)` + + """ + if self.use_cls_token: + # pooled_embedding: [bs x num_channels x d_model] + pooled_embedding = embedding[:, :, 0, :] + else: + if self.pooling_type == "mean": + # pooled_embedding: [bs x num_channels x d_model] + pooled_embedding = embedding.mean(dim=2) + elif self.pooling_type == "max": + # pooled_embedding: [bs x num_channels x d_model] + pooled_embedding = embedding.max(dim=2).values + else: + # pooled_embedding: [bs x num_channels x num_patches x d_model] + pooled_embedding = embedding + + if not self.share_projection: + output = [] + for i in range(self.num_input_channels): + # pooled_embedding: [bs x (d_model * num_patches)] or [bs x d_model)] + pooled_embedding = self.flattens[i](pooled_embedding[:, i, :]) + pooled_embedding = self.dropouts[i](pooled_embedding) + # pooled_embedding: [bs x forecast_len] + # or tuple ([bs x forecast_len], [bs x forecast_len]) if using distribution head + pooled_embedding = self.projections[i](pooled_embedding) + output.append(pooled_embedding) + # output: [bs x num_channels x forecast_len] + output = torch.stack(output, dim=1) + else: + # pooled_embedding: [bs x num_channels x (d_model * num_patches)] or [bs x num_channels x d_model)] + pooled_embedding = self.flatten(pooled_embedding) + pooled_embedding = self.dropout(pooled_embedding) + # output: [bs x num_channels x forecast_len] or + # tuple ([bs x num_channels x forecast_len], [bs x num_channels x forecast_len]) if using distribution head + output = self.projection(pooled_embedding) + + if isinstance(output, tuple): + # output: ([bs x forecast_len x num_channels], [bs x forecast_len x num_channels]) + output = tuple(z.transpose(2, 1) for z in output) + else: + output = output.transpose(2, 1) # [bs x forecast_len x num_channels] + return output + + +@add_start_docstrings( + "The PatchTST for prediction model.", + PATCHTST_START_DOCSTRING, +) +class PatchTSTForPrediction(PatchTSTPreTrainedModel): + def __init__(self, config: PatchTSTConfig): + super().__init__(config) + + # Turn off masking + if config.do_mask_input: + logger.warning("Setting `do_mask_input` parameter to False.") + config.do_mask_input = False + + self.model = PatchTSTModel(config) + + if config.loss == "mse": + self.distribution_output = None + else: + if config.distribution_output == "student_t": + self.distribution_output = StudentTOutput(dim=config.prediction_length) + elif config.distribution_output == "normal": + self.distribution_output = NormalOutput(dim=config.prediction_length) + elif config.distribution_output == "negative_binomial": + self.distribution_output = NegativeBinomialOutput(dim=config.prediction_length) + else: + raise ValueError(f"Unknown distribution output {config.distribution_output}") + + self.head = PatchTSTPredictionHead( + config, self.model.patchifier.num_patches, distribution_output=self.distribution_output + ) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + past_values: torch.Tensor, + past_observed_mask: Optional[torch.Tensor] = None, + future_values: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, PatchTSTForPredictionOutput]: + r""" + Parameters: + past_values (`torch.Tensor` of shape `(bs, sequence_length, num_input_channels)`, *required*): + Input sequence to the model + past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*): + Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected + in `[0, 1]`: + + - 1 for values that are **observed**, + - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros). + future_values (`torch.Tensor` of shape `(bs, forecast_len, num_input_channels)`, *optional*): + Future target values associated with the `past_values` + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers + output_attentions (`bool`, *optional*): + Whether or not to return the output attention of all layers + return_dict (`bool`, *optional*): + Whether or not to return a `ModelOutput` instead of a plain tuple. + + Returns: + `PatchTSTForPredictionOutput` or tuple of `torch.Tensor` (if `return_dict`=False or + `config.return_dict`=False) + + Examples: + + ```python + >>> from huggingface_hub import hf_hub_download + >>> import torch + >>> from transformers import PatchTSTConfig, PatchTSTForPrediction + + >>> file = hf_hub_download( + ... repo_id="hf-internal-testing/etth1-hourly-batch", filename="train-batch.pt", repo_type="dataset" + ... ) + >>> batch = torch.load(file) + + >>> # Prediction task with 7 input channels and prediction length is 96 + >>> model = PatchTSTForPrediction.from_pretrained("namctin/patchtst_etth1_forecast") + + >>> # during training, one provides both past and future values + >>> outputs = model( + ... past_values=batch["past_values"], + ... future_values=batch["future_values"], + ... ) + + >>> loss = outputs.loss + >>> loss.backward() + + >>> # during inference, one only provides past values, the model outputs future values + >>> outputs = model(past_values=batch["past_values"]) + >>> prediction_outputs = outputs.prediction_outputs + ```""" + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # get model output + model_output = self.model( + past_values=past_values, + past_observed_mask=past_observed_mask, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + return_dict=True, + ) + # get output head + y_hat = self.head(model_output.last_hidden_state) + + loss_val = None + + if self.distribution_output: + y_hat_out = y_hat + else: + y_hat_out = y_hat * model_output.scale + model_output.loc + + if future_values is not None: + if self.distribution_output: + distribution = self.distribution_output.distribution( + y_hat, loc=model_output.loc, scale=model_output.scale + ) + loss_val = nll(distribution, future_values) + # take average of the loss + loss_val = weighted_average(loss_val) + else: + loss = nn.MSELoss(reduction="mean") + loss_val = loss(y_hat_out, future_values) + + loc = model_output.loc + scale = model_output.scale + + if not return_dict: + outputs = (y_hat_out,) + model_output[1:-1] + outputs = (loss_val,) + outputs if loss_val is not None else outputs + return outputs + return PatchTSTForPredictionOutput( + loss=loss_val, + prediction_outputs=y_hat_out, + hidden_states=model_output.hidden_states, + attentions=model_output.attentions, + loc=loc, + scale=scale, + ) + + def generate( + self, + past_values: torch.Tensor, + past_observed_mask: Optional[torch.Tensor] = None, + ) -> SamplePatchTSTOutput: + """ + Generate sequences of sample predictions from a model with a probability distribution head. + + Parameters: + past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`): + Past values of the time series that serves as context in order to predict the future. + past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*): + Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected + in `[0, 1]`: + + - 1 for values that are **observed**, + - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros). + + Return: + [`SamplePatchTSTOutput`] where the outputs `sequences` tensor will have shape `(batch_size, number of + samples, prediction_length, 1)` or `(batch_size, number of samples, prediction_length, num_input_channels)` + for multivariate predictions. + """ + # get number of samples + num_parallel_samples = self.config.num_parallel_samples + + # get model output + outputs = self( + past_values=past_values, + future_values=None, + past_observed_mask=past_observed_mask, + output_hidden_states=False, + ) + if self.distribution_output: + # get distribution + distribution = self.distribution_output.distribution( + outputs.prediction_outputs, loc=outputs.loc, scale=outputs.scale + ) + # get samples: list of [bs x forecast_len x num_channels] + samples = [distribution.sample() for _ in range(num_parallel_samples)] + # samples: [bs x num_samples x forecast_len x num_channels] + samples = torch.stack(samples, dim=1) + else: + samples = outputs.prediction_outputs.unsqueeze(1) + + return SamplePatchTSTOutput(sequences=samples) + + +class PatchTSTRegressionHead(nn.Module): + """ + Regression head + """ + + def __init__(self, config: PatchTSTConfig, distribution_output=None): + super().__init__() + self.y_range = config.output_range + self.use_cls_token = config.use_cls_token + self.pooling_type = config.pooling_type + self.distribution_output = distribution_output + + head_dim = config.num_input_channels * config.d_model + + self.flatten = nn.Flatten(start_dim=1) + self.dropout = nn.Dropout(config.head_dropout) if config.head_dropout > 0 else nn.Identity() + + if distribution_output is None: + self.projection = nn.Linear(head_dim, config.num_targets) + else: + self.projection = distribution_output.get_parameter_projection(head_dim) + + def forward(self, embedding: torch.Tensor): + """ + Parameters: + embedding (`torch.Tensor` of shape `(bs, num_channels, num_patches, d_model)` or + `(bs, num_channels, num_patches+1, d_model)` if `cls_token` is set to True, *required*): + Embedding from the model + Returns: + `torch.Tensor` of shape `(bs, output_dim)` + + """ + if self.use_cls_token: + # use the first output token, pooled_embedding: [bs x num_channels x d_model] + pooled_embedding = embedding[:, :, 0, :] + elif self.pooling_type == "mean": + # pooled_embedding: [bs x num_channels x d_model] + pooled_embedding = embedding.mean(dim=2) + elif self.pooling_type == "max": + # pooled_embedding: [bs x num_channels x d_model] + pooled_embedding = embedding.max(dim=2).values + else: + raise ValueError(f"pooling operator {self.pooling_type} is not implemented yet") + # flatten the input + # pooled_embedding: bs x (num_channels * d_model) + pooled_embedding = self.dropout(self.flatten(pooled_embedding)) + # projection + # output: bs x output_dim or a tuple of this shape for distribution head + output = self.projection(pooled_embedding) + # apply sigmoid to bound the output if required + if (self.distribution_output is None) & (self.y_range is not None): # linear head + output = torch.sigmoid(output) * (self.y_range[1] - self.y_range[0]) + self.y_range[0] + return output + + +@add_start_docstrings( + "The PatchTST for regression model.", + PATCHTST_START_DOCSTRING, +) +class PatchTSTForRegression(PatchTSTPreTrainedModel): + def __init__(self, config: PatchTSTConfig): + super().__init__(config) + + # Turn off masking + if config.do_mask_input: + logger.warning("Setting `do_mask_input` parameter to False.") + config.do_mask_input = False + + self.model = PatchTSTModel(config) + if config.loss == "mse": + self.distribution_output = None + else: + if config.distribution_output == "student_t": + self.distribution_output = StudentTOutput(dim=config.num_targets) + elif config.distribution_output == "normal": + self.distribution_output = NormalOutput(dim=config.num_targets) + elif config.distribution_output == "negative_binomial": + self.distribution_output = NegativeBinomialOutput(dim=config.num_targets) + else: + raise ValueError(f"Unknown distribution output {config.distribution_output}") + + self.head = PatchTSTRegressionHead(config, self.distribution_output) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + past_values: torch.Tensor, + target_values: Optional[torch.Tensor] = None, + past_observed_mask: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, PatchTSTForRegressionOutput]: + r""" + Parameters: + past_values (`torch.Tensor` of shape `(bs, sequence_length, num_input_channels)`, *required*): + Input sequence to the model + target_values (`torch.Tensor` of shape `(bs, num_input_channels)`): + Target values associates with the `past_values` + past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*): + Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected + in `[0, 1]`: + + - 1 for values that are **observed**, + - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros). + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers + output_attentions (`bool`, *optional*): + Whether or not to return the output attention of all layers + return_dict (`bool`, *optional*): + Whether or not to return a `ModelOutput` instead of a plain tuple. + + Returns: + `PatchTSTForRegressionOutput` or tuple of `torch.Tensor` (if `return_dict`=False or + `config.return_dict`=False) + + Examples: + + ```python + >>> from transformers import PatchTSTConfig, PatchTSTForRegression + + >>> # Regression task with 6 input channels and regress 2 targets + >>> model = PatchTSTForRegression.from_pretrained("namctin/patchtst_etth1_regression") + + >>> # during inference, one only provides past values, the model outputs future values + >>> past_values = torch.randn(20, 512, 6) + >>> outputs = model(past_values=past_values) + >>> regression_outputs = outputs.regression_outputs + ```""" + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + model_output = self.model( + past_values=past_values, + past_observed_mask=past_observed_mask, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + return_dict=True, + ) + # get output head. y_hat is of shape [bs x num_targets] or tuple of this shape + y_hat = self.head(model_output.last_hidden_state) + + loss = None + if target_values is not None: + if self.distribution_output: + distribution = self.distribution_output.distribution(y_hat) + # y_hat should be a 2-tuple, each with dimension [bs, num_targets] + y_hat = tuple([item.view(-1, self.config.num_targets) for item in y_hat]) + loss = nll(distribution, target_values) + # take average of the loss + loss = weighted_average(loss) + else: + loss = nn.MSELoss(reduction="mean") + loss = loss(y_hat, target_values) + + if not return_dict: + # hidden_states, attentions, mask + outputs = (y_hat,) + model_output[1:-3] + outputs = (loss,) + outputs if loss is not None else outputs + return outputs + return PatchTSTForRegressionOutput( + loss=loss, + regression_outputs=y_hat, + hidden_states=model_output.hidden_states, + attentions=model_output.attentions, + ) + + def generate( + self, + past_values: torch.Tensor, + past_observed_mask: Optional[torch.Tensor] = None, + ) -> SamplePatchTSTOutput: + """ + Generate sequences of sample predictions from a model with a probability distribution head. + + Parameters: + past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`): + Past values of the time series that serves as context in order to predict the future. + past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*): + Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected + in `[0, 1]`: + + - 1 for values that are **observed**, + - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros). + + Return: + [`SamplePatchTSTOutput`] where the outputs `sequences` tensor will have shape `(batch_size, number of + samples, num_targets)`. + """ + # get number of samples + num_parallel_samples = self.config.num_parallel_samples + + # get model output + outputs = self( + past_values=past_values, + target_values=None, + past_observed_mask=past_observed_mask, + output_hidden_states=False, + ) + + # get distribution + distribution = self.distribution_output.distribution(outputs.regression_outputs) + # get samples: list of [bs x num_targets] + samples = [distribution.sample() for _ in range(num_parallel_samples)] + # samples: [bs x num_samples x num_targets] + samples = torch.stack(samples, dim=1).view(-1, num_parallel_samples, self.config.num_targets) + return SamplePatchTSTOutput(sequences=samples) + + +__all__ = [ + "PatchTSTModel", + "PatchTSTPreTrainedModel", + "PatchTSTForPrediction", + "PatchTSTForPretraining", + "PatchTSTForRegression", + "PatchTSTForClassification", +] diff --git a/docs/transformers/src/transformers/models/pegasus/__init__.py b/docs/transformers/src/transformers/models/pegasus/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4903c400f9827239f7920b7e2c9bfddfda48eecd --- /dev/null +++ b/docs/transformers/src/transformers/models/pegasus/__init__.py @@ -0,0 +1,31 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_pegasus import * + from .modeling_flax_pegasus import * + from .modeling_pegasus import * + from .modeling_tf_pegasus import * + from .tokenization_pegasus import * + from .tokenization_pegasus_fast import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/pegasus/configuration_pegasus.py b/docs/transformers/src/transformers/models/pegasus/configuration_pegasus.py new file mode 100644 index 0000000000000000000000000000000000000000..23bff5d7719f45061328757c916dbfcfaaa0c365 --- /dev/null +++ b/docs/transformers/src/transformers/models/pegasus/configuration_pegasus.py @@ -0,0 +1,164 @@ +# coding=utf-8 +# Copyright 2021, Google and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PEGASUS model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class PegasusConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`PegasusModel`]. It is used to instantiate an + PEGASUS model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the PEGASUS + [google/pegasus-large](https://huggingface.co/google/pegasus-large) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50265): + Vocabulary size of the PEGASUS model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`PegasusModel`] or [`TFPegasusModel`]. + d_model (`int`, *optional*, defaults to 1024): + Dimensionality of the layers and the pooler layer. + encoder_layers (`int`, *optional*, defaults to 12): + Number of encoder layers. + decoder_layers (`int`, *optional*, defaults to 12): + Number of decoder layers. + encoder_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + decoder_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer decoder. + decoder_ffn_dim (`int`, *optional*, defaults to 4096): + Dimensionality of the "intermediate" (often named feed-forward) layer in decoder. + encoder_ffn_dim (`int`, *optional*, defaults to 4096): + Dimensionality of the "intermediate" (often named feed-forward) layer in decoder. + activation_function (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + activation_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for activations inside the fully connected layer. + max_position_embeddings (`int`, *optional*, defaults to 1024): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + init_std (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + encoder_layerdrop (`float`, *optional*, defaults to 0.0): + The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) + for more details. + decoder_layerdrop (`float`, *optional*, defaults to 0.0): + The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) + for more details. + scale_embedding (`bool`, *optional*, defaults to `False`): + Scale embeddings by diving by sqrt(d_model). + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models) + forced_eos_token_id (`int`, *optional*, defaults to 1): + The id of the token to force as the last generated token when `max_length` is reached. Usually set to + `eos_token_id`. + + Example: + + ```python + >>> from transformers import PegasusConfig, PegasusModel + + >>> # Initializing a PEGASUS google/pegasus-large style configuration + >>> configuration = PegasusConfig() + + >>> # Initializing a model (with random weights) from the google/pegasus-large style configuration + >>> model = PegasusModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "pegasus" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"} + + def __init__( + self, + vocab_size=50265, + max_position_embeddings=1024, + encoder_layers=12, + encoder_ffn_dim=4096, + encoder_attention_heads=16, + decoder_layers=12, + decoder_ffn_dim=4096, + decoder_attention_heads=16, + encoder_layerdrop=0.0, + decoder_layerdrop=0.0, + use_cache=True, + is_encoder_decoder=True, + activation_function="gelu", + d_model=1024, + dropout=0.1, + attention_dropout=0.0, + activation_dropout=0.0, + init_std=0.02, + decoder_start_token_id=0, + scale_embedding=False, + pad_token_id=0, + eos_token_id=1, + forced_eos_token_id=1, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.d_model = d_model + self.encoder_ffn_dim = encoder_ffn_dim + self.encoder_layers = encoder_layers + self.encoder_attention_heads = encoder_attention_heads + self.decoder_ffn_dim = decoder_ffn_dim + self.decoder_layers = decoder_layers + self.decoder_attention_heads = decoder_attention_heads + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.activation_function = activation_function + self.init_std = init_std + self.encoder_layerdrop = encoder_layerdrop + self.decoder_layerdrop = decoder_layerdrop + self.use_cache = use_cache + self.num_hidden_layers = encoder_layers + self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True + super().__init__( + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + is_encoder_decoder=is_encoder_decoder, + decoder_start_token_id=decoder_start_token_id, + forced_eos_token_id=forced_eos_token_id, + **kwargs, + ) + + @property + def num_attention_heads(self) -> int: + return self.encoder_attention_heads + + @property + def hidden_size(self) -> int: + return self.d_model + + +__all__ = ["PegasusConfig"] diff --git a/docs/transformers/src/transformers/models/pegasus/convert_pegasus_tf_to_pytorch.py b/docs/transformers/src/transformers/models/pegasus/convert_pegasus_tf_to_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..cf183b590c1b853099abae10ded4aa6a120fe107 --- /dev/null +++ b/docs/transformers/src/transformers/models/pegasus/convert_pegasus_tf_to_pytorch.py @@ -0,0 +1,131 @@ +# coding=utf-8 +# Copyright 2020 Google and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +from pathlib import Path +from typing import Dict + +import tensorflow as tf +import torch +from tqdm import tqdm + +from transformers import PegasusConfig, PegasusForConditionalGeneration, PegasusTokenizer +from transformers.models.pegasus.configuration_pegasus import DEFAULTS, task_specific_params + + +PATTERNS = [ + # replace left string with right string to get the relevant state_dict key (identical state dict to bart) + ["memory_attention", "encoder_attn"], + ["attention", "attn"], + ["/", "."], + [".LayerNorm.gamma", "_layer_norm.weight"], + [".LayerNorm.beta", "_layer_norm.bias"], + ["r.layer_", "r.layers."], + ["output_proj", "out_proj"], + ["ffn.dense_1.", "fc2."], + ["ffn.dense.", "fc1."], + ["ffn_layer_norm", "final_layer_norm"], + ["kernel", "weight"], + ["encoder_layer_norm.", "encoder.layer_norm."], + ["decoder_layer_norm.", "decoder.layer_norm."], + ["embeddings.weights", "shared.weight"], +] + + +def rename_state_dict_key(k): + for pegasus_name, hf_name in PATTERNS: + k = k.replace(pegasus_name, hf_name) + return k + + +# See appendix C of paper for all hyperparams + + +def convert_pegasus(tf_weights: dict, cfg_updates: dict) -> PegasusForConditionalGeneration: + cfg_kwargs = DEFAULTS.copy() + cfg_kwargs.update(cfg_updates) + cfg = PegasusConfig(**cfg_kwargs) + torch_model = PegasusForConditionalGeneration(cfg) + sd = torch_model.model.state_dict() + mapping = {} + for k, v in tf_weights.items(): + new_k = rename_state_dict_key(k) + if new_k not in sd: + raise ValueError(f"could not find new key {new_k} in state dict. (converted from {k})") + + if "dense" in k or "proj" in new_k: + v = v.T + mapping[new_k] = torch.tensor(v, dtype=sd[new_k].dtype) + assert v.shape == sd[new_k].shape, f"{new_k}, {k}, {v.shape}, {sd[new_k].shape}" + # make sure embedding.padding_idx is respected + mapping["shared.weight"][cfg.pad_token_id] = torch.zeros_like(mapping["shared.weight"][cfg.pad_token_id + 1]) + mapping["encoder.embed_tokens.weight"] = mapping["shared.weight"] + mapping["decoder.embed_tokens.weight"] = mapping["shared.weight"] + empty_biases = {k: torch.zeros_like(v) for k, v in sd.items() if k.endswith("bias") and k not in mapping} + mapping.update(**empty_biases) + missing, extra = torch_model.model.load_state_dict(mapping, strict=False) + unexpected_missing = [ + k for k in missing if k not in ["encoder.embed_positions.weight", "decoder.embed_positions.weight"] + ] + assert unexpected_missing == [], f"no matches found for the following torch keys {unexpected_missing}" + assert extra == [], f"no matches found for the following tf keys {extra}" + return torch_model + + +def get_tf_weights_as_numpy(path="./ckpt/aeslc/model.ckpt-32000") -> Dict: + init_vars = tf.train.list_variables(path) + tf_weights = {} + ignore_name = ["Adafactor", "global_step"] + for name, shape in tqdm(init_vars, desc="converting tf checkpoint to dict"): + skip_key = any(pat in name for pat in ignore_name) + if skip_key: + continue + array = tf.train.load_variable(path, name) + tf_weights[name] = array + return tf_weights + + +def convert_pegasus_ckpt_to_pytorch(ckpt_path: str, save_dir: str): + # save tokenizer first + dataset = Path(ckpt_path).parent.name + desired_max_model_length = task_specific_params[f"summarization_{dataset}"]["max_position_embeddings"] + tok = PegasusTokenizer.from_pretrained("sshleifer/pegasus", model_max_length=desired_max_model_length) + assert tok.model_max_length == desired_max_model_length + tok.save_pretrained(save_dir) + + # convert model + tf_weights = get_tf_weights_as_numpy(ckpt_path) + cfg_updates = task_specific_params[f"summarization_{dataset}"] + if dataset == "large": + cfg_updates["task_specific_params"] = task_specific_params + torch_model = convert_pegasus(tf_weights, cfg_updates) + torch_model.save_pretrained(save_dir) + sd = torch_model.state_dict() + sd.pop("model.decoder.embed_positions.weight") + sd.pop("model.encoder.embed_positions.weight") + torch.save(sd, Path(save_dir) / "pytorch_model.bin") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument("tf_ckpt_path", type=str, help="passed to tf.train.list_variables") + parser.add_argument("save_dir", default=None, type=str, help="Path to the output PyTorch model.") + args = parser.parse_args() + if args.save_dir is None: + dataset = Path(args.tf_ckpt_path).parent.name + args.save_dir = os.path.join("pegasus", dataset) + convert_pegasus_ckpt_to_pytorch(args.tf_ckpt_path, args.save_dir) diff --git a/docs/transformers/src/transformers/models/pegasus/modeling_flax_pegasus.py b/docs/transformers/src/transformers/models/pegasus/modeling_flax_pegasus.py new file mode 100644 index 0000000000000000000000000000000000000000..bd450698937ccb8e52c1d02ca2cb3be750a7823f --- /dev/null +++ b/docs/transformers/src/transformers/models/pegasus/modeling_flax_pegasus.py @@ -0,0 +1,1532 @@ +# coding=utf-8 +# Copyright 2021, Google and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Flax PEGASUS model.""" + +import math +import random +from functools import partial +from typing import Callable, Optional, Tuple + +import flax.linen as nn +import jax +import jax.numpy as jnp +import numpy as np +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze +from flax.linen import combine_masks, make_causal_mask +from flax.linen.attention import dot_product_attention_weights +from flax.traverse_util import flatten_dict, unflatten_dict +from jax import lax +from jax.random import PRNGKey + +from ...modeling_flax_outputs import ( + FlaxBaseModelOutput, + FlaxBaseModelOutputWithPastAndCrossAttentions, + FlaxCausalLMOutputWithCrossAttentions, + FlaxSeq2SeqLMOutput, + FlaxSeq2SeqModelOutput, +) +from ...modeling_flax_utils import ( + ACT2FN, + FlaxPreTrainedModel, + add_start_docstrings_to_model_forward, + append_call_sample_docstring, + append_replace_return_docstrings, + overwrite_call_docstring, +) +from ...utils import add_start_docstrings, logging, replace_return_docstrings +from .configuration_pegasus import PegasusConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "google/pegasus-large" +_CONFIG_FOR_DOC = "PegasusConfig" + +PEGASUS_START_DOCSTRING = r""" + This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a Flax Linen + [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a + regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior. + + Finally, this model supports inherent JAX features such as: + + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + + Parameters: + config ([`PegasusConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. + dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): + The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and + `jax.numpy.bfloat16` (on TPUs). + + This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If + specified all the computation will be performed with the given `dtype`. + + **Note that this only specifies the dtype of the computation and does not influence the dtype of model + parameters.** + + If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and + [`~FlaxPreTrainedModel.to_bf16`]. +""" + +PEGASUS_INPUTS_DOCSTRING = r""" + Args: + input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + + If you want to change padding behavior, you should modify to your needs. See diagram 1 in [the + paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy. + position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the + range `[0, config.max_position_embeddings - 1]`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +PEGASUS_ENCODE_INPUTS_DOCSTRING = r""" + Args: + input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +PEGASUS_DECODE_INPUTS_DOCSTRING = r""" + Args: + decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + encoder_outputs (`tuple(tuple(jnp.ndarray)`): + Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of + hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. + encoder_attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + + If you want to change padding behavior, you should modify to your needs. See diagram 1 in [the + paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy. + decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the + range `[0, config.max_position_embeddings - 1]`. + past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`): + Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast + auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right +def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: + """ + Shift input ids one token to the right. + """ + shifted_input_ids = jnp.zeros_like(input_ids) + shifted_input_ids = shifted_input_ids.at[:, 1:].set(input_ids[:, :-1]) + shifted_input_ids = shifted_input_ids.at[:, 0].set(decoder_start_token_id) + + shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids) + return shifted_input_ids + + +# Copied from transformers.models.marian.modeling_flax_marian.create_sinusoidal_positions +def create_sinusoidal_positions(n_pos, dim): + position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]) + sentinel = dim // 2 + dim % 2 + out = np.zeros_like(position_enc) + out[:, 0:sentinel] = np.sin(position_enc[:, 0::2]) + out[:, sentinel:] = np.cos(position_enc[:, 1::2]) + + return jnp.array(out) + + +# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention with Bart->Pegasus +class FlaxPegasusAttention(nn.Module): + config: PegasusConfig + embed_dim: int + num_heads: int + dropout: float = 0.0 + causal: bool = False + bias: bool = True + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self) -> None: + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {self.num_heads})." + ) + + dense = partial( + nn.Dense, + self.embed_dim, + use_bias=self.bias, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.init_std), + ) + + self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense() + self.out_proj = dense() + + self.dropout_layer = nn.Dropout(rate=self.dropout) + + if self.causal: + self.causal_mask = make_causal_mask( + jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool" + ) + + def _split_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim)) + + def _merge_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,)) + + @nn.compact + def _concatenate_to_cache(self, key, value, query, attention_mask): + """ + This function takes projected key, value states from a single input token and concatenates the states to cached + states from previous steps. This function is slightly adapted from the official Flax repository: + https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 + """ + # detect if we're initializing by absence of existing cache data. + is_initialized = self.has_variable("cache", "cached_key") + cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) + cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) + cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) + + if is_initialized: + *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape + # update key, value caches with our new 1d spatial slices + cur_index = cache_index.value + indices = (0,) * len(batch_dims) + (cur_index, 0, 0) + key = lax.dynamic_update_slice(cached_key.value, key, indices) + value = lax.dynamic_update_slice(cached_value.value, value, indices) + cached_key.value = key + cached_value.value = value + num_updated_cache_vectors = query.shape[1] + cache_index.value = cache_index.value + num_updated_cache_vectors + # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements. + pad_mask = jnp.broadcast_to( + jnp.arange(max_length) < cur_index + num_updated_cache_vectors, + tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), + ) + attention_mask = combine_masks(pad_mask, attention_mask) + return key, value, attention_mask + + def __call__( + self, + hidden_states: jnp.ndarray, + key_value_states: Optional[jnp.ndarray] = None, + attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + deterministic: bool = True, + ) -> Tuple[jnp.ndarray]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + batch_size = hidden_states.shape[0] + + # get query proj + query_states = self.q_proj(hidden_states) + # get key, value proj + if is_cross_attention: + # cross_attentions + key_states = self.k_proj(key_value_states) + value_states = self.v_proj(key_value_states) + else: + # self_attention + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = self._split_heads(query_states) + key_states = self._split_heads(key_states) + value_states = self._split_heads(value_states) + + # handle cache prepare causal attention mask + if self.causal: + query_length, key_length = query_states.shape[1], key_states.shape[1] + if self.has_variable("cache", "cached_key"): + mask_shift = self.variables["cache"]["cache_index"] + max_decoder_length = self.variables["cache"]["cached_key"].shape[1] + causal_mask = lax.dynamic_slice( + self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length) + ) + else: + causal_mask = self.causal_mask[:, :, :query_length, :key_length] + causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) + + # combine masks if needed + if attention_mask is not None and self.causal: + attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) + attention_mask = combine_masks(attention_mask, causal_mask) + elif self.causal: + attention_mask = causal_mask + elif attention_mask is not None: + attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) + + # During fast autoregressive decoding, we feed one position at a time, + # and cache the keys and values step by step. + if self.causal and (self.has_variable("cache", "cached_key") or init_cache): + key_states, value_states, attention_mask = self._concatenate_to_cache( + key_states, value_states, query_states, attention_mask + ) + + # Convert the boolean attention mask to an attention bias. + if attention_mask is not None: + # attention mask in the form of attention bias + attention_bias = lax.select( + attention_mask > 0, + jnp.full(attention_mask.shape, 0.0).astype(self.dtype), + jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), + ) + else: + attention_bias = None + + dropout_rng = None + if not deterministic and self.dropout > 0.0: + dropout_rng = self.make_rng("dropout") + + attn_weights = dot_product_attention_weights( + query_states, + key_states, + bias=attention_bias, + dropout_rng=dropout_rng, + dropout_rate=self.dropout, + broadcast_dropout=True, + deterministic=deterministic, + dtype=self.dtype, + precision=None, + ) + + attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) + attn_output = self._merge_heads(attn_output) + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights + + +# Copied from transformers.models.mbart.modeling_flax_mbart.FlaxMBartEncoderLayer with MBart->Pegasus +class FlaxPegasusEncoderLayer(nn.Module): + config: PegasusConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self) -> None: + self.embed_dim = self.config.d_model + self.self_attn = FlaxPegasusAttention( + config=self.config, + embed_dim=self.embed_dim, + num_heads=self.config.encoder_attention_heads, + dropout=self.config.attention_dropout, + dtype=self.dtype, + ) + self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) + self.dropout_layer = nn.Dropout(rate=self.config.dropout) + self.activation_fn = ACT2FN[self.config.activation_function] + self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout) + self.fc1 = nn.Dense( + self.config.encoder_ffn_dim, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.init_std), + ) + self.fc2 = nn.Dense( + self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std) + ) + self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) + + def __call__( + self, + hidden_states: jnp.ndarray, + attention_mask: jnp.ndarray, + output_attentions: bool = True, + deterministic: bool = True, + ) -> Tuple[jnp.ndarray]: + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states, attn_weights = self.self_attn(hidden_states=hidden_states, attention_mask=attention_mask) + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = self.fc2(hidden_states) + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartEncoderLayerCollection with Bart->Pegasus +class FlaxPegasusEncoderLayerCollection(nn.Module): + config: PegasusConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.layers = [ + FlaxPegasusEncoderLayer(self.config, name=str(i), dtype=self.dtype) + for i in range(self.config.encoder_layers) + ] + self.layerdrop = self.config.encoder_layerdrop + + def __call__( + self, + hidden_states, + attention_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + all_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + for encoder_layer in self.layers: + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = random.uniform(0, 1) + if not deterministic and (dropout_probability < self.layerdrop): # skip the layer + layer_outputs = (None, None) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + output_attentions, + deterministic, + ) + hidden_states = layer_outputs[0] + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + outputs = (hidden_states, all_hidden_states, all_attentions) + + if not return_dict: + return tuple(v for v in outputs if v is not None) + + return FlaxBaseModelOutput( + last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions + ) + + +# Copied from transformers.models.mbart.modeling_flax_mbart.FlaxMBartDecoderLayer with MBart->Pegasus +class FlaxPegasusDecoderLayer(nn.Module): + config: PegasusConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self) -> None: + self.embed_dim = self.config.d_model + self.self_attn = FlaxPegasusAttention( + config=self.config, + embed_dim=self.embed_dim, + num_heads=self.config.decoder_attention_heads, + dropout=self.config.attention_dropout, + causal=True, + dtype=self.dtype, + ) + self.dropout_layer = nn.Dropout(rate=self.config.dropout) + self.activation_fn = ACT2FN[self.config.activation_function] + self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout) + + self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) + self.encoder_attn = FlaxPegasusAttention( + config=self.config, + embed_dim=self.embed_dim, + num_heads=self.config.decoder_attention_heads, + dropout=self.config.attention_dropout, + dtype=self.dtype, + ) + self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) + self.fc1 = nn.Dense( + self.config.decoder_ffn_dim, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.init_std), + ) + self.fc2 = nn.Dense( + self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std) + ) + self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) + + def __call__( + self, + hidden_states: jnp.ndarray, + attention_mask: jnp.ndarray, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + output_attentions: bool = True, + deterministic: bool = True, + ) -> Tuple[jnp.ndarray]: + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, attention_mask=attention_mask, init_cache=init_cache + ) + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = residual + hidden_states + + # Cross-Attention Block + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + + hidden_states = self.encoder_attn_layer_norm(hidden_states) + hidden_states, cross_attn_weights = self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + ) + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = self.fc2(hidden_states) + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + return outputs + + +# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartDecoderLayerCollection with Bart->Pegasus +class FlaxPegasusDecoderLayerCollection(nn.Module): + config: PegasusConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.layers = [ + FlaxPegasusDecoderLayer(self.config, name=str(i), dtype=self.dtype) + for i in range(self.config.decoder_layers) + ] + self.layerdrop = self.config.decoder_layerdrop + + def __call__( + self, + hidden_states, + attention_mask, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + deterministic: bool = True, + init_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = random.uniform(0, 1) + if not deterministic and (dropout_probability < self.layerdrop): + layer_outputs = (None, None, None) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + init_cache=init_cache, + output_attentions=output_attentions, + deterministic=deterministic, + ) + + hidden_states = layer_outputs[0] + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + outputs = [hidden_states, all_hidden_states, all_self_attns, all_cross_attentions] + + if not return_dict: + return tuple(v for v in outputs if v is not None) + + return FlaxBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + +class FlaxPegasusEncoder(nn.Module): + config: PegasusConfig + embed_tokens: nn.Embed + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dropout_layer = nn.Dropout(rate=self.config.dropout) + + embed_dim = self.config.d_model + self.padding_idx = self.config.pad_token_id + self.max_source_positions = self.config.max_position_embeddings + self.embed_scale = math.sqrt(embed_dim) if self.config.scale_embedding else 1.0 + + self.embed_positions = create_sinusoidal_positions(self.config.max_position_embeddings, embed_dim) + self.layers = FlaxPegasusEncoderLayerCollection(self.config, self.dtype) + self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) + + def __call__( + self, + input_ids, + attention_mask, + position_ids, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + ): + input_shape = input_ids.shape + input_ids = input_ids.reshape(-1, input_shape[-1]) + + inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + + # embed positions + embed_pos = jnp.take(self.embed_positions, position_ids, axis=0) + # explicitly cast the positions here, since self.embed_positions are not registered as parameters + embed_pos = embed_pos.astype(inputs_embeds.dtype) + + hidden_states = inputs_embeds + embed_pos + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + outputs = self.layers( + hidden_states, + attention_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + last_hidden_state = outputs[0] + last_hidden_state = self.layer_norm(last_hidden_state) + + # update the last element in `hidden_states` after applying `layernorm` above + hidden_states = None + if output_hidden_states: + hidden_states = outputs[1] + hidden_states = hidden_states[:-1] + (last_hidden_state,) + + if not return_dict: + outputs = (last_hidden_state, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:]) + return tuple(v for v in outputs if v is not None) + + return FlaxBaseModelOutput( + last_hidden_state=last_hidden_state, + hidden_states=hidden_states, + attentions=outputs.attentions, + ) + + +class FlaxPegasusDecoder(nn.Module): + config: PegasusConfig + embed_tokens: nn.Embed + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dropout_layer = nn.Dropout(rate=self.config.dropout) + + embed_dim = self.config.d_model + self.padding_idx = self.config.pad_token_id + self.max_target_positions = self.config.max_position_embeddings + self.embed_scale = math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0 + + self.embed_positions = create_sinusoidal_positions(self.config.max_position_embeddings, embed_dim) + + self.layers = FlaxPegasusDecoderLayerCollection(self.config, self.dtype) + self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) + + def __call__( + self, + input_ids, + attention_mask, + position_ids, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + ): + input_shape = input_ids.shape + input_ids = input_ids.reshape(-1, input_shape[-1]) + + inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + + # embed positions + positions = jnp.take(self.embed_positions, position_ids, axis=0) + # explicitly cast the positions here, since self.embed_positions are not registered as parameters + positions = positions.astype(inputs_embeds.dtype) + + hidden_states = inputs_embeds + positions + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + outputs = self.layers( + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + deterministic=deterministic, + init_cache=init_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + last_hidden_state = outputs[0] + last_hidden_state = self.layer_norm(last_hidden_state) + + # update the last element in `hidden_states` after applying `layernorm` above + hidden_states = None + if output_hidden_states: + hidden_states = outputs[1] + hidden_states = hidden_states[:-1] + (last_hidden_state,) + + if not return_dict: + outputs = (last_hidden_state, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:]) + return tuple(v for v in outputs if v is not None) + + return FlaxBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=last_hidden_state, + hidden_states=hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartModule with Bart->Pegasus +class FlaxPegasusModule(nn.Module): + config: PegasusConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.shared = nn.Embed( + self.config.vocab_size, + self.config.d_model, + embedding_init=jax.nn.initializers.normal(self.config.init_std), + dtype=self.dtype, + ) + + self.encoder = FlaxPegasusEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared) + self.decoder = FlaxPegasusDecoder(self.config, dtype=self.dtype, embed_tokens=self.shared) + + def _get_encoder_module(self): + return self.encoder + + def _get_decoder_module(self): + return self.decoder + + def __call__( + self, + input_ids, + attention_mask, + decoder_input_ids, + decoder_attention_mask, + position_ids, + decoder_position_ids, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + ): + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + position_ids=decoder_position_ids, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return FlaxSeq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +class FlaxPegasusPreTrainedModel(FlaxPreTrainedModel): + config_class = PegasusConfig + base_model_prefix: str = "model" + module_class: nn.Module = None + + def __init__( + self, + config: PegasusConfig, + input_shape: Tuple[int] = (1, 1), + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, + **kwargs, + ): + module = self.module_class(config=config, dtype=dtype, **kwargs) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: + # init input tensors + input_ids = jnp.zeros(input_shape, dtype="i4") + attention_mask = jnp.ones_like(input_ids) + decoder_input_ids = input_ids + decoder_attention_mask = jnp.ones_like(input_ids) + + batch_size, sequence_length = input_ids.shape + position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) + decoder_position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) + + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + random_params = self.module.init( + rngs, + input_ids, + attention_mask, + decoder_input_ids, + decoder_attention_mask, + position_ids, + decoder_position_ids, + )["params"] + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + + def init_cache(self, batch_size, max_length, encoder_outputs): + r""" + Args: + batch_size (`int`): + batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. + max_length (`int`): + maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized + cache. + encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`): + `encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: + `attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) + is a sequence of hidden-states at the output of the last layer of the encoder. Used in the + cross-attention of the decoder. + """ + # init input variables to retrieve cache + decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4") + decoder_attention_mask = jnp.ones_like(decoder_input_ids) + decoder_position_ids = jnp.broadcast_to( + jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape + ) + + def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): + decoder_module = module._get_decoder_module() + return decoder_module( + decoder_input_ids, + decoder_attention_mask, + decoder_position_ids, + **kwargs, + ) + + init_variables = self.module.init( + jax.random.PRNGKey(0), + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + decoder_position_ids=decoder_position_ids, + encoder_hidden_states=encoder_outputs[0], + init_cache=True, + method=_decoder_forward, # we only need to call the decoder to init the cache + ) + return unfreeze(init_variables["cache"]) + + @add_start_docstrings(PEGASUS_ENCODE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=PegasusConfig) + def encode( + self, + input_ids: jnp.ndarray, + attention_mask: Optional[jnp.ndarray] = None, + position_ids: Optional[jnp.ndarray] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: dict = None, + dropout_rng: PRNGKey = None, + ): + r""" + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, FlaxPegasusForConditionalGeneration + + >>> model = FlaxPegasusForConditionalGeneration.from_pretrained("google/pegasus-large") + >>> tokenizer = AutoTokenizer.from_pretrained("google/pegasus-large") + + >>> text = "My friends are cool but they eat too many carbs." + >>> inputs = tokenizer(text, max_length=1024, return_tensors="np") + >>> encoder_outputs = model.encode(**inputs) + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + if position_ids is None: + batch_size, sequence_length = input_ids.shape + position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + def _encoder_forward(module, input_ids, attention_mask, position_ids, **kwargs): + encode_module = module._get_encoder_module() + return encode_module(input_ids, attention_mask, position_ids, **kwargs) + + return self.module.apply( + {"params": params or self.params}, + input_ids=jnp.array(input_ids, dtype="i4"), + attention_mask=jnp.array(attention_mask, dtype="i4"), + position_ids=jnp.array(position_ids, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + method=_encoder_forward, + ) + + @add_start_docstrings(PEGASUS_DECODE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=FlaxBaseModelOutputWithPastAndCrossAttentions, config_class=PegasusConfig) + def decode( + self, + decoder_input_ids, + encoder_outputs, + encoder_attention_mask: Optional[jnp.ndarray] = None, + decoder_attention_mask: Optional[jnp.ndarray] = None, + decoder_position_ids: Optional[jnp.ndarray] = None, + past_key_values: dict = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: dict = None, + dropout_rng: PRNGKey = None, + ): + r""" + Returns: + + Example: + + ```python + >>> import jax.numpy as jnp + >>> from transformers import AutoTokenizer, FlaxPegasusForConditionalGeneration + + >>> model = FlaxPegasusForConditionalGeneration.from_pretrained("google/pegasus-large") + >>> tokenizer = AutoTokenizer.from_pretrained("google/pegasus-large") + + >>> text = "My friends are cool but they eat too many carbs." + >>> inputs = tokenizer(text, max_length=1024, return_tensors="np") + >>> encoder_outputs = model.encode(**inputs) + + >>> decoder_start_token_id = model.config.decoder_start_token_id + >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id + + >>> outputs = model.decode(decoder_input_ids, encoder_outputs) + >>> last_decoder_hidden_states = outputs.last_hidden_state + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + encoder_hidden_states = encoder_outputs[0] + if encoder_attention_mask is None: + batch_size, sequence_length = encoder_hidden_states.shape[:2] + encoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + batch_size, sequence_length = decoder_input_ids.shape + if decoder_attention_mask is None: + decoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + if decoder_position_ids is None: + if past_key_values is not None: + raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.") + + decoder_position_ids = jnp.broadcast_to( + jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) + ) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + inputs = {"params": params or self.params} + + # if past_key_values are passed then cache is already initialized a private flag init_cache has to be + # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that + # it can be changed by FlaxPegasusAttention module + if past_key_values: + inputs["cache"] = past_key_values + mutable = ["cache"] + else: + mutable = False + + def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): + decoder_module = module._get_decoder_module() + return decoder_module( + decoder_input_ids, + decoder_attention_mask, + decoder_position_ids, + **kwargs, + ) + + outputs = self.module.apply( + inputs, + decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), + decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), + decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + mutable=mutable, + method=_decoder_forward, + ) + + # add updated cache to model output + if past_key_values is not None and return_dict: + outputs, past = outputs + outputs["past_key_values"] = unfreeze(past["cache"]) + return outputs + elif past_key_values is not None and not return_dict: + outputs, past = outputs + outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] + + return outputs + + @add_start_docstrings_to_model_forward(PEGASUS_INPUTS_DOCSTRING) + def __call__( + self, + input_ids: jnp.ndarray, + attention_mask: Optional[jnp.ndarray] = None, + decoder_input_ids: Optional[jnp.ndarray] = None, + decoder_attention_mask: Optional[jnp.ndarray] = None, + position_ids: Optional[jnp.ndarray] = None, + decoder_position_ids: Optional[jnp.ndarray] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: dict = None, + dropout_rng: PRNGKey = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + # prepare encoder inputs + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + if position_ids is None: + batch_size, sequence_length = input_ids.shape + position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) + + # prepare decoder inputs + if decoder_input_ids is None: + decoder_input_ids = shift_tokens_right( + input_ids, self.config.pad_token_id, decoder_start_token_id=self.config.decoder_start_token_id + ) + if decoder_attention_mask is None: + decoder_attention_mask = jnp.ones_like(decoder_input_ids) + if decoder_position_ids is None: + batch_size, sequence_length = decoder_input_ids.shape + decoder_position_ids = jnp.broadcast_to( + jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) + ) + + # Handle any PRNG if needed + rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} + + return self.module.apply( + {"params": params or self.params}, + input_ids=jnp.array(input_ids, dtype="i4"), + attention_mask=jnp.array(attention_mask, dtype="i4"), + position_ids=jnp.array(position_ids, dtype="i4"), + decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), + decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), + decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + ) + + +@add_start_docstrings( + "The bare Pegasus Model transformer outputting raw hidden-states without any specific head on top.", + PEGASUS_START_DOCSTRING, +) +class FlaxPegasusModel(FlaxPegasusPreTrainedModel): + config: PegasusConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + module_class = FlaxPegasusModule + + +append_call_sample_docstring(FlaxPegasusModel, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC) + + +# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartForConditionalGenerationModule with Bart->Pegasus +class FlaxPegasusForConditionalGenerationModule(nn.Module): + config: PegasusConfig + dtype: jnp.dtype = jnp.float32 + bias_init: Callable[..., jnp.ndarray] = jax.nn.initializers.zeros + + def setup(self): + self.model = FlaxPegasusModule(config=self.config, dtype=self.dtype) + self.lm_head = nn.Dense( + self.model.shared.num_embeddings, + use_bias=False, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.init_std), + ) + self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.model.shared.num_embeddings)) + + def _get_encoder_module(self): + return self.model.encoder + + def _get_decoder_module(self): + return self.model.decoder + + def __call__( + self, + input_ids, + attention_mask, + decoder_input_ids, + decoder_attention_mask, + position_ids, + decoder_position_ids, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + ): + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + position_ids=position_ids, + decoder_position_ids=decoder_position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + hidden_states = outputs[0] + + if self.config.tie_word_embeddings: + shared_embedding = self.model.variables["params"]["shared"]["embedding"] + lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) + else: + lm_logits = self.lm_head(hidden_states) + + lm_logits += jax.lax.stop_gradient(self.final_logits_bias.astype(self.dtype)) + + if not return_dict: + output = (lm_logits,) + outputs[1:] + return output + + return FlaxSeq2SeqLMOutput( + logits=lm_logits, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + +@add_start_docstrings( + "The PEGASUS Model with a language modeling head. Can be used for summarization.", PEGASUS_START_DOCSTRING +) +class FlaxPegasusForConditionalGeneration(FlaxPegasusPreTrainedModel): + module_class = FlaxPegasusForConditionalGenerationModule + dtype: jnp.dtype = jnp.float32 + + @add_start_docstrings(PEGASUS_DECODE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=PegasusConfig) + def decode( + self, + decoder_input_ids, + encoder_outputs, + encoder_attention_mask: Optional[jnp.ndarray] = None, + decoder_attention_mask: Optional[jnp.ndarray] = None, + decoder_position_ids: Optional[jnp.ndarray] = None, + past_key_values: dict = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + deterministic: bool = True, + params: dict = None, + dropout_rng: PRNGKey = None, + ): + r""" + Returns: + + Example: + + ```python + >>> import jax.numpy as jnp + >>> from transformers import AutoTokenizer, FlaxPegasusForConditionalGeneration + + >>> model = FlaxPegasusForConditionalGeneration.from_pretrained("google/pegasus-large") + >>> tokenizer = AutoTokenizer.from_pretrained("google/pegasus-large") + + >>> text = "My friends are cool but they eat too many carbs." + >>> inputs = tokenizer(text, max_length=1024, return_tensors="np") + >>> encoder_outputs = model.encode(**inputs) + + >>> decoder_start_token_id = model.config.decoder_start_token_id + >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id + + >>> outputs = model.decode(decoder_input_ids, encoder_outputs) + >>> logits = outputs.logits + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + encoder_hidden_states = encoder_outputs[0] + if encoder_attention_mask is None: + batch_size, sequence_length = encoder_hidden_states.shape[:2] + encoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + batch_size, sequence_length = decoder_input_ids.shape + if decoder_attention_mask is None: + decoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + if decoder_position_ids is None: + if past_key_values is not None: + raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.") + + decoder_position_ids = jnp.broadcast_to( + jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) + ) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + inputs = {"params": params or self.params} + + # if past_key_values are passed then cache is already initialized a private flag init_cache has to be + # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that + # it can be changed by FlaxPegasusAttention module + if past_key_values: + inputs["cache"] = past_key_values + mutable = ["cache"] + else: + mutable = False + + def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): + decoder_module = module._get_decoder_module() + outputs = decoder_module( + decoder_input_ids, + decoder_attention_mask, + decoder_position_ids, + **kwargs, + ) + hidden_states = outputs[0] + + if self.config.tie_word_embeddings: + shared_embedding = module.model.variables["params"]["shared"]["embedding"] + lm_logits = module.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) + else: + lm_logits = module.lm_head(hidden_states) + + lm_logits += module.final_logits_bias.astype(self.dtype) + return lm_logits, outputs + + outputs = self.module.apply( + inputs, + decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), + decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), + decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + rngs=rngs, + mutable=mutable, + method=_decoder_forward, + ) + + if past_key_values is None: + lm_logits, decoder_outputs = outputs + else: + (lm_logits, decoder_outputs), past = outputs + + if return_dict: + outputs = FlaxCausalLMOutputWithCrossAttentions( + logits=lm_logits, + hidden_states=decoder_outputs.hidden_states, + attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + ) + else: + outputs = (lm_logits,) + decoder_outputs[1:] + + # add updated cache to model output + if past_key_values is not None and return_dict: + outputs["past_key_values"] = unfreeze(past["cache"]) + return outputs + elif past_key_values is not None and not return_dict: + outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] + + return outputs + + def prepare_inputs_for_generation( + self, + decoder_input_ids, + max_length, + attention_mask: Optional[jax.Array] = None, + decoder_attention_mask: Optional[jax.Array] = None, + encoder_outputs=None, + **kwargs, + ): + # initializing the cache + batch_size, seq_length = decoder_input_ids.shape + + past_key_values = self.init_cache(batch_size, max_length, encoder_outputs) + # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. + # But since the decoder uses a causal mask, those positions are masked anyways. + # Thus we can create a single static attention_mask here, which is more efficient for compilation + extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") + if decoder_attention_mask is not None: + position_ids = decoder_attention_mask.cumsum(axis=-1) - 1 + extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0)) + else: + position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) + + return { + "past_key_values": past_key_values, + "encoder_outputs": encoder_outputs, + "encoder_attention_mask": attention_mask, + "decoder_attention_mask": extended_attention_mask, + "decoder_position_ids": position_ids, + } + + def update_inputs_for_generation(self, model_outputs, model_kwargs): + model_kwargs["past_key_values"] = model_outputs.past_key_values + model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1 + return model_kwargs + + +FLAX_PEGASUS_CONDITIONAL_GENERATION_DOCSTRING = """ + Returns: + + Summarization example: + + ```pyton + >>> from transformers import AutoTokenizer, FlaxPegasusForConditionalGeneration + + >>> model = FlaxPegasusForConditionalGeneration.from_pretrained('google/pegasus-large') + >>> tokenizer = AutoTokenizer.from_pretrained('google/pegasus-large') + + >>> ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs." + >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='np') + + >>> # Generate Summary + >>> summary_ids = model.generate(inputs['input_ids']).sequences + >>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)) + ``` + + Mask filling example: + + ```python + >>> from transformers import AutoTokenizer, FlaxPegasusForConditionalGeneration + + >>> tokenizer = AutoTokenizer.from_pretrained("google/pegasus-large") + >>> TXT = "My friends are but they eat too many carbs." + + >>> model = FlaxPegasusForConditionalGeneration.from_pretrained("google/pegasus-large") + >>> input_ids = tokenizer([TXT], return_tensors="np")["input_ids"] + >>> logits = model(input_ids).logits + + >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item() + >>> probs = jax.nn.softmax(logits[0, masked_index], axis=0) + >>> values, predictions = jax.lax.top_k(probs) + + >>> tokenizer.decode(predictions).split() + ``` +""" + +overwrite_call_docstring( + FlaxPegasusForConditionalGeneration, PEGASUS_INPUTS_DOCSTRING + FLAX_PEGASUS_CONDITIONAL_GENERATION_DOCSTRING +) +append_replace_return_docstrings( + FlaxPegasusForConditionalGeneration, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC +) + + +__all__ = ["FlaxPegasusForConditionalGeneration", "FlaxPegasusModel", "FlaxPegasusPreTrainedModel"] diff --git a/docs/transformers/src/transformers/models/pegasus/modeling_pegasus.py b/docs/transformers/src/transformers/models/pegasus/modeling_pegasus.py new file mode 100644 index 0000000000000000000000000000000000000000..b4c5a0acdcd8b8394459b244e845be7b53f63485 --- /dev/null +++ b/docs/transformers/src/transformers/models/pegasus/modeling_pegasus.py @@ -0,0 +1,1633 @@ +# coding=utf-8 +# Copyright 2021, Google and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch PEGASUS model.""" + +import copy +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss + +from ...activations import ACT2FN +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, +) +from ...modeling_utils import PreTrainedModel +from ...utils import ( + add_end_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from .configuration_pegasus import PegasusConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "google/pegasus-large" +_CONFIG_FOR_DOC = "PegasusConfig" + + +# Copied from transformers.models.bart.modeling_bart.shift_tokens_right +def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): + """ + Shift input ids one token to the right. + """ + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() + shifted_input_ids[:, 0] = decoder_start_token_id + + if pad_token_id is None: + raise ValueError("self.model.config.pad_token_id has to be defined.") + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +# Copied from transformers.models.marian.modeling_marian.MarianSinusoidalPositionalEmbedding with Marian->Pegasus +class PegasusSinusoidalPositionalEmbedding(nn.Embedding): + """This module produces sinusoidal positional embeddings of any length.""" + + def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None) -> None: + super().__init__(num_positions, embedding_dim) + + def _init_weight(self): + """ + Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in + the 2nd half of the vector. [dim // 2:] + """ + n_pos, dim = self.weight.shape + position_enc = np.array( + [[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)] + ) + out = torch.empty(n_pos, dim, dtype=self.weight.dtype, requires_grad=False) + sentinel = dim // 2 if dim % 2 == 0 else (dim // 2) + 1 + out[:, 0:sentinel] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) + out[:, sentinel:] = torch.FloatTensor(np.cos(position_enc[:, 1::2])) + self.weight = nn.Parameter(out, requires_grad=False) + + @torch.no_grad() + def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0) -> torch.Tensor: + """`input_ids_shape` is expected to be [bsz x seqlen].""" + bsz, seq_len = input_ids_shape[:2] + positions = torch.arange( + past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device + ) + return super().forward(positions) + + +# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->Pegasus +class PegasusAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + is_causal: bool = False, + config: Optional[PegasusConfig] = None, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + self.config = config + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + self.is_causal = is_causal + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.reshape(*proj_shape) + value_states = value_states.reshape(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value + + +PEGASUS_ATTENTION_CLASSES = {"eager": PegasusAttention} + + +# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Pegasus, MBART->PEGASUS +class PegasusEncoderLayer(nn.Module): + def __init__(self, config: PegasusConfig): + super().__init__() + self.embed_dim = config.d_model + + self.self_attn = PEGASUS_ATTENTION_CLASSES[config._attn_implementation]( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + config=config, + ) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + layer_head_mask: torch.Tensor, + output_attentions: bool = False, + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states, attn_weights, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + if hidden_states.dtype == torch.float16: + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +# Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Pegasus, MBART->PEGASUS +class PegasusDecoderLayer(nn.Module): + def __init__(self, config: PegasusConfig): + super().__init__() + self.embed_dim = config.d_model + + self.self_attn = PEGASUS_ATTENTION_CLASSES[config._attn_implementation]( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + is_causal=True, + config=config, + ) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.encoder_attn = PEGASUS_ATTENTION_CLASSES[config._attn_implementation]( + self.embed_dim, + config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + config=config, + ) + self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) + self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + cross_attn_layer_head_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = True, + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + encoder_hidden_states (`torch.FloatTensor`): + cross attention input to the layer of shape `(batch, seq_len, embed_dim)` + encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of + size `(decoder_attention_heads,)`. + past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + # add present self-attn cache to positions 1,2 of present_key_value tuple + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=self_attn_past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + # Cross-Attention Block + cross_attn_present_key_value = None + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=cross_attn_past_key_value, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + # add cross-attn to positions 3,4 of present_key_value tuple + present_key_value = present_key_value + cross_attn_present_key_value + + # Fully Connected + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class PegasusPreTrainedModel(PreTrainedModel): + config_class = PegasusConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + std = self.config.init_std + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, PegasusSinusoidalPositionalEmbedding): + module._init_weight() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +PEGASUS_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`PegasusConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +PEGASUS_GENERATION_EXAMPLE = r""" + Summarization example: + + ```python + >>> from transformers import AutoTokenizer, PegasusForConditionalGeneration + + >>> model = PegasusForConditionalGeneration.from_pretrained("google/pegasus-xsum") + >>> tokenizer = AutoTokenizer.from_pretrained("google/pegasus-xsum") + + >>> ARTICLE_TO_SUMMARIZE = ( + ... "PG&E stated it scheduled the blackouts in response to forecasts for high winds " + ... "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were " + ... "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow." + ... ) + >>> inputs = tokenizer(ARTICLE_TO_SUMMARIZE, max_length=1024, return_tensors="pt") + + >>> # Generate Summary + >>> summary_ids = model.generate(inputs["input_ids"]) + >>> tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "California's largest electricity provider has turned off power to hundreds of thousands of customers." + ``` +""" + +PEGASUS_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + Pegasus uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If + `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + decoder_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0, + 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): + Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of + hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded + representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be + input (see `past_key_values`). This is useful if you want more control over how to convert + `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. + + If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value + of `inputs_embeds`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class PegasusEncoder(PegasusPreTrainedModel): + """ + Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a + [`PegasusEncoderLayer`]. + + Args: + config: PegasusConfig + embed_tokens (nn.Embedding): output embedding + """ + + def __init__(self, config: PegasusConfig, embed_tokens: Optional[nn.Embedding] = None): + super().__init__(config) + + self.dropout = config.dropout + self.layerdrop = config.encoder_layerdrop + + embed_dim = config.d_model + self.padding_idx = config.pad_token_id + self.max_source_positions = config.max_position_embeddings + self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + + if embed_tokens is not None: + self.embed_tokens = embed_tokens + else: + self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx) + + self.embed_positions = PegasusSinusoidalPositionalEmbedding( + config.max_position_embeddings, + embed_dim, + self.padding_idx, + ) + self.layers = nn.ModuleList([PegasusEncoderLayer(config) for _ in range(config.encoder_layers)]) + self.layer_norm = nn.LayerNorm(config.d_model) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def resize_position_embeddings(self, new_num_position_embeddings: int): + """ + Resizes position embeddings matrix of the model if `new_num_position_embeddings != + config.max_position_embeddings`. + + Arguments: + new_num_position_embeddings (`int`): + The number of new position embeddings. If position embeddings are learned, increasing the size will add + newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If + position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will + add correct vectors at the end following the position encoding algorithm, whereas reducing the size + will remove vectors from the end. + """ + logger.info(f"Setting `config.max_position_embeddings={new_num_position_embeddings}`...") + self.config.max_position_embeddings = new_num_position_embeddings + + self.embed_positions = PegasusSinusoidalPositionalEmbedding( + self.config.max_position_embeddings, + self.config.d_model, + self.padding_idx, + ) + self.embed_positions._init_weight() + self.embed_positions.to(self.device) + + def get_position_embeddings(self) -> nn.Embedding: + """ + Returns the position embeddings matrix + """ + return self.embed_positions + + def forward( + self, + input_ids=None, + attention_mask=None, + head_mask=None, + inputs_embeds=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + + embed_pos = self.embed_positions(input_shape) + + hidden_states = inputs_embeds + embed_pos + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + # expand attention_mask + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + if head_mask.size()[0] != len(self.layers): + raise ValueError( + f"The head_mask should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + to_drop = False + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: # skip the layer + to_drop = True + + if to_drop: + layer_outputs = (None, None) + else: + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + attention_mask, + (head_mask[idx] if head_mask is not None else None), + output_attentions, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + hidden_states = self.layer_norm(hidden_states) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +class PegasusDecoder(PegasusPreTrainedModel): + """ + Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`PegasusDecoderLayer`] + + Args: + config: PegasusConfig + embed_tokens (nn.Embedding): output embedding + """ + + def __init__(self, config: PegasusConfig, embed_tokens: Optional[nn.Embedding] = None): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.decoder_layerdrop + self.padding_idx = config.pad_token_id + self.max_target_positions = config.max_position_embeddings + self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + + if embed_tokens is not None: + self.embed_tokens = embed_tokens + else: + self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) + + self.embed_positions = PegasusSinusoidalPositionalEmbedding( + config.max_position_embeddings, + config.d_model, + self.padding_idx, + ) + self.layers = nn.ModuleList([PegasusDecoderLayer(config) for _ in range(config.decoder_layers)]) + self.layer_norm = nn.LayerNorm(config.d_model) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def resize_position_embeddings(self, new_num_position_embeddings: int): + """ + Resizes position embeddings matrix of the model if `new_num_position_embeddings != + config.max_position_embeddings`. + + Arguments: + new_num_position_embeddings (`int`): + The number of new position embeddings. If position embeddings are learned, increasing the size will add + newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If + position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will + add correct vectors at the end following the position encoding algorithm, whereas reducing the size + will remove vectors from the end. + """ + logger.info(f"Setting `config.max_position_embeddings={new_num_position_embeddings}`...") + self.config.max_position_embeddings = new_num_position_embeddings + + self.embed_positions = PegasusSinusoidalPositionalEmbedding( + self.config.max_position_embeddings, + self.config.d_model, + self.padding_idx, + ) + self.embed_positions._init_weight() + self.embed_positions.to(self.device) + + def get_position_embeddings(self) -> nn.Embedding: + """ + Returns the position embeddings matrix + """ + return self.embed_positions + + def forward( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + head_mask=None, + cross_attn_head_mask=None, + past_key_values=None, + inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): + Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values + selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in decoder to avoid performing + cross-attention on hidden heads. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + # embed positions + positions = self.embed_positions(input_shape, past_key_values_length) + + hidden_states = inputs_embeds + positions + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + next_decoder_cache = () if use_cache else None + + # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): + if attn_mask is not None: + if attn_mask.size()[0] != len(self.layers): + raise ValueError( + f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: + continue + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + head_mask[idx] if head_mask is not None else None, + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, + None, + output_attentions, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=( + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None + ), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + hidden_states = self.layer_norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + +@add_start_docstrings( + "The bare PEGASUS Model outputting raw hidden-states without any specific head on top.", + PEGASUS_START_DOCSTRING, +) +class PegasusModel(PegasusPreTrainedModel): + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + + def __init__(self, config: PegasusConfig): + super().__init__(config) + + padding_idx, vocab_size = config.pad_token_id, config.vocab_size + self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx) + + self.encoder = PegasusEncoder(config, self.shared) + self.decoder = PegasusDecoder(config, self.shared) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, value): + self.shared = value + self.encoder.embed_tokens = self.shared + self.decoder.embed_tokens = self.shared + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + def resize_position_embeddings(self, new_num_position_embeddings: int): + """ + Resizes position embeddings matrix of the model if `new_num_position_embeddings != + config.max_position_embeddings`. + + Arguments: + new_num_position_embeddings (`int`): + The number of new position embeddings. If position embeddings are learned, increasing the size will add + newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If + position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will + add correct vectors at the end following the position encoding algorithm, whereas reducing the size + will remove vectors from the end. + """ + self.config.max_position_embeddings = new_num_position_embeddings + self.encoder.resize_position_embeddings(new_num_position_embeddings) + self.decoder.resize_position_embeddings(new_num_position_embeddings) + + def get_position_embeddings(self) -> Tuple[nn.Embedding]: + """ + Returns the position embeddings matrix + """ + return (self.encoder.get_position_embeddings(), self.decoder.get_position_embeddings()) + + @add_start_docstrings_to_model_forward(PEGASUS_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.Tensor] = None, + decoder_attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None, + past_key_values: Optional[Tuple[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + decoder_inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, Seq2SeqModelOutput]: + r""" + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, PegasusModel + + >>> tokenizer = AutoTokenizer.from_pretrained("google/pegasus-large") + >>> model = PegasusModel.from_pretrained("google/pegasus-large") + + >>> inputs = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="pt") + >>> decoder_inputs = tokenizer("Studies show that", return_tensors="pt") + >>> outputs = model(input_ids=inputs.input_ids, decoder_input_ids=decoder_inputs.input_ids) + + >>> last_hidden_states = outputs.last_hidden_state + >>> list(last_hidden_states.shape) + [1, 4, 1024] + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + "The PEGASUS Model with a language modeling head. Can be used for summarization.", PEGASUS_START_DOCSTRING +) +class PegasusForConditionalGeneration(PegasusPreTrainedModel, GenerationMixin): + base_model_prefix = "model" + _keys_to_ignore_on_load_missing = ["final_logits_bias"] + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] + + def __init__(self, config: PegasusConfig): + super().__init__(config) + self.model = PegasusModel(config) + self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings))) + self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_encoder(self): + return self.model.get_encoder() + + def get_decoder(self): + return self.model.get_decoder() + + def resize_token_embeddings( + self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True + ) -> nn.Embedding: + new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing) + self._resize_final_logits_bias(new_embeddings.weight.shape[0]) + return new_embeddings + + def _resize_final_logits_bias(self, new_num_tokens: int) -> None: + old_num_tokens = self.final_logits_bias.shape[-1] + if new_num_tokens <= old_num_tokens: + new_bias = self.final_logits_bias[:, :new_num_tokens] + else: + extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device) + new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1) + self.register_buffer("final_logits_bias", new_bias) + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def resize_position_embeddings(self, new_num_position_embeddings: int): + """ + Resizes position embeddings matrix of the model if `new_num_position_embeddings != + config.max_position_embeddings`. + + Arguments: + new_num_position_embeddings (`int`): + The number of new position embeddings. If position embeddings are learned, increasing the size will add + newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If + position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will + add correct vectors at the end following the position encoding algorithm, whereas reducing the size + will remove vectors from the end. + """ + self.config.max_position_embeddings = new_num_position_embeddings + self.model.encoder.resize_position_embeddings(new_num_position_embeddings) + self.model.decoder.resize_position_embeddings(new_num_position_embeddings) + + def get_position_embeddings(self) -> Tuple[nn.Embedding]: + """ + Returns the position embeddings matrix + """ + return (self.model.encoder.get_position_embeddings(), self.model.decoder.get_position_embeddings()) + + @add_start_docstrings_to_model_forward(PEGASUS_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + @add_end_docstrings(PEGASUS_GENERATION_EXAMPLE) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.Tensor] = None, + decoder_attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None, + past_key_values: Optional[Tuple[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + decoder_inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, Seq2SeqLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None: + if use_cache: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + encoder_outputs=encoder_outputs, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (lm_logits,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return Seq2SeqLMOutput( + loss=masked_lm_loss, + logits=lm_logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + # cached cross_attention states don't have to be reordered -> they are always the same + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) + + layer_past[2:], + ) + return reordered_past + + +# Copied from transformers.models.bart.modeling_bart.BartDecoderWrapper with Bart->Pegasus +class PegasusDecoderWrapper(PegasusPreTrainedModel): + """ + This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is + used in combination with the [`EncoderDecoderModel`] framework. + """ + + def __init__(self, config): + super().__init__(config) + self.decoder = PegasusDecoder(config) + + def forward(self, *args, **kwargs): + return self.decoder(*args, **kwargs) + + +class PegasusForCausalLM(PegasusPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + config = copy.deepcopy(config) + config.is_decoder = True + config.is_encoder_decoder = False + super().__init__(config) + self.model = PegasusDecoderWrapper(config) + + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.model.decoder.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model.decoder = decoder + + def get_decoder(self): + return self.model.decoder + + def get_position_embeddings(self) -> nn.Embedding: + """ + Returns the position embeddings matrix + """ + return self.model.decoder.get_position_embeddings() + + def resize_position_embeddings(self, new_num_position_embeddings: int): + """ + Resizes position embeddings matrix of the model if `new_num_position_embeddings != + config.max_position_embeddings`. + + Arguments: + new_num_position_embeddings (`int`): + The number of new position embeddings. If position embeddings are learned, increasing the size will add + newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If + position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will + add correct vectors at the end following the position encoding algorithm, whereas reducing the size + will remove vectors from the end. + """ + self.config.max_position_embeddings = new_num_position_embeddings + self.model.decoder.resize_position_embeddings(new_num_position_embeddings) + + @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) + # Copied from transformers.models.bart.modeling_bart.BartForCausalLM.forward with Bart->Pegasus, facebook/bart-base->google/pegasus-large + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + if the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used + in the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional + tensors are only required when the model is used as a decoder in a Sequence to Sequence model. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, PegasusForCausalLM + + >>> tokenizer = AutoTokenizer.from_pretrained("google/pegasus-large") + >>> model = PegasusForCausalLM.from_pretrained("google/pegasus-large", add_cross_attention=False) + >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder." + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> logits = outputs.logits + >>> expected_shape = [1, inputs.input_ids.shape[-1], model.config.vocab_size] + >>> list(logits.shape) == expected_shape + True + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + logits = self.lm_head(outputs[0]) + + loss = None + if labels is not None: + labels = labels.to(logits.device) + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + + +__all__ = ["PegasusForCausalLM", "PegasusForConditionalGeneration", "PegasusModel", "PegasusPreTrainedModel"] diff --git a/docs/transformers/src/transformers/models/pegasus/modeling_tf_pegasus.py b/docs/transformers/src/transformers/models/pegasus/modeling_tf_pegasus.py new file mode 100644 index 0000000000000000000000000000000000000000..a51835dcfa447f889913b3de7506d38af690b741 --- /dev/null +++ b/docs/transformers/src/transformers/models/pegasus/modeling_tf_pegasus.py @@ -0,0 +1,1574 @@ +# coding=utf-8 +# Copyright 2021, Google Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TF 2.0 Pegasus model.""" + +from __future__ import annotations + +import random +from typing import Optional, Tuple, Union + +import numpy as np +import tensorflow as tf + +from ...activations_tf import get_tf_activation +from ...modeling_tf_outputs import ( + TFBaseModelOutput, + TFBaseModelOutputWithPastAndCrossAttentions, + TFSeq2SeqLMOutput, + TFSeq2SeqModelOutput, +) + +# Public API +from ...modeling_tf_utils import ( + TFCausalLanguageModelingLoss, + TFModelInputType, + TFPreTrainedModel, + keras, + keras_serializable, + unpack_inputs, +) +from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax +from ...utils import ( + add_code_sample_docstrings, + add_end_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from .configuration_pegasus import PegasusConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "google/pegasus-large" +_CONFIG_FOR_DOC = "PegasusConfig" + + +LARGE_NEGATIVE = -1e8 + + +# Copied from transformers.models.bart.modeling_tf_bart.shift_tokens_right +def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int): + pad_token_id = tf.cast(pad_token_id, input_ids.dtype) + decoder_start_token_id = tf.cast(decoder_start_token_id, input_ids.dtype) + start_tokens = tf.fill( + (shape_list(input_ids)[0], 1), tf.convert_to_tensor(decoder_start_token_id, input_ids.dtype) + ) + shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1) + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids = tf.where( + shifted_input_ids == -100, + tf.fill(shape_list(shifted_input_ids), tf.convert_to_tensor(pad_token_id, input_ids.dtype)), + shifted_input_ids, + ) + + # "Verify that `labels` has only positive values and -100" + assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype)) + + # Make sure the assertion op is called by wrapping the result in an identity no-op + with tf.control_dependencies([assert_gte0]): + shifted_input_ids = tf.identity(shifted_input_ids) + + return shifted_input_ids + + +# Copied from transformers.models.bart.modeling_tf_bart._make_causal_mask +def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: int = 0): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz = input_ids_shape[0] + tgt_len = input_ids_shape[1] + mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE + mask_cond = tf.range(shape_list(mask)[-1]) + + mask = tf.where(mask_cond < tf.reshape(mask_cond + 1, (shape_list(mask)[-1], 1)), 0.0, mask) + + if past_key_values_length > 0: + mask = tf.concat([tf.zeros((tgt_len, past_key_values_length)), mask], axis=-1) + + return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1)) + + +# Copied from transformers.models.bart.modeling_tf_bart._expand_mask +def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + src_len = shape_list(mask)[1] + tgt_len = tgt_len if tgt_len is not None else src_len + one_cst = tf.constant(1.0) + mask = tf.cast(mask, dtype=one_cst.dtype) + expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1)) + + return (one_cst - expanded_mask) * LARGE_NEGATIVE + + +# Copied from transformers.models.marian.modeling_tf_marian.TFMarianSinusoidalPositionalEmbedding with Marian->Pegasus +class TFPegasusSinusoidalPositionalEmbedding(keras.layers.Layer): + """This module produces sinusoidal positional embeddings of any length.""" + + def __init__(self, num_positions: int, embedding_dim: int, **kwargs): + super().__init__(**kwargs) + + if embedding_dim % 2 != 0: + raise NotImplementedError(f"odd embedding_dim {embedding_dim} not supported") + + self.embedding_dim = embedding_dim + self.num_positions = num_positions + + def build(self, input_shape: tf.TensorShape): + """ + Build shared token embedding layer Shared weights logic adapted from + https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24 + """ + + weight = self._init_weight(self.num_positions, self.embedding_dim) + + self.weight = self.add_weight( + name="embeddings", + shape=[self.num_positions, self.embedding_dim], + ) + weight = tf.cast(weight, dtype=self.weight.dtype) + + self.weight.assign(weight) + + super().build(input_shape) + + @staticmethod + def _init_weight(n_pos: int, dim: int): + """ + Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in + the 2nd half of the vector. [dim // 2:] + """ + position_enc = np.array( + [[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)] + ) + table = np.zeros_like(position_enc) + # index 0 is all zero + table[:, 0 : dim // 2] = np.sin(position_enc[:, 0::2]) + table[:, dim // 2 :] = np.cos(position_enc[:, 1::2]) + # convert to tensor + table = tf.convert_to_tensor(table) + tf.stop_gradient(table) + return table + + def call( + self, input_shape: tf.TensorShape, past_key_values_length: int = 0, position_ids: tf.Tensor | None = None + ): + """Input is expected to be of size [bsz x seqlen].""" + if position_ids is None: + seq_len = input_shape[1] + position_ids = tf.range(past_key_values_length, seq_len + past_key_values_length, delta=1, name="range") + return tf.gather(self.weight, position_ids) + + +# Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention with Bart->Pegasus +class TFPegasusAttention(keras.layers.Layer): + """Multi-headed attention from "Attention Is All You Need""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + **kwargs, + ): + super().__init__(**kwargs) + self.embed_dim = embed_dim + + self.num_heads = num_heads + self.dropout = keras.layers.Dropout(dropout) + self.head_dim = embed_dim // num_heads + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + + self.k_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="k_proj") + self.q_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="q_proj") + self.v_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="v_proj") + self.out_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="out_proj") + + def _shape(self, tensor: tf.Tensor, seq_len: int, bsz: int): + return tf.transpose(tf.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)), (0, 2, 1, 3)) + + def call( + self, + hidden_states: tf.Tensor, + key_value_states: tf.Tensor | None = None, + past_key_value: Tuple[Tuple[tf.Tensor]] | None = None, + attention_mask: tf.Tensor | None = None, + layer_head_mask: tf.Tensor | None = None, + training: Optional[bool] = False, + ) -> Tuple[tf.Tensor, tf.Tensor | None]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + bsz, tgt_len, embed_dim = shape_list(hidden_states) + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = tf.concat([past_key_value[0], key_states], axis=2) + value_states = tf.concat([past_key_value[1], value_states], axis=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = tf.reshape(self._shape(query_states, tgt_len, bsz), proj_shape) + key_states = tf.reshape(key_states, proj_shape) + value_states = tf.reshape(value_states, proj_shape) + + src_len = shape_list(key_states)[1] + attn_weights = tf.matmul(query_states, key_states, transpose_b=True) + + tf.debugging.assert_equal( + shape_list(attn_weights), + [bsz * self.num_heads, tgt_len, src_len], + message=( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {shape_list(attn_weights)}" + ), + ) + + if attention_mask is not None: + tf.debugging.assert_equal( + shape_list(attention_mask), + [bsz, 1, tgt_len, src_len], + message=( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" + f" {shape_list(attention_mask)}" + ), + ) + + attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype) + attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask + attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) + + attn_weights = stable_softmax(attn_weights, axis=-1) + + if layer_head_mask is not None: + tf.debugging.assert_equal( + shape_list(layer_head_mask), + [self.num_heads], + message=( + f"Head mask for a single layer should be of size {(self.num_heads)}, but is" + f" {shape_list(layer_head_mask)}" + ), + ) + + attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape( + attn_weights, (bsz, self.num_heads, tgt_len, src_len) + ) + attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) + + attn_probs = self.dropout(attn_weights, training=training) + attn_output = tf.matmul(attn_probs, value_states) + + tf.debugging.assert_equal( + shape_list(attn_output), + [bsz * self.num_heads, tgt_len, self.head_dim], + message=( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {shape_list(attn_output)}" + ), + ) + + attn_output = tf.transpose( + tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3) + ) + attn_output = tf.reshape(attn_output, (bsz, tgt_len, embed_dim)) + + attn_output = self.out_proj(attn_output) + attn_weights: tf.Tensor = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + + return attn_output, attn_weights, past_key_value + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "k_proj", None) is not None: + with tf.name_scope(self.k_proj.name): + self.k_proj.build([None, None, self.embed_dim]) + if getattr(self, "q_proj", None) is not None: + with tf.name_scope(self.q_proj.name): + self.q_proj.build([None, None, self.embed_dim]) + if getattr(self, "v_proj", None) is not None: + with tf.name_scope(self.v_proj.name): + self.v_proj.build([None, None, self.embed_dim]) + if getattr(self, "out_proj", None) is not None: + with tf.name_scope(self.out_proj.name): + self.out_proj.build([None, None, self.embed_dim]) + + +# Copied from transformers.models.mbart.modeling_tf_mbart.TFMBartEncoderLayer with MBart->Pegasus +class TFPegasusEncoderLayer(keras.layers.Layer): + def __init__(self, config: PegasusConfig, **kwargs): + super().__init__(**kwargs) + self.embed_dim = config.d_model + self.self_attn = TFPegasusAttention( + self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout, name="self_attn" + ) + self.self_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm") + self.dropout = keras.layers.Dropout(config.dropout) + self.activation_fn = get_tf_activation(config.activation_function) + self.activation_dropout = keras.layers.Dropout(config.activation_dropout) + self.fc1 = keras.layers.Dense(config.encoder_ffn_dim, name="fc1") + self.fc2 = keras.layers.Dense(self.embed_dim, name="fc2") + self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm") + self.config = config + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: tf.Tensor, + layer_head_mask: tf.Tensor, + training: Optional[bool] = False, + ): + """ + Args: + hidden_states (`tf.Tensor`): input to the layer of shape *(batch, seq_len, embed_dim)* + attention_mask (`tf.Tensor`): attention mask of size + *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values. + layer_head_mask (`tf.Tensor`): mask for attention heads in a given layer of size + *(encoder_attention_heads,)* + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states, self_attn_weights, _ = self.self_attn( + hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask + ) + + tf.debugging.assert_equal( + shape_list(hidden_states), + shape_list(residual), + message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}", + ) + + hidden_states = self.dropout(hidden_states, training=training) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = self.activation_dropout(hidden_states, training=training) + hidden_states = self.fc2(hidden_states) + hidden_states = self.dropout(hidden_states, training=training) + hidden_states = residual + hidden_states + + return hidden_states, self_attn_weights + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "self_attn", None) is not None: + with tf.name_scope(self.self_attn.name): + self.self_attn.build(None) + if getattr(self, "self_attn_layer_norm", None) is not None: + with tf.name_scope(self.self_attn_layer_norm.name): + self.self_attn_layer_norm.build([None, None, self.embed_dim]) + if getattr(self, "fc1", None) is not None: + with tf.name_scope(self.fc1.name): + self.fc1.build([None, None, self.embed_dim]) + if getattr(self, "fc2", None) is not None: + with tf.name_scope(self.fc2.name): + self.fc2.build([None, None, self.config.encoder_ffn_dim]) + if getattr(self, "final_layer_norm", None) is not None: + with tf.name_scope(self.final_layer_norm.name): + self.final_layer_norm.build([None, None, self.embed_dim]) + + +# Copied from transformers.models.mbart.modeling_tf_mbart.TFMBartDecoderLayer with MBart->Pegasus +class TFPegasusDecoderLayer(keras.layers.Layer): + def __init__(self, config: PegasusConfig, **kwargs): + super().__init__(**kwargs) + self.embed_dim = config.d_model + self.self_attn = TFPegasusAttention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + name="self_attn", + is_decoder=True, + ) + self.dropout = keras.layers.Dropout(config.dropout) + self.activation_fn = get_tf_activation(config.activation_function) + self.activation_dropout = keras.layers.Dropout(config.activation_dropout) + + self.self_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm") + self.encoder_attn = TFPegasusAttention( + self.embed_dim, + config.decoder_attention_heads, + dropout=config.attention_dropout, + name="encoder_attn", + is_decoder=True, + ) + self.encoder_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="encoder_attn_layer_norm") + self.fc1 = keras.layers.Dense(config.decoder_ffn_dim, name="fc1") + self.fc2 = keras.layers.Dense(self.embed_dim, name="fc2") + self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm") + self.config = config + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: tf.Tensor | None = None, + encoder_hidden_states: tf.Tensor | None = None, + encoder_attention_mask: tf.Tensor | None = None, + layer_head_mask: tf.Tensor | None = None, + cross_attn_layer_head_mask: tf.Tensor | None = None, + past_key_value: Tuple[tf.Tensor] | None = None, + training: Optional[bool] = False, + ) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]: + """ + Args: + hidden_states (`tf.Tensor`): input to the layer of shape *(batch, seq_len, embed_dim)* + attention_mask (`tf.Tensor`): attention mask of size + *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values. + encoder_hidden_states (`tf.Tensor`): + cross attention input to the layer of shape *(batch, seq_len, embed_dim)* + encoder_attention_mask (`tf.Tensor`): encoder attention mask of size + *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values. + layer_head_mask (`tf.Tensor`): mask for attention heads in a given layer of size + *(decoder_attention_heads,)* + cross_attn_layer_head_mask (`tf.Tensor`): mask for heads of the cross-attention module. + *(decoder_attention_heads,)* + past_key_value (`Tuple(tf.Tensor)`): cached past key and value projection states + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + # add present self-attn cache to positions 1,2 of present_key_value tuple + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=self_attn_past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + ) + hidden_states = self.dropout(hidden_states, training=training) + hidden_states = residual + hidden_states + + # Cross-Attention Block + cross_attn_present_key_value = None + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=cross_attn_past_key_value, + ) + hidden_states = self.dropout(hidden_states, training=training) + hidden_states = residual + hidden_states + + # add cross-attn to positions 3,4 of present_key_value tuple + present_key_value = present_key_value + cross_attn_present_key_value + + # Fully Connected + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = self.activation_dropout(hidden_states, training=training) + hidden_states = self.fc2(hidden_states) + hidden_states = self.dropout(hidden_states, training=training) + hidden_states = residual + hidden_states + + return ( + hidden_states, + self_attn_weights, + cross_attn_weights, + present_key_value, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "self_attn", None) is not None: + with tf.name_scope(self.self_attn.name): + self.self_attn.build(None) + if getattr(self, "self_attn_layer_norm", None) is not None: + with tf.name_scope(self.self_attn_layer_norm.name): + self.self_attn_layer_norm.build([None, None, self.embed_dim]) + if getattr(self, "encoder_attn", None) is not None: + with tf.name_scope(self.encoder_attn.name): + self.encoder_attn.build(None) + if getattr(self, "encoder_attn_layer_norm", None) is not None: + with tf.name_scope(self.encoder_attn_layer_norm.name): + self.encoder_attn_layer_norm.build([None, None, self.embed_dim]) + if getattr(self, "fc1", None) is not None: + with tf.name_scope(self.fc1.name): + self.fc1.build([None, None, self.embed_dim]) + if getattr(self, "fc2", None) is not None: + with tf.name_scope(self.fc2.name): + self.fc2.build([None, None, self.config.decoder_ffn_dim]) + if getattr(self, "final_layer_norm", None) is not None: + with tf.name_scope(self.final_layer_norm.name): + self.final_layer_norm.build([None, None, self.embed_dim]) + + +class TFPegasusPreTrainedModel(TFPreTrainedModel): + config_class = PegasusConfig + base_model_prefix = "model" + + +PEGASUS_START_DOCSTRING = r""" + This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it + as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and + behavior. + + + + TensorFlow models and layers in `transformers` accept two formats as input: + + - having all inputs as keyword arguments (like PyTorch models), or + - having all inputs as a list, tuple or dict in the first positional argument. + + The reason the second format is supported is that Keras methods prefer this format when passing inputs to models + and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just + pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second + format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with + the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first + positional argument: + + - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` + - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: + `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` + - a dictionary with one or several input Tensors associated to the input names given in the docstring: + `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` + + Note that when creating models and layers with + [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry + about any of this, as you can just pass inputs like you would to any other Python function! + + + + Args: + config ([`PegasusConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. +""" + +PEGASUS_GENERATION_EXAMPLE = r""" + Summarization example: + + ```python + >>> from transformers import AutoTokenizer, TFPegasusForConditionalGeneration + + >>> model = TFPegasusForConditionalGeneration.from_pretrained("google/pegasus-xsum") + >>> tokenizer = AutoTokenizer.from_pretrained("google/pegasus-xsum") + + >>> ARTICLE_TO_SUMMARIZE = ( + ... "PG&E stated it scheduled the blackouts in response to forecasts for high winds " + ... "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were " + ... "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow." + ... ) + >>> inputs = tokenizer(ARTICLE_TO_SUMMARIZE, max_length=1024, return_tensors="tf") + + >>> # Generate Summary + >>> summary_ids = model.generate(input_ids) + >>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)) + ``` +""" + +PEGASUS_INPUTS_DOCSTRING = r""" + Args: + input_ids (`tf.Tensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`tf.Tensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_input_ids (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + Pegasus uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If + `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + decoder_attention_mask (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*): + will be made by default and ignore pad tokens. It is not recommended to set this for most use cases. + decoder_position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the + range `[0, config.max_position_embeddings - 1]`. + head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + decoder_head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + encoder_outputs (`tf.FloatTensor`, *optional*): + hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. + of shape `(batch_size, sequence_length, hidden_size)` is a sequence of + past_key_values (`Tuple[Tuple[tf.Tensor]]` of length `config.n_layers`) + contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + use_cache (`bool`, *optional*, defaults to `True`): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). Set to `False` during training, `True` during generation output_attentions (`bool`, + *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` + under returned tensors for more detail. This argument can be used only in eager mode, in graph mode the + value in the config will be used instead. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the + config will be used instead. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. This argument can be used only in eager mode, in graph mode the value in the config will be + used instead. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in + eager mode, in graph mode the value will always be set to True. + training (`bool`, *optional*, defaults to `False`): + Whether or not to use the model in training mode (some modules like dropout modules have different + behaviors between training and evaluation). +""" + + +@keras_serializable +class TFPegasusEncoder(keras.layers.Layer): + config_class = PegasusConfig + """ + Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a + [`TFPegasusEncoderLayer`]. + + Args: + config: PegasusConfig + """ + + def __init__(self, config: PegasusConfig, embed_tokens: Optional[keras.layers.Embedding] = None, **kwargs): + super().__init__(**kwargs) + self.config = config + self.dropout = keras.layers.Dropout(config.dropout) + self.layerdrop = config.encoder_layerdrop + self.padding_idx = config.pad_token_id + self.max_source_positions = config.max_position_embeddings + self.embed_scale = tf.math.sqrt(float(config.d_model)) if config.scale_embedding else 1.0 + + self.embed_tokens = embed_tokens + self.embed_positions = TFPegasusSinusoidalPositionalEmbedding( + config.max_position_embeddings, + config.d_model, + name="embed_positions", + ) + self.layers = [TFPegasusEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)] + self.layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm") + + def get_embed_tokens(self): + return self.embed_tokens + + def set_embed_tokens(self, embed_tokens): + self.embed_tokens = embed_tokens + + @unpack_inputs + def call( + self, + input_ids: tf.Tensor | None = None, + inputs_embeds: tf.Tensor | None = None, + attention_mask: tf.Tensor | None = None, + head_mask: tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: Optional[bool] = False, + ): + """ + Args: + input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, `optional): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. This argument can be used only in eager mode, in graph mode the value + in the config will be used instead. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. This argument can be used only in eager mode, in graph mode the value in the config + will be used instead. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used + in eager mode, in graph mode the value will always be set to True. + training (`bool`, *optional*, defaults to `False`): + Whether or not to use the model in training mode (some modules like dropout modules have different + behaviors between training and evaluation). + """ + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = shape_list(input_ids) + elif inputs_embeds is not None: + input_shape = shape_list(inputs_embeds)[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim) + inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + + embed_pos = self.embed_positions(input_shape) + hidden_states = inputs_embeds + embed_pos + hidden_states = self.dropout(hidden_states, training=training) + + # check attention mask and invert + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _expand_mask(attention_mask) + else: + attention_mask = None + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + tf.debugging.assert_equal( + shape_list(head_mask)[0], + len(self.layers), + message=( + f"The head_mask should be specified for {len(self.layers)} layers, but it is for" + f" {shape_list(head_mask)[0]}." + ), + ) + + # encoder layers + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = random.uniform(0, 1) + if training and (dropout_probability < self.layerdrop): # skip the layer + continue + + hidden_states, attn = encoder_layer( + hidden_states, + attention_mask, + head_mask[idx] if head_mask is not None else None, + ) + + if output_attentions: + all_attentions += (attn,) + + hidden_states = self.layer_norm(hidden_states) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return TFBaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "embed_positions", None) is not None: + with tf.name_scope(self.embed_positions.name): + self.embed_positions.build(None) + if getattr(self, "layer_norm", None) is not None: + with tf.name_scope(self.layer_norm.name): + self.layer_norm.build([None, None, self.config.d_model]) + if getattr(self, "layers", None) is not None: + for layer in self.layers: + with tf.name_scope(layer.name): + layer.build(None) + + +@keras_serializable +class TFPegasusDecoder(keras.layers.Layer): + config_class = PegasusConfig + """ + Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`TFPegasusDecoderLayer`] + + Args: + config: PegasusConfig + embed_tokens: output embedding + """ + + def __init__(self, config: PegasusConfig, embed_tokens: Optional[keras.layers.Embedding] = None, **kwargs): + super().__init__(**kwargs) + self.config = config + self.padding_idx = config.pad_token_id + self.embed_tokens = embed_tokens + self.layerdrop = config.decoder_layerdrop + self.embed_positions = TFPegasusSinusoidalPositionalEmbedding( + config.max_position_embeddings, + config.d_model, + name="embed_positions", + ) + self.embed_scale = tf.math.sqrt(float(config.d_model)) if config.scale_embedding else 1.0 + self.layers = [TFPegasusDecoderLayer(config, name=f"layers.{i}") for i in range(config.decoder_layers)] + self.layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm") + + self.dropout = keras.layers.Dropout(config.dropout) + + def get_embed_tokens(self): + return self.embed_tokens + + def set_embed_tokens(self, embed_tokens): + self.embed_tokens = embed_tokens + + @unpack_inputs + def call( + self, + input_ids: tf.Tensor | None = None, + inputs_embeds: tf.Tensor | None = None, + attention_mask: tf.Tensor | None = None, + position_ids: tf.Tensor | None = None, + encoder_hidden_states: tf.Tensor | None = None, + encoder_attention_mask: tf.Tensor | None = None, + head_mask: tf.Tensor | None = None, + cross_attn_head_mask: tf.Tensor | None = None, + past_key_values: Tuple[Tuple[tf.Tensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: Optional[bool] = False, + ): + r""" + Args: + input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the + range `[0, config.max_position_embeddings - 1]`. + encoder_hidden_states (`tf.Tensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + encoder_attention_mask (`tf.Tensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): + Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values + selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`Tuple[Tuple[tf.Tensor]]` of length `config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up + decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. This argument can be used only in eager mode, in graph mode the value + in the config will be used instead. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. This argument can be used only in eager mode, in graph mode the value in the config + will be used instead. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used + in eager mode, in graph mode the value will always be set to True. + training (`bool`, *optional*, defaults to `False`): + Whether or not to use the model in training mode (some modules like dropout modules have different + behaviors between training and evaluation). + """ + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input_shape = shape_list(input_ids) + elif inputs_embeds is not None: + input_shape = shape_list(inputs_embeds)[:-1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + past_key_values_length = shape_list(past_key_values[0][0])[2] if past_key_values is not None else 0 + + # embed positions + if position_ids is None: + positions = self.embed_positions(input_shape, past_key_values_length) + else: + positions = self.embed_positions(input_shape, position_ids=position_ids) + + if inputs_embeds is None: + check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim) + inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + + hidden_states = inputs_embeds + + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length) + else: + combined_attention_mask = _expand_mask( + tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1] + ) + + if attention_mask is not None: + combined_attention_mask = combined_attention_mask + _expand_mask(attention_mask, tgt_len=input_shape[-1]) + + if encoder_hidden_states is not None and encoder_attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _expand_mask(encoder_attention_mask, tgt_len=input_shape[-1]) + + hidden_states = self.dropout(hidden_states + positions, training=training) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attns = () if (output_attentions and encoder_hidden_states is not None) else None + present_key_values = () if use_cache else None + + # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired + for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]: + if attn_mask is not None: + tf.debugging.assert_equal( + shape_list(attn_mask)[0], + len(self.layers), + message=( + f"The {attn_mask_name} should be specified for {len(self.layers)} layers, but it is for" + f" {shape_list(attn_mask)[0]}." + ), + ) + + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + dropout_probability = random.uniform(0, 1) + + if training and (dropout_probability < self.layerdrop): + continue + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer( + hidden_states, + attention_mask=combined_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=head_mask[idx] if head_mask is not None else None, + cross_attn_layer_head_mask=cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, + past_key_value=past_key_value, + ) + + if use_cache: + present_key_values += (present_key_value,) + + if output_attentions: + all_self_attns += (layer_self_attn,) + + if encoder_hidden_states is not None: + all_cross_attns += (layer_cross_attn,) + + hidden_states = self.layer_norm(hidden_states) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns + else: + return TFBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=present_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attns, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "embed_positions", None) is not None: + with tf.name_scope(self.embed_positions.name): + self.embed_positions.build(None) + if getattr(self, "layer_norm", None) is not None: + with tf.name_scope(self.layer_norm.name): + self.layer_norm.build([None, None, self.config.d_model]) + if getattr(self, "layers", None) is not None: + for layer in self.layers: + with tf.name_scope(layer.name): + layer.build(None) + + +@keras_serializable +class TFPegasusMainLayer(keras.layers.Layer): + config_class = PegasusConfig + + def __init__(self, config: PegasusConfig, **kwargs): + super().__init__(**kwargs) + + self.config = config + self.shared = keras.layers.Embedding( + input_dim=config.vocab_size, + output_dim=config.d_model, + embeddings_initializer=keras.initializers.TruncatedNormal(stddev=self.config.init_std), + name="model.shared", + ) + # Additional attribute to specify the expected name scope of the layer (for loading/storing weights) + self.shared.load_weight_prefix = "model.shared" + + self.encoder = TFPegasusEncoder(config, self.shared, name="encoder") + self.decoder = TFPegasusDecoder(config, self.shared, name="decoder") + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.embed_tokens = self.shared + self.decoder.embed_tokens = self.shared + + @unpack_inputs + def call( + self, + input_ids: tf.Tensor | None = None, + attention_mask: tf.Tensor | None = None, + decoder_input_ids: tf.Tensor | None = None, + decoder_attention_mask: tf.Tensor | None = None, + decoder_position_ids: tf.Tensor | None = None, + head_mask: tf.Tensor | None = None, + decoder_head_mask: tf.Tensor | None = None, + cross_attn_head_mask: tf.Tensor | None = None, + encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, + past_key_values: Tuple[Tuple[tf.Tensor]] = None, + inputs_embeds: tf.Tensor | None = None, + decoder_inputs_embeds: tf.Tensor | None = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: Optional[bool] = False, + **kwargs, + ): + if decoder_input_ids is None and decoder_inputs_embeds is None: + use_cache = False + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a TFBaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, TFBaseModelOutput): + encoder_outputs = TFBaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + # If the user passed a TFBaseModelOutput for encoder_outputs, we wrap it in a tuple when return_dict=False + elif not return_dict and not isinstance(encoder_outputs, tuple): + encoder_outputs = encoder_outputs.to_tuple() + + decoder_outputs = self.decoder( + decoder_input_ids, + attention_mask=decoder_attention_mask, + position_ids=decoder_position_ids, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return TFSeq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + # The shared/tied weights expect to be in the model base namespace + # Adding "/" to the end (not the start!) of a tf.name_scope puts it in the root namespace rather than + # the current one. + with tf.name_scope(self.shared.load_weight_prefix + "/" + self.shared.name + "/"): + self.shared.build(None) + if getattr(self, "encoder", None) is not None: + with tf.name_scope(self.encoder.name): + self.encoder.build(None) + if getattr(self, "decoder", None) is not None: + with tf.name_scope(self.decoder.name): + self.decoder.build(None) + + +@add_start_docstrings( + "The bare PEGASUS Model outputting raw hidden-states without any specific head on top.", + PEGASUS_START_DOCSTRING, +) +class TFPegasusModel(TFPegasusPreTrainedModel): + def __init__(self, config: PegasusConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.model = TFPegasusMainLayer(config, name="model") + + def get_encoder(self): + return self.model.encoder + + def get_decoder(self): + return self.model.decoder + + @unpack_inputs + @add_start_docstrings_to_model_forward(PEGASUS_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFSeq2SeqModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + decoder_input_ids: np.ndarray | tf.Tensor | None = None, + decoder_attention_mask: np.ndarray | tf.Tensor | None = None, + decoder_position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + decoder_head_mask: np.ndarray | tf.Tensor | None = None, + cross_attn_head_mask: np.ndarray | tf.Tensor | None = None, + encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, + past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + decoder_inputs_embeds: np.ndarray | tf.Tensor | None = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: bool = False, + **kwargs, + ) -> Union[TFSeq2SeqModelOutput, Tuple[tf.Tensor]]: + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + decoder_position_ids=decoder_position_ids, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + encoder_outputs=encoder_outputs, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + return outputs + + # Copied from transformers.models.bart.modeling_tf_bart.TFBartModel.serving_output + def serving_output(self, output): + pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None + dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None + dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None + cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None + enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None + enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None + + return TFSeq2SeqModelOutput( + last_hidden_state=output.last_hidden_state, + past_key_values=pkv, + decoder_hidden_states=dec_hs, + decoder_attentions=dec_attns, + cross_attentions=cross_attns, + encoder_last_hidden_state=output.encoder_last_hidden_state, + encoder_hidden_states=enc_hs, + encoder_attentions=enc_attns, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "model", None) is not None: + with tf.name_scope(self.model.name): + self.model.build(None) + + +# Copied from transformers.models.bart.modeling_tf_bart.BiasLayer +class BiasLayer(keras.layers.Layer): + """ + Bias as a layer. It is used for serialization purposes: `keras.Model.save_weights` stores on a per-layer basis, + so all weights have to be registered in a layer. + """ + + def __init__(self, shape, initializer, trainable, name, **kwargs): + super().__init__(name=name, **kwargs) + # Note: the name of this variable will NOT be scoped when serialized, i.e. it will not be in the format of + # "outer_layer/inner_layer/.../name:0". Instead, it will be "name:0". For further details, see: + # https://github.com/huggingface/transformers/pull/18833#issuecomment-1233090214 + self.bias = self.add_weight(name=name, shape=shape, initializer=initializer, trainable=trainable) + + def call(self, x): + return x + self.bias + + +@add_start_docstrings( + "The PEGASUS Model with a language modeling head. Can be used for summarization.", + PEGASUS_START_DOCSTRING, +) +class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLanguageModelingLoss): + _keys_to_ignore_on_load_unexpected = [ + r"model.encoder.embed_tokens.weight", + r"model.decoder.embed_tokens.weight", + ] + + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.model = TFPegasusMainLayer(config, name="model") + self.use_cache = config.use_cache + # final_bias_logits is registered as a buffer in pytorch, so not trainable for the sake of consistency. + self.bias_layer = BiasLayer( + name="final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False + ) + + def get_decoder(self): + return self.model.decoder + + def get_encoder(self): + return self.model.encoder + + def get_output_embeddings(self): + return self.get_input_embeddings() + + def set_output_embeddings(self, value): + self.set_input_embeddings(value) + + def get_bias(self): + return {"final_logits_bias": self.bias_layer.bias} + + def set_bias(self, value): + # Replaces the existing layers containing bias for correct (de)serialization. + vocab_size = value["final_logits_bias"].shape[-1] + self.bias_layer = BiasLayer( + name="final_logits_bias", shape=[1, vocab_size], initializer="zeros", trainable=False + ) + self.bias_layer.bias.assign(value["final_logits_bias"]) + + @unpack_inputs + @add_start_docstrings_to_model_forward(PEGASUS_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + @add_end_docstrings(PEGASUS_GENERATION_EXAMPLE) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + decoder_input_ids: np.ndarray | tf.Tensor | None = None, + decoder_attention_mask: np.ndarray | tf.Tensor | None = None, + decoder_position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + decoder_head_mask: np.ndarray | tf.Tensor | None = None, + cross_attn_head_mask: np.ndarray | tf.Tensor | None = None, + encoder_outputs: Optional[TFBaseModelOutput] = None, + past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + decoder_inputs_embeds: np.ndarray | tf.Tensor | None = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: np.ndarray | tf.Tensor | None = None, + training: bool = False, + ) -> Union[TFSeq2SeqLMOutput, Tuple[tf.Tensor]]: + """ + labels (`tf.tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + """ + + if labels is not None: + labels = tf.where( + labels == self.config.pad_token_id, + tf.cast(tf.fill(shape_list(labels), -100), labels.dtype), + labels, + ) + use_cache = False + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + encoder_outputs=encoder_outputs, + decoder_attention_mask=decoder_attention_mask, + decoder_position_ids=decoder_position_ids, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + lm_logits = tf.matmul(outputs[0], self.model.shared.weights, transpose_b=True) + lm_logits = self.bias_layer(lm_logits) + masked_lm_loss = None if labels is None else self.hf_compute_loss(labels, lm_logits) + + if not return_dict: + output = (lm_logits,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + return TFSeq2SeqLMOutput( + loss=masked_lm_loss, + logits=lm_logits, + past_key_values=outputs.past_key_values, # index 1 of d outputs + decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs + decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs + cross_attentions=outputs.cross_attentions, # index 4 of d outputs + encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs + encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out + encoder_attentions=outputs.encoder_attentions, # 2 of e out + ) + + # Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.serving_output + def serving_output(self, output): + pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None + dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None + dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None + cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None + enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None + enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None + + return TFSeq2SeqLMOutput( + logits=output.logits, + past_key_values=pkv, + decoder_hidden_states=dec_hs, + decoder_attentions=dec_attns, + cross_attentions=cross_attns, + encoder_last_hidden_state=output.encoder_last_hidden_state, + encoder_hidden_states=enc_hs, + encoder_attentions=enc_attns, + ) + + # Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.prepare_inputs_for_generation + def prepare_inputs_for_generation( + self, + decoder_input_ids, + past_key_values=None, + attention_mask=None, + decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs, + ): + # cut decoder_input_ids if past_key_values is used + if past_key_values is not None: + decoder_input_ids = decoder_input_ids[:, -1:] + + if decoder_attention_mask is not None: # xla + decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:] + elif past_key_values is not None: # no xla + past_key_values + decoder_position_ids = past_key_values[0][0].shape[2] + else: # no xla + no past_key_values + decoder_position_ids = tf.range(decoder_input_ids.shape[1]) + + return { + "input_ids": None, # encoder_outputs is defined. input_ids not needed + "encoder_outputs": encoder_outputs, + "past_key_values": past_key_values, + "decoder_input_ids": decoder_input_ids, + "attention_mask": attention_mask, + "decoder_attention_mask": decoder_attention_mask, + "decoder_position_ids": decoder_position_ids, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "use_cache": use_cache, # change this to avoid caching (presumably for debugging) + } + + def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor): + return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "model", None) is not None: + with tf.name_scope(self.model.name): + self.model.build(None) + if getattr(self, "bias_layer", None) is not None: + with tf.name_scope(self.bias_layer.name): + self.bias_layer.build(None) + + +__all__ = ["TFPegasusForConditionalGeneration", "TFPegasusModel", "TFPegasusPreTrainedModel"] diff --git a/docs/transformers/src/transformers/models/pegasus/tokenization_pegasus.py b/docs/transformers/src/transformers/models/pegasus/tokenization_pegasus.py new file mode 100644 index 0000000000000000000000000000000000000000..c338e0fac1f3ddef4c92a81f4cd7f3d8aa4b7dfb --- /dev/null +++ b/docs/transformers/src/transformers/models/pegasus/tokenization_pegasus.py @@ -0,0 +1,292 @@ +# coding=utf-8 +# Copyright 2020 Google and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple + +import sentencepiece as spm + +from ...tokenization_utils import AddedToken, PreTrainedTokenizer +from ...utils import logging +from ...utils.import_utils import requires + + +SPIECE_UNDERLINE = "▁" + +VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"} + + +logger = logging.get_logger(__name__) + + +# TODO ArthurZ refactor this to only use the added_tokens_encoder + + +@requires(backends=("sentencepiece",)) +class PegasusTokenizer(PreTrainedTokenizer): + r""" + Construct a PEGASUS tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece). + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that + contains the vocabulary necessary to instantiate a tokenizer. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + mask_token (`str`, *optional*, defaults to `""`): + The token used for masking single token values. This is the token used when training this model with masked + language modeling (MLM). This is the token that the PEGASUS encoder will try to predict during pretraining. + It corresponds to *[MASK2]* in [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive + Summarization](https://arxiv.org/pdf/1912.08777.pdf). + mask_token_sent (`str`, *optional*, defaults to `""`): + The token used for masking whole target sentences. This is the token used when training this model with gap + sentences generation (GSG). This is the sentence that the PEGASUS decoder will try to predict during + pretraining. It corresponds to *[MASK1]* in [PEGASUS: Pre-training with Extracted Gap-sentences for + Abstractive Summarization](https://arxiv.org/pdf/1912.08777.pdf). + additional_special_tokens (`List[str]`, *optional*): + Additional special tokens used by the tokenizer. If no additional_special_tokens are provided and + are used as additional special tokens corresponding to the [original PEGASUS + tokenizer](https://github.com/google-research/pegasus/blob/939830367bcf411193d2b5eca2f2f90f3f9260ca/pegasus/ops/pretrain_parsing_ops.cc#L66) + that uses the tokens 2 - 104 only for pretraining + sp_model_kwargs (`dict`, *optional*): + Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for + SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things, + to set: + + - `enable_sampling`: Enable subword regularization. + - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout. + + - `nbest_size = {0,1}`: No sampling is performed. + - `nbest_size > 1`: samples from the nbest_size results. + - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice) + using forward-filtering-and-backward-sampling algorithm. + + - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for + BPE-dropout. + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + pad_token="", + eos_token="", + unk_token="", + mask_token="", + mask_token_sent="", + additional_special_tokens=None, + offset=103, # entries 2 - 104 are only used for pretraining + sp_model_kwargs: Optional[Dict[str, Any]] = None, + **kwargs, + ) -> None: + self.offset = offset + if additional_special_tokens is not None: + if not isinstance(additional_special_tokens, list): + raise TypeError( + f"additional_special_tokens should be of type {type(list)}, but is" + f" {type(additional_special_tokens)}" + ) + additional_special_tokens_extended = ( + ([mask_token_sent] + additional_special_tokens) + if mask_token_sent not in additional_special_tokens and mask_token_sent is not None + else additional_special_tokens + ) + # fill additional tokens with ..., in case not all additional tokens are already taken + additional_special_tokens_extended += [ + f"" for i in range(len(additional_special_tokens_extended), self.offset - 1) + ] + + if len(set(additional_special_tokens_extended)) != len(additional_special_tokens_extended): + raise ValueError( + "Please make sure that the provided additional_special_tokens do not contain an incorrectly" + f" shifted list of tokens. Found {additional_special_tokens_extended}." + ) + additional_special_tokens = additional_special_tokens_extended + else: + additional_special_tokens_extended = [] + additional_special_tokens = [mask_token_sent] if mask_token_sent is not None else [] + additional_special_tokens += [f"" for i in range(2, self.offset)] + + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + self.mask_token_sent = mask_token_sent + self.vocab_file = vocab_file + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(vocab_file) + + _added_tokens_decoder = { + 0: AddedToken(str(pad_token), special=True), + 1: AddedToken(str(eos_token), special=True), + } + + if self.mask_token_sent is not None: + _added_tokens_decoder[2] = AddedToken(mask_token_sent, special=True) + _added_tokens_decoder[3] = AddedToken(str(mask_token), special=True) + + for i in range(2, self.offset): + _added_tokens_decoder[len(_added_tokens_decoder)] = AddedToken(f"", special=True) + + # Force update as we want to make sure vocab is enforced (same as fast) + self._added_tokens_decoder = kwargs.pop("added_tokens_decoder", {}) + self._added_tokens_decoder.update(_added_tokens_decoder) + + super().__init__( + eos_token=eos_token, + unk_token=unk_token, + mask_token=mask_token, + pad_token=pad_token, + mask_token_sent=mask_token_sent, + offset=offset, + additional_special_tokens=additional_special_tokens, + sp_model_kwargs=self.sp_model_kwargs, + **kwargs, + ) + + @property + def vocab_size(self) -> int: + return len(self.sp_model) + self.offset + + def get_vocab(self) -> Dict[str, int]: + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + def __getstate__(self): + state = self.__dict__.copy() + state["sp_model"] = None + return state + + def __setstate__(self, d): + self.__dict__ = d + + # for backward compatibility + if not hasattr(self, "sp_model_kwargs"): + self.sp_model_kwargs = {} + + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(self.vocab_file) + + def _tokenize(self, text: str) -> List[str]: + """Take as input a string and return a list of strings (tokens) for words/sub-words""" + return self.sp_model.encode(text, out_type=str) + + def _convert_token_to_id(self, token: str) -> int: + """Converts a token (str) to an id using the vocab.""" + sp_id = self.sp_model.piece_to_id(token) + return sp_id + self.offset + + def _convert_id_to_token(self, index: int) -> str: + """Converts an index (integer) to a token (str) using the vocab.""" + if index < self.offset: + return self.sp_model.IdToPiece(index) + token = self.sp_model.IdToPiece(index - self.offset) + return token + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + current_sub_tokens = [] + out_string = "" + for token in tokens: + # make sure that special tokens are not decoded using sentencepiece model + if token in self.all_special_tokens: + out_string += self.sp_model.decode(current_sub_tokens) + token + current_sub_tokens = [] + else: + current_sub_tokens.append(token) + out_string += self.sp_model.decode(current_sub_tokens) + return out_string.strip() + + def num_special_tokens_to_add(self, pair=False): + """Just EOS""" + return 1 + + def _special_token_mask(self, seq): + all_special_ids = set(self.all_special_ids) # call it once instead of inside list comp + all_special_ids.remove(self.unk_token_id) # is only sometimes special + + return [1 if x in all_special_ids else 0 for x in seq] + + def get_special_tokens_mask( + self, token_ids_0: List, token_ids_1: Optional[List] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """Get list where entries are [1] if a token is [eos] or [pad] else 0.""" + if already_has_special_tokens: + return self._special_token_mask(token_ids_0) + elif token_ids_1 is None: + return self._special_token_mask(token_ids_0) + [1] + else: + return self._special_token_mask(token_ids_0 + token_ids_1) + [1] + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequences for sequence classification tasks by concatenating + and adding special tokens. A PEGASUS sequence has the following format, where `X` represents the sequence: + + - single sequence: `X ` + - pair of sequences: `A B ` (not intended use) + + BOS is never used. Pairs of sequences are not the expected use case, but they will be handled without a + separator. + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + if token_ids_1 is None: + return token_ids_0 + [self.eos_token_id] + # We don't expect to process pairs, but leave the pair logic for API consistency + return token_ids_0 + token_ids_1 + [self.eos_token_id] + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): + copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + + return (out_vocab_file,) + + +__all__ = ["PegasusTokenizer"] diff --git a/docs/transformers/src/transformers/models/pegasus/tokenization_pegasus_fast.py b/docs/transformers/src/transformers/models/pegasus/tokenization_pegasus_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..af62976cb751450699c389fad874811517124e0e --- /dev/null +++ b/docs/transformers/src/transformers/models/pegasus/tokenization_pegasus_fast.py @@ -0,0 +1,219 @@ +# coding=utf-8 +# Copyright 2020 Google and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization class for model PEGASUS.""" + +import os +from shutil import copyfile +from typing import List, Optional, Tuple + +from ...tokenization_utils_fast import PreTrainedTokenizerFast +from ...utils import is_sentencepiece_available, logging + + +if is_sentencepiece_available(): + from .tokenization_pegasus import PegasusTokenizer +else: + PegasusTokenizer = None + + +logger = logging.get_logger(__name__) + + +SPIECE_UNDERLINE = "▁" + +VOCAB_FILES_NAMES = {"vocab_file": "spiece.model", "tokenizer_file": "tokenizer.json"} + + +class PegasusTokenizerFast(PreTrainedTokenizerFast): + r""" + Construct a "fast" PEGASUS tokenizer (backed by HuggingFace's *tokenizers* library). Based on + [Unigram](https://huggingface.co/docs/tokenizers/python/latest/components.html?highlight=unigram#models). + + This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should + refer to this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that + contains the vocabulary necessary to instantiate a tokenizer. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + mask_token (`str`, *optional*, defaults to `""`): + The token used for masking single token values. This is the token used when training this model with masked + language modeling (MLM). This is the token that the PEGASUS encoder will try to predict during pretraining. + It corresponds to *[MASK2]* in [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive + Summarization](https://arxiv.org/pdf/1912.08777.pdf). + mask_token_sent (`str`, *optional*, defaults to `""`): + The token used for masking whole target sentences. This is the token used when training this model with gap + sentences generation (GSG). This is the sentence that the PEGASUS decoder will try to predict during + pretraining. It corresponds to *[MASK1]* in [PEGASUS: Pre-training with Extracted Gap-sentences for + Abstractive Summarization](https://arxiv.org/pdf/1912.08777.pdf). + additional_special_tokens (`List[str]`, *optional*): + Additional special tokens used by the tokenizer. If no additional_special_tokens are provided and + are used as additional special tokens corresponding to the [original PEGASUS + tokenizer](https://github.com/google-research/pegasus/blob/939830367bcf411193d2b5eca2f2f90f3f9260ca/pegasus/ops/pretrain_parsing_ops.cc#L66) + that uses the tokens 2 - 104 only for pretraining + """ + + vocab_files_names = VOCAB_FILES_NAMES + slow_tokenizer_class = PegasusTokenizer + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file=None, + tokenizer_file=None, + pad_token="", + eos_token="", + unk_token="", + mask_token="", + mask_token_sent="", + additional_special_tokens=None, + offset=103, # entries 2 - 104 are only used for pretraining + **kwargs, + ): + self.offset = offset + + if additional_special_tokens is not None: + if not isinstance(additional_special_tokens, list): + raise TypeError( + f"additional_special_tokens should be of type {type(list)}, but is" + f" {type(additional_special_tokens)}" + ) + + additional_special_tokens_extended = ( + ([mask_token_sent] + additional_special_tokens) + if mask_token_sent not in additional_special_tokens and mask_token_sent is not None + else additional_special_tokens + ) + # fill additional tokens with ..., in case not all additional tokens are already taken + additional_special_tokens_extended += [ + f"" for i in range(len(additional_special_tokens_extended), self.offset - 1) + ] + + if len(set(additional_special_tokens_extended)) != len(additional_special_tokens_extended): + raise ValueError( + "Please make sure that the provided additional_special_tokens do not contain an incorrectly" + f" shifted list of tokens. Found {additional_special_tokens_extended}." + ) + additional_special_tokens = additional_special_tokens_extended + else: + additional_special_tokens = [mask_token_sent] if mask_token_sent is not None else [] + additional_special_tokens += [f"" for i in range(2, self.offset)] + + # pegasus was design to support changing the index of the first tokens. If one of the padding/eos/unk/mask token + # is different from default, we must rebuild the vocab + from_slow = kwargs.pop("from_slow", None) + from_slow = from_slow or str(pad_token) != "" or str(eos_token) != "" or str(unk_token) != "" + + kwargs.pop("added_tokens_decoder", {}) + + super().__init__( + vocab_file, + tokenizer_file=tokenizer_file, + pad_token=pad_token, + eos_token=eos_token, + unk_token=unk_token, + mask_token=mask_token, + mask_token_sent=mask_token_sent, + offset=offset, + additional_special_tokens=additional_special_tokens, + from_slow=from_slow, + **kwargs, + ) + self.vocab_file = vocab_file + + @property + def can_save_slow_tokenizer(self) -> bool: + return os.path.isfile(self.vocab_file) if self.vocab_file else False + + def _special_token_mask(self, seq): + all_special_ids = set(self.all_special_ids) # call it once instead of inside list comp + all_special_ids.remove(self.unk_token_id) # is only sometimes special + + if all_special_ids != set(range(len(self.additional_special_tokens) + 3)): + raise ValueError( + "There should be 3 special tokens: mask_token, pad_token, and eos_token +" + f" {len(self.additional_special_tokens)} additional_special_tokens, but got {all_special_ids}" + ) + + return [1 if x in all_special_ids else 0 for x in seq] + + def get_special_tokens_mask( + self, token_ids_0: List, token_ids_1: Optional[List] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """Get list where entries are [1] if a token is [eos] or [pad] else 0.""" + if already_has_special_tokens: + return self._special_token_mask(token_ids_0) + elif token_ids_1 is None: + return self._special_token_mask(token_ids_0) + [1] + else: + return self._special_token_mask(token_ids_0 + token_ids_1) + [1] + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]: + """ + Build model inputs from a sequence by adding eos to the end. no bos token is added to the front. + + - single sequence: `X ` + - pair of sequences: `A B ` (not intended use) + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: list of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + if token_ids_1 is None: + return token_ids_0 + [self.eos_token_id] + # We don't expect to process pairs, but leave the pair logic for API consistency + return token_ids_0 + token_ids_1 + [self.eos_token_id] + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + if not self.can_save_slow_tokenizer: + raise ValueError( + "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow " + "tokenizer." + ) + + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): + copyfile(self.vocab_file, out_vocab_file) + + return (out_vocab_file,) + + +__all__ = ["PegasusTokenizerFast"] diff --git a/docs/transformers/src/transformers/models/pegasus_x/__init__.py b/docs/transformers/src/transformers/models/pegasus_x/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3d362bacf491ea32dcbc89e7b8bba4d9ae1a9261 --- /dev/null +++ b/docs/transformers/src/transformers/models/pegasus_x/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_pegasus_x import * + from .modeling_pegasus_x import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/pegasus_x/configuration_pegasus_x.py b/docs/transformers/src/transformers/models/pegasus_x/configuration_pegasus_x.py new file mode 100644 index 0000000000000000000000000000000000000000..c92f5662b5992f0edbd48c9d37c9151fdf13bc03 --- /dev/null +++ b/docs/transformers/src/transformers/models/pegasus_x/configuration_pegasus_x.py @@ -0,0 +1,177 @@ +# coding=utf-8 +# Copyright 2022, Google and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PEGASUS-X model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class PegasusXConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`PegasusXModel`]. It is used to instantiate a + PEGASUS-X model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the PEGASUS-X + [google/pegasus-x-large](https://huggingface.co/google/pegasus-x-large) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 96103): + Vocabulary size of the PEGASUS-X model. Defines the number of different tokens that can be represented by + the `inputs_ids` passed when calling [`PegasusXModel`]. + d_model (`int`, *optional*, defaults to 1024): + Dimension of the layers and the pooler layer. + encoder_layers (`int`, *optional*, defaults to 16): + Number of encoder layers. + decoder_layers (`int`, *optional*, defaults to 16): + Number of decoder layers. + encoder_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + decoder_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer decoder. + decoder_ffn_dim (`int`, *optional*, defaults to 4096): + Dimension of the "intermediate" (often named feed-forward) layer in decoder. + encoder_ffn_dim (`int`, *optional*, defaults to 4096): + Dimension of the "intermediate" (often named feed-forward) layer in decoder. + activation_function (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + activation_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for activations inside the fully connected layer. + max_position_embeddings (`int`, *optional*, defaults to 16384): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + init_std (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + encoder_layerdrop (`float`, *optional*, defaults to 0.0): + The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) + for more details. + decoder_layerdrop (`float`, *optional*, defaults to 0.0): + The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) + for more details. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models) + forced_eos_token_id (`int`, *optional*, defaults to 1): + The id of the token to force as the last generated token when `max_length` is reached. Usually set to + `eos_token_id`. + num_global_tokens (`int`, *optional*, defaults to 128): + Number of global tokens to use for the encoder + block_size (`int`, *optional*, defaults to 512): + Block size for encoder local attention. Sequence length should be an exact multiple of block size. + block_size must be a multiple of 2 if stagger_local_block is True + stagger_local_block (`bool`, *optional*, defaults to `True`): + Whether to stagger every other local attention by half a block + + Example: + + ```python + >>> from transformers import PegasusXConfig, PegasusXModel + + >>> # Initializing a PEGASUS google/pegasus-x-large style configuration + >>> configuration = PegasusXConfig() + + >>> # Initializing a model (with random weights) from the google/pegasus-x-large style configuration + >>> model = PegasusXModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "pegasus_x" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"} + + def __init__( + self, + vocab_size=96103, + max_position_embeddings=16384, + encoder_layers=16, + encoder_ffn_dim=4096, + encoder_attention_heads=16, + decoder_layers=16, + decoder_ffn_dim=4096, + decoder_attention_heads=16, + encoder_layerdrop=0.0, + decoder_layerdrop=0.0, + use_cache=True, + is_encoder_decoder=True, + activation_function="gelu", + d_model=1024, + dropout=0.1, + attention_dropout=0.0, + activation_dropout=0.0, + init_std=0.02, + decoder_start_token_id=0, + scale_embedding=True, + pad_token_id=0, + eos_token_id=1, + forced_eos_token_id=1, + num_global_tokens=32, + block_size=512, + stagger_local_blocks=True, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.d_model = d_model + self.encoder_ffn_dim = encoder_ffn_dim + self.encoder_layers = encoder_layers + self.encoder_attention_heads = encoder_attention_heads + self.decoder_ffn_dim = decoder_ffn_dim + self.decoder_layers = decoder_layers + self.decoder_attention_heads = decoder_attention_heads + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.activation_function = activation_function + self.init_std = init_std + self.encoder_layerdrop = encoder_layerdrop + self.decoder_layerdrop = decoder_layerdrop + self.use_cache = use_cache + self.num_hidden_layers = encoder_layers + self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True + + self.num_global_tokens = num_global_tokens + self.block_size = block_size + self.stagger_local_blocks = stagger_local_blocks + + super().__init__( + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + is_encoder_decoder=is_encoder_decoder, + decoder_start_token_id=decoder_start_token_id, + forced_eos_token_id=forced_eos_token_id, + **kwargs, + ) + + @property + def num_attention_heads(self) -> int: + return self.encoder_attention_heads + + @property + def hidden_size(self) -> int: + return self.d_model + + +__all__ = ["PegasusXConfig"] diff --git a/docs/transformers/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/docs/transformers/src/transformers/models/pegasus_x/modeling_pegasus_x.py new file mode 100644 index 0000000000000000000000000000000000000000..646ab195947b9084bfb76b9879198fd28b2557da --- /dev/null +++ b/docs/transformers/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -0,0 +1,1621 @@ +# coding=utf-8 +# Copyright 2022, Google and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch PEGASUS-X model.""" + +import dataclasses +import math +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss + +from ...activations import ACT2FN +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, +) +from ...modeling_utils import PreTrainedModel +from ...utils import ( + add_end_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from .configuration_pegasus_x import PegasusXConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "google/pegasus-x-base" +_CONFIG_FOR_DOC = "PegasusXConfig" + + +@dataclasses.dataclass +class DimensionInfo: + """Wrapper for dimension info.""" + + batch_size: int # batch size + seq_len: int # token length + block_size: int # block size + num_heads: int # num heads + hidden_dim: int # hidden dim + dim_per_head: int # dim per head + num_blocks: int # num blocks + global_len: int # global length + padded_seq_len: int # padded token seq length + + # Note: Compared to the original Flax implementation, we will pad the token representations to + # a multiple of block size at the start of the encoder layers, so T=P always. + + +# Copied from transformers.models.bart.modeling_bart.shift_tokens_right +def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): + """ + Shift input ids one token to the right. + """ + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() + shifted_input_ids[:, 0] = decoder_start_token_id + + if pad_token_id is None: + raise ValueError("self.model.config.pad_token_id has to be defined.") + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +# Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->PegasusX +class PegasusXScaledWordEmbedding(nn.Embedding): + """ + This module overrides nn.Embeddings' forward by multiplying with embeddings scale. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0): + super().__init__(num_embeddings, embedding_dim, padding_idx) + self.embed_scale = embed_scale + + def forward(self, input_ids: torch.Tensor): + return super().forward(input_ids) * self.embed_scale + + +class PegasusXSinusoidalPositionalEmbedding(nn.Module): + """This module produces sinusoidal positional embeddings of any length.""" + + def __init__(self, embed_dim, max_scale: int = 10000.0): + super().__init__() + self.embed_dim = embed_dim + self.max_scale = max_scale + + @torch.no_grad() + def forward(self, input_embeds: torch.Tensor, past_key_values_length: int = 0) -> torch.Tensor: + """`input_ids_shape` is expected to be [bsz x seqlen].""" + batch_size, seq_len = input_embeds.shape[:2] + positions = torch.arange( + past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=input_embeds.device + )[:, None] + pe = torch.zeros((seq_len, self.embed_dim), device=input_embeds.device, dtype=input_embeds.dtype) + half_d_feature = self.embed_dim // 2 + div_term = torch.exp( + torch.arange(half_d_feature, device=input_embeds.device, dtype=torch.int64).type_as(input_embeds) + * -(np.log(float(self.max_scale)) / (half_d_feature - 1)) + ) + pe[:, :half_d_feature] = torch.sin(positions * div_term) + pe[:, half_d_feature:] = torch.cos(positions * div_term) + return pe[None].expand(batch_size, -1, -1) + + +# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->PegasusX +class PegasusXAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + is_causal: bool = False, + config: Optional[PegasusXConfig] = None, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + self.config = config + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + self.is_causal = is_causal + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.reshape(*proj_shape) + value_states = value_states.reshape(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value + + +class PegasusXGlobalLocalAttention(nn.Module): + """Global + Local attention. For use with Encoder only.""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + block_size: int, + dropout: float = 0.0, + is_decoder: bool = False, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.block_size = block_size + self.dropout = dropout + self.head_dim = embed_dim // num_heads + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=False) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + token_hidden_states: torch.Tensor, + global_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """Input shape: Batch x Time x Channel""" + dim = DimensionInfo( + batch_size=token_hidden_states.shape[0], + seq_len=token_hidden_states.shape[1], + block_size=self.block_size, + num_heads=self.num_heads, + hidden_dim=token_hidden_states.shape[2], + dim_per_head=self.head_dim, + num_blocks=token_hidden_states.shape[1] // self.block_size, + global_len=global_hidden_states.shape[1], + padded_seq_len=token_hidden_states.shape[1], + ) + + # [batch_size, num_heads, padded_seq_len, dim_per_head] + local_q = self._shape( + self.q_proj(token_hidden_states) * self.scaling, + seq_len=dim.padded_seq_len, + bsz=dim.batch_size, + ) + local_k = self._shape( + self.k_proj(token_hidden_states), + seq_len=dim.padded_seq_len, + bsz=dim.batch_size, + ) + local_v = self._shape( + self.v_proj(token_hidden_states), + seq_len=dim.padded_seq_len, + bsz=dim.batch_size, + ) + + # [batch_size, num_heads, global_len, dim_per_head] + global_q = self._shape( + self.q_proj(global_hidden_states) * self.scaling, + seq_len=dim.global_len, + bsz=dim.batch_size, + ) + global_k = self._shape( + self.k_proj(global_hidden_states), + seq_len=dim.global_len, + bsz=dim.batch_size, + ) + global_v = self._shape( + self.v_proj(global_hidden_states), + seq_len=dim.global_len, + bsz=dim.batch_size, + ) + + global_attn_output, global_attn_probs = self.compute_global_attention_representations( + global_q=global_q, + global_k=global_k, + global_v=global_v, + local_k=local_k, + local_v=local_v, + mask=attention_mask, + dim=dim, + ) + local_attn_output, local_attn_probs = self.compute_local_attention_representations( + global_k=global_k, + global_v=global_v, + local_q=local_q, + local_k=local_k, + local_v=local_v, + mask=attention_mask, + dim=dim, + ) + + # [batch_size, global_len, hidden_dim] + global_attn_output = ( + global_attn_output.transpose(1, 2).contiguous().view(dim.batch_size, dim.global_len, dim.hidden_dim) + ) + # [batch_size, global_len, hidden_dim] + global_attn_output = self.out_proj(global_attn_output) + # [batch_size, num_heads, block_size, num_heads, dim_per_head] + local_attn_output = local_attn_output.permute(0, 2, 3, 1, 4).contiguous() + # [batch_size, padded_seq_len, hidden_dim] + local_attn_output = local_attn_output.view(dim.batch_size, dim.padded_seq_len, dim.hidden_dim) + # [batch_size, padded_seq_len, hidden_dim] + local_attn_output = self.out_proj(local_attn_output) + + if output_attentions: + attn_probs = {"global": global_attn_probs, "local": local_attn_probs} + else: + attn_probs = None + + return local_attn_output, global_attn_output, attn_probs + + def compute_global_attention_representations( + self, global_q, global_k, global_v, local_k, local_v, mask, dim: DimensionInfo + ): + """Compute attention representations for global tokens. + + Global tokens will attend to both global tokens as well as all input sequence tokens. Because the input + sequence tokens are arranged in blocks for local attention, we unblock them and compute attention. + + Args: + global_q (`torch.FloatTensor`) of shape [batch_size, num_heads, global_len, dim_per_head]: + query vectors from global tokens + global_k (`torch.FloatTensor`) of shape [batch_size, num_heads, global_len, dim_per_head]: + key vectors from global tokens + global_v (`torch.FloatTensor`) of shape [batch_size, num_heads, global_len, dim_per_head]: + value vectors from global tokens + local_k (`torch.FloatTensor`) of shape [batch_size, num_heads, padded_seq_len, dim_per_head]: + key vectors from local tokens + local_v (`torch.FloatTensor`) of shape [batch_size, num_heads, padded_seq_len, dim_per_head]: + value vectors from local tokens + mask (`torch.FloatTensor`) of shape [batch_size, padded_seq_len]: attention mask + dim (DimensionInfo): DimensionInfo wrapper for dimensions + + Returns: + output of shape `[batch_sizes, length, features]`. where length will be padded to a multiple of block_size + """ + # [batch_size, num_heads, global_len+padded_seq_len, dim_per_head] + global_and_local_k = torch.cat([global_k, local_k], dim=2) + # [batch_size, num_heads, global_len+padded_seq_len, dim_per_head] + global_and_local_v = torch.cat([global_v, local_v], dim=2) + + # [batch_size, global_len+padded_seq_len] + extended_mask = nn.functional.pad(mask, pad=(dim.global_len, 0), value=0) + + # [batch_size, num_heads, global_len, global_len+padded_seq_len] + attn_weights = torch.einsum("BHGF,BHXF->BHGX", global_q, global_and_local_k) + attn_weights = attn_weights + extended_mask[:, None, None, :] + attn_probs = nn.functional.softmax(attn_weights, dim=-1) + attn_probs = nn.functional.dropout(attn_probs, p=self.dropout, training=self.training) + + # [batch_size, num_heads, global_len, F] + attn_output = torch.einsum("BHGX,BHXF->BHGF", attn_probs, global_and_local_v) + return attn_output, attn_probs + + def compute_local_attention_representations( + self, global_k, global_v, local_q, local_k, local_v, mask, dim: DimensionInfo + ): + """Compute attention representations for local tokens. + + Local tokens will attend to both global tokens as well as all other tokens within the same local block. Hence, + we need to tile and concatenate the global tokens to every local block + + Args: + global_k (`torch.FloatTensor`) of shape [batch_size, num_heads, global_len, dim_per_head]: + key vectors from global tokens + global_v (`torch.FloatTensor`) of shape [batch_size, num_heads, global_len, dim_per_head]: + value vectors from global tokens + local_q (`torch.FloatTensor`) of shape [batch_size, num_heads, padded_seq_len, dim_per_head]: + query vectors from local tokens + local_k (`torch.FloatTensor`) of shape [batch_size, num_heads, padded_seq_len, dim_per_head]: + key vectors from local tokens + local_v (`torch.FloatTensor`) of shape [batch_size, num_heads, padded_seq_len, dim_per_head]: + value vectors from local tokens + mask (`torch.FloatTensor`) of shape [batch_size, padded_seq_len]: attention mask + dim (DimensionInfo): DimensionInfo wrapper for dimensions + + Returns: + output of shape `[batch_sizes, length, features]`. where length will be padded to a multiple of block_size + """ + # [batch_size, num_heads, num_blocks, block_size, dim_per_head] + blocked_local_q = local_q.view(dim.batch_size, dim.num_heads, dim.num_blocks, dim.block_size, dim.dim_per_head) + # [batch_size, num_heads, num_blocks, block_size, dim_per_head] + blocked_local_k = local_k.view(dim.batch_size, dim.num_heads, dim.num_blocks, dim.block_size, dim.dim_per_head) + # [batch_size, num_heads, num_blocks, block_size, dim_per_head] + blocked_local_v = local_v.view(dim.batch_size, dim.num_heads, dim.num_blocks, dim.block_size, dim.dim_per_head) + + # [batch_size, num_blocks, global_len+block_size] + extended_mask = nn.functional.pad( + mask.view(dim.batch_size, dim.num_blocks, dim.block_size), + pad=(dim.global_len, 0), + value=0, + ) + + # [batch_size, num_heads, num_blocks, block_size, global_len] + blocked_local2global = torch.einsum("BHNKF,BHGF->BHNKG", blocked_local_q, global_k) + # [batch_size, num_heads, num_blocks, block_size, block_size] + blocked_local2local = torch.einsum("BHNKF,BHNXF->BHNKX", blocked_local_q, blocked_local_k) + + # [batch_size, num_heads, num_blocks, block_size, global_len+block_size] + attn_weights = torch.cat([blocked_local2global, blocked_local2local], dim=-1) + attn_weights = attn_weights + extended_mask[:, None, :, None, :] + attn_probs = nn.functional.softmax(attn_weights, dim=-1) + attn_probs = nn.functional.dropout(attn_probs, p=self.dropout, training=self.training) + + # [batch_size, num_heads, num_blocks, block_size, global_len] + local2global_attn_probs = attn_probs[:, :, :, :, : dim.global_len] + # [batch_size, num_heads, num_blocks, block_size, block_size] + local2local_attn_probs = attn_probs[:, :, :, :, dim.global_len :] + + # [batch_size, num_heads, num_blocks, block_size, dim_per_head] + local2global_attn_output = torch.einsum("BHNKG,BHGF->BHNKF", local2global_attn_probs, global_v) + # [batch_size, num_heads, num_blocks, block_size, dim_per_head] + local2local_attn_output = torch.einsum("BHNKX,BHNXF->BHNKF", local2local_attn_probs, blocked_local_v) + # [batch_size, num_heads, num_blocks, block_size, dim_per_head] + attn_output = local2global_attn_output + local2local_attn_output + return attn_output, attn_probs + + +class PegasusXEncoderLayer(nn.Module): + def __init__(self, stagger_blocks_this_layer: bool, config: PegasusXConfig): + super().__init__() + self.embed_dim = config.d_model + self.self_attn = PegasusXGlobalLocalAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + block_size=config.block_size, + dropout=config.attention_dropout, + ) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.global_self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + self.stagger_blocks_this_layer = stagger_blocks_this_layer + self.block_size = config.block_size + + def forward( + self, + hidden_states: torch.Tensor, + global_hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + output_attentions: bool = False, + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape *(seq_len, batch, embed_dim)* + global_hidden_states (`torch.FloatTensor`): global token hidden states + *(seq_len, num_global_tokens, embed_dim)* + attention_mask (`torch.FloatTensor`): attention mask of size + *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + global_residual = global_hidden_states + + hidden_states = self.self_attn_layer_norm(hidden_states) + global_hidden_states = self.global_self_attn_layer_norm(global_hidden_states) + + if self.stagger_blocks_this_layer: + # Pad the blocks to simulate staggering + hidden_states, attention_mask = self.pad_local_tokens( + hidden_states=hidden_states, attention_mask=attention_mask, block_size=self.block_size + ) + + hidden_states, global_hidden_states, attn_weights = self.self_attn( + token_hidden_states=hidden_states, + global_hidden_states=global_hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + + if self.stagger_blocks_this_layer: + # Undo the padding + hidden_states = self.unpad_local_tokens(padded_hidden_states=hidden_states, block_size=self.block_size) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + global_hidden_states = nn.functional.dropout(global_hidden_states, p=self.dropout, training=self.training) + global_hidden_states = global_residual + global_hidden_states + + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + global_residual = global_hidden_states + global_hidden_states = self.final_layer_norm(global_hidden_states) + global_hidden_states = self.activation_fn(self.fc1(global_hidden_states)) + global_hidden_states = nn.functional.dropout( + global_hidden_states, p=self.activation_dropout, training=self.training + ) + global_hidden_states = self.fc2(global_hidden_states) + global_hidden_states = nn.functional.dropout(global_hidden_states, p=self.dropout, training=self.training) + global_hidden_states = global_residual + global_hidden_states + outputs = (hidden_states, global_hidden_states) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + @classmethod + def pad_local_tokens(cls, hidden_states, attention_mask, block_size): + # hidden_states: [batch_size, seq_len, hidden_dim] + pad_size = block_size // 2 + mask_min_value = torch.finfo(hidden_states.dtype).min + padded_hidden_states = torch.nn.functional.pad( + hidden_states, + pad=(0, 0, pad_size, pad_size), + ) + padded_mask = torch.nn.functional.pad( + attention_mask, + pad=(pad_size, pad_size), + value=mask_min_value, + ) + return padded_hidden_states, padded_mask + + @classmethod + def unpad_local_tokens(cls, padded_hidden_states, block_size): + # padded_hidden_states: [batch_size, padded seq_len, hidden_dim] + pad_size = block_size // 2 + return padded_hidden_states[:, pad_size:-pad_size, :] + + +class PegasusXDecoderLayer(nn.Module): + def __init__(self, config: PegasusXConfig): + super().__init__() + self.embed_dim = config.d_model + + self.self_attn = PegasusXAttention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + bias=False, + ) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.encoder_attn = PegasusXAttention( + self.embed_dim, + config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + bias=False, + ) + self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) + self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = True, + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape *(seq_len, batch, embed_dim)* + attention_mask (`torch.FloatTensor`): attention mask of size + *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values. + encoder_hidden_states (`torch.FloatTensor`): + cross attention input to the layer of shape *(seq_len, batch, embed_dim)* + encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size + *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values. + past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache: Whether to us KV cache for decoding + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + # add present self-attn cache to positions 1,2 of present_key_value tuple + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=self_attn_past_key_value, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + # Cross-Attention Block + cross_attn_present_key_value = None + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + past_key_value=cross_attn_past_key_value, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + # add cross-attn to positions 3,4 of present_key_value tuple + present_key_value = present_key_value + cross_attn_present_key_value + + # Fully Connected + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class PegasusXPreTrainedModel(PreTrainedModel): + config_class = PegasusXConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = [r"PegasusXEncoderLayer", r"PegasusXDecoderLayer"] + + def _init_weights(self, module): + std = self.config.init_std + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + + +PEGASUS_X_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`PegasusXConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +PEGASUS_X_GENERATION_EXAMPLE = r""" + Summarization example: + + ```python + >>> from transformers import AutoTokenizer, PegasusXForConditionalGeneration + + >>> model = PegasusXForConditionalGeneration.from_pretrained("google/pegasus-x-base") + >>> tokenizer = AutoTokenizer.from_pretrained("google/pegasus-x-large") + + >>> ARTICLE_TO_SUMMARIZE = ( + ... "PG&E stated it scheduled the blackouts in response to forecasts for high winds " + ... "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were " + ... "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow." + ... ) + >>> inputs = tokenizer(ARTICLE_TO_SUMMARIZE, max_length=1024, return_tensors="pt") + + >>> # Generate Summary + >>> summary_ids = model.generate(inputs["input_ids"]) + >>> tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "California's largest electricity provider has turned off power to hundreds of thousands of customers." + ``` +""" + +PEGASUS_X_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + PEGASUS-X uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If + `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + + encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): + Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of + hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded + representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be + input (see `past_key_values`). This is useful if you want more control over how to convert + `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. + + If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value + of `inputs_embeds`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class PegasusXEncoder(PegasusXPreTrainedModel): + """ + Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a + [`PegasusXEncoderLayer`]. + + Args: + config: PegasusXConfig + embed_tokens (nn.Embedding): output embedding + """ + + def __init__(self, config: PegasusXConfig, embed_tokens: Optional[nn.Embedding] = None): + super().__init__(config) + + self.dropout = config.dropout + self.layerdrop = config.encoder_layerdrop + + embed_dim = config.d_model + padding_idx = config.pad_token_id + self.max_source_positions = config.max_position_embeddings + embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + + if embed_tokens is not None: + self.embed_tokens = embed_tokens + else: + self.embed_tokens = PegasusXScaledWordEmbedding( + config.vocab_size, embed_dim, padding_idx, embed_scale=embed_scale + ) + + self.embed_global = nn.Embedding(config.num_global_tokens, embed_dim) + self.embed_positions = PegasusXSinusoidalPositionalEmbedding(embed_dim) + self.layers = nn.ModuleList( + [ + PegasusXEncoderLayer( + stagger_blocks_this_layer=i % 2 == 1 and config.stagger_local_blocks, config=config + ) + for i in range(config.encoder_layers) + ] + ) + self.layer_norm = nn.LayerNorm(config.d_model) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def resize_position_embeddings(self, new_num_position_embeddings: int): + """ + Resizes position embeddings matrix of the model if `new_num_position_embeddings != + config.max_position_embeddings`. + + Arguments: + new_num_position_embeddings (`int`): + The number of new position embeddings. If position embeddings are learned, increasing the size will add + newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If + position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will + add correct vectors at the end following the position encoding algorithm, whereas reducing the size + will remove vectors from the end. + """ + logger.info(f"Setting `config.max_position_embeddings={new_num_position_embeddings}`...") + self.config.max_position_embeddings = new_num_position_embeddings + + self.embed_positions = PegasusXSinusoidalPositionalEmbedding(self.config.d_model) + self.embed_positions.to(self.device) + + def get_position_embeddings(self) -> nn.Embedding: + """ + Returns the position embeddings matrix + """ + return self.embed_positions + + def forward( + self, + input_ids=None, + attention_mask=None, + inputs_embeds=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + embed_pos = self.embed_positions(inputs_embeds) + + hidden_states = inputs_embeds + embed_pos + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + batch_size, seq_len, _ = hidden_states.shape + + # Setup mask + if attention_mask is None: + attention_mask = torch.ones(*input_shape, dtype=inputs_embeds.dtype, device=inputs_embeds.device) + attention_mask = attention_mask.to(dtype=hidden_states.dtype) + mask_min_value = torch.finfo(hidden_states.dtype).min + inverted_mask = 1.0 - attention_mask + attention_mask = inverted_mask.masked_fill( + inverted_mask.to(torch.bool), + mask_min_value, + ) + + # padding to block_size + if seq_len % self.config.block_size != 0: + pad_len = self.config.block_size - seq_len % self.config.block_size + hidden_states = nn.functional.pad(hidden_states, pad=(0, 0, 0, pad_len), value=0) + attention_mask = nn.functional.pad(attention_mask, pad=(0, pad_len), value=mask_min_value) + + # Global tokens + global_hidden_states = self.embed_global( + torch.arange(self.config.num_global_tokens, device=hidden_states.device)[None].expand(batch_size, -1) + ) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + to_drop = False + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: # skip the layer + to_drop = True + + if to_drop: + layer_outputs = (None, None) + else: + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + global_hidden_states, + attention_mask, + output_attentions, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + global_hidden_states, + attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + global_hidden_states = layer_outputs[1] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[2],) + + # Undo padding-to-block-size + hidden_states = hidden_states[:, :seq_len] + + hidden_states = self.layer_norm(hidden_states) + + if output_hidden_states: + encoder_states = encoder_states + ((hidden_states, global_hidden_states),) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +class PegasusXDecoder(PegasusXPreTrainedModel): + """ + Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`PegasusDecoderLayer`] + + Args: + config: PegasusXConfig + embed_tokens (nn.Embedding): output embedding + """ + + def __init__(self, config: PegasusXConfig, embed_tokens: Optional[nn.Embedding] = None): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.decoder_layerdrop + self.max_target_positions = config.max_position_embeddings + embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + padding_idx = config.pad_token_id + + if embed_tokens is not None: + self.embed_tokens = embed_tokens + else: + self.embed_tokens = PegasusXScaledWordEmbedding( + config.vocab_size, config.d_model, padding_idx=padding_idx, embed_scale=embed_scale + ) + + self.embed_positions = PegasusXSinusoidalPositionalEmbedding(config.d_model) + self.layers = nn.ModuleList([PegasusXDecoderLayer(config) for _ in range(config.decoder_layers)]) + self.layer_norm = nn.LayerNorm(config.d_model) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): + Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values + selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of + shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing + `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more + control over how to convert `input_ids` indices into associated vectors than the model's internal + embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + # embed positions + positions = self.embed_positions(inputs_embeds, past_key_values_length) + + positions = positions.to(inputs_embeds.device) + + hidden_states = inputs_embeds + positions + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + next_decoder_cache = () if use_cache else None + + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: + continue + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + None, + output_attentions, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + hidden_states = self.layer_norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + +@add_start_docstrings( + "The bare PEGASUS-X Model outputting raw hidden-states without any specific head on top.", + PEGASUS_X_START_DOCSTRING, +) +class PegasusXModel(PegasusXPreTrainedModel): + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + + def __init__(self, config: PegasusXConfig): + super().__init__(config) + + vocab_size = config.vocab_size + embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + padding_idx = config.pad_token_id + self.shared = PegasusXScaledWordEmbedding( + vocab_size, config.d_model, padding_idx=padding_idx, embed_scale=embed_scale + ) + + self.encoder = PegasusXEncoder(config, self.shared) + self.decoder = PegasusXDecoder(config, self.shared) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, value): + self.shared = value + self.encoder.embed_tokens = self.shared + self.decoder.embed_tokens = self.shared + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + def resize_position_embeddings(self, new_num_position_embeddings: int): + """ + Resizes position embeddings matrix of the model if `new_num_position_embeddings != + config.max_position_embeddings`. + + Arguments: + new_num_position_embeddings (`int`): + The number of new position embeddings. If position embeddings are learned, increasing the size will add + newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If + position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will + add correct vectors at the end following the position encoding algorithm, whereas reducing the size + will remove vectors from the end. + """ + self.config.max_position_embeddings = new_num_position_embeddings + self.encoder.resize_position_embeddings(new_num_position_embeddings) + self.decoder.resize_position_embeddings(new_num_position_embeddings) + + def get_position_embeddings(self) -> Tuple[nn.Embedding]: + """ + Returns the position embeddings matrix + """ + return (self.encoder.get_position_embeddings(), self.decoder.get_position_embeddings()) + + @add_start_docstrings_to_model_forward(PEGASUS_X_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.Tensor] = None, + decoder_attention_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None, + past_key_values: Optional[Tuple[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + decoder_inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, Seq2SeqModelOutput]: + r""" + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, PegasusModel + + >>> tokenizer = AutoTokenizer.from_pretrained("google/pegasus-x-large") + >>> model = PegasusModel.from_pretrained("google/pegasus-x-large") + + >>> inputs = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="pt") + >>> decoder_inputs = tokenizer("Studies show that", return_tensors="pt") + >>> outputs = model(input_ids=inputs.input_ids, decoder_input_ids=decoder_inputs.input_ids) + + >>> last_hidden_states = outputs.last_hidden_state + >>> list(last_hidden_states.shape) + [1, 4, 1024] + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings("The PEGASUS-X for conditional generation (e.g. summarization).", PEGASUS_X_START_DOCSTRING) +class PegasusXForConditionalGeneration(PegasusXPreTrainedModel, GenerationMixin): + base_model_prefix = "model" + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] + + def __init__(self, config: PegasusXConfig): + super().__init__(config) + self.model = PegasusXModel(config) + self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_encoder(self): + return self.model.get_encoder() + + def get_decoder(self): + return self.model.get_decoder() + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def resize_position_embeddings(self, new_num_position_embeddings: int): + """ + Resizes position embeddings matrix of the model if `new_num_position_embeddings != + config.max_position_embeddings`. + + Arguments: + new_num_position_embeddings (`int`): + The number of new position embeddings. If position embeddings are learned, increasing the size will add + newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If + position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will + add correct vectors at the end following the position encoding algorithm, whereas reducing the size + will remove vectors from the end. + """ + self.config.max_position_embeddings = new_num_position_embeddings + self.model.encoder.resize_position_embeddings(new_num_position_embeddings) + self.model.decoder.resize_position_embeddings(new_num_position_embeddings) + + def get_position_embeddings(self) -> Tuple[nn.Embedding]: + """ + Returns the position embeddings matrix + """ + return (self.model.encoder.get_position_embeddings(), self.model.decoder.get_position_embeddings()) + + @add_start_docstrings_to_model_forward(PEGASUS_X_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + @add_end_docstrings(PEGASUS_X_GENERATION_EXAMPLE) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.Tensor] = None, + decoder_attention_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None, + past_key_values: Optional[Tuple[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + decoder_inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, Seq2SeqLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None: + if use_cache: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + encoder_outputs=encoder_outputs, + decoder_attention_mask=decoder_attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + lm_logits = self.lm_head(outputs[0]) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (lm_logits,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return Seq2SeqLMOutput( + loss=masked_lm_loss, + logits=lm_logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + # cached cross_attention states don't have to be reordered -> they are always the same + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) + + layer_past[2:], + ) + return reordered_past + + +# Copied from transformers.models.bart.modeling_bart.BartDecoderWrapper with Bart->PegasusX +class PegasusXDecoderWrapper(PegasusXPreTrainedModel): + """ + This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is + used in combination with the [`EncoderDecoderModel`] framework. + """ + + def __init__(self, config): + super().__init__(config) + self.decoder = PegasusXDecoder(config) + + def forward(self, *args, **kwargs): + return self.decoder(*args, **kwargs) + + +__all__ = ["PegasusXForConditionalGeneration", "PegasusXModel", "PegasusXPreTrainedModel"] diff --git a/docs/transformers/src/transformers/models/perceiver/__init__.py b/docs/transformers/src/transformers/models/perceiver/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..da8d30a5a9b54b27afb74fba377024c2c95436f6 --- /dev/null +++ b/docs/transformers/src/transformers/models/perceiver/__init__.py @@ -0,0 +1,31 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_perceiver import * + from .feature_extraction_perceiver import * + from .image_processing_perceiver import * + from .image_processing_perceiver_fast import * + from .modeling_perceiver import * + from .tokenization_perceiver import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/perceiver/configuration_perceiver.py b/docs/transformers/src/transformers/models/perceiver/configuration_perceiver.py new file mode 100644 index 0000000000000000000000000000000000000000..fc9da1d650421c0f176e85a5c4dbcab5bdadbdc8 --- /dev/null +++ b/docs/transformers/src/transformers/models/perceiver/configuration_perceiver.py @@ -0,0 +1,244 @@ +# coding=utf-8 +# Copyright Deepmind and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Perceiver model configuration""" + +from collections import OrderedDict +from typing import Any, Mapping, Optional, Union + +from ...configuration_utils import PretrainedConfig +from ...feature_extraction_utils import FeatureExtractionMixin +from ...onnx import OnnxConfig +from ...onnx.utils import compute_effective_axis_dimension +from ...tokenization_utils_base import PreTrainedTokenizerBase +from ...utils import TensorType, logging + + +logger = logging.get_logger(__name__) + + +class PerceiverConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`PerceiverModel`]. It is used to instantiate an + Perceiver model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the Perceiver + [deepmind/language-perceiver](https://huggingface.co/deepmind/language-perceiver) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + num_latents (`int`, *optional*, defaults to 256): + The number of latents. + d_latents (`int`, *optional*, defaults to 1280): + Dimension of the latent embeddings. + d_model (`int`, *optional*, defaults to 768): + Dimension of the inputs. Should only be provided in case [*PerceiverTextPreprocessor*] is used or no + preprocessor is provided. + num_blocks (`int`, *optional*, defaults to 1): + Number of blocks in the Transformer encoder. + num_self_attends_per_block (`int`, *optional*, defaults to 26): + The number of self-attention layers per block. + num_self_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each self-attention layer in the Transformer encoder. + num_cross_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each cross-attention layer in the Transformer encoder. + qk_channels (`int`, *optional*): + Dimension to project the queries + keys before applying attention in the cross-attention and self-attention + layers of the encoder. Will default to preserving the dimension of the queries if not specified. + v_channels (`int`, *optional*): + Dimension to project the values before applying attention in the cross-attention and self-attention layers + of the encoder. Will default to preserving the dimension of the queries if not specified. + cross_attention_shape_for_attention (`str`, *optional*, defaults to `"kv"`): + Dimension to use when downsampling the queries and keys in the cross-attention layer of the encoder. + self_attention_widening_factor (`int`, *optional*, defaults to 1): + Dimension of the feed-forward layer in the cross-attention layer of the Transformer encoder. + cross_attention_widening_factor (`int`, *optional*, defaults to 1): + Dimension of the feed-forward layer in the self-attention layers of the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + use_query_residual (`float`, *optional*, defaults to `True`): + Whether to add a query residual in the cross-attention layer of the encoder. + vocab_size (`int`, *optional*, defaults to 262): + Vocabulary size for the masked language modeling model. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that the masked language modeling model might ever be used with. Typically set + this to something large just in case (e.g., 512 or 1024 or 2048). + image_size (`int`, *optional*, defaults to 56): + Size of the images after preprocessing, for [`PerceiverForImageClassificationLearned`]. + train_size (`List[int]`, *optional*, defaults to `[368, 496]`): + Training size of the images for the optical flow model. + num_frames (`int`, *optional*, defaults to 16): + Number of video frames used for the multimodal autoencoding model. + audio_samples_per_frame (`int`, *optional*, defaults to 1920): + Number of audio samples per frame for the multimodal autoencoding model. + samples_per_patch (`int`, *optional*, defaults to 16): + Number of audio samples per patch when preprocessing the audio for the multimodal autoencoding model. + output_shape (`List[int]`, *optional*, defaults to `[1, 16, 224, 224]`): + Shape of the output (batch_size, num_frames, height, width) for the video decoder queries of the multimodal + autoencoding model. This excludes the channel dimension. + output_num_channels (`int`, *optional*, defaults to 512): + Number of output channels for each modalitiy decoder. + + Example: + + ```python + >>> from transformers import PerceiverModel, PerceiverConfig + + >>> # Initializing a Perceiver deepmind/language-perceiver style configuration + >>> configuration = PerceiverConfig() + + >>> # Initializing a model from the deepmind/language-perceiver style configuration + >>> model = PerceiverModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "perceiver" + + def __init__( + self, + num_latents=256, + d_latents=1280, + d_model=768, + num_blocks=1, + num_self_attends_per_block=26, + num_self_attention_heads=8, + num_cross_attention_heads=8, + qk_channels=None, + v_channels=None, + cross_attention_shape_for_attention="kv", + self_attention_widening_factor=1, + cross_attention_widening_factor=1, + hidden_act="gelu", + attention_probs_dropout_prob=0.1, + initializer_range=0.02, + layer_norm_eps=1e-12, + use_query_residual=True, + vocab_size=262, + max_position_embeddings=2048, + image_size=56, + train_size=[368, 496], + num_frames=16, + audio_samples_per_frame=1920, + samples_per_patch=16, + output_shape=[1, 16, 224, 224], + output_num_channels=512, + _label_trainable_num_channels=1024, + **kwargs, + ): + super().__init__(**kwargs) + + self.num_latents = num_latents + self.d_latents = d_latents + self.d_model = d_model + self.num_blocks = num_blocks + self.num_self_attends_per_block = num_self_attends_per_block + self.num_self_attention_heads = num_self_attention_heads + self.num_cross_attention_heads = num_cross_attention_heads + self.qk_channels = qk_channels + self.v_channels = v_channels + self.cross_attention_shape_for_attention = cross_attention_shape_for_attention + self.self_attention_widening_factor = self_attention_widening_factor + self.cross_attention_widening_factor = cross_attention_widening_factor + self.hidden_act = hidden_act + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.use_query_residual = use_query_residual + # masked language modeling attributes + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + # image classification attributes + self.image_size = image_size + # flow attributes + self.train_size = train_size + # multimodal autoencoding attributes + self.num_frames = num_frames + self.audio_samples_per_frame = audio_samples_per_frame + self.samples_per_patch = samples_per_patch + self.output_shape = output_shape + self.output_num_channels = output_num_channels + self._label_trainable_num_channels = _label_trainable_num_channels + + +class PerceiverOnnxConfig(OnnxConfig): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + if self.task == "multiple-choice": + dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"} + else: + dynamic_axis = {0: "batch", 1: "sequence"} + return OrderedDict( + [ + ("inputs", dynamic_axis), + ("attention_mask", dynamic_axis), + ] + ) + + @property + def atol_for_validation(self) -> float: + return 1e-4 + + def generate_dummy_inputs( + self, + preprocessor: Union["PreTrainedTokenizerBase", "FeatureExtractionMixin"], + batch_size: int = -1, + seq_length: int = -1, + num_choices: int = -1, + is_pair: bool = False, + framework: Optional[TensorType] = None, + num_channels: int = 3, + image_width: int = 40, + image_height: int = 40, + ) -> Mapping[str, Any]: + # copied from `transformers.onnx.config.OnnxConfig` and slightly altered/simplified + + if isinstance(preprocessor, PreTrainedTokenizerBase): + # If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX + batch_size = compute_effective_axis_dimension( + batch_size, fixed_dimension=OnnxConfig.default_fixed_batch, num_token_to_add=0 + ) + # If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX + token_to_add = preprocessor.num_special_tokens_to_add(is_pair) + seq_length = compute_effective_axis_dimension( + seq_length, fixed_dimension=OnnxConfig.default_fixed_sequence, num_token_to_add=token_to_add + ) + # Generate dummy inputs according to compute batch and sequence + dummy_input = [" ".join(["a"]) * seq_length] * batch_size + inputs = dict(preprocessor(dummy_input, return_tensors=framework)) + inputs["inputs"] = inputs.pop("input_ids") + return inputs + elif isinstance(preprocessor, FeatureExtractionMixin) and preprocessor.model_input_names[0] == "pixel_values": + # If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX + batch_size = compute_effective_axis_dimension(batch_size, fixed_dimension=OnnxConfig.default_fixed_batch) + dummy_input = self._generate_dummy_images(batch_size, num_channels, image_height, image_width) + inputs = dict(preprocessor(images=dummy_input, return_tensors=framework)) + inputs["inputs"] = inputs.pop("pixel_values") + return inputs + else: + raise ValueError( + "Unable to generate dummy inputs for the model. Please provide a tokenizer or a preprocessor." + ) + + +__all__ = ["PerceiverConfig", "PerceiverOnnxConfig"] diff --git a/docs/transformers/src/transformers/models/perceiver/convert_perceiver_haiku_to_pytorch.py b/docs/transformers/src/transformers/models/perceiver/convert_perceiver_haiku_to_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..082b9449374a7eb7e07bd40eb9d349f83ef5bcba --- /dev/null +++ b/docs/transformers/src/transformers/models/perceiver/convert_perceiver_haiku_to_pytorch.py @@ -0,0 +1,468 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert Perceiver checkpoints originally implemented in Haiku.""" + +import argparse +import json +import pickle +from pathlib import Path + +import haiku as hk +import numpy as np +import requests +import torch +from huggingface_hub import hf_hub_download +from PIL import Image + +from transformers import ( + PerceiverConfig, + PerceiverForImageClassificationConvProcessing, + PerceiverForImageClassificationFourier, + PerceiverForImageClassificationLearned, + PerceiverForMaskedLM, + PerceiverForMultimodalAutoencoding, + PerceiverForOpticalFlow, + PerceiverImageProcessor, + PerceiverTokenizer, +) +from transformers.utils import logging + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + + +def prepare_img(): + # We will verify our results on an image of a dog + url = "https://storage.googleapis.com/perceiver_io/dalmation.jpg" + im = Image.open(requests.get(url, stream=True).raw) + return im + + +def rename_keys(state_dict, architecture): + for name in list(state_dict): + param = state_dict.pop(name) + + # PREPROCESSORS + # rename text preprocessor embeddings (for MLM model) + name = name.replace("embed/embeddings", "input_preprocessor.embeddings.weight") + if name.startswith("trainable_position_encoding/pos_embs"): + name = name.replace( + "trainable_position_encoding/pos_embs", "input_preprocessor.position_embeddings.weight" + ) + + # rename image preprocessor embeddings (for image classification model with learned position embeddings) + name = name.replace("image_preprocessor/~/conv2_d/w", "input_preprocessor.convnet_1x1.weight") + name = name.replace("image_preprocessor/~/conv2_d/b", "input_preprocessor.convnet_1x1.bias") + name = name.replace( + "image_preprocessor/~_build_network_inputs/trainable_position_encoding/pos_embs", + "input_preprocessor.position_embeddings.position_embeddings", + ) + name = name.replace( + "image_preprocessor/~_build_network_inputs/position_encoding_projector/linear/w", + "input_preprocessor.positions_projection.weight", + ) + name = name.replace( + "image_preprocessor/~_build_network_inputs/position_encoding_projector/linear/b", + "input_preprocessor.positions_projection.bias", + ) + + # rename image preprocessor embeddings (for image classification model with conv processing) + if "counter" in name or "hidden" in name: + continue + name = name.replace( + "image_preprocessor/~/conv2_d_downsample/~/conv/w", "input_preprocessor.convnet.conv.weight" + ) + name = name.replace( + "image_preprocessor/~/conv2_d_downsample/~/batchnorm/offset", "input_preprocessor.convnet.batchnorm.bias" + ) + name = name.replace( + "image_preprocessor/~/conv2_d_downsample/~/batchnorm/scale", "input_preprocessor.convnet.batchnorm.weight" + ) + name = name.replace( + "image_preprocessor/~/conv2_d_downsample/~/batchnorm/~/mean_ema/average", + "input_preprocessor.convnet.batchnorm.running_mean", + ) + name = name.replace( + "image_preprocessor/~/conv2_d_downsample/~/batchnorm/~/var_ema/average", + "input_preprocessor.convnet.batchnorm.running_var", + ) + + # rename image preprocessor embeddings (for optical flow model) + name = name.replace("image_preprocessor/patches_linear/b", "input_preprocessor.conv_after_patches.bias") + name = name.replace("image_preprocessor/patches_linear/w", "input_preprocessor.conv_after_patches.weight") + + # rename multimodal preprocessor embeddings + name = name.replace("multimodal_preprocessor/audio_mask_token/pos_embs", "input_preprocessor.mask.audio") + name = name.replace("multimodal_preprocessor/audio_padding/pos_embs", "input_preprocessor.padding.audio") + name = name.replace("multimodal_preprocessor/image_mask_token/pos_embs", "input_preprocessor.mask.image") + name = name.replace("multimodal_preprocessor/image_padding/pos_embs", "input_preprocessor.padding.image") + name = name.replace("multimodal_preprocessor/label_mask_token/pos_embs", "input_preprocessor.mask.label") + name = name.replace("multimodal_preprocessor/label_padding/pos_embs", "input_preprocessor.padding.label") + + # DECODERS + # rename prefix of decoders + # multimodal autoencoding model + name = name.replace( + "multimodal_decoder/~/basic_decoder/cross_attention/", "decoder.decoder.decoding_cross_attention." + ) + name = name.replace("multimodal_decoder/~decoder_query/audio_padding/pos_embs", "decoder.padding.audio") + name = name.replace("multimodal_decoder/~decoder_query/image_padding/pos_embs", "decoder.padding.image") + name = name.replace("multimodal_decoder/~decoder_query/label_padding/pos_embs", "decoder.padding.label") + name = name.replace("multimodal_decoder/~/basic_decoder/output/b", "decoder.decoder.final_layer.bias") + name = name.replace("multimodal_decoder/~/basic_decoder/output/w", "decoder.decoder.final_layer.weight") + if architecture == "multimodal_autoencoding": + name = name.replace( + "classification_decoder/~/basic_decoder/~/trainable_position_encoding/pos_embs", + "decoder.modalities.label.decoder.output_position_encodings.position_embeddings", + ) + # flow model + name = name.replace( + "flow_decoder/~/basic_decoder/cross_attention/", "decoder.decoder.decoding_cross_attention." + ) + name = name.replace("flow_decoder/~/basic_decoder/output/w", "decoder.decoder.final_layer.weight") + name = name.replace("flow_decoder/~/basic_decoder/output/b", "decoder.decoder.final_layer.bias") + # image models + name = name.replace( + "classification_decoder/~/basic_decoder/~/trainable_position_encoding/pos_embs", + "decoder.decoder.output_position_encodings.position_embeddings", + ) + name = name.replace( + "basic_decoder/~/trainable_position_encoding/pos_embs", + "decoder.output_position_encodings.position_embeddings", + ) + name = name.replace( + "classification_decoder/~/basic_decoder/cross_attention/", "decoder.decoder.decoding_cross_attention." + ) + name = name.replace("classification_decoder/~/basic_decoder/output/b", "decoder.decoder.final_layer.bias") + name = name.replace("classification_decoder/~/basic_decoder/output/w", "decoder.decoder.final_layer.weight") + name = name = name.replace("classification_decoder/~/basic_decoder/~/", "decoder.decoder.") + name = name.replace("basic_decoder/cross_attention/", "decoder.decoding_cross_attention.") + name = name.replace("basic_decoder/~/", "decoder.") + + # POSTPROCESSORS + name = name.replace( + "projection_postprocessor/linear/b", "output_postprocessor.modalities.image.classifier.bias" + ) + name = name.replace( + "projection_postprocessor/linear/w", "output_postprocessor.modalities.image.classifier.weight" + ) + name = name.replace( + "classification_postprocessor/linear/b", "output_postprocessor.modalities.label.classifier.bias" + ) + name = name.replace( + "classification_postprocessor/linear/w", "output_postprocessor.modalities.label.classifier.weight" + ) + name = name.replace("audio_postprocessor/linear/b", "output_postprocessor.modalities.audio.classifier.bias") + name = name.replace("audio_postprocessor/linear/w", "output_postprocessor.modalities.audio.classifier.weight") + + # PERCEIVER MODEL + + # rename latent embeddings + name = name.replace("perceiver_encoder/~/trainable_position_encoding/pos_embs", "embeddings.latents") + # rename latent embeddings (for multimodal model) + name = name.replace("encoder/~/trainable_position_encoding/pos_embs", "embeddings.latents") + + # rename prefixes + if name.startswith("perceiver_encoder/~/"): + if "self_attention" in name: + suffix = "self_attends." + else: + suffix = "" + name = name.replace("perceiver_encoder/~/", "encoder." + suffix) + if name.startswith("encoder/~/"): + if "self_attention" in name: + suffix = "self_attends." + else: + suffix = "" + name = name.replace("encoder/~/", "encoder." + suffix) + # rename layernorm parameters + if "offset" in name: + name = name.replace("offset", "bias") + if "scale" in name: + name = name.replace("scale", "weight") + # in HuggingFace, the layernorm in between attention + MLP is just called "layernorm" + # rename layernorm in between attention + MLP of cross-attention + if "cross_attention" in name and "layer_norm_2" in name: + name = name.replace("layer_norm_2", "layernorm") + # rename layernorm in between attention + MLP of self-attention + if "self_attention" in name and "layer_norm_1" in name: + name = name.replace("layer_norm_1", "layernorm") + + # in HuggingFace, the layernorms for queries + keys are called "layernorm1" and "layernorm2" + if "cross_attention" in name and "layer_norm_1" in name: + name = name.replace("layer_norm_1", "attention.self.layernorm2") + if "cross_attention" in name and "layer_norm" in name: + name = name.replace("layer_norm", "attention.self.layernorm1") + if "self_attention" in name and "layer_norm" in name: + name = name.replace("layer_norm", "attention.self.layernorm1") + + # rename special characters by dots + name = name.replace("-", ".") + name = name.replace("/", ".") + # rename keys, queries, values and output of attention layers + if ("cross_attention" in name or "self_attention" in name) and "mlp" not in name: + if "linear.b" in name: + name = name.replace("linear.b", "self.query.bias") + if "linear.w" in name: + name = name.replace("linear.w", "self.query.weight") + if "linear_1.b" in name: + name = name.replace("linear_1.b", "self.key.bias") + if "linear_1.w" in name: + name = name.replace("linear_1.w", "self.key.weight") + if "linear_2.b" in name: + name = name.replace("linear_2.b", "self.value.bias") + if "linear_2.w" in name: + name = name.replace("linear_2.w", "self.value.weight") + if "linear_3.b" in name: + name = name.replace("linear_3.b", "output.dense.bias") + if "linear_3.w" in name: + name = name.replace("linear_3.w", "output.dense.weight") + if "self_attention_" in name: + name = name.replace("self_attention_", "") + if "self_attention" in name: + name = name.replace("self_attention", "0") + # rename dense layers of 2-layer MLP + if "mlp" in name: + if "linear.b" in name: + name = name.replace("linear.b", "dense1.bias") + if "linear.w" in name: + name = name.replace("linear.w", "dense1.weight") + if "linear_1.b" in name: + name = name.replace("linear_1.b", "dense2.bias") + if "linear_1.w" in name: + name = name.replace("linear_1.w", "dense2.weight") + + # finally, TRANSPOSE if kernel and not embedding layer, and set value + if name[-6:] == "weight" and "embeddings" not in name: + param = np.transpose(param) + + # if batchnorm, we need to squeeze it + if "batchnorm" in name: + param = np.squeeze(param) + + if "embedding_decoder" not in name: + state_dict["perceiver." + name] = torch.from_numpy(param) + else: + state_dict[name] = torch.from_numpy(param) + + +@torch.no_grad() +def convert_perceiver_checkpoint(pickle_file, pytorch_dump_folder_path, architecture="MLM"): + """ + Copy/paste/tweak model's weights to our Perceiver structure. + """ + + # load parameters as FlatMapping data structure + with open(pickle_file, "rb") as f: + checkpoint = pickle.loads(f.read()) + + state = None + if isinstance(checkpoint, dict) and architecture in [ + "image_classification", + "image_classification_fourier", + "image_classification_conv", + ]: + # the image classification_conv checkpoint also has batchnorm states (running_mean and running_var) + params = checkpoint["params"] + state = checkpoint["state"] + else: + params = checkpoint + + # turn into initial state dict + state_dict = {} + for scope_name, parameters in hk.data_structures.to_mutable_dict(params).items(): + for param_name, param in parameters.items(): + state_dict[scope_name + "/" + param_name] = param + + if state is not None: + # add state variables + for scope_name, parameters in hk.data_structures.to_mutable_dict(state).items(): + for param_name, param in parameters.items(): + state_dict[scope_name + "/" + param_name] = param + + # rename keys + rename_keys(state_dict, architecture=architecture) + + # load HuggingFace model + config = PerceiverConfig() + subsampling = None + repo_id = "huggingface/label-files" + if architecture == "MLM": + config.qk_channels = 8 * 32 + config.v_channels = 1280 + model = PerceiverForMaskedLM(config) + elif "image_classification" in architecture: + config.num_latents = 512 + config.d_latents = 1024 + config.d_model = 512 + config.num_blocks = 8 + config.num_self_attends_per_block = 6 + config.num_cross_attention_heads = 1 + config.num_self_attention_heads = 8 + config.qk_channels = None + config.v_channels = None + # set labels + config.num_labels = 1000 + filename = "imagenet-1k-id2label.json" + id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r")) + id2label = {int(k): v for k, v in id2label.items()} + config.id2label = id2label + config.label2id = {v: k for k, v in id2label.items()} + if architecture == "image_classification": + config.image_size = 224 + model = PerceiverForImageClassificationLearned(config) + elif architecture == "image_classification_fourier": + config.d_model = 261 + model = PerceiverForImageClassificationFourier(config) + elif architecture == "image_classification_conv": + config.d_model = 322 + model = PerceiverForImageClassificationConvProcessing(config) + else: + raise ValueError(f"Architecture {architecture} not supported") + elif architecture == "optical_flow": + config.num_latents = 2048 + config.d_latents = 512 + config.d_model = 322 + config.num_blocks = 1 + config.num_self_attends_per_block = 24 + config.num_self_attention_heads = 16 + config.num_cross_attention_heads = 1 + model = PerceiverForOpticalFlow(config) + elif architecture == "multimodal_autoencoding": + config.num_latents = 28 * 28 * 1 + config.d_latents = 512 + config.d_model = 704 + config.num_blocks = 1 + config.num_self_attends_per_block = 8 + config.num_self_attention_heads = 8 + config.num_cross_attention_heads = 1 + config.num_labels = 700 + # define dummy inputs + subsampling (as each forward pass is only on a chunk of image + audio data) + images = torch.randn((1, 16, 3, 224, 224)) + audio = torch.randn((1, 30720, 1)) + nchunks = 128 + image_chunk_size = np.prod((16, 224, 224)) // nchunks + audio_chunk_size = audio.shape[1] // config.samples_per_patch // nchunks + # process the first chunk + chunk_idx = 0 + subsampling = { + "image": torch.arange(image_chunk_size * chunk_idx, image_chunk_size * (chunk_idx + 1)), + "audio": torch.arange(audio_chunk_size * chunk_idx, audio_chunk_size * (chunk_idx + 1)), + "label": None, + } + model = PerceiverForMultimodalAutoencoding(config) + # set labels + filename = "kinetics700-id2label.json" + id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r")) + id2label = {int(k): v for k, v in id2label.items()} + config.id2label = id2label + config.label2id = {v: k for k, v in id2label.items()} + else: + raise ValueError(f"Architecture {architecture} not supported") + model.eval() + + # load weights + model.load_state_dict(state_dict) + + # prepare dummy input + input_mask = None + if architecture == "MLM": + tokenizer = PerceiverTokenizer.from_pretrained("/Users/NielsRogge/Documents/Perceiver/Tokenizer files") + text = "This is an incomplete sentence where some words are missing." + encoding = tokenizer(text, padding="max_length", return_tensors="pt") + # mask " missing.". Note that the model performs much better if the masked chunk starts with a space. + encoding.input_ids[0, 51:60] = tokenizer.mask_token_id + inputs = encoding.input_ids + input_mask = encoding.attention_mask + elif architecture in ["image_classification", "image_classification_fourier", "image_classification_conv"]: + image_processor = PerceiverImageProcessor() + image = prepare_img() + encoding = image_processor(image, return_tensors="pt") + inputs = encoding.pixel_values + elif architecture == "optical_flow": + inputs = torch.randn(1, 2, 27, 368, 496) + elif architecture == "multimodal_autoencoding": + images = torch.randn((1, 16, 3, 224, 224)) + audio = torch.randn((1, 30720, 1)) + inputs = {"image": images, "audio": audio, "label": torch.zeros((images.shape[0], 700))} + + # forward pass + if architecture == "multimodal_autoencoding": + outputs = model(inputs=inputs, attention_mask=input_mask, subsampled_output_points=subsampling) + else: + outputs = model(inputs=inputs, attention_mask=input_mask) + logits = outputs.logits + + # verify logits + if not isinstance(logits, dict): + print("Shape of logits:", logits.shape) + else: + for k, v in logits.items(): + print(f"Shape of logits of modality {k}", v.shape) + + if architecture == "MLM": + expected_slice = torch.tensor( + [[-11.8336, -11.6850, -11.8483], [-12.8149, -12.5863, -12.7904], [-12.8440, -12.6410, -12.8646]] + ) + assert torch.allclose(logits[0, :3, :3], expected_slice) + masked_tokens_predictions = logits[0, 51:60].argmax(dim=-1).tolist() + expected_list = [38, 115, 111, 121, 121, 111, 116, 109, 52] + assert masked_tokens_predictions == expected_list + print("Greedy predictions:") + print(masked_tokens_predictions) + print() + print("Predicted string:") + print(tokenizer.decode(masked_tokens_predictions)) + + elif architecture in ["image_classification", "image_classification_fourier", "image_classification_conv"]: + print("Predicted class:", model.config.id2label[logits.argmax(-1).item()]) + + # Finally, save files + Path(pytorch_dump_folder_path).mkdir(exist_ok=True) + print(f"Saving model to {pytorch_dump_folder_path}") + model.save_pretrained(pytorch_dump_folder_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--pickle_file", + type=str, + default=None, + required=True, + help="Path to local pickle file of a Perceiver checkpoint you'd like to convert.\n" + "Given the files are in the pickle format, please be wary of passing it files you trust.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", + default=None, + type=str, + required=True, + help="Path to the output PyTorch model directory, provided as a string.", + ) + parser.add_argument( + "--architecture", + default="MLM", + type=str, + help=""" + Architecture, provided as a string. One of 'MLM', 'image_classification', image_classification_fourier', + image_classification_fourier', 'optical_flow' or 'multimodal_autoencoding'. + """, + ) + + args = parser.parse_args() + convert_perceiver_checkpoint(args.pickle_file, args.pytorch_dump_folder_path, args.architecture) diff --git a/docs/transformers/src/transformers/models/perceiver/feature_extraction_perceiver.py b/docs/transformers/src/transformers/models/perceiver/feature_extraction_perceiver.py new file mode 100644 index 0000000000000000000000000000000000000000..566cb31fea78a52c0c7cfaa3e509437c55dab578 --- /dev/null +++ b/docs/transformers/src/transformers/models/perceiver/feature_extraction_perceiver.py @@ -0,0 +1,38 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Feature extractor class for Perceiver.""" + +import warnings + +from ...utils import logging +from ...utils.import_utils import requires +from .image_processing_perceiver import PerceiverImageProcessor + + +logger = logging.get_logger(__name__) + + +@requires(backends=("vision",)) +class PerceiverFeatureExtractor(PerceiverImageProcessor): + def __init__(self, *args, **kwargs) -> None: + warnings.warn( + "The class PerceiverFeatureExtractor is deprecated and will be removed in version 5 of Transformers." + " Please use PerceiverImageProcessor instead.", + FutureWarning, + ) + super().__init__(*args, **kwargs) + + +__all__ = ["PerceiverFeatureExtractor"] diff --git a/docs/transformers/src/transformers/models/perceiver/image_processing_perceiver.py b/docs/transformers/src/transformers/models/perceiver/image_processing_perceiver.py new file mode 100644 index 0000000000000000000000000000000000000000..82d571347392ece5057f62f67061c722edd5046f --- /dev/null +++ b/docs/transformers/src/transformers/models/perceiver/image_processing_perceiver.py @@ -0,0 +1,353 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for Perceiver.""" + +from typing import Dict, List, Optional, Union + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import center_crop, resize, to_channel_dimension_format +from ...image_utils import ( + IMAGENET_DEFAULT_MEAN, + IMAGENET_DEFAULT_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + get_image_size, + infer_channel_dimension_format, + is_scaled_image, + make_list_of_images, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging +from ...utils.import_utils import requires + + +if is_vision_available(): + import PIL + + +logger = logging.get_logger(__name__) + + +@requires(backends=("vision",)) +class PerceiverImageProcessor(BaseImageProcessor): + r""" + Constructs a Perceiver image processor. + + Args: + do_center_crop (`bool`, `optional`, defaults to `True`): + Whether or not to center crop the image. If the input size if smaller than `crop_size` along any edge, the + image will be padded with zeros and then center cropped. Can be overridden by the `do_center_crop` + parameter in the `preprocess` method. + crop_size (`Dict[str, int]`, *optional*, defaults to `{"height": 256, "width": 256}`): + Desired output size when applying center-cropping. Can be overridden by the `crop_size` parameter in the + `preprocess` method. + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image to `(size["height"], size["width"])`. Can be overridden by the `do_resize` + parameter in the `preprocess` method. + size (`Dict[str, int]` *optional*, defaults to `{"height": 224, "width": 224}`): + Size of the image after resizing. Can be overridden by the `size` parameter in the `preprocess` method. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + Defines the resampling filter to use if resizing the image. Can be overridden by the `resample` parameter + in the `preprocess` method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale` + parameter in the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Defines the scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter + in the `preprocess` method. + do_normalize: + Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess` + method. + image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_center_crop: bool = True, + crop_size: Dict[str, int] = None, + do_resize: bool = True, + size: Dict[str, int] = None, + resample: PILImageResampling = PILImageResampling.BICUBIC, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + crop_size = crop_size if crop_size is not None else {"height": 256, "width": 256} + crop_size = get_size_dict(crop_size, param_name="crop_size") + size = size if size is not None else {"height": 224, "width": 224} + size = get_size_dict(size) + + self.do_center_crop = do_center_crop + self.crop_size = crop_size + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN + self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD + + def center_crop( + self, + image: np.ndarray, + crop_size: Dict[str, int], + size: Optional[int] = None, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Center crop an image to `(size["height"] / crop_size["height"] * min_dim, size["width"] / crop_size["width"] * + min_dim)`. Where `min_dim = min(size["height"], size["width"])`. + + If the input size is smaller than `crop_size` along any edge, the image will be padded with zeros and then + center cropped. + + Args: + image (`np.ndarray`): + Image to center crop. + crop_size (`Dict[str, int]`): + Desired output size after applying the center crop. + size (`Dict[str, int]`, *optional*): + Size of the image after resizing. If not provided, the self.size attribute will be used. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be the same as the input image. + input_data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + size = self.size if size is None else size + size = get_size_dict(size) + crop_size = get_size_dict(crop_size, param_name="crop_size") + + height, width = get_image_size(image, channel_dim=input_data_format) + min_dim = min(height, width) + cropped_height = (size["height"] / crop_size["height"]) * min_dim + cropped_width = (size["width"] / crop_size["width"]) * min_dim + return center_crop( + image, + size=(cropped_height, cropped_width), + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + # Copied from transformers.models.vit.image_processing_vit.ViTImageProcessor.resize with PILImageResampling.BILINEAR->PILImageResampling.BICUBIC + def resize( + self, + image: np.ndarray, + size: Dict[str, int], + resample: PILImageResampling = PILImageResampling.BICUBIC, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize an image to `(size["height"], size["width"])`. + + Args: + image (`np.ndarray`): + Image to resize. + size (`Dict[str, int]`): + Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BICUBIC`. + data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the output image. If unset, the channel dimension format of the input + image is used. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + + Returns: + `np.ndarray`: The resized image. + """ + size = get_size_dict(size) + if "height" not in size or "width" not in size: + raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}") + output_size = (size["height"], size["width"]) + return resize( + image, + size=output_size, + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + @filter_out_non_signature_kwargs() + def preprocess( + self, + images: ImageInput, + do_center_crop: Optional[bool] = None, + crop_size: Optional[Dict[str, int]] = None, + do_resize: Optional[bool] = None, + size: Optional[Dict[str, int]] = None, + resample: PILImageResampling = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: ChannelDimension = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> PIL.Image.Image: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`): + Whether to center crop the image to `crop_size`. + crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`): + Desired output size after applying the center crop. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Size of the image after resizing. + resample (`int`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only + has an effect if `do_resize` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `ChannelDimension.LAST`: image in (height, width, num_channels) format. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop + crop_size = crop_size if crop_size is not None else self.crop_size + crop_size = get_size_dict(crop_size, param_name="crop_size") + do_resize = do_resize if do_resize is not None else self.do_resize + size = size if size is not None else self.size + size = get_size_dict(size) + resample = resample if resample is not None else self.resample + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + + images = make_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_resize=do_resize, + size=size, + resample=resample, + ) + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if do_rescale and is_scaled_image(images[0]): + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + if do_center_crop: + images = [ + self.center_crop(image, crop_size, size=size, input_data_format=input_data_format) for image in images + ] + + if do_resize: + images = [ + self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format) + for image in images + ] + + if do_rescale: + images = [ + self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + for image in images + ] + + if do_normalize: + images = [ + self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) + for image in images + ] + + images = [ + to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images + ] + + data = {"pixel_values": images} + return BatchFeature(data=data, tensor_type=return_tensors) + + +__all__ = ["PerceiverImageProcessor"] diff --git a/docs/transformers/src/transformers/models/perceiver/image_processing_perceiver_fast.py b/docs/transformers/src/transformers/models/perceiver/image_processing_perceiver_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..87b24e768465b9fb7db0f06f0d8bb5c15cf038c4 --- /dev/null +++ b/docs/transformers/src/transformers/models/perceiver/image_processing_perceiver_fast.py @@ -0,0 +1,133 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for Perceiver.""" + +from typing import Optional, Union + +from ...image_processing_utils_fast import BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, BaseImageProcessorFast, BatchFeature +from ...image_transforms import group_images_by_shape, reorder_images +from ...image_utils import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, PILImageResampling, SizeDict +from ...utils import ( + TensorType, + add_start_docstrings, + is_torch_available, + is_torchvision_available, + is_torchvision_v2_available, +) + + +if is_torch_available(): + import torch + +if is_torchvision_available(): + if is_torchvision_v2_available(): + from torchvision.transforms.v2 import functional as F + else: + from torchvision.transforms import functional as F + + +@add_start_docstrings( + "Constructs a fast Perceiver image processor.", + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, +) +class PerceiverImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.BICUBIC + image_mean = IMAGENET_DEFAULT_MEAN + image_std = IMAGENET_DEFAULT_STD + size = {"height": 224, "width": 224} + crop_size = {"height": 256, "width": 256} + do_resize = True + do_center_crop = True + do_rescale = True + do_normalize = True + + def center_crop( + self, + image: "torch.Tensor", + crop_size: dict[str, int], + size: dict[str, int], + **kwargs, + ) -> "torch.Tensor": + """ + Center crop an image to `(size["height"] / crop_size["height"] * min_dim, size["width"] / crop_size["width"] * + min_dim)`. Where `min_dim = min(size["height"], size["width"])`. + + If the input size is smaller than `crop_size` along any edge, the image will be padded with zeros and then + center cropped. + + Args: + image (`"torch.Tensor"`): + Image to center crop. + crop_size (`Dict[str, int]`): + Desired output size after applying the center crop. + size (`Dict[str, int]`): + Size of the output image. + + Returns: + `torch.Tensor`: The center cropped image. + """ + if size.height is None or size.width is None: + raise ValueError(f"The size dictionary must have keys 'height' and 'width'. Got {size.keys()}") + height, width = image.shape[-2:] + min_dim = min(height, width) + cropped_height = int((size.height / crop_size.height) * min_dim) + cropped_width = int((size.width / crop_size.width) * min_dim) + return F.center_crop(image, (cropped_height, cropped_width)) + + def _preprocess( + self, + images: list["torch.Tensor"], + do_resize: bool, + size: SizeDict, + interpolation: Optional["F.InterpolationMode"], + do_center_crop: bool, + crop_size: SizeDict, + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + image_mean: Optional[Union[float, list[float]]], + image_std: Optional[Union[float, list[float]]], + return_tensors: Optional[Union[str, TensorType]], + **kwargs, + ) -> BatchFeature: + # Group images by size for batched resizing + grouped_images, grouped_images_index = group_images_by_shape(images) + resized_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_center_crop: + stacked_images = self.center_crop(stacked_images, size=size, crop_size=crop_size) + if do_resize: + stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation) + resized_images_grouped[shape] = stacked_images + resized_images = reorder_images(resized_images_grouped, grouped_images_index) + + # Group images by size for further processing + # Needed in case do_resize is False, or resize returns images with different sizes + grouped_images, grouped_images_index = group_images_by_shape(resized_images) + processed_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + # Fused rescale and normalize + stacked_images = self.rescale_and_normalize( + stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std + ) + processed_images_grouped[shape] = stacked_images + + processed_images = reorder_images(processed_images_grouped, grouped_images_index) + processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images + + return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) + + +__all__ = ["PerceiverImageProcessorFast"] diff --git a/docs/transformers/src/transformers/models/perceiver/modeling_perceiver.py b/docs/transformers/src/transformers/models/perceiver/modeling_perceiver.py new file mode 100644 index 0000000000000000000000000000000000000000..047fd5b6bafecb2f1b8908b597042aaa682fc756 --- /dev/null +++ b/docs/transformers/src/transformers/models/perceiver/modeling_perceiver.py @@ -0,0 +1,3515 @@ +# coding=utf-8 +# Copyright 2021 Deepmind and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Perceiver model.""" + +import abc +import math +from dataclasses import dataclass +from functools import reduce +from operator import __add__ +from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union + +import numpy as np +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...modeling_outputs import BaseModelOutputWithCrossAttentions +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, meshgrid, prune_linear_layer +from ...utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, + torch_int, +) +from .configuration_perceiver import PerceiverConfig + + +ModalitySizeType = Mapping[str, int] +PreprocessorOutputType = Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor] +PreprocessorType = Callable[..., PreprocessorOutputType] +PostprocessorType = Callable[..., Any] + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "deepmind/language-perceiver" +_CONFIG_FOR_DOC = "PerceiverConfig" + + +@dataclass +class PerceiverModelOutput(ModelOutput): + """ + Base class for Perceiver base model's outputs, with potential hidden states, attentions and cross-attentions. + + Args: + logits (`torch.FloatTensor` of shape `(batch_size, num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer + plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in + the self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax, + used to compute the weighted average in the cross-attention heads. + """ + + logits: Optional[torch.FloatTensor] = None + last_hidden_state: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class PerceiverDecoderOutput(ModelOutput): + """ + Base class for Perceiver decoder outputs, with potential cross-attentions. + + Args: + logits (`torch.FloatTensor` of shape `(batch_size, num_labels)`): + Output of the basic decoder. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax, + used to compute the weighted average in the cross-attention heads. + """ + + logits: Optional[torch.FloatTensor] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class PerceiverMaskedLMOutput(ModelOutput): + """ + Base class for Perceiver's masked language model outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Masked language modeling (MLM) loss. + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer + plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, num_latents, + num_latents)`. Attentions weights after the attention softmax, used to compute the weighted average in the + self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax, + used to compute the weighted average in the cross-attention heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class PerceiverClassifierOutput(ModelOutput): + """ + Base class for Perceiver's outputs of sequence/image classification models, optical flow and multimodal + autoencoding. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer + plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in + the self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax, + used to compute the weighted average in the cross-attention heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + + +class PerceiverEmbeddings(nn.Module): + """Construct the latent embeddings.""" + + def __init__(self, config): + super().__init__() + self.latents = nn.Parameter(torch.randn(config.num_latents, config.d_latents)) + + def forward(self, batch_size: int): + return self.latents.expand(batch_size, -1, -1) # Thanks, Phil Wang + + +class PerceiverSelfAttention(nn.Module): + """Multi-headed {cross, self}-attention. Can be used both in the encoder as well as in the decoder.""" + + def __init__( + self, + config, + is_cross_attention=False, + qk_channels=None, + v_channels=None, + num_heads=1, + q_dim=None, + kv_dim=None, + ): + super().__init__() + self.num_heads = num_heads + # Q and K must have the same number of channels. + # Default to preserving Q's input's shape. + if qk_channels is None: + qk_channels = q_dim + # V's num_channels determines the shape of the output of QKV-attention. + # Default to the same number of channels used in the key-query operation. + if v_channels is None: + v_channels = qk_channels + if qk_channels % num_heads != 0: + raise ValueError(f"qk_channels ({qk_channels}) must be divisible by num_heads ({num_heads}).") + if v_channels % num_heads != 0: + raise ValueError(f"v_channels ({v_channels}) must be divisible by num_heads ({num_heads}).") + + self.qk_channels = qk_channels + self.v_channels = v_channels + self.qk_channels_per_head = self.qk_channels // num_heads + self.v_channels_per_head = self.v_channels // num_heads + + # Layer normalization + self.layernorm1 = nn.LayerNorm(q_dim) + self.layernorm2 = nn.LayerNorm(kv_dim) if is_cross_attention else nn.Identity() + + # Projection matrices + self.query = nn.Linear(q_dim, qk_channels) + self.key = nn.Linear(kv_dim, qk_channels) + self.value = nn.Linear(kv_dim, v_channels) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x, channels_per_head): + new_x_shape = x.size()[:-1] + (self.num_heads, channels_per_head) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs: Optional[torch.FloatTensor] = None, + inputs_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + hidden_states = self.layernorm1(hidden_states) + inputs = self.layernorm2(inputs) + + # Project queries, keys and values to a common feature dimension. If this is instantiated as a cross-attention module, + # the keys and values come from the inputs; the attention mask needs to be such that the inputs's non-relevant tokens are not attended to. + is_cross_attention = inputs is not None + queries = self.query(hidden_states) + + if is_cross_attention: + keys = self.key(inputs) + values = self.value(inputs) + attention_mask = inputs_mask + else: + keys = self.key(hidden_states) + values = self.value(hidden_states) + + # Reshape channels for multi-head attention. + # We reshape from (batch_size, time, channels) to (batch_size, num_heads, time, channels per head) + queries = self.transpose_for_scores(queries, self.qk_channels_per_head) + keys = self.transpose_for_scores(keys, self.qk_channels_per_head) + values = self.transpose_for_scores(values, self.v_channels_per_head) + + # Take the dot product between the queries and keys to get the raw attention scores. + attention_scores = torch.matmul(queries, keys.transpose(-1, -2)) + + batch_size, num_heads, seq_len, q_head_dim = queries.shape + _, _, _, v_head_dim = values.shape + hiddens = self.num_heads * v_head_dim + + attention_scores = attention_scores / math.sqrt(q_head_dim) + + if attention_mask is not None: + # Apply the attention mask (precomputed for all layers in PerceiverModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, values) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (hiddens,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +class PerceiverSelfOutput(nn.Module): + def __init__(self, config, input_channels, output_channels): + super().__init__() + self.dense = nn.Linear(input_channels, output_channels) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + return hidden_states + + +class PerceiverAttention(nn.Module): + """Attention module, including a dense block.""" + + def __init__( + self, + config, + is_cross_attention=False, + qk_channels=None, + v_channels=None, + num_heads=1, + q_dim=None, + kv_dim=None, + use_query_residual=True, + ): + super().__init__() + # MultiHead attention + if is_cross_attention and qk_channels is None: + if config.cross_attention_shape_for_attention == "q": + qk_channels = q_dim + elif config.cross_attention_shape_for_attention == "kv": + qk_channels = kv_dim + else: + raise ValueError( + f"Unknown value {config.cross_attention_shape_for_attention} for " + "cross_attention_shape_for_attention." + ) + else: + if qk_channels is None: + qk_channels = q_dim + if v_channels is None: + v_channels = qk_channels + self.self = PerceiverSelfAttention( + config, + is_cross_attention=is_cross_attention, + qk_channels=qk_channels, + v_channels=v_channels, + num_heads=num_heads, + q_dim=q_dim, + kv_dim=kv_dim, + ) + # dense block + output_channels = None + if is_cross_attention: + output_channels = q_dim + else: + if output_channels is None: + output_channels = v_channels + self.output = PerceiverSelfOutput(config, input_channels=self.self.v_channels, output_channels=output_channels) + self.use_query_residual = use_query_residual + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs: Optional[torch.FloatTensor] = None, + inputs_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + inputs, + inputs_mask, + output_attentions, + ) + + # Output projection + attention_output = self.output(self_outputs[0]) + + # Optionally include a residual to the original queries. + # Consider omitting the residual if the semantics of query and output + # are different, e.g. if queries are positions and outputs are pixels. + if self.use_query_residual: + attention_output = attention_output + hidden_states + + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class PerceiverMLP(nn.Module): + """A Transformer-style dense module to follow attention.""" + + def __init__(self, config, input_size, widening_factor): + super().__init__() + self.dense1 = nn.Linear(input_size, widening_factor * input_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + self.dense2 = nn.Linear(widening_factor * input_size, input_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense1(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + hidden_states = self.dense2(hidden_states) + return hidden_states + + +class PerceiverLayer(nn.Module): + def __init__( + self, + config, + is_cross_attention=False, + qk_channels=None, + v_channels=None, + num_heads=1, + q_dim=None, + kv_dim=None, + widening_factor=4, + use_query_residual=True, + ): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = PerceiverAttention( + config, + is_cross_attention=is_cross_attention, + qk_channels=qk_channels, + v_channels=v_channels, + num_heads=num_heads, + q_dim=q_dim, + kv_dim=kv_dim, + use_query_residual=use_query_residual, + ) + self.layernorm = nn.LayerNorm(q_dim) + self.mlp = PerceiverMLP(config, input_size=q_dim, widening_factor=widening_factor) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs: Optional[torch.FloatTensor] = None, + inputs_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + inputs, + inputs_mask, + output_attentions, + ) + attention_output = attention_outputs[0] + + outputs = attention_outputs[1:] # add attentions if we output attention weights + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output + ) + + layer_output = layer_output + attention_output # residual connection + + outputs = (layer_output,) + outputs + + return outputs + + def feed_forward_chunk(self, attention_output): + layer_output = self.layernorm(attention_output) + layer_output = self.mlp(layer_output) + return layer_output + + +class PerceiverEncoder(nn.Module): + """The Perceiver Encoder: a scalable, fully attentional encoder.""" + + def __init__(self, config, kv_dim=None): + super().__init__() + self.config = config + + # Check that we can use multihead-attention with these shapes. + if config.d_latents % config.num_self_attention_heads != 0: + raise ValueError( + f"num_z_channels ({config.d_latents}) must be divisible by" + f" num_self_attend_heads ({config.num_self_attention_heads})." + ) + if config.d_latents % config.num_cross_attention_heads != 0: + raise ValueError( + f"num_z_channels ({config.d_latents}) must be divisible by" + f" num_cross_attend_heads ({config.num_cross_attention_heads})." + ) + + # Construct the cross attention layer. + self.cross_attention = PerceiverLayer( + config, + is_cross_attention=True, + qk_channels=config.qk_channels, + v_channels=config.v_channels, + num_heads=config.num_cross_attention_heads, + q_dim=config.d_latents, + kv_dim=kv_dim, + widening_factor=config.cross_attention_widening_factor, + use_query_residual=config.use_query_residual, + ) + + # Construct a single block of self-attention layers. + # We get deeper architectures by applying this block more than once. + self_attention_layers = [] + for _ in range(config.num_self_attends_per_block): + layer = PerceiverLayer( + config, + is_cross_attention=False, + qk_channels=config.qk_channels, + v_channels=config.v_channels, + num_heads=config.num_self_attention_heads, + q_dim=config.d_latents, + kv_dim=config.d_latents, + widening_factor=config.self_attention_widening_factor, + ) + self_attention_layers.append(layer) + + self.self_attends = nn.ModuleList(self_attention_layers) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs: Optional[torch.FloatTensor] = None, + inputs_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + ) -> Union[Tuple, BaseModelOutputWithCrossAttentions]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions else None + + # Apply the cross-attention between the latents (hidden_states) and inputs: + layer_outputs = self.cross_attention( + hidden_states, + attention_mask=attention_mask, + head_mask=None, + inputs=inputs, + inputs_mask=inputs_mask, + output_attentions=output_attentions, + ) + hidden_states = layer_outputs[0] + + if output_attentions: + all_cross_attentions = all_cross_attentions + (layer_outputs[1],) + + # Apply the block of self-attention layers more than once: + for _ in range(self.config.num_blocks): + for i, layer_module in enumerate(self.self_attends): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + + layer_outputs = layer_module( + hidden_states, + attention_mask=attention_mask, + head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [hidden_states, all_hidden_states, all_self_attentions, all_cross_attentions] + if v is not None + ) + return BaseModelOutputWithCrossAttentions( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class PerceiverPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = PerceiverConfig + base_model_prefix = "perceiver" + main_input_name = "inputs" + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif hasattr(module, "latents"): + module.latents.data.normal_(mean=0.0, std=self.config.initializer_range) + elif hasattr(module, "position_embeddings") and isinstance(module, PerceiverTrainablePositionEncoding): + module.position_embeddings.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, nn.ParameterDict): + for modality in module.keys(): + module[modality].data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +PERCEIVER_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use + it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`PerceiverConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +PERCEIVER_MODEL_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use + it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`PerceiverConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. + decoder (*DecoderType*, *optional*): + Optional decoder to use to decode the latent representation of the encoder. Examples include + *transformers.models.perceiver.modeling_perceiver.PerceiverBasicDecoder*, + *transformers.models.perceiver.modeling_perceiver.PerceiverClassificationDecoder*, + *transformers.models.perceiver.modeling_perceiver.PerceiverMultimodalDecoder*. + input_preprocessor (*PreprocessorType*, *optional*): + Optional input preprocessor to use. Examples include + *transformers.models.perceiver.modeling_perceiver.PerceiverImagePreprocessor*, + *transformers.models.perceiver.modeling_perceiver.PerceiverAudioPreprocessor*, + *transformers.models.perceiver.modeling_perceiver.PerceiverTextPreprocessor*, + *transformers.models.perceiver.modeling_perceiver.PerceiverMultimodalPreprocessor*. + output_postprocessor (*PostprocessorType*, *optional*): + Optional output postprocessor to use. Examples include + *transformers.models.perceiver.modeling_perceiver.PerceiverImagePostprocessor*, + *transformers.models.perceiver.modeling_perceiver.PerceiverAudioPostprocessor*, + *transformers.models.perceiver.modeling_perceiver.PerceiverClassificationPostprocessor*, + *transformers.models.perceiver.modeling_perceiver.PerceiverProjectionPostprocessor*, + *transformers.models.perceiver.modeling_perceiver.PerceiverMultimodalPostprocessor*. + + Note that you can define your own decoders, preprocessors and/or postprocessors to fit your use-case. +""" + +PERCEIVER_INPUTS_DOCSTRING = r""" + Args: + inputs (`torch.FloatTensor`): + Inputs to the perceiver. Can be anything: images, text, audio, video, etc. + attention_mask (`torch.FloatTensor` of shape `{0}`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): + Whether to interpolate the pre-trained position encodings. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + """The Perceiver: a scalable, fully attentional architecture. + + + + Note that it's possible to fine-tune Perceiver on higher resolution images than the ones it has been trained on, by + setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained + position embeddings to the higher resolution. + + + """, + PERCEIVER_MODEL_START_DOCSTRING, +) +class PerceiverModel(PerceiverPreTrainedModel): + def __init__( + self, + config, + decoder=None, + input_preprocessor: PreprocessorType = None, + output_postprocessor: PostprocessorType = None, + ): + super().__init__(config) + self.config = config + + self.input_preprocessor = input_preprocessor + self.output_postprocessor = output_postprocessor + self.embeddings = PerceiverEmbeddings(config) + self.encoder = PerceiverEncoder( + config, kv_dim=input_preprocessor.num_channels if input_preprocessor is not None else config.d_model + ) + self.decoder = decoder + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.latents + + def set_input_embeddings(self, value): + self.embeddings.latents = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(PERCEIVER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) + @replace_return_docstrings(output_type=PerceiverModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + inputs: torch.FloatTensor, + attention_mask: Optional[torch.FloatTensor] = None, + subsampled_output_points: Optional[Dict[str, torch.Tensor]] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, PerceiverModelOutput]: + r""" + Returns: + + Examples: + + ```python + >>> from transformers import PerceiverConfig, PerceiverTokenizer, PerceiverImageProcessor, PerceiverModel + >>> from transformers.models.perceiver.modeling_perceiver import ( + ... PerceiverTextPreprocessor, + ... PerceiverImagePreprocessor, + ... PerceiverClassificationDecoder, + ... ) + >>> import torch + >>> import requests + >>> from PIL import Image + + >>> # EXAMPLE 1: using the Perceiver to classify texts + >>> # - we define a TextPreprocessor, which can be used to embed tokens + >>> # - we define a ClassificationDecoder, which can be used to decode the + >>> # final hidden states of the latents to classification logits + >>> # using trainable position embeddings + >>> config = PerceiverConfig() + >>> preprocessor = PerceiverTextPreprocessor(config) + >>> decoder = PerceiverClassificationDecoder( + ... config, + ... num_channels=config.d_latents, + ... trainable_position_encoding_kwargs=dict(num_channels=config.d_latents, index_dims=1), + ... use_query_residual=True, + ... ) + >>> model = PerceiverModel(config, input_preprocessor=preprocessor, decoder=decoder) + + >>> # you can then do a forward pass as follows: + >>> tokenizer = PerceiverTokenizer() + >>> text = "hello world" + >>> inputs = tokenizer(text, return_tensors="pt").input_ids + + >>> with torch.no_grad(): + ... outputs = model(inputs=inputs) + >>> logits = outputs.logits + >>> list(logits.shape) + [1, 2] + + >>> # to train, one can train the model using standard cross-entropy: + >>> criterion = torch.nn.CrossEntropyLoss() + + >>> labels = torch.tensor([1]) + >>> loss = criterion(logits, labels) + + >>> # EXAMPLE 2: using the Perceiver to classify images + >>> # - we define an ImagePreprocessor, which can be used to embed images + >>> config = PerceiverConfig(image_size=224) + >>> preprocessor = PerceiverImagePreprocessor( + ... config, + ... prep_type="conv1x1", + ... spatial_downsample=1, + ... out_channels=256, + ... position_encoding_type="trainable", + ... concat_or_add_pos="concat", + ... project_pos_dim=256, + ... trainable_position_encoding_kwargs=dict( + ... num_channels=256, + ... index_dims=config.image_size**2, + ... ), + ... ) + + >>> model = PerceiverModel( + ... config, + ... input_preprocessor=preprocessor, + ... decoder=PerceiverClassificationDecoder( + ... config, + ... num_channels=config.d_latents, + ... trainable_position_encoding_kwargs=dict(num_channels=config.d_latents, index_dims=1), + ... use_query_residual=True, + ... ), + ... ) + + >>> # you can then do a forward pass as follows: + >>> image_processor = PerceiverImageProcessor() + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> inputs = image_processor(image, return_tensors="pt").pixel_values + + >>> with torch.no_grad(): + ... outputs = model(inputs=inputs) + >>> logits = outputs.logits + >>> list(logits.shape) + [1, 2] + + >>> # to train, one can train the model using standard cross-entropy: + >>> criterion = torch.nn.CrossEntropyLoss() + + >>> labels = torch.tensor([1]) + >>> loss = criterion(logits, labels) + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.input_preprocessor is not None: + inputs, modality_sizes, inputs_without_pos = self.input_preprocessor( + inputs, interpolate_pos_encoding=interpolate_pos_encoding + ) + else: + modality_sizes = None + inputs_without_pos = None + if inputs.size()[-1] != self.config.d_model: + raise ValueError( + f"Last dimension of the inputs: {inputs.size()[-1]} doesn't correspond to config.d_model:" + f" {self.config.d_model}. Make sure to set config.d_model appropriately." + ) + + batch_size, seq_length, _ = inputs.size() + device = inputs.device + + # If no attention mask is provided, make them all ones + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length), device=device) + # Make the attention mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + extended_attention_mask = self.invert_attention_mask(attention_mask) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_blocks x num_heads] + # and head_mask is converted to shape [num_blocks x batch x num_heads x N x N] + head_mask = self.get_head_mask(head_mask, self.config.num_blocks * self.config.num_self_attends_per_block) + + embedding_output = self.embeddings(batch_size=batch_size) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=None, + head_mask=head_mask, + inputs=inputs, + inputs_mask=extended_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + + logits = None + if self.decoder: + if subsampled_output_points is not None: + output_modality_sizes = { + "audio": subsampled_output_points["audio"].shape[0], + "image": subsampled_output_points["image"].shape[0], + "label": 1, + } + else: + output_modality_sizes = modality_sizes + decoder_query = self.decoder.decoder_query( + inputs, modality_sizes, inputs_without_pos, subsampled_points=subsampled_output_points + ) + decoder_outputs = self.decoder( + decoder_query, + z=sequence_output, + query_mask=extended_attention_mask, + output_attentions=output_attentions, + ) + logits = decoder_outputs.logits + + # add cross-attentions of decoder + if output_attentions and decoder_outputs.cross_attentions is not None: + if return_dict: + encoder_outputs.cross_attentions = ( + encoder_outputs.cross_attentions + decoder_outputs.cross_attentions + ) + else: + encoder_outputs = encoder_outputs + decoder_outputs.cross_attentions + + if self.output_postprocessor: + logits = self.output_postprocessor(logits, modality_sizes=output_modality_sizes) + + if not return_dict: + if logits is not None: + return (logits, sequence_output) + encoder_outputs[1:] + else: + return (sequence_output,) + encoder_outputs[1:] + + return PerceiverModelOutput( + logits=logits, + last_hidden_state=sequence_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +@add_start_docstrings("""Example use of Perceiver for masked language modeling.""", PERCEIVER_START_DOCSTRING) +class PerceiverForMaskedLM(PerceiverPreTrainedModel): + def __init__(self, config: PerceiverConfig): + super().__init__(config) + + text_preprocessor = PerceiverTextPreprocessor(config) + + trainable_position_encoding_kwargs_decoder = { + "num_channels": text_preprocessor.num_channels, + "index_dims": config.max_position_embeddings, + } + + self.perceiver = PerceiverModel( + config, + input_preprocessor=text_preprocessor, + decoder=PerceiverBasicDecoder( + config, + output_num_channels=config.d_latents, + output_index_dims=config.max_position_embeddings, # we need to define the seq_len of the inputs beforehand + num_channels=text_preprocessor.num_channels, + qk_channels=8 * 32, + v_channels=text_preprocessor.num_channels, + num_heads=8, + use_query_residual=False, + final_project=False, + trainable_position_encoding_kwargs=trainable_position_encoding_kwargs_decoder, + ), + ) + self.embedding_decoder = PerceiverEmbeddingDecoder(config) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(PERCEIVER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=PerceiverMaskedLMOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + inputs: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + labels: Optional[torch.Tensor] = None, + return_dict: Optional[bool] = None, + input_ids: Optional[torch.Tensor] = None, + ) -> Union[Tuple, PerceiverMaskedLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + + Returns: + + Examples: + + ```python + >>> from transformers import AutoTokenizer, PerceiverForMaskedLM + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("deepmind/language-perceiver") + >>> model = PerceiverForMaskedLM.from_pretrained("deepmind/language-perceiver") + + >>> # training + >>> text = "This is an incomplete sentence where some words are missing." + >>> inputs = tokenizer(text, padding="max_length", return_tensors="pt") + >>> # mask " missing." + >>> inputs["input_ids"][0, 52:61] = tokenizer.mask_token_id + >>> labels = tokenizer(text, padding="max_length", return_tensors="pt").input_ids + + >>> outputs = model(**inputs, labels=labels) + >>> loss = outputs.loss + >>> round(loss.item(), 2) + 19.87 + + >>> logits = outputs.logits + >>> list(logits.shape) + [1, 2048, 262] + + >>> # inference + >>> text = "This is an incomplete sentence where some words are missing." + >>> encoding = tokenizer(text, padding="max_length", return_tensors="pt") + + >>> # mask bytes corresponding to " missing.". Note that the model performs much better if the masked span starts with a space. + >>> encoding["input_ids"][0, 52:61] = tokenizer.mask_token_id + + >>> # forward pass + >>> with torch.no_grad(): + ... outputs = model(**encoding) + >>> logits = outputs.logits + >>> list(logits.shape) + [1, 2048, 262] + + >>> masked_tokens_predictions = logits[0, 52:61].argmax(dim=-1).tolist() + >>> tokenizer.decode(masked_tokens_predictions) + ' missing.' + ```""" + if inputs is not None and input_ids is not None: + raise ValueError("You cannot use both `inputs` and `input_ids`") + elif inputs is None and input_ids is not None: + inputs = input_ids + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.perceiver( + inputs=inputs, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + logits = self.embedding_decoder( + outputs.logits if return_dict else outputs[0], embedding_layer=self.perceiver.input_preprocessor.embeddings + ) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return PerceiverMaskedLMOutput( + loss=masked_lm_loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +@add_start_docstrings("""Example use of Perceiver for text classification.""", PERCEIVER_START_DOCSTRING) +class PerceiverForSequenceClassification(PerceiverPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + trainable_position_encoding_kwargs_decoder = {"num_channels": config.d_latents, "index_dims": 1} + + self.num_labels = config.num_labels + self.perceiver = PerceiverModel( + config, + input_preprocessor=PerceiverTextPreprocessor(config), + decoder=PerceiverClassificationDecoder( + config, + num_channels=config.d_latents, + trainable_position_encoding_kwargs=trainable_position_encoding_kwargs_decoder, + use_query_residual=True, + ), + ) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(PERCEIVER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=PerceiverClassifierOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + inputs: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + labels: Optional[torch.Tensor] = None, + return_dict: Optional[bool] = None, + input_ids: Optional[torch.Tensor] = None, + ) -> Union[Tuple, PerceiverClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the classification/regression loss. Indices should be in `[0, ..., config.num_labels - + 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > + 1` a classification loss is computed (Cross-Entropy). + + Returns: + + Examples: + + ```python + >>> from transformers import AutoTokenizer, PerceiverForSequenceClassification + + >>> tokenizer = AutoTokenizer.from_pretrained("deepmind/language-perceiver") + >>> model = PerceiverForSequenceClassification.from_pretrained("deepmind/language-perceiver") + + >>> text = "hello world" + >>> inputs = tokenizer(text, return_tensors="pt").input_ids + >>> outputs = model(inputs=inputs) + >>> logits = outputs.logits + >>> list(logits.shape) + [1, 2] + ```""" + if inputs is not None and input_ids is not None: + raise ValueError("You cannot use both `inputs` and `input_ids`") + elif inputs is None and input_ids is not None: + inputs = input_ids + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.perceiver( + inputs=inputs, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + logits = outputs.logits if return_dict else outputs[0] + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return PerceiverClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +@add_start_docstrings( + """ +Example use of Perceiver for image classification, for tasks such as ImageNet. + +This model uses learned position embeddings. In other words, this model is not given any privileged information about +the structure of images. As shown in the paper, this model can achieve a top-1 accuracy of 72.7 on ImageNet. + +[`PerceiverForImageClassificationLearned`] uses [`~models.perceiver.modeling_perceiver.PerceiverImagePreprocessor`] +(with `prep_type="conv1x1"`) to preprocess the input images, and +[`~models.perceiver.modeling_perceiver.PerceiverClassificationDecoder`] to decode the latent representation of +[`PerceiverModel`] into classification logits. +""", + PERCEIVER_START_DOCSTRING, +) +class PerceiverForImageClassificationLearned(PerceiverPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + trainable_position_encoding_kwargs_preprocessor = {"num_channels": 256, "index_dims": config.image_size**2} + trainable_position_encoding_kwargs_decoder = {"num_channels": config.d_latents, "index_dims": 1} + + self.num_labels = config.num_labels + self.perceiver = PerceiverModel( + config, + input_preprocessor=PerceiverImagePreprocessor( + config, + prep_type="conv1x1", + spatial_downsample=1, + out_channels=256, + position_encoding_type="trainable", + concat_or_add_pos="concat", + project_pos_dim=256, + trainable_position_encoding_kwargs=trainable_position_encoding_kwargs_preprocessor, + ), + decoder=PerceiverClassificationDecoder( + config, + num_channels=config.d_latents, + trainable_position_encoding_kwargs=trainable_position_encoding_kwargs_decoder, + use_query_residual=True, + ), + ) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(PERCEIVER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=PerceiverClassifierOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + inputs: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + labels: Optional[torch.Tensor] = None, + interpolate_pos_encoding: bool = False, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + ) -> Union[Tuple, PerceiverClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + Returns: + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, PerceiverForImageClassificationLearned + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("deepmind/vision-perceiver-learned") + >>> model = PerceiverForImageClassificationLearned.from_pretrained("deepmind/vision-perceiver-learned") + + >>> inputs = image_processor(images=image, return_tensors="pt").pixel_values + >>> outputs = model(inputs=inputs) + >>> logits = outputs.logits + >>> list(logits.shape) + [1, 1000] + + >>> # model predicts one of the 1000 ImageNet classes + >>> predicted_class_idx = logits.argmax(-1).item() + >>> print("Predicted class:", model.config.id2label[predicted_class_idx]) + Predicted class: tabby, tabby cat + ```""" + if inputs is not None and pixel_values is not None: + raise ValueError("You cannot use both `inputs` and `pixel_values`") + elif inputs is None and pixel_values is not None: + inputs = pixel_values + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.perceiver( + inputs=inputs, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + ) + logits = outputs.logits if return_dict else outputs[0] + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return PerceiverClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +@add_start_docstrings( + """ +Example use of Perceiver for image classification, for tasks such as ImageNet. + +This model uses fixed 2D Fourier position embeddings. As shown in the paper, this model can achieve a top-1 accuracy of +79.0 on ImageNet, and 84.5 when pre-trained on a large-scale dataset (i.e. JFT). + +[`PerceiverForImageClassificationLearned`] uses [`~models.perceiver.modeling_perceiver.PerceiverImagePreprocessor`] +(with `prep_type="pixels"`) to preprocess the input images, and +[`~models.perceiver.modeling_perceiver.PerceiverClassificationDecoder`] to decode the latent representation of +[`PerceiverModel`] into classification logits. +""", + PERCEIVER_START_DOCSTRING, +) +class PerceiverForImageClassificationFourier(PerceiverPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + fourier_position_encoding_kwargs_preprocessor = { + "concat_pos": True, + "max_resolution": (224, 224), + "num_bands": 64, + "sine_only": False, + } + trainable_position_encoding_kwargs_decoder = {"num_channels": config.d_latents, "index_dims": 1} + + self.num_labels = config.num_labels + self.perceiver = PerceiverModel( + config, + input_preprocessor=PerceiverImagePreprocessor( + config, + prep_type="pixels", + spatial_downsample=1, + fourier_position_encoding_kwargs=fourier_position_encoding_kwargs_preprocessor, + ), + decoder=PerceiverClassificationDecoder( + config, + num_channels=config.d_latents, + trainable_position_encoding_kwargs=trainable_position_encoding_kwargs_decoder, + use_query_residual=True, + ), + ) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(PERCEIVER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=PerceiverClassifierOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + inputs: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + labels: Optional[torch.Tensor] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + ) -> Union[Tuple, PerceiverClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + Returns: + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, PerceiverForImageClassificationFourier + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("deepmind/vision-perceiver-fourier") + >>> model = PerceiverForImageClassificationFourier.from_pretrained("deepmind/vision-perceiver-fourier") + + >>> inputs = image_processor(images=image, return_tensors="pt").pixel_values + >>> outputs = model(inputs=inputs) + >>> logits = outputs.logits + >>> list(logits.shape) + [1, 1000] + + >>> # model predicts one of the 1000 ImageNet classes + >>> predicted_class_idx = logits.argmax(-1).item() + >>> print("Predicted class:", model.config.id2label[predicted_class_idx]) + Predicted class: tabby, tabby cat + ```""" + if inputs is not None and pixel_values is not None: + raise ValueError("You cannot use both `inputs` and `pixel_values`") + elif inputs is None and pixel_values is not None: + inputs = pixel_values + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.perceiver( + inputs=inputs, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + logits = outputs.logits if return_dict else outputs[0] + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return PerceiverClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +@add_start_docstrings( + """ +Example use of Perceiver for image classification, for tasks such as ImageNet. + +This model uses a 2D conv+maxpool preprocessing network. As shown in the paper, this model can achieve a top-1 accuracy +of 82.1 on ImageNet. + +[`PerceiverForImageClassificationLearned`] uses [`~models.perceiver.modeling_perceiver.PerceiverImagePreprocessor`] +(with `prep_type="conv"`) to preprocess the input images, and +[`~models.perceiver.modeling_perceiver.PerceiverClassificationDecoder`] to decode the latent representation of +[`PerceiverModel`] into classification logits. +""", + PERCEIVER_START_DOCSTRING, +) +class PerceiverForImageClassificationConvProcessing(PerceiverPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + fourier_position_encoding_kwargs_preprocessor = { + "concat_pos": True, + "max_resolution": (56, 56), + "num_bands": 64, + "sine_only": False, + } + trainable_position_encoding_kwargs_decoder = {"num_channels": config.d_latents, "index_dims": 1} + + self.num_labels = config.num_labels + self.perceiver = PerceiverModel( + config, + input_preprocessor=PerceiverImagePreprocessor( + config, + prep_type="conv", + spatial_downsample=1, + position_encoding_type="fourier", + fourier_position_encoding_kwargs=fourier_position_encoding_kwargs_preprocessor, + ), + decoder=PerceiverClassificationDecoder( + config, + num_channels=config.d_latents, + trainable_position_encoding_kwargs=trainable_position_encoding_kwargs_decoder, + use_query_residual=True, + ), + ) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(PERCEIVER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=PerceiverClassifierOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + inputs: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + labels: Optional[torch.Tensor] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + ) -> Union[Tuple, PerceiverClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + Returns: + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, PerceiverForImageClassificationConvProcessing + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("deepmind/vision-perceiver-conv") + >>> model = PerceiverForImageClassificationConvProcessing.from_pretrained("deepmind/vision-perceiver-conv") + + >>> inputs = image_processor(images=image, return_tensors="pt").pixel_values + >>> outputs = model(inputs=inputs) + >>> logits = outputs.logits + >>> list(logits.shape) + [1, 1000] + + >>> # model predicts one of the 1000 ImageNet classes + >>> predicted_class_idx = logits.argmax(-1).item() + >>> print("Predicted class:", model.config.id2label[predicted_class_idx]) + Predicted class: tabby, tabby cat + ```""" + if inputs is not None and pixel_values is not None: + raise ValueError("You cannot use both `inputs` and `pixel_values`") + elif inputs is None and pixel_values is not None: + inputs = pixel_values + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.perceiver( + inputs=inputs, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + logits = outputs.logits if return_dict else outputs[0] + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return PerceiverClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +@add_start_docstrings( + """ +Example use of Perceiver for optical flow, for tasks such as Sintel and KITTI. [`PerceiverForOpticalFlow`] uses +[`~models.perceiver.modeling_perceiver.PerceiverImagePreprocessor`] (with *prep_type="patches"*) to preprocess the +input images, and [`~models.perceiver.modeling_perceiver.PerceiverOpticalFlowDecoder`] to decode the latent +representation of [`PerceiverModel`]. + +As input, one concatenates 2 subsequent frames along the channel dimension and extract a 3 x 3 patch around each pixel +(leading to 3 x 3 x 3 x 2 = 54 values for each pixel). Fixed Fourier position encodings are used to encode the position +of each pixel in the patch. Next, one applies the Perceiver encoder. To decode, one queries the latent representation +using the same encoding used for the input. +""", + PERCEIVER_START_DOCSTRING, +) +class PerceiverForOpticalFlow(PerceiverPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + fourier_position_encoding_kwargs_preprocessor = { + "num_bands": 64, + "max_resolution": config.train_size, + "sine_only": False, + "concat_pos": True, + } + fourier_position_encoding_kwargs_decoder = { + "concat_pos": True, + "max_resolution": config.train_size, + "num_bands": 64, + "sine_only": False, + } + + image_preprocessor = PerceiverImagePreprocessor( + config, + prep_type="patches", + spatial_downsample=1, + conv_after_patching=True, + conv_after_patching_in_channels=54, + temporal_downsample=2, + position_encoding_type="fourier", + # position_encoding_kwargs + fourier_position_encoding_kwargs=fourier_position_encoding_kwargs_preprocessor, + ) + + self.perceiver = PerceiverModel( + config, + input_preprocessor=image_preprocessor, + decoder=PerceiverOpticalFlowDecoder( + config, + num_channels=image_preprocessor.num_channels, + output_image_shape=config.train_size, + rescale_factor=100.0, + # decoder kwargs + use_query_residual=False, + output_num_channels=2, + # We query the decoder using the first frame features + # rather than a standard decoder position encoding. + position_encoding_type="fourier", + fourier_position_encoding_kwargs=fourier_position_encoding_kwargs_decoder, + ), + ) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(PERCEIVER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=PerceiverClassifierOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + inputs: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + labels: Optional[torch.Tensor] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, PerceiverClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the optical flow loss. Indices should be in `[0, ..., config.num_labels - 1]`. + + Returns: + + Examples: + + ```python + >>> from transformers import PerceiverForOpticalFlow + >>> import torch + + >>> model = PerceiverForOpticalFlow.from_pretrained("deepmind/optical-flow-perceiver") + + >>> # in the Perceiver IO paper, the authors extract a 3 x 3 patch around each pixel, + >>> # leading to 3 x 3 x 3 = 27 values for each pixel (as each pixel also has 3 color channels) + >>> # patches have shape (batch_size, num_frames, num_channels, height, width) + >>> # the authors train on resolutions of 368 x 496 + >>> patches = torch.randn(1, 2, 27, 368, 496) + >>> outputs = model(inputs=patches) + >>> logits = outputs.logits + >>> list(logits.shape) + [1, 368, 496, 2] + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + loss = None + if labels is not None: + raise NotImplementedError("Optical flow training is not yet supported") + + outputs = self.perceiver( + inputs=inputs, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + logits = outputs.logits if return_dict else outputs[0] + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return PerceiverClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +@add_start_docstrings( + """ +Example use of Perceiver for multimodal (video) autoencoding, for tasks such as Kinetics-700. + +[`PerceiverForMultimodalAutoencoding`] uses [`~models.perceiver.modeling_perceiver.PerceiverMultimodalPreprocessor`] to +preprocess the 3 modalities: images, audio and class labels. This preprocessor uses modality-specific preprocessors to +preprocess every modality separately, after which they are concatenated. Trainable position embeddings are used to pad +each modality to the same number of channels to make concatenation along the time dimension possible. Next, one applies +the Perceiver encoder. + +[`~models.perceiver.modeling_perceiver.PerceiverMultimodalDecoder`] is used to decode the latent representation of +[`PerceiverModel`]. This decoder uses each modality-specific decoder to construct queries. The decoder queries are +created based on the inputs after preprocessing. However, autoencoding an entire video in a single forward pass is +computationally infeasible, hence one only uses parts of the decoder queries to do cross-attention with the latent +representation. This is determined by the subsampled indices for each modality, which can be provided as additional +input to the forward pass of [`PerceiverForMultimodalAutoencoding`]. + +[`~models.perceiver.modeling_perceiver.PerceiverMultimodalDecoder`] also pads the decoder queries of the different +modalities to the same number of channels, in order to concatenate them along the time dimension. Next, cross-attention +is performed with the latent representation of [`PerceiverModel`]. + +Finally, [`~models.perceiver.modeling_perceiver.PerceiverMultiModalPostprocessor`] is used to turn this tensor into an +actual video. It first splits up the output into the different modalities, and then applies the respective +postprocessor for each modality. + +Note that, by masking the classification label during evaluation (i.e. simply providing a tensor of zeros for the +"label" modality), this auto-encoding model becomes a Kinetics 700 video classifier. +""", + PERCEIVER_START_DOCSTRING, +) +class PerceiverForMultimodalAutoencoding(PerceiverPreTrainedModel): + def __init__(self, config: PerceiverConfig): + super().__init__(config) + + n_audio_samples = config.num_frames * config.audio_samples_per_frame + + input_preprocessor = PerceiverMultimodalPreprocessor( + min_padding_size=4, + modalities={ + "audio": PerceiverAudioPreprocessor( + config, + position_encoding_type="fourier", + fourier_position_encoding_kwargs={ + "num_bands": 192, + "max_resolution": (n_audio_samples,), + "sine_only": False, + "concat_pos": True, + }, + prep_type="patches", + samples_per_patch=config.samples_per_patch, + ), + "image": PerceiverImagePreprocessor( + config, + position_encoding_type="fourier", + fourier_position_encoding_kwargs={ + "num_bands": 32, + "max_resolution": (config.num_frames, config.image_size, config.image_size), + "sine_only": False, + "concat_pos": True, + }, + prep_type="patches", + spatial_downsample=4, + temporal_downsample=1, + ), + "label": PerceiverOneHotPreprocessor(config), + }, + mask_probs={"image": 0.0, "audio": 0.0, "label": 1.0}, + ) + + image_decoder = PerceiverBasicVideoAutoencodingDecoder( + config, + # Autoencoding, don't pass inputs to the queries. + concat_preprocessed_input=False, + output_shape=config.output_shape, + output_num_channels=config.output_num_channels, + use_query_residual=False, + position_encoding_only=True, + position_encoding_type="fourier", + fourier_position_encoding_kwargs={ + "num_bands": 32, + "max_resolution": (config.num_frames, config.image_size, config.image_size), + "sine_only": False, + "concat_pos": True, + }, + ) + + decoder = PerceiverMultimodalDecoder( + config, + # Autoencoding, don't pass inputs to the queries. + concat_preprocessed_input=False, + # Modality specific decoders are used ONLY to generate queries. + # All modalties are decoded together using a unified decoder. + modalities={ + "audio": PerceiverBasicDecoder( + config, + # Autoencoding, don't pass inputs to the queries. + concat_preprocessed_input=False, + output_index_dims=(n_audio_samples // config.samples_per_patch,), + output_num_channels=config.output_num_channels, + use_query_residual=False, + position_encoding_only=True, + position_encoding_type="fourier", + fourier_position_encoding_kwargs={ + "num_bands": 192, + "max_resolution": (n_audio_samples,), + "sine_only": False, + "concat_pos": True, + }, + ), + "image": image_decoder, + "label": PerceiverClassificationDecoder( + config, + # Autoencoding, don't pass inputs to the queries. + concat_preprocessed_input=False, + use_query_residual=False, + position_encoding_only=True, + position_encoding_type="trainable", + trainable_position_encoding_kwargs={ + "num_channels": config._label_trainable_num_channels, + "index_dims": 1, + }, + ), + }, + num_outputs=None, + output_num_channels=config.output_num_channels, + use_query_residual=False, + ) + + output_postprocessor = PerceiverMultimodalPostprocessor( + modalities={ + "audio": PerceiverAudioPostprocessor(config, in_channels=config.output_num_channels), + "image": PerceiverProjectionPostprocessor(in_channels=config.output_num_channels, out_channels=3), + "label": PerceiverClassificationPostprocessor(config, in_channels=config.output_num_channels), + } + ) + + self.perceiver = PerceiverModel( + config, + input_preprocessor=input_preprocessor, + decoder=decoder, + output_postprocessor=output_postprocessor, + ) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(PERCEIVER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=PerceiverClassifierOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + inputs: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + subsampled_output_points: Optional[Dict[str, torch.Tensor]] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + labels: Optional[torch.Tensor] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, PerceiverClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + Returns: + + Examples: + + ```python + >>> from transformers import PerceiverForMultimodalAutoencoding + >>> import torch + >>> import numpy as np + + >>> # create multimodal inputs + >>> images = torch.randn((1, 16, 3, 224, 224)) + >>> audio = torch.randn((1, 30720, 1)) + >>> inputs = dict(image=images, audio=audio, label=torch.zeros((images.shape[0], 700))) + + >>> model = PerceiverForMultimodalAutoencoding.from_pretrained("deepmind/multimodal-perceiver") + + >>> # in the Perceiver IO paper, videos are auto-encoded in chunks + >>> # each chunk subsamples different index dimensions of the image and audio modality decoder queries + >>> nchunks = 128 + >>> image_chunk_size = np.prod((16, 224, 224)) // nchunks + >>> audio_chunk_size = audio.shape[1] // model.config.samples_per_patch // nchunks + >>> # process the first chunk + >>> chunk_idx = 0 + >>> subsampling = { + ... "image": torch.arange(image_chunk_size * chunk_idx, image_chunk_size * (chunk_idx + 1)), + ... "audio": torch.arange(audio_chunk_size * chunk_idx, audio_chunk_size * (chunk_idx + 1)), + ... "label": None, + ... } + + >>> outputs = model(inputs=inputs, subsampled_output_points=subsampling) + >>> logits = outputs.logits + >>> list(logits["audio"].shape) + [1, 240] + + >>> list(logits["image"].shape) + [1, 6272, 3] + + >>> list(logits["label"].shape) + [1, 700] + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + loss = None + if labels is not None: + raise NotImplementedError("Multimodal autoencoding training is not yet supported") + + outputs = self.perceiver( + inputs=inputs, + attention_mask=attention_mask, + subsampled_output_points=subsampled_output_points, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + logits = outputs.logits if return_dict else outputs[0] + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return PerceiverClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +# Below: position encodings + + +def build_position_encoding( + position_encoding_type, + out_channels=None, + project_pos_dim=-1, + trainable_position_encoding_kwargs=None, + fourier_position_encoding_kwargs=None, +): + """ + Builds the position encoding. + + Args: + - out_channels: refers to the number of channels of the position encodings. + - project_pos_dim: if specified, will project the position encodings to this dimension. + + """ + + if position_encoding_type == "trainable": + if not trainable_position_encoding_kwargs: + raise ValueError("Make sure to pass trainable_position_encoding_kwargs") + output_pos_enc = PerceiverTrainablePositionEncoding(**trainable_position_encoding_kwargs) + elif position_encoding_type == "fourier": + # We don't use the index_dims argument, as this is only known during the forward pass + if not fourier_position_encoding_kwargs: + raise ValueError("Make sure to pass fourier_position_encoding_kwargs") + output_pos_enc = PerceiverFourierPositionEncoding(**fourier_position_encoding_kwargs) + else: + raise ValueError(f"Unknown position encoding type: {position_encoding_type}.") + + # Optionally, project the position encoding to a target dimension: + positions_projection = nn.Linear(out_channels, project_pos_dim) if project_pos_dim > 0 else nn.Identity() + + return output_pos_enc, positions_projection + + +# Below: Perceiver decoders + + +class PerceiverAbstractDecoder(nn.Module, metaclass=abc.ABCMeta): + """Perceiver abstract decoder.""" + + @abc.abstractmethod + def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, subsampled_points=None): + raise NotImplementedError + + @property + @abc.abstractmethod + def num_query_channels(self): + raise NotImplementedError + + @abc.abstractmethod + def forward(self, query, z, query_mask=None): + raise NotImplementedError + + +class PerceiverProjectionDecoder(PerceiverAbstractDecoder): + """ + Baseline projection decoder (no cross-attention). + + Args: + config ([`PerceiverConfig`]): + Model configuration. + """ + + def __init__(self, config): + super().__init__() + self.classifier = nn.Linear(config.d_latents, config.num_labels) + + def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, subsampled_points=None): + return None + + def forward( + self, query: torch.Tensor, z: torch.FloatTensor, query_mask: Optional[torch.FloatTensor] = None + ) -> torch.FloatTensor: + # (batch_size, num_latents, d_latents) -> (batch_size, d_latents) + z = torch.mean(z, dim=1) + # (batch_size, d_latents) -> (batch_size, config.num_labels) + logits = self.classifier(z) + return logits + + +class PerceiverBasicDecoder(PerceiverAbstractDecoder): + """ + Cross-attention-based decoder. This class can be used to decode the final hidden states of the latents using a + cross-attention operation, in which the latents produce keys and values. + + The shape of the output of this class depends on how one defines the output queries (also called decoder queries). + + Args: + config ([*PerceiverConfig*]): + Model configuration. + output_num_channels (`int`, *optional*): + The number of channels in the output. Will only be used in case *final_project* is set to `True`. + position_encoding_type (`str`, *optional*, defaults to "trainable"): + The type of position encoding to use. Can be either "trainable", "fourier", or "none". + output_index_dims (`int`, *optional*): + The number of dimensions of the output queries. Ignored if 'position_encoding_type' == 'none'. + num_channels (`int`, *optional*, defaults to 128): + The number of channels of the decoder queries. Ignored if 'position_encoding_type' == 'none'. + qk_channels (`int`, *optional*): + The number of channels of the queries and keys in the cross-attention layer. + v_channels (`int`, *optional*): + The number of channels of the values in the cross-attention layer. + num_heads (`int`, *optional*, defaults to 1): + The number of attention heads in the cross-attention layer. + widening_factor (`int`, *optional*, defaults to 1): + The widening factor of the cross-attention layer. + use_query_residual (`bool`, *optional*, defaults to `False`): + Whether to use a residual connection between the query and the output of the cross-attention layer. + concat_preprocessed_input (`bool`, *optional*, defaults to `False`): + Whether to concatenate the preprocessed input to the query. + final_project (`bool`, *optional*, defaults to `True`): + Whether to project the output of the cross-attention layer to a target dimension. + position_encoding_only (`bool`, *optional*, defaults to `False`): + Whether to only use this class to define output queries. + """ + + def __init__( + self, + config: PerceiverConfig, + output_num_channels: int, + position_encoding_type: Optional[str] = "trainable", + # The following 2 arguments are ignored if position_encoding_type == 'none': + output_index_dims: Optional[int] = None, + num_channels: Optional[int] = 128, + subsampled_index_dims: Optional[int] = None, + qk_channels: Optional[int] = None, + v_channels: Optional[int] = None, + num_heads: Optional[int] = 1, + widening_factor: Optional[int] = 1, + use_query_residual: Optional[bool] = False, + concat_preprocessed_input: Optional[bool] = False, + final_project: Optional[bool] = True, + position_encoding_only: Optional[bool] = False, + **position_encoding_kwargs, + ) -> None: + super().__init__() + + self.output_num_channels = output_num_channels + # If `none`, the decoder will not construct any position encodings. + # You should construct your own when querying the decoder. + self.output_position_encodings = None + self.position_encoding_type = position_encoding_type + self.position_encoding_kwargs = position_encoding_kwargs + if position_encoding_type != "none": + self.output_position_encodings, self.positions_projection = build_position_encoding( + position_encoding_type=position_encoding_type, **position_encoding_kwargs + ) + + self.output_index_dims = output_index_dims + self.num_channels = num_channels + if subsampled_index_dims is None: + subsampled_index_dims = output_index_dims + self.subsampled_index_dims = subsampled_index_dims + self.concat_preprocessed_input = concat_preprocessed_input + self.final_project = final_project + self.position_encoding_only = position_encoding_only + + # for multimodal autoencoding, we don't need the decoder cross-attention and final layer + # so then we will set position_encoding_only to True + if not self.position_encoding_only: + self.decoding_cross_attention = PerceiverLayer( + config, + is_cross_attention=True, + qk_channels=qk_channels, + v_channels=v_channels, + num_heads=num_heads, + q_dim=num_channels, + kv_dim=config.d_latents, + widening_factor=widening_factor, + use_query_residual=use_query_residual, + ) + self.final_layer = nn.Linear(num_channels, output_num_channels) if final_project else nn.Identity() + + @property + def num_query_channels(self) -> int: + if self.position_encoding_type == "none": # Queries come from elsewhere + raise ValueError( + "You cannot calculate number of decoder query channels when position_encoding_type is set to none" + ) + if self.position_encoding_only: + if "project_pos_dim" in self.position_encoding_kwargs: + return self.position_encoding_kwargs["project_pos_dim"] + return self.output_position_encodings.output_size() + if self.final_project: + return self.output_num_channels + return self.num_channels + + def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, subsampled_points=None): + if self.position_encoding_type == "none": # Queries come from elsewhere + raise ValueError("You cannot construct decoder queries when position_encoding_type is set to none") + if subsampled_points is not None: + # subsampled_points are the indices if the inputs would be flattened + # however, the inputs aren't flattened, that's why we use unravel_index + # to get the indices for the unflattened array + # unravel_index returns a tuple (x_idx, y_idx, ...) + # stack to get the [n, d] tensor of coordinates + indices = [torch.from_numpy(x) for x in np.unravel_index(subsampled_points.cpu(), self.output_index_dims)] + pos = torch.stack(indices, dim=1) + batch_size = inputs.shape[0] + # Map these coordinates to [-1, 1] + pos = -1 + 2 * pos / torch.tensor(self.output_index_dims)[None, :] + pos = torch.broadcast_to(pos[None], [batch_size, pos.shape[0], pos.shape[1]]) + # Construct the position encoding. + if self.position_encoding_type == "trainable": + pos_emb = self.output_position_encodings(batch_size) + elif self.position_encoding_type == "fourier": + pos_emb = self.output_position_encodings( + self.output_index_dims, batch_size=batch_size, device=inputs.device, dtype=inputs.dtype, pos=pos + ) + + # Optionally project them to a target dimension. + pos_emb = self.positions_projection(pos_emb) + pos_emb = torch.reshape(pos_emb, [pos_emb.shape[0], -1, pos_emb.shape[-1]]) + else: + batch_size = inputs.shape[0] + index_dims = inputs.shape[2:] + + # Construct the position encoding. + if self.position_encoding_type == "trainable": + pos_emb = self.output_position_encodings(batch_size) + elif self.position_encoding_type == "fourier": + pos_emb = self.output_position_encodings( + index_dims, batch_size, device=inputs.device, dtype=inputs.dtype + ) + + # Optionally project them to a target dimension. + pos_emb = self.positions_projection(pos_emb) + + if self.concat_preprocessed_input: + if inputs_without_pos is None: + raise ValueError("Value is required for inputs_without_pos if concat_preprocessed_input is True") + pos_emb = torch.cat([inputs_without_pos, pos_emb], dim=-1) + + return pos_emb + + def forward( + self, + query: torch.Tensor, + z: torch.FloatTensor, + query_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> PerceiverDecoderOutput: + # Cross-attention decoding. + # key, value: B x N x K; query: B x M x K + # Attention maps -> B x N x M + # Output -> B x M x K + cross_attentions = () if output_attentions else None + + layer_outputs = self.decoding_cross_attention( + query, + attention_mask=query_mask, + head_mask=None, + inputs=z, + inputs_mask=None, + output_attentions=output_attentions, + ) + output = layer_outputs[0] + + if output_attentions: + cross_attentions = cross_attentions + (layer_outputs[1],) + + logits = self.final_layer(output) + + return PerceiverDecoderOutput(logits=logits, cross_attentions=cross_attentions) + + +class PerceiverClassificationDecoder(PerceiverAbstractDecoder): + """ + Cross-attention based classification decoder. Light-weight wrapper of [`PerceiverBasicDecoder`] for logit output. + Will turn the output of the Perceiver encoder which is of shape (batch_size, num_latents, d_latents) to a tensor of + shape (batch_size, num_labels). The queries are of shape (batch_size, 1, num_labels). + + Args: + config ([`PerceiverConfig`]): + Model configuration. + """ + + def __init__(self, config, **decoder_kwargs): + super().__init__() + + self.num_labels = config.num_labels + self.decoder = PerceiverBasicDecoder( + config, + output_num_channels=self.num_labels, + output_index_dims=1, # Predict a single logit array. + **decoder_kwargs, + ) + + @property + def num_query_channels(self) -> int: + return self.decoder.num_query_channels + + def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, subsampled_points=None): + return self.decoder.decoder_query( + inputs, modality_sizes, inputs_without_pos, subsampled_points=subsampled_points + ) + + def forward( + self, + query: torch.Tensor, + z: torch.FloatTensor, + query_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> PerceiverDecoderOutput: + decoder_outputs = self.decoder(query, z, output_attentions=output_attentions) + + # B x 1 x num_classes -> B x num_classes + logits = decoder_outputs.logits[:, 0, :] + + return PerceiverDecoderOutput(logits=logits, cross_attentions=decoder_outputs.cross_attentions) + + +class PerceiverOpticalFlowDecoder(PerceiverAbstractDecoder): + """Cross-attention based optical flow decoder.""" + + def __init__(self, config, output_image_shape, output_num_channels=2, rescale_factor=100.0, **decoder_kwargs): + super().__init__() + + self.output_image_shape = output_image_shape + self.output_num_channels = output_num_channels + self.rescale_factor = rescale_factor + self.decoder = PerceiverBasicDecoder(config, output_num_channels=output_num_channels, **decoder_kwargs) + + @property + def num_query_channels(self) -> int: + return self.decoder.num_query_channels + + def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, subsampled_points=None): + if subsampled_points is not None: + raise ValueError("FlowDecoder doesn't support subsampling yet.") + return inputs + + def forward( + self, + query: torch.Tensor, + z: torch.FloatTensor, + query_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> PerceiverDecoderOutput: + decoder_outputs = self.decoder(query, z, output_attentions=output_attentions) + preds = decoder_outputs.logits + # Output flow and rescale. + preds /= self.rescale_factor + preds = preds.reshape([preds.shape[0]] + list(self.output_image_shape) + [preds.shape[-1]]) + return PerceiverDecoderOutput(logits=preds, cross_attentions=decoder_outputs.cross_attentions) + + +class PerceiverBasicVideoAutoencodingDecoder(PerceiverAbstractDecoder): + """ + Cross-attention based video-autoencoding decoder. Light-weight wrapper of [*PerceiverBasicDecoder*] with video + reshaping logic. + + Args: + config ([*PerceiverConfig*]): + Model configuration. + output_shape (`List[int]`): + Shape of the output as (batch_size, num_frames, height, width), excluding the channel dimension. + position_encoding_type (`str`): + The type of position encoding to use. Can be either "trainable", "fourier", or "none". + """ + + def __init__( + self, config: PerceiverConfig, output_shape: List[int], position_encoding_type: str, **decoder_kwargs + ) -> None: + super().__init__() + if len(output_shape) != 4: # B, T, H, W + raise ValueError(f"Expected rank 4 output_shape, got {output_shape}.") + # Build the decoder components: + self.output_shape = output_shape + self.output_num_channels = decoder_kwargs["output_num_channels"] + + self.decoder = PerceiverBasicDecoder( + config, + output_index_dims=self.output_shape[1:4], # T*H*W + position_encoding_type=position_encoding_type, + **decoder_kwargs, + ) + + @property + def num_query_channels(self) -> int: + return self.decoder.num_query_channels + + def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, subsampled_points=None): + return self.decoder.decoder_query( + inputs, + modality_sizes=modality_sizes, + inputs_without_pos=inputs_without_pos, + subsampled_points=subsampled_points, + ) + + def forward( + self, query: torch.Tensor, z: torch.FloatTensor, query_mask: Optional[torch.FloatTensor] = None + ) -> PerceiverDecoderOutput: + decoder_outputs = self.decoder(query, z) + logits = decoder_outputs.logits + + logits = torch.reshape(logits, self.output_shape + [logits.shape[-1]]) + return PerceiverDecoderOutput(logits=logits, cross_attentions=decoder_outputs.cross_attentions) + + +def restructure(modality_sizes: ModalitySizeType, inputs: torch.Tensor) -> Mapping[str, torch.Tensor]: + """ + Partitions a [B, N, C] tensor into tensors for each modality. + + Args: + modality_sizes + dict specifying the size of the modality + inputs: + input tensor + + Returns: + dict mapping name of modality to its associated tensor. + """ + outputs = {} + index = 0 + # Apply a predictable ordering to the modalities + for modality in sorted(modality_sizes.keys()): + size = modality_sizes[modality] + inp = inputs[:, index : index + size] + index += size + outputs[modality] = inp + return outputs + + +class PerceiverMultimodalDecoder(PerceiverAbstractDecoder): + """ + Multimodal decoding by composing uni-modal decoders. The *modalities* argument of the constructor is a dictionary + mapping modality name to the decoder of that modality. That decoder will be used to construct queries for that + modality. Modality-specific queries are padded with trainable modality-specific parameters, after which they are + concatenated along the time dimension. + + Next, there is a shared cross attention operation across all modalities. + + Args: + config ([*PerceiverConfig*]): + Model configuration. + modalities (`Dict[str, PerceiverAbstractDecoder]`): + Dictionary mapping modality name to the decoder of that modality. + num_outputs (`int`): + The number of outputs of the decoder. + output_num_channels (`int`): + The number of channels in the output. + min_padding_size (`int`, *optional*, defaults to 2): + The minimum padding size for all modalities. The final output will have num_channels equal to the maximum + channels across all modalities plus min_padding_size. + subsampled_index_dims (`Dict[str, PerceiverAbstractDecoder]`, *optional*): + Dictionary mapping modality name to the subsampled index dimensions to use for the decoder query of that + modality. + """ + + def __init__( + self, + config: PerceiverConfig, + modalities: Dict[str, PerceiverAbstractDecoder], + num_outputs: int, + output_num_channels: int, + min_padding_size: Optional[int] = 2, + subsampled_index_dims: Optional[Dict[str, PerceiverAbstractDecoder]] = None, + **decoder_kwargs, + ) -> None: + super().__init__() + self.modalities = nn.ModuleDict(modalities) + self.subsampled_index_dims = subsampled_index_dims + self.min_padding_size = min_padding_size + self.output_num_channels = output_num_channels + self.num_outputs = num_outputs + self.decoder = PerceiverBasicDecoder( + config, + output_index_dims=(num_outputs,), + output_num_channels=output_num_channels, + position_encoding_type="none", + num_channels=self.num_query_channels, + **decoder_kwargs, + ) + self.padding = nn.ParameterDict( + { + modality: nn.Parameter(torch.randn(1, self.num_query_channels - decoder.num_query_channels)) + for modality, decoder in modalities.items() + } + ) + + @property + def num_query_channels(self) -> int: + max_channel_size = max(decoder.num_query_channels for _, decoder in self.modalities.items()) + common_channel_size = max_channel_size + self.min_padding_size + return common_channel_size + + def decoder_query(self, inputs, modality_sizes, inputs_without_pos=None, subsampled_points=None): + # Partition the flat inputs among the different modalities + inputs = restructure(modality_sizes, inputs) + + # Obtain modality-specific decoders' queries + subsampled_points = subsampled_points or {} + + decoder_queries = {} + for modality, decoder in self.modalities.items(): + # Get input_without_pos for this modality if it exists. + input_without_pos = None + if inputs_without_pos is not None: + input_without_pos = inputs_without_pos.get(modality, None) + query = decoder.decoder_query( + inputs=inputs[modality], + modality_sizes=None, + inputs_without_pos=input_without_pos, + subsampled_points=subsampled_points.get(modality, None), + ) + decoder_queries[modality] = query + + # Pad all queries with trainable position encodings to make them have the same channels + + def embed(modality, x): + x = torch.reshape(x, [x.shape[0], np.prod(x.shape[1:-1]), x.shape[-1]]) + pos = self.padding[modality] + pos = torch.broadcast_to(pos, [x.shape[0], x.shape[1], self.num_query_channels - x.shape[2]]) + return torch.cat([x, pos], dim=2) + + # Apply a predictable ordering to the modalities + return torch.cat( + [embed(modality, decoder_queries[modality]) for modality in sorted(self.modalities.keys())], dim=1 + ) + + def forward( + self, + query: torch.Tensor, + z: torch.FloatTensor, + query_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> torch.Tensor: + # B x 1 x num_classes -> B x num_classes + decoder_outputs = self.decoder(query, z, output_attentions=output_attentions) + + return decoder_outputs + + +# Below: IO pre- and post-processor classes for Perceiver. +def space_to_depth(frames: torch.Tensor, temporal_block_size: int = 1, spatial_block_size: int = 1) -> torch.Tensor: + """ + Space to depth transform. Rearranges blocks of spatial data, into depth. + + This function assumes the channels to be first, but will place the channels last after transformation. + + Based on https://discuss.pytorch.org/t/is-there-any-layer-like-tensorflows-space-to-depth-function/3487/15. + """ + if len(frames.shape) == 4: + batch_size, num_channels, height, width = frames.shape + # split up dimensions (height by spatial_block_size, width by spatial_block_size) + frames = frames.view( + batch_size, + num_channels, + height // spatial_block_size, + spatial_block_size, + width // spatial_block_size, + spatial_block_size, + ) + # move blocks to last dimension: (batch_size, H//bs, W//bs, bs, bs, C) + frames = frames.permute(0, 2, 4, 3, 5, 1).contiguous() + # concatenate blocks along channel dimension: (batch_size, H//bs, W//bs, bs*bs*C) + frames = frames.view( + batch_size, + height // spatial_block_size, + width // spatial_block_size, + (spatial_block_size**2) * num_channels, + ) + return frames + elif len(frames.shape) == 5: + batch_size, time, num_channels, height, width = frames.shape + # split up dimensions (time by temporal_block_size, height by spatial_block_size, width by spatial_block_size) + frames = frames.view( + batch_size, + time // temporal_block_size, + temporal_block_size, + num_channels, + height // spatial_block_size, + spatial_block_size, + width // spatial_block_size, + spatial_block_size, + ) + # move blocks to last dimension: (batch_size, T//ts, H//bs, W//bs, ts, bs, bs, C) + frames = frames.permute(0, 1, 4, 6, 2, 5, 7, 3).contiguous() + # concatenate blocks along channel dimension: (batch_size, T//ts, H//bs, W//bs, ts*bs*bs*C) + frames = frames.view( + batch_size, + time // temporal_block_size, + height // spatial_block_size, + width // spatial_block_size, + temporal_block_size * (spatial_block_size**2) * num_channels, + ) + return frames + else: + raise ValueError( + "Frames should be of rank 4 (batch, channels, height, width)" + " or rank 5 (batch, time, channels, height, width)" + ) + + +class Conv2dSamePadding(nn.Conv2d): + """ + Conv2d layer with padding="same" support. Source: + https://gist.github.com/sumanmichael/4de9dee93f972d47c80c4ade8e149ea6 + """ + + def __init__(self, *args, **kwargs): + super(Conv2dSamePadding, self).__init__(*args, **kwargs) + self.zero_pad_2d = nn.ZeroPad2d( + reduce(__add__, [(k // 2 + (k - 2 * (k // 2)) - 1, k // 2) for k in self.kernel_size[::-1]]) + ) + + def forward(self, input): + return self._conv_forward(self.zero_pad_2d(input), self.weight, self.bias) + + +class Conv2DDownsample(nn.Module): + """Downsamples 4x by applying a 2D convolution and doing max pooling.""" + + def __init__( + self, + num_layers: int = 1, + in_channels: int = 3, + out_channels: int = 64, + use_batchnorm: bool = True, + ): + """ + Constructs a Conv2DDownsample model. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 64): + The number of conv output channels. + use_batchnorm (`bool`, *optional*, defaults to `True`): + Whether to use batchnorm. + """ + super().__init__() + + self.conv = Conv2dSamePadding( + in_channels=in_channels, out_channels=out_channels, kernel_size=7, stride=2, bias=False + ) + self.batchnorm = nn.BatchNorm2d(num_features=out_channels) if use_batchnorm else nn.Identity() + self.relu = nn.ReLU() + self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + out = self.conv(inputs) + out = self.batchnorm(out) + out = self.relu(out) + out = self.max_pool(out) + return out + + +def generate_fourier_features(pos, num_bands, max_resolution=(224, 224), concat_pos=True, sine_only=False): + """ + Generate a Fourier frequency position encoding with linear spacing. + + Args: + pos (`torch.LongTensor` of shape `(batch_size, sequence_length, dim)`): + The Tensor containing the position of n points in d dimensional space. + num_bands (`int`): + The number of frequency bands (K) to use. + max_resolution (`Tuple[int]`, *optional*, defaults to (224, 224)): + The maximum resolution (i.e. the number of pixels per dim). A tuple representing resolution for each dimension. + concat_pos (`bool`, *optional*, defaults to `True`): + Whether to concatenate the input position encoding to the Fourier features. + sine_only (`bool`, *optional*, defaults to `False`): + Whether to use a single phase (sin) or two (sin/cos) for each frequency band. + + Returns: + `torch.FloatTensor` of shape `(batch_size, sequence_length, n_channels)`: The Fourier position embeddings. If + `concat_pos` is `True` and `sine_only` is `False`, output dimensions are ordered as: [dim_1, dim_2, ..., dim_d, + sin(pi*f_1*dim_1), ..., sin(pi*f_K*dim_1), ..., sin(pi*f_1*dim_d), ..., sin(pi*f_K*dim_d), cos(pi*f_1*dim_1), + ..., cos(pi*f_K*dim_1), ..., cos(pi*f_1*dim_d), ..., cos(pi*f_K*dim_d)], where dim_i is pos[:, i] and f_k is the + kth frequency band. + """ + + batch_size = pos.shape[0] + + min_freq = 1.0 + # Nyquist frequency at the target resolution: + freq_bands = torch.stack( + [torch.linspace(start=min_freq, end=res / 2, steps=num_bands) for res in max_resolution], dim=0 + ) + + # Get frequency bands for each spatial dimension. + # Output is size [n, d * num_bands] + per_pos_features = pos[0, :, :][:, :, None] * freq_bands[None, :, :] + per_pos_features = torch.reshape(per_pos_features, [-1, np.prod(per_pos_features.shape[1:])]) + + if sine_only: + # Output is size [n, d * num_bands] + per_pos_features = torch.sin(np.pi * (per_pos_features)) + else: + # Output is size [n, 2 * d * num_bands] + per_pos_features = torch.cat( + [torch.sin(np.pi * per_pos_features), torch.cos(np.pi * per_pos_features)], dim=-1 + ) + # Concatenate the raw input positions. + if concat_pos: + # Adds d bands to the encoding. + per_pos_features = torch.cat([pos, per_pos_features.expand(batch_size, -1, -1)], dim=-1) + return per_pos_features + + +def build_linear_positions(index_dims, output_range=(-1.0, 1.0)): + """ + Generate an array of position indices for an N-D input array. + + Args: + index_dims (`List[int]`): + The shape of the index dimensions of the input array. + output_range (`Tuple[float]`, *optional*, defaults to `(-1.0, 1.0)`): + The min and max values taken by each input index dimension. + + Returns: + `torch.FloatTensor` of shape `(index_dims[0], index_dims[1], .., index_dims[-1], N)`. + """ + + def _linspace(n_xels_per_dim): + return torch.linspace(start=output_range[0], end=output_range[1], steps=n_xels_per_dim, dtype=torch.float32) + + dim_ranges = [_linspace(n_xels_per_dim) for n_xels_per_dim in index_dims] + array_index_grid = meshgrid(*dim_ranges, indexing="ij") + + return torch.stack(array_index_grid, dim=-1) + + +class PerceiverAbstractPositionEncoding(nn.Module, metaclass=abc.ABCMeta): + """Perceiver abstract position encoding.""" + + @property + @abc.abstractmethod + def num_dimensions(self) -> int: + raise NotImplementedError + + @abc.abstractmethod + def output_size(self, *args, **kwargs) -> int: + raise NotImplementedError + + @abc.abstractmethod + def forward(self, batch_size, pos): + raise NotImplementedError + + +class PerceiverTrainablePositionEncoding(PerceiverAbstractPositionEncoding): + """Trainable position encoding.""" + + def __init__(self, index_dims, num_channels=128): + super().__init__() + self._num_channels = num_channels + self._index_dims = index_dims + index_dim = np.prod(index_dims) + self.position_embeddings = nn.Parameter(torch.randn(index_dim, num_channels)) + + @property + def num_dimensions(self) -> int: + if isinstance(self._index_dims, int): + return 1 + return len(self._index_dims) + + def output_size(self, *args, **kwargs) -> int: + return self._num_channels + + def interpolate_pos_encoding(self, position_embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + num_positions = position_embeddings.shape[0] + new_height = new_width = torch_int(num_positions**0.5) + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and height == new_height and width == new_width: + return position_embeddings + + position_embeddings = position_embeddings.reshape(1, new_height, new_width, self._num_channels).permute( + 0, 3, 1, 2 + ) + + position_embeddings = nn.functional.interpolate( + position_embeddings, + size=(new_height, new_width), + mode="bicubic", + align_corners=False, + ) + position_embeddings = position_embeddings.reshape(1, self._num_channels, -1).permute(0, 2, 1).squeeze(0) + return position_embeddings + + def forward( + self, batch_size: int, interpolate_pos_encoding: bool = False, input_size: torch.Size = None + ) -> torch.Tensor: + position_embeddings = self.position_embeddings + + if interpolate_pos_encoding: + height, width = input_size + position_embeddings = self.interpolate_pos_encoding(position_embeddings, height, width) + + if batch_size is not None: + position_embeddings = position_embeddings.expand(batch_size, -1, -1) + return position_embeddings + + +def _check_or_build_spatial_positions(pos, index_dims, batch_size): + """ + Checks or builds spatial position features (x, y, ...). + + Args: + pos (`torch.FloatTensor`): + None, or an array of position features. If None, position features are built. Otherwise, their size is checked. + index_dims (`List[int]`): + An iterable giving the spatial/index size of the data to be featurized. + batch_size (`int`): + The batch size of the data to be featurized. + + Returns: + `torch.FloatTensor` of shape `(batch_size, prod(index_dims))` an array of position features. + """ + if pos is None: + pos = build_linear_positions(index_dims) + # equivalent to `torch.broadcast_to(pos[None], (batch_size,) + pos.shape)` + # but `torch.broadcast_to` cannot be converted to ONNX + pos = pos[None].expand((batch_size,) + pos.shape) + pos = torch.reshape(pos, [batch_size, np.prod(index_dims), -1]) + else: + # Just a warning label: you probably don't want your spatial features to + # have a different spatial layout than your pos coordinate system. + # But feel free to override if you think it'll work! + if pos.shape[-1] != len(index_dims): + raise ValueError("Spatial features have the wrong number of dimensions.") + return pos + + +class PerceiverFourierPositionEncoding(PerceiverAbstractPositionEncoding): + """Fourier (Sinusoidal) position encoding.""" + + def __init__(self, num_bands, max_resolution, concat_pos=True, sine_only=False): + super().__init__() + self.num_bands = num_bands + self.max_resolution = max_resolution + self.concat_pos = concat_pos + self.sine_only = sine_only + + @property + def num_dimensions(self) -> int: + return len(self.max_resolution) + + def output_size(self): + """Returns size of positional encodings last dimension.""" + num_dims = len(self.max_resolution) + encoding_size = self.num_bands * num_dims + if not self.sine_only: + encoding_size *= 2 + if self.concat_pos: + encoding_size += self.num_dimensions + + return encoding_size + + def forward( + self, + index_dims: List[int], + batch_size: int, + device: torch.device, + dtype: torch.dtype, + pos: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + pos = _check_or_build_spatial_positions(pos, index_dims, batch_size) + fourier_pos_enc = generate_fourier_features( + pos, + num_bands=self.num_bands, + max_resolution=self.max_resolution, + concat_pos=self.concat_pos, + sine_only=self.sine_only, + ).to(device=device, dtype=dtype) + return fourier_pos_enc + + +class AbstractPreprocessor(nn.Module): + @property + def num_channels(self) -> int: + """Returns size of preprocessor output.""" + raise NotImplementedError() + + +class PerceiverTextPreprocessor(AbstractPreprocessor): + """ + Text preprocessing for Perceiver Encoder. Can be used to embed `inputs` and add positional encodings. + + The dimensionality of the embeddings is determined by the `d_model` attribute of the configuration. + + Args: + config ([`PerceiverConfig`]): + Model configuration. + """ + + def __init__(self, config: PerceiverConfig) -> None: + super().__init__() + self.config = config + self.embeddings = nn.Embedding(num_embeddings=config.vocab_size, embedding_dim=config.d_model) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.d_model) + + @property + def num_channels(self) -> int: + return self.config.d_model + + def forward( + self, + inputs: torch.LongTensor, + pos: Optional[torch.Tensor] = None, + network_input_is_1d: bool = True, + interpolate_pos_encoding: bool = False, + ): + embeddings_without_pos = self.embeddings(inputs) + + seq_length = inputs.shape[1] + position_ids = torch.arange(0, seq_length, device=inputs.device) + embeddings = embeddings_without_pos + self.position_embeddings(position_ids) + + return embeddings, None, embeddings_without_pos + + +class PerceiverEmbeddingDecoder(nn.Module): + """ + Module to decode embeddings (for masked language modeling). + + Args: + config ([`PerceiverConfig`]): + Model configuration. + """ + + def __init__(self, config: PerceiverConfig) -> None: + super().__init__() + self.config = config + self.vocab_size = config.vocab_size + self.bias = nn.Parameter(torch.zeros(self.vocab_size)) + + def forward(self, hidden_states: torch.Tensor, embedding_layer: torch.Tensor) -> torch.Tensor: + batch_size, seq_len, d_model = hidden_states.shape + # Flatten batch dim + output = torch.matmul(hidden_states.reshape([-1, d_model]), embedding_layer.weight.transpose(0, 1)) + output = output + self.bias + + return output.reshape([batch_size, seq_len, self.vocab_size]) + + +class PerceiverMultimodalPostprocessor(nn.Module): + """ + Multimodal postprocessing for Perceiver. Can be used to combine modality-specific postprocessors into a single + postprocessor. + + Args: + modalities (`Mapping[str, PostprocessorType]`): + Dictionary mapping modality name to postprocessor class for that modality. + input_is_dict (`bool`, *optional*, defaults to `False`): + If True, input is assumed to be dictionary structured, and outputs keep the same dictionary shape. If + False, input is a tensor which is sliced up during postprocessing by *modality_sizes*. + """ + + def __init__(self, modalities: Mapping[str, PostprocessorType], input_is_dict: bool = False): + super().__init__() + self.modalities = nn.ModuleDict(modalities) + self.input_is_dict = input_is_dict + + def forward( + self, inputs: torch.Tensor, pos: Optional[torch.Tensor] = None, modality_sizes=None + ) -> Mapping[str, torch.Tensor]: + if not self.input_is_dict: + # Slice up modalities by their sizes. + if modality_sizes is None: + raise ValueError("Modality sizes should be specified if input is not a dictionary.") + inputs = restructure(modality_sizes=modality_sizes, inputs=inputs) + + outputs = { + modality: postprocessor(inputs[modality], pos=pos, modality_sizes=None) + for modality, postprocessor in self.modalities.items() + } + return outputs + + +class PerceiverClassificationPostprocessor(nn.Module): + """ + Classification postprocessing for Perceiver. Can be used to convert the decoder output to classification logits. + + Args: + config ([*PerceiverConfig*]): + Model configuration. + in_channels (`int`): + Number of channels in the input. + """ + + def __init__(self, config: PerceiverConfig, in_channels: int) -> None: + super().__init__() + self.classifier = nn.Linear(in_channels, config.num_labels) + + def forward(self, inputs, pos: Optional[torch.Tensor] = None, modality_sizes=None) -> torch.Tensor: + logits = self.classifier(inputs) + return logits[:, 0, :] + + +class PerceiverAudioPostprocessor(nn.Module): + """ + Audio postprocessing for Perceiver. Can be used to convert the decoder output to audio features. + + Args: + config ([*PerceiverConfig*]): + Model configuration. + in_channels (`int`): + Number of channels in the input. + postproc_type (`str`, *optional*, defaults to `"patches"`): + Postprocessor type to use. Currently, only "patches" is supported. + """ + + def __init__(self, config: PerceiverConfig, in_channels: int, postproc_type: str = "patches") -> None: + super().__init__() + + if postproc_type not in ("patches",): # to be supported: 'conv', 'patches', 'pixels' + raise ValueError("Invalid postproc_type!") + + # Architecture parameters: + self.classifier = nn.Linear(in_channels, config.samples_per_patch) + + def forward(self, inputs: torch.Tensor, pos: Optional[torch.Tensor] = None, modality_sizes=None) -> torch.Tensor: + logits = self.classifier(inputs) + return torch.reshape(logits, [inputs.shape[0], -1]) + + +class PerceiverProjectionPostprocessor(nn.Module): + """ + Projection postprocessing for Perceiver. Can be used to project the channels of the decoder output to a lower + dimension. + + Args: + in_channels (`int`): + Number of channels in the input. + out_channels (`int`): + Number of channels in the output. + """ + + def __init__(self, in_channels: int, out_channels: int) -> None: + super().__init__() + self.classifier = nn.Linear(in_channels, out_channels) + + def forward(self, inputs: torch.Tensor, pos: Optional[torch.Tensor] = None, modality_sizes=None) -> torch.Tensor: + logits = self.classifier(inputs) + return logits + + +class PerceiverImagePreprocessor(AbstractPreprocessor): + """ + Image preprocessing for Perceiver Encoder. + + Note: the *out_channels* argument refers to the output channels of a convolutional layer, if *prep_type* is set to + "conv1x1" or "conv". If one adds absolute position embeddings, one must make sure the *num_channels* of the + position encoding kwargs are set equal to the *out_channels*. + + Args: + config ([*PerceiverConfig*]): + Model configuration. + prep_type (`str`, *optional*, defaults to `"conv"`): + Preprocessing type. Can be "conv1x1", "conv", "patches", "pixels". + spatial_downsample (`int`, *optional*, defaults to 4): + Spatial downsampling factor. + temporal_downsample (`int`, *optional*, defaults to 1): + Temporal downsampling factor (only relevant in case a time dimension is present). + position_encoding_type (`str`, *optional*, defaults to `"fourier"`): + Position encoding type. Can be "fourier" or "trainable". + in_channels (`int`, *optional*, defaults to 3): + Number of channels in the input. + out_channels (`int`, *optional*, defaults to 64): + Number of channels in the output. + conv_after_patching (`bool`, *optional*, defaults to `False`): + Whether to apply a convolutional layer after patching. + conv_after_patching_in_channels (`int`, *optional*, defaults to 54): + Number of channels in the input of the convolutional layer after patching. + conv2d_use_batchnorm (`bool`, *optional*, defaults to `True`): + Whether to use batch normalization in the convolutional layer. + concat_or_add_pos (`str`, *optional*, defaults to `"concat"`): + How to concatenate the position encoding to the input. Can be "concat" or "add". + project_pos_dim (`int`, *optional*, defaults to -1): + Dimension of the position encoding to project to. If -1, no projection is applied. + **position_encoding_kwargs (`Dict`, *optional*): + Keyword arguments for the position encoding. + """ + + def __init__( + self, + config, + prep_type="conv", + spatial_downsample: int = 4, + temporal_downsample: int = 1, + position_encoding_type: str = "fourier", + in_channels: int = 3, + out_channels: int = 64, + conv_after_patching: bool = False, + conv_after_patching_in_channels: int = 54, # only relevant when conv_after_patching = True + conv2d_use_batchnorm: bool = True, + concat_or_add_pos: str = "concat", + project_pos_dim: int = -1, + **position_encoding_kwargs, + ): + super().__init__() + self.config = config + + if prep_type not in ("conv", "patches", "pixels", "conv1x1"): + raise ValueError(f"Prep_type {prep_type} is invalid") + + if concat_or_add_pos not in ["concat", "add"]: + raise ValueError(f"Invalid value {concat_or_add_pos} for concat_or_add_pos.") + + self.in_channels = in_channels + self.prep_type = prep_type + self.spatial_downsample = spatial_downsample + self.temporal_downsample = temporal_downsample + self.position_encoding_type = position_encoding_type + self.concat_or_add_pos = concat_or_add_pos + self.conv_after_patching = conv_after_patching + self.out_channels = out_channels + + if self.prep_type == "conv": + # Downsampling with conv is currently restricted + convnet_num_layers = math.log(spatial_downsample, 4) + convnet_num_layers_is_int = convnet_num_layers == np.round(convnet_num_layers) + if not convnet_num_layers_is_int or temporal_downsample != 1: + raise ValueError( + "Only powers of 4 expected for spatial and 1 expected for temporal downsampling with conv." + ) + self.convnet = Conv2DDownsample( + in_channels=in_channels, + num_layers=int(convnet_num_layers), + out_channels=out_channels, + use_batchnorm=conv2d_use_batchnorm, + ) + + elif self.prep_type == "conv1x1": + if temporal_downsample != 1: + raise ValueError("Conv1x1 does not downsample in time.") + self.convnet_1x1 = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(1, 1), + # spatial_downsample is unconstrained for 1x1 convolutions. + stride=(spatial_downsample, spatial_downsample), + ) + + # Position embeddings + self.project_pos_dim = project_pos_dim + self.position_embeddings, self.positions_projection = build_position_encoding( + position_encoding_type=position_encoding_type, + out_channels=out_channels, + project_pos_dim=project_pos_dim, + **position_encoding_kwargs, + ) + + # Optional convolutional layer after patches. + self.conv_after_patches = ( + nn.Linear(conv_after_patching_in_channels, self.out_channels) if conv_after_patching else nn.Identity() + ) + + @property + def num_channels(self) -> int: + # Let's assume that the number of resolutions (in the context of image preprocessing) + # of the input data is 2 or 3 depending on whether we are processing image or video respectively. + # In this case, for convenience, we will declare is_temporal variable, + # which will show whether the data has a temporal dimension or not. + is_temporal = self.position_embeddings.num_dimensions > 2 + + # position embedding + if self.project_pos_dim > 0: + pos_dim = self.project_pos_dim + else: + pos_dim = self.position_embeddings.output_size() + if self.concat_or_add_pos == "add": + return pos_dim + + # inputs + if self.conv_after_patching or self.prep_type in ("conv1x1", "conv"): + inp_dim = self.out_channels + elif self.prep_type == "pixels": + inp_dim = self.in_channels + if not is_temporal: + inp_dim = math.ceil(inp_dim / self.spatial_downsample) + elif self.prep_type == "patches": + if self.conv_after_patching: + inp_dim = self.out_channels + else: + inp_dim = self.in_channels * self.spatial_downsample**2 + if is_temporal: + inp_dim *= self.temporal_downsample + + return inp_dim + pos_dim + + def _build_network_inputs( + self, inputs: torch.Tensor, network_input_is_1d: bool = True, interpolate_pos_encoding: bool = False + ): + """ + Construct the final input, including position encoding. + + This method expects the inputs to always have channels as last dimension. + + """ + batch_size = inputs.shape[0] + input_size = inputs.shape[1:3] + index_dims = inputs.shape[1:-1] + indices = np.prod(index_dims) + + # Flatten input features to a 1D index dimension if necessary. + if len(inputs.shape) > 3 and network_input_is_1d: + inputs = torch.reshape(inputs, [batch_size, indices, -1]) + + # Construct the position encoding. + if self.position_encoding_type == "trainable": + pos_enc = self.position_embeddings(batch_size, interpolate_pos_encoding, input_size) + elif self.position_encoding_type == "fourier": + pos_enc = self.position_embeddings(index_dims, batch_size, device=inputs.device, dtype=inputs.dtype) + + # Optionally project them to a target dimension. + pos_enc = self.positions_projection(pos_enc) + + if not network_input_is_1d: + # Reshape pos to match the input feature shape + # if the network takes non-1D inputs + sh = inputs.shape + pos_enc = torch.reshape(pos_enc, list(sh)[:-1] + [-1]) + if self.concat_or_add_pos == "concat": + inputs_with_pos = torch.cat([inputs, pos_enc], dim=-1) + elif self.concat_or_add_pos == "add": + inputs_with_pos = inputs + pos_enc + return inputs_with_pos, inputs + + def forward( + self, + inputs: torch.Tensor, + pos: Optional[torch.Tensor] = None, + network_input_is_1d: bool = True, + interpolate_pos_encoding: bool = False, + ): + if self.prep_type == "conv": + # Convnet image featurization. + # Downsamples spatially by a factor of 4 + inputs = self.convnet(inputs) + + elif self.prep_type == "conv1x1": + # map inputs to self.out_channels + inputs = self.convnet_1x1(inputs) + + elif self.prep_type == "pixels": + # if requested, downsamples in the crudest way + if inputs.ndim == 4: + inputs = inputs[:: self.spatial_downsample, :: self.spatial_downsample] + elif inputs.ndim == 5: + inputs = inputs[ + :, :: self.temporal_downsample, :, :: self.spatial_downsample, :: self.spatial_downsample + ] + else: + raise ValueError("Unsupported data format for pixels.") + + elif self.prep_type == "patches": + # Space2depth featurization. + # Video: B x T x C x H x W + inputs = space_to_depth( + inputs, temporal_block_size=self.temporal_downsample, spatial_block_size=self.spatial_downsample + ) + + if inputs.ndim == 5 and inputs.shape[1] == 1: + # for flow + inputs = inputs.squeeze(dim=1) + + # Optionally apply conv layer. + inputs = self.conv_after_patches(inputs) + + if self.prep_type != "patches": + # move channels to last dimension, as the _build_network_inputs method below expects this + if inputs.ndim == 4: + inputs = inputs.permute(0, 2, 3, 1) + elif inputs.ndim == 5: + inputs = inputs.permute(0, 1, 3, 4, 2) + else: + raise ValueError("Unsupported data format for conv1x1.") + + inputs, inputs_without_pos = self._build_network_inputs(inputs, network_input_is_1d, interpolate_pos_encoding) + modality_sizes = None # Size for each modality, only needed for multimodal + + return inputs, modality_sizes, inputs_without_pos + + +class PerceiverOneHotPreprocessor(AbstractPreprocessor): + """ + One-hot preprocessor for Perceiver Encoder. Can be used to add a dummy index dimension to the input. + + Args: + config ([`PerceiverConfig`]): + Model configuration. + """ + + def __init__(self, config: PerceiverConfig) -> None: + super().__init__() + self.config: PerceiverConfig = config + + @property + def num_channels(self) -> int: + return self.config.num_labels + + def forward(self, inputs: torch.Tensor, pos: Optional[torch.Tensor] = None, network_input_is_1d: bool = True): + # Add a dummy index dimension. + inputs = inputs[:, None, :] + + # No position encodings, so the 1st (input) and 3rd (inputs_without_pos) + # outputs are identical. + return inputs, None, inputs + + +class PerceiverAudioPreprocessor(AbstractPreprocessor): + """ + Audio preprocessing for Perceiver Encoder. + + Args: + config ([*PerceiverConfig*]): + Model configuration. + prep_type (`str`, *optional*, defaults to `"patches"`): + Preprocessor type to use. Only "patches" is supported. + samples_per_patch (`int`, *optional*, defaults to 96): + Number of samples per patch. + position_encoding_type (`str`, *optional*, defaults to `"fourier"`): + Type of position encoding to use. Can be "trainable" or "fourier". + concat_or_add_pos (`str`, *optional*, defaults to `"concat"`): + How to concatenate the position encoding to the input. Can be "concat" or "add". + out_channels (`int`, *optional*, defaults to 64): + Number of channels in the output. + project_pos_dim (`int`, *optional*, defaults to -1): + Dimension of the position encoding to project to. If -1, no projection is applied. + **position_encoding_kwargs (`Dict`, *optional*): + Keyword arguments for the position encoding. + """ + + def __init__( + self, + config, + prep_type: str = "patches", + samples_per_patch: int = 96, + position_encoding_type: str = "fourier", + concat_or_add_pos: str = "concat", + out_channels=64, + project_pos_dim=-1, + **position_encoding_kwargs, + ): + super().__init__() + self.config = config + + if prep_type not in ("patches",): + raise ValueError(f"Prep_type {prep_type} is invalid, can only be 'patches'.") + + if concat_or_add_pos not in ["concat", "add"]: + raise ValueError(f"Concat_or_pos {concat_or_add_pos} is invalid, can only be 'concat' or 'add'.") + + self.samples_per_patch = samples_per_patch + self.position_encoding_type = position_encoding_type + self.concat_or_add_pos = concat_or_add_pos + self.project_pos_dim = project_pos_dim + + # Position embeddings + self.position_embeddings, self.positions_projection = build_position_encoding( + position_encoding_type=position_encoding_type, + out_channels=out_channels, + project_pos_dim=project_pos_dim, + **position_encoding_kwargs, + ) + + @property + def num_channels(self) -> int: + # position embedding + if self.project_pos_dim > 0: + pos_dim = self.project_pos_dim + else: + pos_dim = self.position_embeddings.output_size() + if self.concat_or_add_pos == "add": + return pos_dim + return self.samples_per_patch + pos_dim + + def _build_network_inputs(self, inputs): + """Construct the final input, including position encoding.""" + batch_size = inputs.shape[0] + index_dims = inputs.shape[1:-1] + + # Construct the position encoding. + if self.position_encoding_type == "trainable": + pos_enc = self.position_embeddings(batch_size) + elif self.position_encoding_type == "fourier": + pos_enc = self.position_embeddings(index_dims, batch_size, device=inputs.device, dtype=inputs.dtype) + + # Optionally project them to a target dimension. + pos_enc = self.positions_projection(pos_enc) + + if self.concat_or_add_pos == "concat": + inputs_with_pos = torch.cat([inputs, pos_enc], dim=-1) + elif self.concat_or_add_pos == "add": + inputs_with_pos = inputs + pos_enc + + return inputs_with_pos, inputs + + def forward( + self, + inputs: torch.Tensor, + pos: Optional[torch.Tensor] = None, + network_input_is_1d: bool = True, + interpolate_pos_encoding: bool = False, + ): + inputs = torch.reshape(inputs, [inputs.shape[0], -1, self.samples_per_patch]) + + inputs, inputs_without_pos = self._build_network_inputs(inputs) + modality_sizes = None # Size for each modality, only needed for multimodal + + return inputs, modality_sizes, inputs_without_pos + + +class PerceiverMultimodalPreprocessor(AbstractPreprocessor): + """ + Multimodal preprocessing for Perceiver Encoder. + + Inputs for each modality are preprocessed, then padded with trainable position embeddings to have the same number + of channels. + + Args: + modalities (`Mapping[str, PreprocessorType]`): + Dict mapping modality name to preprocessor. + mask_probs (`Dict[str, float]`): + Dict mapping modality name to masking probability of that modality. + min_padding_size (`int`, *optional*, defaults to 2): + The minimum padding size for all modalities. The final output will have num_channels equal to the maximum + channels across all modalities plus min_padding_size. + """ + + def __init__( + self, + modalities: Mapping[str, PreprocessorType], + mask_probs: Optional[Mapping[str, float]] = None, + min_padding_size: int = 2, + ): + super().__init__() + self.modalities = nn.ModuleDict(modalities) + self.min_padding_size = min_padding_size + self.mask_probs = mask_probs if mask_probs is not None else {} + self.padding = nn.ParameterDict( + { + modality: nn.Parameter(torch.randn(1, self.num_channels - preprocessor.num_channels)) + for modality, preprocessor in modalities.items() + } + ) + self.mask = nn.ParameterDict( + {modality: nn.Parameter(torch.randn(1, self.num_channels)) for modality, _ in self.mask_probs.items()} + ) + + @property + def num_channels(self) -> int: + max_channel_size = max(processor.num_channels for _, processor in self.modalities.items()) + common_channel_size = max_channel_size + self.min_padding_size + return common_channel_size + + def forward( + self, + inputs: Mapping[str, torch.Tensor], + pos: Optional[torch.Tensor] = None, + network_input_is_1d: bool = True, + interpolate_pos_encoding: bool = False, + ) -> PreprocessorOutputType: + padded = {} + modality_sizes = {} + inputs_without_pos = {} + for modality, preprocessor in self.modalities.items(): + # preprocess each modality using the respective preprocessor. + output, _, inputs_without_pos[modality] = preprocessor( + inputs[modality], pos=pos, network_input_is_1d=network_input_is_1d + ) + + # pad to the same common_channel_size. + batch_size, num_samples, num_channels = output.shape + pos_enc = self.padding[modality].expand(batch_size, -1, -1) + + padding = torch.broadcast_to( + pos_enc, + [batch_size, num_samples, self.num_channels - num_channels], + ) + output_padded = torch.cat([output, padding], dim=2) + + # mask if required + if modality in self.mask_probs: + mask_token = self.mask[modality].expand(batch_size, -1, -1) + mask_prob = self.mask_probs[modality] + mask = torch.bernoulli(torch.full([batch_size, num_samples], mask_prob)) + mask = torch.unsqueeze(mask, dim=2).to(mask_token.device) + output_padded = (1 - mask) * output_padded + mask * mask_token + + padded[modality] = output_padded + modality_sizes[modality] = output_padded.shape[1] + + # Apply a predictable ordering to the modalities + padded_ls = [padded[k] for k in sorted(padded.keys())] + + # Finally, concatenate along the time dimension + final_inputs = torch.cat(padded_ls, dim=1) + + return final_inputs, modality_sizes, inputs_without_pos + + +__all__ = [ + "PerceiverForImageClassificationConvProcessing", + "PerceiverForImageClassificationFourier", + "PerceiverForImageClassificationLearned", + "PerceiverForMaskedLM", + "PerceiverForMultimodalAutoencoding", + "PerceiverForOpticalFlow", + "PerceiverForSequenceClassification", + "PerceiverLayer", + "PerceiverModel", + "PerceiverPreTrainedModel", +] diff --git a/docs/transformers/src/transformers/models/perceiver/tokenization_perceiver.py b/docs/transformers/src/transformers/models/perceiver/tokenization_perceiver.py new file mode 100644 index 0000000000000000000000000000000000000000..2a7fe6f43c6d7a0b319f564c7a42c83ad7966afb --- /dev/null +++ b/docs/transformers/src/transformers/models/perceiver/tokenization_perceiver.py @@ -0,0 +1,200 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization class for Perceiver.""" + +from typing import Dict, List, Optional, Tuple + +from ...tokenization_utils import AddedToken, PreTrainedTokenizer +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class PerceiverTokenizer(PreTrainedTokenizer): + """ + Construct a Perceiver tokenizer. The Perceiver simply uses raw bytes utf-8 encoding. + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + pad_token (`str`, *optional*, defaults to `"[PAD]"`): + The token used for padding, for example when batching sequences of different lengths. + bos_token (`str`, *optional*, defaults to `"[BOS]"`): + The BOS token (reserved in the vocab, but not actually used). + eos_token (`str`, *optional*, defaults to `"[EOS]"`): + The end of sequence token (reserved in the vocab, but not actually used). + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + mask_token (`str`, *optional*, defaults to `"[MASK]"`): + The MASK token, useful for masked language modeling. + cls_token (`str`, *optional*, defaults to `"[CLS]"`): + The CLS token (reserved in the vocab, but not actually used). + sep_token (`str`, *optional*, defaults to `"[SEP]"`): + The separator token, which is used when building a sequence from two sequences. + + """ + + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + pad_token="[PAD]", + bos_token="[BOS]", + eos_token="[EOS]", + mask_token="[MASK]", + cls_token="[CLS]", + sep_token="[SEP]", + model_max_length=2048, + **kwargs, + ) -> None: + pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token + bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token + eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token + mask_token = AddedToken(mask_token, lstrip=False, rstrip=False) if isinstance(mask_token, str) else mask_token + cls_token = AddedToken(cls_token, lstrip=False, rstrip=False) if isinstance(cls_token, str) else cls_token + sep_token = AddedToken(sep_token, lstrip=False, rstrip=False) if isinstance(sep_token, str) else sep_token + + self._utf_vocab_size = 2**8 # utf is 8 bits + + # Since these tokens are not part of the vocabulary, we manually add them + self._added_tokens_decoder: Dict[str, int] = { + 0: pad_token, + 1: bos_token, + 2: eos_token, + 3: mask_token, + 4: cls_token, + 5: sep_token, + } + self._num_special_tokens = len(self._added_tokens_decoder) + super().__init__( + pad_token=pad_token, + bos_token=bos_token, + eos_token=eos_token, + mask_token=mask_token, + cls_token=cls_token, + sep_token=sep_token, + model_max_length=model_max_length, + **kwargs, + ) + + def get_vocab(self) -> Dict[str, int]: + vocab = {} + for i in range(self._utf_vocab_size): + token = chr(i) + vocab[token] = i + self._num_special_tokens + vocab.update(self.added_tokens_encoder) + return vocab + + @property + def vocab_size(self): + return self._utf_vocab_size + + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + # normal case: some special tokens + if token_ids_1 is None: + return [1] + [0] * len(token_ids_0) + [1] + return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] + + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks. A sequence has the + following format: + + - single sequence: `[CLS] X [SEP]` + - pair of sequences: `[CLS] A [SEP] B [SEP]` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + if token_ids_1 is None: + return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + else: + return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + token_ids_1 + [self.sep_token_id] + + def _tokenize(self, text: str) -> List[str]: + """Take as input a string and return a list of strings (tokens) for words/sub-words""" + tokens = [chr(i) for i in text.encode("utf-8")] + return tokens + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + if len(token) != 1: + token_id = self.unk_token_id + else: + token_id = ord(token) + self._num_special_tokens + return token_id + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + token = chr(index - self._num_special_tokens) + return token + + # TODO @ArthurZ refactor this as well.... + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + bstring = b"" + for token in tokens: + if token in self.added_tokens_encoder: + tok_string = str(token).encode("utf-8") + else: + tok_string = bytes([ord(token)]) + bstring += tok_string + string = bstring.decode("utf-8", errors="replace") + return string + + # PerceiverTokenizer has no vocab file + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + return () + + +__all__ = ["PerceiverTokenizer"] diff --git a/docs/transformers/src/transformers/models/persimmon/__init__.py b/docs/transformers/src/transformers/models/persimmon/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cb71eae2547c59a4f9ba7ecbafda56fb9c86b494 --- /dev/null +++ b/docs/transformers/src/transformers/models/persimmon/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_persimmon import * + from .modeling_persimmon import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/persimmon/configuration_persimmon.py b/docs/transformers/src/transformers/models/persimmon/configuration_persimmon.py new file mode 100644 index 0000000000000000000000000000000000000000..80ca823f28e1fe7419dbc5c5ea6298c02e882e7d --- /dev/null +++ b/docs/transformers/src/transformers/models/persimmon/configuration_persimmon.py @@ -0,0 +1,176 @@ +# coding=utf-8 +# Copyright 2023 Adept AI and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Persimmon model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...modeling_rope_utils import rope_config_validation +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class PersimmonConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`PersimmonModel`]. It is used to instantiate an + Persimmon model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the + [adept/persimmon-8b-base](https://huggingface.co/adept/persimmon-8b-base). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 262144): + Vocabulary size of the Persimmon model. Defines the number of different tokens that can be represented by + the `inputs_ids` passed when calling [`PersimmonModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 16384): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 36): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 64): + Number of attention heads for each attention layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"relu2"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 16384): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-5): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings(`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 25000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + qk_layernorm (`bool`, *optional*, default to `True`): + Whether or not to normalize the Queries and Keys after projecting the hidden states + hidden_dropout (`float`, *optional*, default to 0.0): + The dropout ratio after applying the MLP to the hidden states. + attention_dropout (`float`, *optional*, default to 0.0): + The dropout ratio after computing the attention scores. + partial_rotary_factor (`float`, *optional*, default to 0.5): + Percentage of the query and keys which will have rotary embedding. + + Example: + + ```python + >>> from transformers import PersimmonModel, PersimmonConfig + + >>> # Initializing a Persimmon persimmon-7b style configuration + >>> configuration = PersimmonConfig() + ```""" + + model_type = "persimmon" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=262144, + hidden_size=4096, + intermediate_size=16384, + num_hidden_layers=36, + num_attention_heads=64, + hidden_act="relu2", + max_position_embeddings=16384, + initializer_range=0.02, + layer_norm_eps=1e-5, + use_cache=True, + tie_word_embeddings=False, + rope_theta=25000.0, + rope_scaling=None, + qk_layernorm=True, + hidden_dropout=0.0, + attention_dropout=0.0, + partial_rotary_factor=0.5, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.qk_layernorm = qk_layernorm + self.hidden_dropout = hidden_dropout + self.attention_dropout = attention_dropout + self.partial_rotary_factor = partial_rotary_factor + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, move it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +__all__ = ["PersimmonConfig"] diff --git a/docs/transformers/src/transformers/models/persimmon/convert_persimmon_weights_to_hf.py b/docs/transformers/src/transformers/models/persimmon/convert_persimmon_weights_to_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..c4b410fd3bbf8104ff62e7de381f0b9398d43c01 --- /dev/null +++ b/docs/transformers/src/transformers/models/persimmon/convert_persimmon_weights_to_hf.py @@ -0,0 +1,129 @@ +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import os +import warnings + +import flatdict +import torch + +from transformers import LlamaTokenizer, PersimmonConfig, PersimmonForCausalLM + + +try: + from transformers import LlamaTokenizerFast + + tokenizer_class = LlamaTokenizerFast +except ImportError as e: + warnings.warn(e) + warnings.warn( + "The converted tokenizer will be the `slow` tokenizer. To use the fast, update your `tokenizers` library and re-run the tokenizer conversion" + ) + tokenizer_class = LlamaTokenizer + +""" +Sample usage: + +``` +git clone https://github.com/persimmon-ai-labs/adept-inference +wget https://axtkn4xl5cip.objectstorage.us-phoenix-1.oci.customer-oci.com/n/axtkn4xl5cip/b/adept-public-data/o/8b_base_model_release.tar +wget https://axtkn4xl5cip.objectstorage.us-phoenix-1.oci.customer-oci.com/n/axtkn4xl5cip/b/adept-public-data/o/8b_chat_model_release.tar +python src/transformers/models/persimmon/convert_persimmon_weights_to_hf.py --input_dir /path/to/downloaded/persimmon/weights/ --output_dir /output/path +``` + +Thereafter, models can be loaded via: + +```py +from transformers import PersimmonForCausalLM, PersimmonTokenizer + +model = PersimmonForCausalLM.from_pretrained("/output/path") +tokenizer = PersimmonTokenizer.from_pretrained("/output/path") +``` + +Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions +come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM). +""" + + +KEYS_TO_MODIFY_MAPPING = { + "self_attention": "self_attn", + "language_model.encoder": "model", + "word_embeddings_for_head": "lm_head", + "language_model.embedding.word_embeddings": "model.embed_tokens", +} + +KEYS_TO_REMOVE = "rotary_emb.inv_freq" + + +def rename_state_dict(state_dict): + model_state_dict = {} + for key, value in state_dict.items(): + for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items(): + if key_to_modify in key: + key = key.replace(key_to_modify, new_key) + if KEYS_TO_REMOVE in key: + continue + model_state_dict[key] = value + return model_state_dict + + +def convert_persimmon_checkpoint(pytorch_dump_folder_path, ada_lib_path, pt_model_path, safe_serialization=False): + import sys + + sys.path.insert(0, ada_lib_path) + model_state_dict_base = torch.load(pt_model_path, map_location="cpu", weights_only=True) + state_dict = flatdict.FlatDict(model_state_dict_base["model"], ".") + state_dict = rename_state_dict(state_dict) + + transformers_config = PersimmonConfig() + model = PersimmonForCausalLM(transformers_config, eos_token_id=71013, bos_token_id=71013).to(torch.bfloat16) + model.load_state_dict(state_dict) + model.save_pretrained(pytorch_dump_folder_path, safe_serialization=safe_serialization) + transformers_config.save_pretrained(pytorch_dump_folder_path) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--input_dir", + help="Location of Persimmon weights, which contains tokenizer.model and model folders", + ) + parser.add_argument( + "--pt_model_path", + help="Location of Persimmon `model_optim_rng.pt`", + ) + parser.add_argument( + "--output_dir", + help="Location to write HF model and tokenizer", + ) + parser.add_argument( + "--ada_lib_path", + help="Location to write HF model and tokenizer", + ) + parser.add_argument("--safe_serialization", type=bool, help="Whether or not to save using `safetensors`.") + args = parser.parse_args() + spm_path = os.path.join(args.input_dir, "adept_vocab.model") + + convert_persimmon_checkpoint( + pytorch_dump_folder_path=args.output_dir, + pt_model_path=args.pt_model_path, + safe_serialization=args.safe_serialization, + ada_lib_path=args.ada_lib_path, + ) + tokenizer = tokenizer_class(spm_path, bos_token="|ENDOFTEXT|", eos_token="|ENDOFTEXT|") + tokenizer.save_pretrained(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/docs/transformers/src/transformers/models/persimmon/modeling_persimmon.py b/docs/transformers/src/transformers/models/persimmon/modeling_persimmon.py new file mode 100644 index 0000000000000000000000000000000000000000..564865ddef6d664b3455a8811f448280ecebc9ac --- /dev/null +++ b/docs/transformers/src/transformers/models/persimmon/modeling_persimmon.py @@ -0,0 +1,1100 @@ +# coding=utf-8 +# Copyright 2023 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Persimmon model.""" + +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import PreTrainedModel +from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + can_return_tuple, + is_torch_flex_attn_available, + logging, + replace_return_docstrings, +) +from .configuration_persimmon import PersimmonConfig + + +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "adept/persimmon-8b-base" +_CONFIG_FOR_DOC = "PersimmonConfig" + + +# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Persimmon +class PersimmonRotaryEmbedding(nn.Module): + def __init__(self, config: PersimmonConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +# Copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXMLP with GPTNeoX->Persimmon +class PersimmonMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.dense_h_to_4h = nn.Linear(config.hidden_size, config.intermediate_size) + self.dense_4h_to_h = nn.Linear(config.intermediate_size, config.hidden_size) + self.act = ACT2FN[config.hidden_act] + + def forward(self, hidden_states): + hidden_states = self.dense_h_to_4h(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.dense_4h_to_h(hidden_states) + return hidden_states + + +class PersimmonAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: PersimmonConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.rope_theta = config.rope_theta + self.rotary_ndims = int(self.head_dim * config.partial_rotary_factor) + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True) + self.dense = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=True) + self.qk_layernorm = config.qk_layernorm + + if self.qk_layernorm: + self.q_layernorm = nn.LayerNorm( + config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True + ) + self.k_layernorm = nn.LayerNorm( + config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True + ) + self.attention_dropout = nn.Dropout(config.attention_dropout) + self.rotary_emb = PersimmonRotaryEmbedding(config=self.config) + + def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory + storage as `fused_qkv` + + Args: + fused_qkv (`torch.tensor`): [batch_size, seq_length, num_heads * 3 * head_dim] + + Returns: + query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim] + value: [batch_size, seq_length, num_heads, head_dim] + """ + batch_size, seq_length, three_times_hidden_size = fused_qkv.shape + fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim) + return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :] + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + # [batch_size, seq_length, 3 x hidden_size] + fused_qkv = self.query_key_value(hidden_states) + + # 3 x [batch_size, seq_length, num_heads, head_dim] + (query_states, key_states, value_states) = self._split_heads(fused_qkv) + + if self.qk_layernorm: + query_states = self.q_layernorm(query_states) + key_states = self.k_layernorm(key_states) + + # [batch_size, num_heads, seq_length, head_dim] -> [batch_size, seq_length, num_heads, head_dim] + query_states = query_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + + cos, sin = position_embeddings + + # Partial rotary embedding + query_rot, query_pass = ( + query_states[..., : self.rotary_ndims], + query_states[..., self.rotary_ndims :], + ) + key_rot, key_pass = ( + key_states[..., : self.rotary_ndims], + key_states[..., self.rotary_ndims :], + ) + # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor] + query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin) + + # [batch_size, seq_length, num_heads, head_dim] + query_states = torch.cat((query_rot, query_pass), dim=-1) + key_states = torch.cat((key_rot, key_pass), dim=-1) + + if past_key_value is not None: + # Specific to RoPE models with partial rotation + cache_kwargs = { + "sin": sin, + "cos": cos, + "partial_rotation_size": self.rotary_ndims, + "cache_position": cache_position, + } + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query_states.dtype) + attn_weights = self.attention_dropout(attn_weights) + + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.dense(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class PersimmonDecoderLayer(nn.Module): + def __init__(self, config: PersimmonConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = PersimmonAttention(config=config, layer_idx=layer_idx) + self.mlp = PersimmonMLP(config) + self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + position_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range + `[0, config.n_positions - 1]`. + [What are position IDs?](../glossary#position-ids) + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): + cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + + hidden_states = self.dropout(hidden_states) + hidden_states = hidden_states + residual + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +PERSIMMON_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`PersimmonConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Persimmon Model outputting raw hidden-states without any specific head on top.", + PERSIMMON_START_DOCSTRING, +) +class PersimmonPreTrainedModel(PreTrainedModel): + config_class = PersimmonConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["PersimmonDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + + +PERSIMMON_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance, see our + [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare Persimmon Model outputting raw hidden-states without any specific head on top.", + PERSIMMON_START_DOCSTRING, +) +class PersimmonModel(PersimmonPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`PersimmonDecoderLayer`] + + Args: + config: PersimmonConfig + """ + + def __init__(self, config: PersimmonConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [PersimmonDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.rotary_emb = PersimmonRotaryEmbedding(config=config) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @can_return_tuple + @add_start_docstrings_to_model_forward(PERSIMMON_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> BaseModelOutputWithPast: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + if past_key_values is None: + past_key_values = DynamicCache() + else: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " + "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " + "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" + ) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.final_layernorm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, "BlockMask"], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to place the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +class PersimmonForCausalLM(PersimmonPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with LLAMA->PERSIMMON,Llama->Persimmon + def __init__(self, config): + super().__init__(config) + self.model = PersimmonModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_input_embeddings + def get_input_embeddings(self): + return self.model.embed_tokens + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_input_embeddings + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_output_embeddings + def get_output_embeddings(self): + return self.lm_head + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_output_embeddings + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_decoder + def set_decoder(self, decoder): + self.model = decoder + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_decoder + def get_decoder(self): + return self.model + + @can_return_tuple + @add_start_docstrings_to_model_forward(PERSIMMON_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs, + ) -> CausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, PersimmonForCausalLM + + >>> model = PersimmonForCausalLM.from_pretrained("adept/persimmon-8b-base") + >>> tokenizer = AutoTokenizer.from_pretrained("adept/persimmon-8b-base") + + >>> prompt = "human: Hey, what should I eat for dinner?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + 'human: Hey, what should I eat for dinner?\n\ncat: 🐱\n\nhuman: 😐\n\n' + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + ) + + hidden_states = outputs.last_hidden_state + # No upscaling to float was ever done for Persimmon + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function( + logits, + labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + The Persimmon transformer with a sequence classification head on top (linear layer). + + [`PersimmonForSequenceClassification`] uses the last token in order to do the classification, as other causal + models (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + PERSIMMON_START_DOCSTRING, +) +# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with LLAMA->PERSIMMON,Llama->Persimmon +class PersimmonForSequenceClassification(PersimmonPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = PersimmonModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @can_return_tuple + @add_start_docstrings_to_model_forward(PERSIMMON_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> SequenceClassifierOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + transformer_outputs: BaseModelOutputWithPast = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + hidden_states = transformer_outputs.last_hidden_state + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) + last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) + else: + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ + The Persimmon Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + PERSIMMON_START_DOCSTRING, +) +# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Persimmon, LLAMA->PERSIMMON +class PersimmonForTokenClassification(PersimmonPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = PersimmonModel(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @can_return_tuple + @add_start_docstrings_to_model_forward(PERSIMMON_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> TokenClassifierOutput: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + outputs: BaseModelOutputWithPast = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + sequence_output = outputs.last_hidden_state + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.config) + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "PersimmonForCausalLM", + "PersimmonModel", + "PersimmonPreTrainedModel", + "PersimmonForSequenceClassification", + "PersimmonForTokenClassification", +] diff --git a/docs/transformers/src/transformers/models/phi/__init__.py b/docs/transformers/src/transformers/models/phi/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cffe33da73ee42eb20a2e7f33ea9351bb7da75c2 --- /dev/null +++ b/docs/transformers/src/transformers/models/phi/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2023 Microsoft and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_phi import * + from .modeling_phi import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/phi/configuration_phi.py b/docs/transformers/src/transformers/models/phi/configuration_phi.py new file mode 100644 index 0000000000000000000000000000000000000000..4ffd89db3adc2caa5d918e2044b2385b783cea44 --- /dev/null +++ b/docs/transformers/src/transformers/models/phi/configuration_phi.py @@ -0,0 +1,217 @@ +# coding=utf-8 +# Copyright 2023 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Phi model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...modeling_rope_utils import rope_config_validation +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class PhiConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`PhiModel`]. It is used to instantiate an Phi + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the Phi + [microsoft/phi-1](https://huggingface.co/microsoft/phi-1). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 51200): + Vocabulary size of the Phi model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`PhiModel`]. + hidden_size (`int`, *optional*, defaults to 2048): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 8192): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 24): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + resid_pdrop (`float`, *optional*, defaults to 0.0): + Dropout probability for mlp outputs. + embd_pdrop (`int`, *optional*, defaults to 0.0): + The dropout ratio for the embeddings. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio after computing the attention scores. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu_new"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. Phi-1 and Phi-1.5 supports up to 2048 + tokens. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. Whether to tie weight embeddings or not. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + partial_rotary_factor (`float`, *optional*, defaults to 0.5): + Percentage of the query and keys which will have rotary embedding. + qk_layernorm (`bool`, *optional*, defaults to `False`): + Whether or not to normalize the Queries and Keys after projecting the hidden states. + bos_token_id (`int`, *optional*, defaults to 1): + Denotes beginning of sequences token id. + eos_token_id (`int`, *optional*, defaults to 2): + Denotes end of sequences token id. + + Example: + + ```python + >>> from transformers import PhiModel, PhiConfig + + >>> # Initializing a Phi-1 style configuration + >>> configuration = PhiConfig.from_pretrained("microsoft/phi-1") + + >>> # Initializing a model from the configuration + >>> model = PhiModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "phi" + keys_to_ignore_at_inference = ["past_key_values"] + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.dense": "rowwise", + "layers.*.mlp.fc1": "colwise", + "layers.*.mlp.fc2": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "embed_dropout": (["inputs_embeds"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "final_layernorm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=51200, + hidden_size=2048, + intermediate_size=8192, + num_hidden_layers=24, + num_attention_heads=32, + num_key_value_heads=None, + resid_pdrop=0.0, + embd_pdrop=0.0, + attention_dropout=0.0, + hidden_act="gelu_new", + max_position_embeddings=2048, + initializer_range=0.02, + layer_norm_eps=1e-5, + use_cache=True, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + partial_rotary_factor=0.5, + qk_layernorm=False, + bos_token_id=1, + eos_token_id=2, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.resid_pdrop = resid_pdrop + self.embd_pdrop = embd_pdrop + self.attention_dropout = attention_dropout + self.hidden_act = hidden_act + self.max_position_embeddings = max_position_embeddings + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.partial_rotary_factor = partial_rotary_factor + self.qk_layernorm = qk_layernorm + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, move it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + + super().__init__( + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +__all__ = ["PhiConfig"] diff --git a/docs/transformers/src/transformers/models/phi/convert_phi_weights_to_hf.py b/docs/transformers/src/transformers/models/phi/convert_phi_weights_to_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..097423366115f22ffb0d05bba859840d91dc38ad --- /dev/null +++ b/docs/transformers/src/transformers/models/phi/convert_phi_weights_to_hf.py @@ -0,0 +1,207 @@ +# coding=utf-8 +# Copyright 2023 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Weights conversion script for Phi + +This script downloads both Phi-1 and Phi-1.5 checkpoints to "checkpoint_path" and then converts the weights to +HugfgingFace model's format and saves them in "pytorch_dump_folder_path". + +Example : $python ./convert_phi_weights_to_hf.py --model_name "microsoft/phi-2" --pytorch_dump_folder ./dump_folder/ --checkpoint_path ./ckpt_path/ +""" + +import argparse +import gc +import os + +import safetensors +import torch +from huggingface_hub import hf_hub_download + +from transformers import PhiConfig, PhiForCausalLM + + +_MODELS = { + "microsoft/phi-1": ["https://huggingface.co/microsoft/phi-1/blob/main/pytorch_model.bin"], + "microsoft/phi-1_5": ["https://huggingface.co/microsoft/phi-1_5/blob/main/pytorch_model.bin"], + "microsoft/phi-2": [ + "https://huggingface.co/microsoft/phi-2/blob/main/model-00001-of-00002.safetensors", + "https://huggingface.co/microsoft/phi-2/blob/main/model-00002-of-00002.safetensors", + ], +} + +PHI_MAPPING = { + "transformer.embd.wte.weight": "model.embed_tokens.weight", + "lm_head.linear": "lm_head", + "lm_head.ln": "model.final_layernorm", + "layers": "model.layers", + "transformer": "model", + ".h.": ".layers.", + "ln": "input_layernorm", + "mixer": "self_attn", + "Wqkv": "query_key_value", + "out_proj": "dense", +} + + +def convert_weights(original_weights, mapping, config): + converted_weights = {} + original_weights_keys = sorted(original_weights.keys()) + + for original_weights_key in original_weights_keys: + new_key = original_weights_key + + if "rotary_emb" in new_key: + continue + + if "Wqkv" in new_key: + if "weight" in new_key: + weight = original_weights[new_key] + weights_shape = weight.shape + weight = ( + weight.view(3, config.num_attention_heads, -1, config.hidden_size) + .transpose(0, 1) + .reshape(*weights_shape) + ) + original_weights[new_key] = weight + elif "bias" in new_key: + bias = original_weights[new_key] + bias_shape = bias.shape + bias = bias.view(3, config.num_attention_heads, -1).transpose(0, 1).reshape(*bias_shape) + original_weights[new_key] = bias + + for k, v in mapping.items(): + if k in new_key: + new_key = new_key.replace(k, v) + + converted_weights[new_key] = original_weights.pop(original_weights_key) + + return converted_weights + + +def _download(url: str, root: str): + repo_id = f"{url.split('/')[3]}/{url.split('/')[4]}" + filename = f"{url.split('/')[-1]}" + hf_hub_download( + repo_id=repo_id, + filename=filename, + force_filename=root, + local_dir_use_symlinks=False, + ) + + +def convert_phi_weights( + model_name, checkpoint_path, pytorch_dump_folder_path, use_cuda, save_weights_directly, _MODELS +): + _MODELS = _MODELS if model_name not in _MODELS.keys() else {model_name: _MODELS.get(model_name)} + device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu" + for model_name, model_url in _MODELS.items(): + converted_checkpoint = {} + model_checkpoint = {} + + # for phi-2 the weights are stored in 2 different safetensors file so we need to iterate over that list and download one at a time + for model_each_url in model_url: + model_path = os.path.join(checkpoint_path, model_name + "_" + model_each_url.split("/")[-1]) + if not os.path.exists(model_path): + print(f"\n{model_name} was not found! Downloading it to {model_path}") + _download(url=model_each_url, root=model_path) + + if model_path.endswith("safetensors"): + loaded_weights = safetensors.torch.load_file(model_path, device=device) + else: + loaded_weights = torch.load(model_path, map_location=device, weights_only=True) + model_checkpoint.update(**loaded_weights) + + model_type = model_name.split("/")[1] # phi-1 or phi-1_5 or phi-2 + + # init the config for phi-1 and phi-1.5 + config = PhiConfig() + # if we are dealing with phi-2 then update the config + if model_type == "phi-2": + config.hidden_size = 2560 + config.intermediate_size = 10240 + config.num_hidden_layers = 32 + config.resid_pdrop = 0.1 + config.partial_rotary_factor = 0.4 + config.num_hidden_layers = 32 + config.torch_dtype = "float16" + + # Converting the weights + converted_checkpoint.update(**convert_weights(model_checkpoint, PHI_MAPPING, config)) + + # Save either the whole model or the converted weights + if save_weights_directly: + save_weights_path = os.path.join(pytorch_dump_folder_path, model_type + "_pytorch_model.bin") + torch.save(converted_checkpoint, save_weights_path) + print(f"Model weights saved at {save_weights_path}!") + + else: + model = PhiForCausalLM(config).to(device) + model.load_state_dict(converted_checkpoint, strict=True) + save_model_path = os.path.join(pytorch_dump_folder_path, model_type) + model.save_pretrained(save_model_path) + print(f"Model saved at {save_model_path}!") + + # release GPU memory for the 2nd model if cuda was used. + del config, model + + # release GPU memory for the 2nd model if cuda was used. + del model_checkpoint, converted_checkpoint + if use_cuda: + torch.cuda.empty_cache() + gc.collect() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # # Required parameters + parser.add_argument( + "--model_name", + type=str, + help="Name of the model to convert. (Please enter one of the following: phi-1, phi-1_5, phi-2). If nothing is provided, all models will be converted.", + default=None, + ) + parser.add_argument( + "--checkpoint_path", type=str, help="Path to the folder of downloaded checkpoints. (Please enter full path)" + ) + parser.add_argument( + "--pytorch_dump_folder_path", + default=None, + type=str, + help="Path to the output PyTorch model. (Please enter full path)", + ) + parser.add_argument( + "--use_cuda", + default=False, + type=bool, + help="Whether to load the weights on GPU during conversion or not, False by default", + ) + parser.add_argument( + "--save_weights_directly", + default=True, + type=bool, + help="Whether to save the weights directly after conversion or load the weight to the Phi model and then save " + "the Phi model along with weights. True by default", + ) + + args = parser.parse_args() + convert_phi_weights( + args.model_name, + args.checkpoint_path, + args.pytorch_dump_folder_path, + args.use_cuda, + args.save_weights_directly, + _MODELS, + ) diff --git a/docs/transformers/src/transformers/models/phi/modeling_phi.py b/docs/transformers/src/transformers/models/phi/modeling_phi.py new file mode 100644 index 0000000000000000000000000000000000000000..8b8bd4b8e82c4d5b546f45f9f9b27a4036ee56b3 --- /dev/null +++ b/docs/transformers/src/transformers/models/phi/modeling_phi.py @@ -0,0 +1,1014 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/phi/modular_phi.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_phi.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +from functools import partial +from typing import Callable, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ( + LossKwargs, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + can_return_tuple, + is_torch_flex_attn_available, + logging, + replace_return_docstrings, +) +from .configuration_phi import PhiConfig + + +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "microsoft/phi-1" +_CONFIG_FOR_DOC = "PhiConfig" + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class PhiAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: PhiConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True) + self.dense = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=True) + self.rotary_ndims = int(self.head_dim * config.partial_rotary_factor) + self.qk_layernorm = config.qk_layernorm + if self.qk_layernorm: + self.q_layernorm = nn.LayerNorm( + config.hidden_size // config.num_attention_heads, eps=config.layer_norm_eps, elementwise_affine=True + ) + self.k_layernorm = nn.LayerNorm( + config.hidden_size // config.num_attention_heads, eps=config.layer_norm_eps, elementwise_affine=True + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + if self.qk_layernorm: + query_states = self.q_layernorm(query_states) + key_states = self.k_layernorm(key_states) + + cos, sin = position_embeddings + # Partial rotary embedding + query_rot, query_pass = ( + query_states[..., : self.rotary_ndims], + query_states[..., self.rotary_ndims :], + ) + key_rot, key_pass = ( + key_states[..., : self.rotary_ndims], + key_states[..., self.rotary_ndims :], + ) + # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor] + query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin) + + # [batch_size, seq_length, num_heads, head_dim] + query_states = torch.cat((query_rot, query_pass), dim=-1) + key_states = torch.cat((key_rot, key_pass), dim=-1) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.dense(attn_output) + return attn_output, attn_weights + + +class PhiMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class PhiDecoderLayer(nn.Module): + def __init__(self, config: PhiConfig, layer_idx: int): + super().__init__() + self.self_attn = PhiAttention(config, layer_idx=layer_idx) + self.mlp = PhiMLP(config) + self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.resid_dropout = nn.Dropout(config.resid_pdrop) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + attn_outputs, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + attn_outputs = self.resid_dropout(attn_outputs) + + feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states)) + hidden_states = attn_outputs + feed_forward_hidden_states + residual + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +class PhiRotaryEmbedding(nn.Module): + def __init__(self, config: PhiConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +PHI_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`PhiConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Phi Model outputting raw hidden-states without any specific head on top.", + PHI_START_DOCSTRING, +) +class PhiPreTrainedModel(PreTrainedModel): + config_class = PhiConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["PhiDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + _supports_attention_backend = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + + +PHI_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length) or `BlockMask`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + If the model is configured to use flex_attention, it will attempt to convert the mask Tensor into a BlockMask, + but you can also pass a `BlockMask` object directly here. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare Phi Model outputting raw hidden-states without any specific head on top.", + PHI_START_DOCSTRING, +) +class PhiModel(PhiPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`PhiDecoderLayer`] + + Args: + config: PhiConfig + """ + + def __init__(self, config: PhiConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [PhiDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.rotary_emb = PhiRotaryEmbedding(config=config) + self.gradient_checkpointing = False + self.embed_dropout = nn.Dropout(config.embd_pdrop) + self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @can_return_tuple + @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> BaseModelOutputWithPast: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + inputs_embeds = self.embed_dropout(inputs_embeds) # diff with Llama + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + partial(decoder_layer.__call__, **flash_attn_kwargs), + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.final_layernorm(hidden_states) # diff with Llama + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, "BlockMask"], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to place the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + +class PhiForCausalLM(PhiPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = PhiModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=True) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @can_return_tuple + @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[KwargsForCausalLM], + ) -> CausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, PhiForCausalLM + + >>> model = PhiForCausalLM.from_pretrained("meta-phi/Phi-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-phi/Phi-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + The Phi Model transformer with a sequence classification head on top (linear layer). + + [`PhiForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + PHI_START_DOCSTRING, +) +class PhiForSequenceClassification(PhiPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = PhiModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @can_return_tuple + @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> SequenceClassifierOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + transformer_outputs: BaseModelOutputWithPast = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + hidden_states = transformer_outputs.last_hidden_state + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) + last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) + else: + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ + The Phi Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + PHI_START_DOCSTRING, +) +class PhiForTokenClassification(PhiPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = PhiModel(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @can_return_tuple + @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> TokenClassifierOutput: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + outputs: BaseModelOutputWithPast = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + sequence_output = outputs.last_hidden_state + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.config) + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "PhiPreTrainedModel", + "PhiModel", + "PhiForCausalLM", + "PhiForSequenceClassification", + "PhiForTokenClassification", +] diff --git a/docs/transformers/src/transformers/models/phi/modular_phi.py b/docs/transformers/src/transformers/models/phi/modular_phi.py new file mode 100644 index 0000000000000000000000000000000000000000..5faee931e0a26e84876039f68e8c2da5890aff82 --- /dev/null +++ b/docs/transformers/src/transformers/models/phi/modular_phi.py @@ -0,0 +1,329 @@ +from functools import partial +from typing import Callable, Optional, Tuple + +import torch +import torch.nn as nn + +from ...cache_utils import Cache, DynamicCache +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import ( + BaseModelOutputWithPast, +) +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...processing_utils import Unpack +from ...utils import logging +from ..clip.modeling_clip import CLIPMLP +from ..llama.modeling_llama import ( + LlamaAttention, + LlamaForCausalLM, + LlamaForSequenceClassification, + LlamaForTokenClassification, + LlamaModel, + LlamaPreTrainedModel, + LlamaRotaryEmbedding, + apply_rotary_pos_emb, + eager_attention_forward, # copied from Llama +) +from .configuration_phi import PhiConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "microsoft/phi-1" +_CONFIG_FOR_DOC = "PhiConfig" + + +class PhiAttention(LlamaAttention): + def __init__(self, config: PhiConfig, layer_idx: int): + super().__init__(config, layer_idx) + self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True) + self.dense = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=True) + del self.o_proj + self.rotary_ndims = int(self.head_dim * config.partial_rotary_factor) + self.qk_layernorm = config.qk_layernorm + if self.qk_layernorm: + self.q_layernorm = nn.LayerNorm( + config.hidden_size // config.num_attention_heads, eps=config.layer_norm_eps, elementwise_affine=True + ) + self.k_layernorm = nn.LayerNorm( + config.hidden_size // config.num_attention_heads, eps=config.layer_norm_eps, elementwise_affine=True + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + if self.qk_layernorm: + query_states = self.q_layernorm(query_states) + key_states = self.k_layernorm(key_states) + + cos, sin = position_embeddings + # Partial rotary embedding + query_rot, query_pass = ( + query_states[..., : self.rotary_ndims], + query_states[..., self.rotary_ndims :], + ) + key_rot, key_pass = ( + key_states[..., : self.rotary_ndims], + key_states[..., self.rotary_ndims :], + ) + # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor] + query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin) + + # [batch_size, seq_length, num_heads, head_dim] + query_states = torch.cat((query_rot, query_pass), dim=-1) + key_states = torch.cat((key_rot, key_pass), dim=-1) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.dense(attn_output) + return attn_output, attn_weights + + +class PhiMLP(CLIPMLP): + pass + + +class PhiDecoderLayer(nn.Module): + def __init__(self, config: PhiConfig, layer_idx: int): + super().__init__() + self.self_attn = PhiAttention(config, layer_idx=layer_idx) + self.mlp = PhiMLP(config) + self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.resid_dropout = nn.Dropout(config.resid_pdrop) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + attn_outputs, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + attn_outputs = self.resid_dropout(attn_outputs) + + feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states)) + hidden_states = attn_outputs + feed_forward_hidden_states + residual + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +class PhiRotaryEmbedding(LlamaRotaryEmbedding): + pass + + +class PhiPreTrainedModel(LlamaPreTrainedModel): + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + + +class PhiModel(LlamaModel): + def __init__(self, config: PhiConfig): + super().__init__(config) + self.layers = nn.ModuleList( + [PhiDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.embed_dropout = nn.Dropout(config.embd_pdrop) + self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + del self.norm + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> BaseModelOutputWithPast: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + inputs_embeds = self.embed_dropout(inputs_embeds) # diff with Llama + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + partial(decoder_layer.__call__, **flash_attn_kwargs), + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.final_layernorm(hidden_states) # diff with Llama + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class PhiForCausalLM(LlamaForCausalLM): + def __init__(self, config): + super().__init__(config) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=True) + + +class PhiForSequenceClassification(LlamaForSequenceClassification): + pass + + +class PhiForTokenClassification(LlamaForTokenClassification): + pass + + +__all__ = [ + "PhiPreTrainedModel", + "PhiModel", + "PhiForCausalLM", + "PhiForSequenceClassification", + "PhiForTokenClassification", +] diff --git a/docs/transformers/src/transformers/models/phi3/__init__.py b/docs/transformers/src/transformers/models/phi3/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0cb1e7a9cd04fb32cb8eb29516d95c0fc8e9d108 --- /dev/null +++ b/docs/transformers/src/transformers/models/phi3/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_phi3 import * + from .modeling_phi3 import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/phi3/configuration_phi3.py b/docs/transformers/src/transformers/models/phi3/configuration_phi3.py new file mode 100644 index 0000000000000000000000000000000000000000..ec133fbd7d504131eb0213f4d1972f85504246ac --- /dev/null +++ b/docs/transformers/src/transformers/models/phi3/configuration_phi3.py @@ -0,0 +1,240 @@ +# coding=utf-8 +# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Phi-3 model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class Phi3Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Phi3Model`]. It is used to instantiate a Phi-3 + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the + [microsoft/Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 32064): + Vocabulary size of the Phi-3 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Phi3Model`]. + hidden_size (`int`, *optional*, defaults to 3072): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 8192): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + resid_pdrop (`float`, *optional*, defaults to 0.0): + Dropout probability for mlp outputs. + embd_pdrop (`int`, *optional*, defaults to 0.0): + The dropout ratio for the embeddings. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio after computing the attention scores. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 4096): + The maximum sequence length that this model might ever be used with. + original_max_position_embeddings (`int`, *optional*, defaults to 4096): + The maximum sequence length that this model was trained with. This is used to determine the size of the + original RoPE embeddings when using long scaling. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon value used for the RMSNorm. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. Whether to tie weight embeddings or not. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`dict`, *optional*): + The scaling strategy for the RoPE embeddings. If `None`, no scaling is applied. If a dictionary, it must + contain the following keys: `type`, `short_factor` and `long_factor`. The `type` must be `longrope` and + the `short_factor` and `long_factor` must be lists of numbers with the same length as the hidden size + divided by the number of attention heads divided by 2. + partial_rotary_factor (`float`, *optional*, defaults to 1.0): + Percentage of the query and keys which will have rotary embedding. Must be between 0.0 and 1.0. + bos_token_id (`int`, *optional*, defaults to 1): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 32000): + The id of the "end-of-sequence" token. + pad_token_id (`int`, *optional*, defaults to 32000): + The id of the padding token. + sliding_window (`int`, *optional*): + Sliding window attention window size. If `None`, no sliding window is applied. + + Example: + + ```python + >>> from transformers import Phi3Model, Phi3Config + + >>> # Initializing a Phi-3 style configuration + >>> configuration = Phi3Config.from_pretrained("microsoft/Phi-3-mini-4k-instruct") + + >>> # Initializing a model from the configuration + >>> model = Phi3Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "phi3" + keys_to_ignore_at_inference = ["past_key_values"] + base_model_tp_plan = { + "layers.*.self_attn.qkv_proj": "colwise_rep", # we need to replicate here due to the slicing of qkv + "layers.*.self_attn.o_proj": "rowwise_rep", # we need to replicate here due to the slicing of qkv + "layers.*.mlp.gate_up_proj": "colwise_rep", # we need to replicate here due to the `chunk` operation + "layers.*.mlp.down_proj": "rowwise_rep", # we need to replicate here due to the `chunk` operation + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=32064, + hidden_size=3072, + intermediate_size=8192, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + resid_pdrop=0.0, + embd_pdrop=0.0, + attention_dropout=0.0, + hidden_act="silu", + max_position_embeddings=4096, + original_max_position_embeddings=4096, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + partial_rotary_factor=1.0, + bos_token_id=1, + eos_token_id=32000, + pad_token_id=32000, + sliding_window=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.resid_pdrop = resid_pdrop + self.embd_pdrop = embd_pdrop + self.attention_dropout = attention_dropout + self.hidden_act = hidden_act + self.max_position_embeddings = max_position_embeddings + self.original_max_position_embeddings = original_max_position_embeddings + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.partial_rotary_factor = partial_rotary_factor + self._rope_scaling_adjustment() + self._rope_scaling_validation() + self.sliding_window = sliding_window + + super().__init__( + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + pad_token_id=pad_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + def _rope_scaling_adjustment(self): + """ + Adjust the `type` of the `rope_scaling` configuration for backward compatibility. + """ + if self.rope_scaling is None: + return + + rope_scaling_type = self.rope_scaling.get("type", None) + + # For backward compatibility if previous version used "su" or "yarn" + if rope_scaling_type is not None and rope_scaling_type in ["su", "yarn"]: + self.rope_scaling["type"] = "longrope" + + def _rope_scaling_validation(self): + """ + Validate the `rope_scaling` configuration. + """ + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 3: + raise ValueError( + "`rope_scaling` must be a dictionary with three fields, `type`, `short_factor` and `long_factor`, " + f"got {self.rope_scaling}" + ) + rope_scaling_type = self.rope_scaling.get("type", None) + rope_scaling_short_factor = self.rope_scaling.get("short_factor", None) + rope_scaling_long_factor = self.rope_scaling.get("long_factor", None) + if rope_scaling_type is None or rope_scaling_type not in ["longrope"]: + raise ValueError(f"`rope_scaling`'s type field must be one of ['longrope'], got {rope_scaling_type}") + if not ( + isinstance(rope_scaling_short_factor, list) + and all(isinstance(x, (int, float)) for x in rope_scaling_short_factor) + ): + raise ValueError( + f"`rope_scaling`'s short_factor field must be a list of numbers, got {rope_scaling_short_factor}" + ) + rotary_ndims = int(self.hidden_size // self.num_attention_heads * self.partial_rotary_factor) + if not len(rope_scaling_short_factor) == rotary_ndims // 2: + raise ValueError( + f"`rope_scaling`'s short_factor field must have length {rotary_ndims // 2}, got {len(rope_scaling_short_factor)}" + ) + if not ( + isinstance(rope_scaling_long_factor, list) + and all(isinstance(x, (int, float)) for x in rope_scaling_long_factor) + ): + raise ValueError( + f"`rope_scaling`'s long_factor field must be a list of numbers, got {rope_scaling_long_factor}" + ) + if not len(rope_scaling_long_factor) == rotary_ndims // 2: + raise ValueError( + f"`rope_scaling`'s long_factor field must have length {rotary_ndims // 2}, got {len(rope_scaling_long_factor)}" + ) + + +__all__ = ["Phi3Config"] diff --git a/docs/transformers/src/transformers/models/phi3/modeling_phi3.py b/docs/transformers/src/transformers/models/phi3/modeling_phi3.py new file mode 100644 index 0000000000000000000000000000000000000000..aaf5332c7176ac1d4b0de76652f343a7da3b99e2 --- /dev/null +++ b/docs/transformers/src/transformers/models/phi3/modeling_phi3.py @@ -0,0 +1,1126 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/phi3/modular_phi3.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_phi3.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Callable, Optional, Tuple, Union + +import torch +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ( + LossKwargs, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + can_return_tuple, + is_torch_flex_attn_available, + logging, + replace_return_docstrings, +) +from .configuration_phi3 import Phi3Config + + +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "microsoft/Phi-3-mini-4k-instruct" +_CONFIG_FOR_DOC = "Phi3Config" + + +class Phi3MLP(nn.Module): + def __init__(self, config): + super().__init__() + + self.config = config + self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False) + self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) + self.activation_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: + up_states = self.gate_up_proj(hidden_states) + + gate, up_states = up_states.chunk(2, dim=-1) + up_states = up_states * self.activation_fn(gate) + + return self.down_proj(up_states) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + rotary_dim = cos.shape[-1] + q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] + k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] + + q_embed = torch.cat([(q_rot * cos) + (rotate_half(q_rot) * sin), q_pass], dim=-1) + k_embed = torch.cat([(k_rot * cos) + (rotate_half(k_rot) * sin), k_pass], dim=-1) + return q_embed, k_embed + + +class Phi3Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Phi3Config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.num_key_value_heads = config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + op_size = config.num_attention_heads * self.head_dim + 2 * (config.num_key_value_heads * self.head_dim) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + self.qkv_proj = nn.Linear(config.hidden_size, op_size, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + qkv = self.qkv_proj(hidden_states) + query_pos = self.config.num_attention_heads * self.head_dim + query_states = qkv[..., :query_pos] + key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim] + value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :] + + query_states = query_states.view(hidden_shape).transpose(1, 2) + key_states = key_states.view(hidden_shape).transpose(1, 2) + value_states = value_states.view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=getattr(self.config, "sliding_window", None), + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +@use_kernel_forward_from_hub("RMSNorm") +class Phi3RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Phi3RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class Phi3DecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Phi3Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = Phi3Attention(config=config, layer_idx=layer_idx) + self.mlp = Phi3MLP(config) + self.input_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.config = config + self.resid_attn_dropout = nn.Dropout(config.resid_pdrop) + self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): + input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + position_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range + `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + past_key_value (`Cache`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + self.resid_attn_dropout(hidden_states) # main diff with Llama + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + self.resid_mlp_dropout(hidden_states) # main diff with Llama + + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +PHI3_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Phi3Config`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Phi3 Model outputting raw hidden-states without any specific head on top.", + PHI3_START_DOCSTRING, +) +class Phi3PreTrainedModel(PreTrainedModel): + config_class = Phi3Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Phi3DecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + _supports_attention_backend = True + _version = "0.0.5" + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, Phi3RMSNorm): + module.weight.data.fill_(1.0) + + +class Phi3RotaryEmbedding(nn.Module): + def __init__(self, config: Phi3Config, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +PHI3_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length) or `BlockMask`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + If the model is configured to use flex_attention, it will attempt to convert the mask Tensor into a BlockMask, + but you can also pass a `BlockMask` object directly here. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare Phi3 Model outputting raw hidden-states without any specific head on top.", + PHI3_START_DOCSTRING, +) +class Phi3Model(Phi3PreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Phi3DecoderLayer`] + + Args: + config: Phi3Config + """ + + def __init__(self, config: Phi3Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Phi3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Phi3RotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @can_return_tuple + @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> BaseModelOutputWithPast: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache + if not isinstance(past_key_values, (type(None), Cache)): + raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, "BlockMask"], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and past_key_values is not None: + is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Phi3. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if ( + self.config._attn_implementation == "sdpa" + and not (using_static_cache or using_sliding_window_cache) + and not output_attentions + ): + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + # SlidingWindowCache or StaticCache + if using_sliding_window_cache or using_static_cache: + target_length = past_key_values.get_max_cache_shape() + # DynamicCache or no cache + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + config=self.config, + past_key_values=past_key_values, + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + config: Phi3Config, + past_key_values: Cache, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to place the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + config (`Phi3Config`): + The model's configuration class + past_key_values (`Cache`): + The cache class that is being used currently to generate + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + if config.get_text_config().sliding_window is not None: + # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also + # the check is needed to verify is current checkpoint was trained with sliding window or not + if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: + sliding_attend_mask = torch.arange(target_length, device=device) <= ( + cache_position.reshape(-1, 1) - config.get_text_config().sliding_window + ) + diagonal_attend_mask.bitwise_or_(sliding_attend_mask) + causal_mask *= diagonal_attend_mask + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.shape[-1] > target_length: + attention_mask = attention_mask[:, :target_length] + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + return causal_mask + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + +class Phi3ForCausalLM(Phi3PreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = Phi3Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @can_return_tuple + @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[KwargsForCausalLM], + ) -> CausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, Phi3ForCausalLM + + >>> model = Phi3ForCausalLM.from_pretrained("meta-phi3/Phi3-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-phi3/Phi3-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + logits_to_keep=None, + **kwargs, + ): + # Overwritten -- this model may need to switch between short and long rope, invalidating the cache in the + # process + + # When the first time input length reached long and short factor switching point, enforce re-compute cache + # It will cause downside of slower at this single token position, however, better than current failure. + if ( + past_key_values + and self.config.rope_scaling + and input_ids.shape[1] >= self.config.original_max_position_embeddings + 1 + ): + past_length = cache_position[0] + if past_length <= self.config.original_max_position_embeddings: + past_key_values = None + + model_inputs = super().prepare_inputs_for_generation( + input_ids=input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + position_ids=position_ids, + use_cache=use_cache, + logits_to_keep=logits_to_keep, + **kwargs, + ) + return model_inputs + + +@add_start_docstrings( + """ + The Phi3 Model transformer with a sequence classification head on top (linear layer). + + [`Phi3ForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + PHI3_START_DOCSTRING, +) +class Phi3ForSequenceClassification(Phi3PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = Phi3Model(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @can_return_tuple + @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> SequenceClassifierOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + transformer_outputs: BaseModelOutputWithPast = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + hidden_states = transformer_outputs.last_hidden_state + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) + last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) + else: + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ + The Phi3 Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + PHI3_START_DOCSTRING, +) +class Phi3ForTokenClassification(Phi3PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = Phi3Model(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @can_return_tuple + @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> TokenClassifierOutput: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + outputs: BaseModelOutputWithPast = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + sequence_output = outputs.last_hidden_state + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.config) + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "Phi3PreTrainedModel", + "Phi3Model", + "Phi3ForCausalLM", + "Phi3ForSequenceClassification", + "Phi3ForTokenClassification", +] diff --git a/docs/transformers/src/transformers/models/phi3/modular_phi3.py b/docs/transformers/src/transformers/models/phi3/modular_phi3.py new file mode 100644 index 0000000000000000000000000000000000000000..4e38e3164df9fa36c5d2eb1dbb2c34ea4a7a90c0 --- /dev/null +++ b/docs/transformers/src/transformers/models/phi3/modular_phi3.py @@ -0,0 +1,305 @@ +# coding=utf-8 +# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""PyTorch Phi-3 model.""" + +from typing import Callable, Optional, Tuple + +import torch +import torch.utils.checkpoint +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...processing_utils import Unpack +from ...utils import logging +from ..mistral.modeling_mistral import ( + MistralDecoderLayer, + MistralForCausalLM, + MistralForSequenceClassification, + MistralForTokenClassification, + MistralPreTrainedModel, + eager_attention_forward, + rotate_half, +) +from .configuration_phi3 import Phi3Config + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "microsoft/Phi-3-mini-4k-instruct" +_CONFIG_FOR_DOC = "Phi3Config" + + +class Phi3MLP(nn.Module): + def __init__(self, config): + super().__init__() + + self.config = config + self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False) + self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) + self.activation_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: + up_states = self.gate_up_proj(hidden_states) + + gate, up_states = up_states.chunk(2, dim=-1) + up_states = up_states * self.activation_fn(gate) + + return self.down_proj(up_states) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + rotary_dim = cos.shape[-1] + q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] + k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] + + q_embed = torch.cat([(q_rot * cos) + (rotate_half(q_rot) * sin), q_pass], dim=-1) + k_embed = torch.cat([(k_rot * cos) + (rotate_half(k_rot) * sin), k_pass], dim=-1) + return q_embed, k_embed + + +class Phi3Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Phi3Config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.num_key_value_heads = config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + op_size = config.num_attention_heads * self.head_dim + 2 * (config.num_key_value_heads * self.head_dim) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + self.qkv_proj = nn.Linear(config.hidden_size, op_size, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + qkv = self.qkv_proj(hidden_states) + query_pos = self.config.num_attention_heads * self.head_dim + query_states = qkv[..., :query_pos] + key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim] + value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :] + + query_states = query_states.view(hidden_shape).transpose(1, 2) + key_states = key_states.view(hidden_shape).transpose(1, 2) + value_states = value_states.view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=getattr(self.config, "sliding_window", None), + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Phi3DecoderLayer(MistralDecoderLayer): + def __init__(self, config: Phi3Config, layer_idx: int): + super().__init__(config, layer_idx) + self.config = config + self.self_attn = Phi3Attention(config=config, layer_idx=layer_idx) + self.mlp = Phi3MLP(config) + self.resid_attn_dropout = nn.Dropout(config.resid_pdrop) + self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): + input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + position_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range + `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + past_key_value (`Cache`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + self.resid_attn_dropout(hidden_states) # main diff with Llama + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + self.resid_mlp_dropout(hidden_states) # main diff with Llama + + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +class Phi3PreTrainedModel(MistralPreTrainedModel): + _version = "0.0.5" + + +class Phi3ForCausalLM(MistralForCausalLM, Phi3PreTrainedModel): + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + logits_to_keep=None, + **kwargs, + ): + # Overwritten -- this model may need to switch between short and long rope, invalidating the cache in the + # process + + # When the first time input length reached long and short factor switching point, enforce re-compute cache + # It will cause downside of slower at this single token position, however, better than current failure. + if ( + past_key_values + and self.config.rope_scaling + and input_ids.shape[1] >= self.config.original_max_position_embeddings + 1 + ): + past_length = cache_position[0] + if past_length <= self.config.original_max_position_embeddings: + past_key_values = None + + model_inputs = Phi3PreTrainedModel().prepare_inputs_for_generation( + input_ids=input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + position_ids=position_ids, + use_cache=use_cache, + logits_to_keep=logits_to_keep, + **kwargs, + ) + return model_inputs + + +class Phi3ForSequenceClassification(MistralForSequenceClassification): + pass + + +class Phi3ForTokenClassification(MistralForTokenClassification): + pass + + +__all__ = [ + "Phi3PreTrainedModel", + "Phi3Model", # noqa: F822 + "Phi3ForCausalLM", + "Phi3ForSequenceClassification", + "Phi3ForTokenClassification", +] diff --git a/docs/transformers/src/transformers/models/phi4_multimodal/__init__.py b/docs/transformers/src/transformers/models/phi4_multimodal/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c4e2e599f57a8f9fe23254a46ff6fda67e82f13b --- /dev/null +++ b/docs/transformers/src/transformers/models/phi4_multimodal/__init__.py @@ -0,0 +1,32 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_phi4_multimodal import * + from .feature_extraction_phi4_multimodal import * + from .image_processing_phi4_multimodal_fast import * + from .modeling_phi4_multimodal import * + from .processing_phi4_multimodal import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/phi4_multimodal/configuration_phi4_multimodal.py b/docs/transformers/src/transformers/models/phi4_multimodal/configuration_phi4_multimodal.py new file mode 100644 index 0000000000000000000000000000000000000000..3f776b0b71ecfd6e57ebe8a0c2c2f320d281ffc0 --- /dev/null +++ b/docs/transformers/src/transformers/models/phi4_multimodal/configuration_phi4_multimodal.py @@ -0,0 +1,482 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_phi4_multimodal.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2025 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +from ...configuration_utils import PretrainedConfig + + +class Phi4MultimodalVisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Phi4MultimodalVisionModel`]. It is used to instantiate a + Phi4Multimodal vision encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the vision encoder of + [microsoft/Phi-4-multimodal-instruct](https://huggingface.co/microsoft/Phi-4-multimodal-instruct) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 1152): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 4304): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 27): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + num_channels (`int`, *optional*, defaults to 3): + Number of channels in the input images. + image_size (`int`, *optional*, defaults to 448): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 14): + The size (resolution) of each patch. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + crop_size (`int`, *optional*, defaults to 448): + Crop size for the input images. + image_token_id (`int`, *optional*, defaults to 200010): + The image token id. + feature_layer (`int`, *optional*, defaults to -2): + The index of the layer of the encoder from which to extract image features. + + Example: + + ```python + >>> from transformers import Phi4MultimodalVisionConfig + + >>> # Initializing a Phi4MultimodalVisionConfig with microsoft/Phi-4-multimodal-instruct style configuration + >>> configuration = Phi4MultimodalVisionConfig() + ```""" + + model_type = "phi4_multimodal_vision" + base_config_key = "vision_config" + + def __init__( + self, + hidden_size=1152, + intermediate_size=4304, + num_hidden_layers=27, + num_attention_heads=16, + num_channels=3, + image_size=448, + patch_size=14, + hidden_act="gelu_pytorch_tanh", + layer_norm_eps=1e-6, + attention_dropout=0.0, + crop_size: int = 448, + image_token_id: int = 200010, + feature_layer: int = -2, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.image_size = image_size + self.attention_dropout = attention_dropout + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + self.crop_size = crop_size + self.image_token_id = image_token_id + self.feature_layer = feature_layer + + +class Phi4MultimodalAudioConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Phi4MultimodalAudioModel`]. It is used to instantiate a + Phi4Multimodal audio encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the audio encoder of + [microsoft/Phi-4-multimodal-instruct](https://huggingface.co/microsoft/Phi-4-multimodal-instruct) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 1024): + Dimensionality of the encoder layers. + intermediate_size (`int`, *optional*, defaults to 1536): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_blocks (`int`, *optional*, defaults to 24): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + activation (`str`, *optional*, defaults to `"swish"`): + The non-linear activation function in the MLPs. + chunk_size (`int`, *optional*, defaults to -1): + The chunk size to create the masks. + left_chunk (`int`, *optional*, defaults to 18): + The left chunk to create the masks. + dropout_rate (`float`, *optional*, defaults to 0.0): + The dropout ratio. + ext_pw_out_channel (`int`, *optional*, defaults to 1024): + Number of out channels in the point-wise conv modules. + depthwise_seperable_out_channel (`int`, *optional*, defaults to 1024): + Number of out channels in the depth-wise separable conv modules. + depthwise_multiplier (`int`, *optional*, defaults to 1): + Input size multiplier for the depth-wise separable conv modules. + kernel_size (`int`, *optional*, defaults to 3): + Kernel size for the depth-wise separable conv modules. + conv_activation (`str`, *optional*, defaults to `"swish"`): + The non-linear activation function in the conv modules. + input_size (`int`, *optional*, defaults to 80): + Input size for the audio model. + conv_glu_type (`str`, *optional*, defaults to `"swish"`): + The non-linear activation function in the point-wise conv modules. + time_reduction (`int`, *optional*, defaults to 8): + Time reduction (subsampling factor). + bias_max_distance (`int`, *optional*, defaults to 1000): + Max distance for the relative attention bias module. + bias_symmetric (`bool`, *optional*, defaults to `False`): + Whether the relative attention bias should be symmetric or not. + nemo_activation (`str`, *optional*, defaults to `"relu"`): + The non-linear activation function in the nemo conv modules. + nemo_conv_channels (`int`, *optional*, defaults to 1024): + Number of channels in the nemo conv modules. + downsample_rate (`int`, *optional*, defaults to 1): + Downsample rate for the audio feature extractor. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + audio_token_id (`int`, *optional*, defaults to 200011): + The audio token id. + feature_layer (`int`, *optional*, defaults to -2): + The index of the layer of the encoder from which to extract audio features. + + Example: + + ```python + >>> from transformers import Phi4MultimodalAudioConfig + + >>> # Initializing a Phi4MultimodalAudioConfig with microsoft/Phi-4-multimodal-instruct style configuration + >>> configuration = Phi4MultimodalAudioConfig() + ```""" + + model_type = "phi4_multimodal_audio" + + def __init__( + self, + hidden_size: int = 1024, + intermediate_size: int = 1536, + num_blocks: int = 24, + num_attention_heads: int = 16, + activation: str = "swish", + chunk_size: int = -1, + left_chunk: int = 18, + dropout_rate: float = 0.0, + ext_pw_out_channel: int = 1024, + depthwise_seperable_out_channel: int = 1024, + depthwise_multiplier: int = 1, + kernel_size: int = 3, + conv_activation: str = "swish", + input_size: int = 80, + conv_glu_type: str = "swish", + time_reduction: int = 8, + bias_max_distance: int = 1000, + bias_symmetric: bool = False, + nemo_activation: str = "relu", + nemo_conv_channels: int = 1024, + downsample_rate: int = 1, + initializer_range: float = 0.02, + audio_token_id: int = 200011, + feature_layer: int = -2, + **kwargs, + ): + super().__init__(**kwargs) + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.activation = activation + self.chunk_size = chunk_size + self.left_chunk = left_chunk + self.num_blocks = num_blocks + self.dropout_rate = dropout_rate + self.ext_pw_out_channel = ext_pw_out_channel + self.depthwise_seperable_out_channel = depthwise_seperable_out_channel + self.depthwise_multiplier = depthwise_multiplier + self.kernel_size = kernel_size + self.conv_activation = conv_activation + self.input_size = input_size + self.conv_glu_type = conv_glu_type + self.time_reduction = time_reduction + self.bias_max_distance = bias_max_distance + self.bias_symmetric = bias_symmetric + self.nemo_activation = nemo_activation + self.nemo_conv_channels = nemo_conv_channels + self.downsample_rate = downsample_rate + self.audio_token_id = audio_token_id + self.initializer_range = initializer_range + self.feature_layer = feature_layer + + if time_reduction % 2 != 0: + raise ValueError("`time_reduction` should be a multiple of 2!") + length = input_size + for _ in range(int(math.log(time_reduction, 2))): + length = math.floor((length - 1) / 2 + 1) + self.nemo_final_size = length + + +class Phi4MultimodalConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Phi4MultimodalModel`]. It is used to instantiate a + Phi4Multimodal model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the + [microsoft/Phi-4-multimodal-instruct](https://huggingface.co/microsoft/Phi-4-multimodal-instruct) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 200064): + Vocabulary size of the Phi-3 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Phi3Model`]. + hidden_size (`int`, *optional*, defaults to 3072): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 8192): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + resid_pdrop (`float`, *optional*, defaults to 0.0): + Dropout probability for mlp outputs. + embd_pdrop (`int`, *optional*, defaults to 0.0): + The dropout ratio for the embeddings. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio after computing the attention scores. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 131072): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon value used for the RMSNorm. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. Whether to tie weight embeddings or not. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`dict`, *optional*): + The scaling strategy for the RoPE embeddings. If `None`, no scaling is applied. If a dictionary, it must + contain the following keys: `type`, `short_factor` and `long_factor`. The `type` must be `longrope` and + the `short_factor` and `long_factor` must be lists of numbers with the same length as the hidden size + divided by the number of attention heads divided by 2. + partial_rotary_factor (`float`, *optional*, defaults to `1.0`): + Percentage of the query and keys which will have rotary embedding. Must be between 0.0 and 1.0. + bos_token_id (`int`, *optional*, defaults to 199999): + The id of the "beginning-of-sequence" token. + eos_token_id (`int` or `list[int]`, *optional*, defaults to `[199999, 200020]`): + The id of the "end-of-sequence" token. + pad_token_id (`int`, *optional*, defaults to 199999): + The id of the padding token. + original_max_position_embeddings (`int`, *optional*, defaults to 4096): + The maximum sequence length that this model was trained with. This is used to determine the size of the + original RoPE embeddings when using long scaling. + sliding_window (`int`, *optional*): + Sliding window attention window size. If `None`, no sliding window is applied. + vision_config (`Phi4MultimodalVisionConfig` or `dict`, *optional*): + The vision config for the underlying image embedding model. If not provided, will default to the configuration + used to instantiate a model similar in architecture as + [microsoft/Phi-4-multimodal-instruct](https://huggingface.co/microsoft/Phi-4-multimodal-instruct). + audio_config (`Phi4MultimodalAudioConfig` or `dict`, *optional*): + The audio config for the underlying audio embedding model. If not provided, will default to the configuration + used to instantiate a model similar in architecture as + [microsoft/Phi-4-multimodal-instruct](https://huggingface.co/microsoft/Phi-4-multimodal-instruct). + + Example: + + ```python + >>> from transformers import Phi4MultimodalModel, Phi4MultimodalConfig + + >>> # Initializing a Phi4Multimodal style configuration + >>> configuration = Phi4MultimodalConfig.from_pretrained("microsoft/Phi-4-multimodal-instruct") + + >>> # Initializing a model from the configuration + >>> model = Phi4MultimodalModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "phi4_multimodal" + keys_to_ignore_at_inference = ["past_key_values"] + base_model_tp_plan = { + "layers.*.self_attn.qkv_proj": "colwise_rep", # we need to replicate here due to the slicing of qkv + "layers.*.self_attn.o_proj": "rowwise_rep", # we need to replicate here due to the slicing of qkv + "layers.*.mlp.gate_up_proj": "colwise_rep", # we need to replicate here due to the `chunk` operation + "layers.*.mlp.down_proj": "rowwise_rep", # we need to replicate here due to the `chunk` operation + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + sub_configs = {"audio_config": Phi4MultimodalAudioConfig, "vision_config": Phi4MultimodalVisionConfig} + + def __init__( + self, + vocab_size=200064, + hidden_size=3072, + intermediate_size=8192, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=8, + resid_pdrop=0.0, + embd_pdrop=0.0, + attention_dropout=0.0, + hidden_act="silu", + max_position_embeddings=131072, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + partial_rotary_factor=1, + bos_token_id=199999, + eos_token_id=[199999, 200020], + pad_token_id=199999, + original_max_position_embeddings=4096, + sliding_window=None, + vision_config=None, + audio_config=None, + **kwargs, + ): + super().__init__( + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + pad_token_id=pad_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.resid_pdrop = resid_pdrop + self.embd_pdrop = embd_pdrop + self.attention_dropout = attention_dropout + self.hidden_act = hidden_act + self.max_position_embeddings = max_position_embeddings + self.original_max_position_embeddings = original_max_position_embeddings + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.partial_rotary_factor = partial_rotary_factor + self._rope_scaling_adjustment() + self._rope_scaling_validation() + self.sliding_window = sliding_window + + if isinstance(vision_config, dict): + vision_config = Phi4MultimodalVisionConfig(**vision_config) + elif vision_config is None: + Phi4MultimodalVisionConfig() + self.vision_config = vision_config + + if isinstance(audio_config, dict): + audio_config = Phi4MultimodalAudioConfig(**audio_config) + elif vision_config is None: + audio_config = Phi4MultimodalAudioConfig() + self.audio_config = audio_config + + def _rope_scaling_adjustment(self): + """ + Adjust the `type` of the `rope_scaling` configuration for backward compatibility. + """ + if self.rope_scaling is None: + return + + rope_scaling_type = self.rope_scaling.get("type", None) + + # For backward compatibility if previous version used "su" or "yarn" + if rope_scaling_type is not None and rope_scaling_type in ["su", "yarn"]: + self.rope_scaling["type"] = "longrope" + + def _rope_scaling_validation(self): + """ + Validate the `rope_scaling` configuration. + """ + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 3: + raise ValueError( + "`rope_scaling` must be a dictionary with three fields, `type`, `short_factor` and `long_factor`, " + f"got {self.rope_scaling}" + ) + rope_scaling_type = self.rope_scaling.get("type", None) + rope_scaling_short_factor = self.rope_scaling.get("short_factor", None) + rope_scaling_long_factor = self.rope_scaling.get("long_factor", None) + if rope_scaling_type is None or rope_scaling_type not in ["longrope"]: + raise ValueError(f"`rope_scaling`'s type field must be one of ['longrope'], got {rope_scaling_type}") + if not ( + isinstance(rope_scaling_short_factor, list) + and all(isinstance(x, (int, float)) for x in rope_scaling_short_factor) + ): + raise ValueError( + f"`rope_scaling`'s short_factor field must be a list of numbers, got {rope_scaling_short_factor}" + ) + rotary_ndims = int(self.hidden_size // self.num_attention_heads * self.partial_rotary_factor) + if not len(rope_scaling_short_factor) == rotary_ndims // 2: + raise ValueError( + f"`rope_scaling`'s short_factor field must have length {rotary_ndims // 2}, got {len(rope_scaling_short_factor)}" + ) + if not ( + isinstance(rope_scaling_long_factor, list) + and all(isinstance(x, (int, float)) for x in rope_scaling_long_factor) + ): + raise ValueError( + f"`rope_scaling`'s long_factor field must be a list of numbers, got {rope_scaling_long_factor}" + ) + if not len(rope_scaling_long_factor) == rotary_ndims // 2: + raise ValueError( + f"`rope_scaling`'s long_factor field must have length {rotary_ndims // 2}, got {len(rope_scaling_long_factor)}" + ) + + +__all__ = ["Phi4MultimodalVisionConfig", "Phi4MultimodalAudioConfig", "Phi4MultimodalConfig"] diff --git a/docs/transformers/src/transformers/models/phi4_multimodal/convert_phi4_multimodal_weights_to_hf.py b/docs/transformers/src/transformers/models/phi4_multimodal/convert_phi4_multimodal_weights_to_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..b1a4ac90ac7c3057ec592fee004b3d9de5e6f0d5 --- /dev/null +++ b/docs/transformers/src/transformers/models/phi4_multimodal/convert_phi4_multimodal_weights_to_hf.py @@ -0,0 +1,271 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import json +import os +import re + +import torch +from peft import LoraConfig +from safetensors.torch import load_file, save_file + +from transformers import ( + AutoProcessor, + Phi4MultimodalAudioConfig, + Phi4MultimodalConfig, + Phi4MultimodalFeatureExtractor, + Phi4MultimodalForCausalLM, + Phi4MultimodalImageProcessorFast, + Phi4MultimodalProcessor, + Phi4MultimodalVisionConfig, +) + + +CHAT_TEMPLATE = "{% for message in messages %}{{ '<|' + message['role'] + '|>' }}{% if message['content'] is string %}{{ message['content'] }}{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' %}{{ '<|image|>' }}{% elif content['type'] == 'audio' %}{{ '<|audio|>' }}{% elif content['type'] == 'text' %}{{ content['text'] }}{% endif %}{% endfor %}{% endif %}{% if message['role'] == 'system' and 'tools' in message and message['tools'] is not none %}{{ '<|tool|>' + message['tools'] + '<|/tool|>' + '<|end|>' }}{% endif %}{{ '<|end|>' }}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>' }}{% else %}{{ eos_token }}{% endif %}" + +# fmt: off +STATE_DICT_MAPPING = { + r"^model.embed_tokens_extend.audio_embed.encoder.encoders.(\d+).feed_forward_(in|out).net.0.linear": r"model.embed_tokens_extend.audio_embed.encoder.encoders.\1.feed_forward_\2.gate_up_proj", + r"^model.embed_tokens_extend.audio_embed.encoder.encoders.(\d+).feed_forward_(in|out).net.2": r"model.embed_tokens_extend.audio_embed.encoder.encoders.\1.feed_forward_\2.down_proj", + + r"^model.embed_tokens_extend.audio_embed.encoder.encoders.(\d+).self_attn.linear_(q|k|v)": r"model.embed_tokens_extend.audio_embed.encoder.encoders.\1.self_attn.\2_proj", + r"^model.embed_tokens_extend.audio_embed.encoder.encoders.(\d+).self_attn.linear_out": r"model.embed_tokens_extend.audio_embed.encoder.encoders.\1.self_attn.o_proj", + + r"^model.embed_tokens_extend.image_embed.img_projection.0": r"model.embed_tokens_extend.image_embed.img_projection_up", + r"^model.embed_tokens_extend.image_embed.img_projection.2": r"model.embed_tokens_extend.image_embed.img_projection_down", + + r"^model.embed_tokens_extend.image_embed.glb_GN": r"model.embed_tokens_extend.image_embed.global_img_feature_extensor", + r"^model.embed_tokens_extend.image_embed.sub_GN": r"model.embed_tokens_extend.image_embed.sub_img_feature_extensor", + + r"^model.embed_tokens_extend.audio_embed.audio_projection.speech.0": r"model.embed_tokens_extend.audio_embed.up_proj_for_speech", + r"^model.embed_tokens_extend.audio_embed.audio_projection.speech.2": r"model.embed_tokens_extend.audio_embed.down_proj_for_speech", + r"^model.embed_tokens_extend.audio_embed.audio_projection.vision.0": r"model.embed_tokens_extend.audio_embed.up_proj_for_vision_speech", + r"^model.embed_tokens_extend.audio_embed.audio_projection.vision.2": r"model.embed_tokens_extend.audio_embed.down_proj_for_vision_speech", +} +# fmt: on + + +def map_old_key_to_new(old_key): + """Map of a key of the original state dict to the equivalent key in HF format""" + for pattern, replacement in STATE_DICT_MAPPING.items(): + new_key, n_replace = re.subn(pattern, replacement, old_key) + # Early exit of the loop + if n_replace > 0: + return new_key + + # The state dict contains lora keys.... + if "lora" in old_key: + return None + # This extracts the original weight before adding the lora adapter + if "base_layer." in old_key: + return old_key.replace("base_layer.", "") + + # not part of the key mapping, we keep the original name + return old_key + + +def convert_state_dict(original_state_dict: dict): + """Convert a state dict file.""" + new_dict = {} + for old_key, tensor in original_state_dict.items(): + new_key = map_old_key_to_new(old_key) + if new_key is not None: + new_dict[new_key] = tensor + return new_dict + + +def convert_config(original_config: dict): + # Remove unused args + original_config.pop("_name_or_path", None) + original_config.pop("architectures", None) + original_config.pop("auto_map", None) + original_config.pop("vision_lora", None) + original_config.pop("speech_lora", None) + original_config.pop("transformers_version", None) + original_config.pop("_attn_implementation", None) + + embd_layer = original_config.pop("embd_layer") + audio_embd_layer = embd_layer["audio_embd_layer"] + vision_embd_layer = embd_layer["image_embd_layer"] + + # Keep only some of the subdict + keep_audio_embd_layer = ["downsample_rate"] + keep_vision_embd_layer = ["crop_size"] + audio_embd_layer = {k: v for k, v in audio_embd_layer.items() if k in keep_audio_embd_layer} + vision_embd_layer = {k: v for k, v in vision_embd_layer.items() if k in keep_vision_embd_layer} + + audio_config = original_config.pop("audio_processor")["config"] + # remove + audio_config.pop("activation_checkpointing", None) + audio_config.pop("cnn_layer_norm", None) + audio_config.pop("input_layer", None) + audio_config.pop("batch_norm", None) + audio_config.pop("encoder_embedding_config", None) + audio_config.pop("ext_pw_kernel_size", None) + audio_config.pop("bias_in_glu", None) + audio_config.pop("causal", None) + # rename + audio_config["hidden_size"] = audio_config.pop("attention_dim") + audio_config["num_attention_heads"] = audio_config.pop("attention_heads") + audio_config["intermediate_size"] = audio_config.pop("linear_units") + audio_config["nemo_conv_channels"] = audio_config.pop("nemo_conv_settings")["conv_channels"] + audio_config["bias_max_distance"] = audio_config.pop("relative_attention_bias_args")["t5_bias_max_distance"] + # add + audio_config = {**audio_config, **audio_embd_layer} + + # Create transformers config objects + audio_config = Phi4MultimodalAudioConfig(**audio_config) + vision_config = Phi4MultimodalVisionConfig(**vision_embd_layer) + + # Add 2nd eos to config + original_config["eos_token_id"] = [199999, 200020] + + new_config = Phi4MultimodalConfig(**original_config, vision_config=vision_config, audio_config=audio_config) + return new_config + + +def read_json(path): + with open(path, "r") as f: + return json.load(f) + + +def convert_and_write_model(input_dir: str, output_dir: str): + """Convert the model and save it (this implicitly save the config as well).""" + original_config = read_json(os.path.join(input_dir, "config.json")) + config = convert_config(original_config) + + full_state_dict = {} + shards = [file for file in os.listdir(input_dir) if file.endswith(".safetensors")] + for shard_file in shards: + original_state_dict = load_file(os.path.join(input_dir, shard_file)) + new_dict = convert_state_dict(original_state_dict) + full_state_dict.update(new_dict) + + # Load weights into model and resave them + with torch.device("meta"): + model = Phi4MultimodalForCausalLM(config) + missing, unexpected = model.load_state_dict(full_state_dict, strict=False, assign=True) + # The lm_head is missing because it's tied + if missing != ["lm_head.weight"]: + raise ValueError("Missing keys:\n{missing}") + if len(unexpected) > 0: + raise ValueError(f"Unexpected keys:\n{unexpected}") + + model.tie_weights() + model.save_pretrained(output_dir) + + +def convert_and_save_processor(input_dir: str, output_dir: str): + """Convert the processor.""" + original_processor = AutoProcessor.from_pretrained(input_dir, trust_remote_code=True) + original_processor.tokenizer.extra_special_tokens = {"image_token": "<|image|>", "audio_token": "<|audio|>"} + # We need to add those temporarily to instantiate the processor + original_processor.tokenizer.image_token = "<|image|>" + original_processor.tokenizer.audio_token = "<|audio|>" + original_processor.tokenizer.image_token_id = 200010 + original_processor.tokenizer.audio_token_id = 200011 + + converted_processor = Phi4MultimodalProcessor( + tokenizer=original_processor.tokenizer, + image_processor=Phi4MultimodalImageProcessorFast(), + audio_processor=Phi4MultimodalFeatureExtractor(), + chat_template=CHAT_TEMPLATE, + ) + # We remove them before saving to avoid polluting somehow + del converted_processor.tokenizer.image_token + del converted_processor.tokenizer.image_token_id + del converted_processor.tokenizer.audio_token + del converted_processor.tokenizer.audio_token_id + + # Save the processor + converted_processor.save_pretrained(output_dir) + + # we need to rename a few tokens but tokenizers doesn't allow doing that programatically + # To avoid consufion and manual renaming, the below part load and re-saved each json file + vocab = json.load(open(f"{output_dir}/vocab.json", "r")) + vocab["<|endoftext11|>"] = "<|audio|>" + vocab["<|endoftext10|>"] = "<|image|>" + json.dump(vocab, open(f"{output_dir}/vocab.json", "w")) + + tokenizer = json.load(open(f"{output_dir}/tokenizer.json", "r")) + tokenizer["added_tokens"][1]["content"] = "<|image|>" + tokenizer["added_tokens"][2]["content"] = "<|audio|>" + tokenizer["model"]["vocab"]["<|audio|>"] = tokenizer["model"]["vocab"]["<|endoftext11|>"] + tokenizer["model"]["vocab"]["<|image|>"] = tokenizer["model"]["vocab"]["<|endoftext10|>"] + del tokenizer["model"]["vocab"]["<|endoftext11|>"] + del tokenizer["model"]["vocab"]["<|endoftext10|>"] + json.dump(tokenizer, open(f"{output_dir}/tokenizer.json", "w")) + + tokenizer_config = json.load(open(f"{output_dir}/tokenizer_config.json", "r")) + tokenizer_config["added_tokens_decoder"]["200010"]["content"] = "<|image|>" + tokenizer_config["added_tokens_decoder"]["200011"]["content"] = "<|audio|>" + json.dump(tokenizer_config, open(f"{output_dir}/tokenizer_config.json", "w")) + + +def extract_adapters_data(input_dir: str, output_dir: str): + """Extract adapters data from the state dict and save weights and configs.""" + speech_lora = {} + vision_lora = {} + shards = [file for file in os.listdir(input_dir) if file.endswith(".safetensors")] + for shard_file in shards: + original_state_dict = load_file(os.path.join(input_dir, shard_file)) + for k, v in original_state_dict.items(): + if "lora" in k: + if "speech" in k: + speech_lora[k.replace("speech.", "")] = v + elif "vision" in k: + vision_lora[k.replace("vision.", "")] = v + + # Create and save the lora configs + speech_lora_config = LoraConfig( + r=320, + lora_alpha=640, + target_modules=r"model.layers.\d+.((self_attn.(qkv|o)_proj)|(mlp.(gate_up|down)_proj))", + lora_dropout=0.01, + task_type="CAUSAL_LM", + ) + speech_lora_config.save_pretrained(os.path.join(output_dir, "speech-lora")) + vision_lora_config = LoraConfig( + r=256, + lora_alpha=512, + target_modules=r"model.layers.\d+.((self_attn.(qkv|o)_proj)|(mlp.(gate_up|down)_proj))", + lora_dropout=0.0, + task_type="CAUSAL_LM", + ) + vision_lora_config.save_pretrained(os.path.join(output_dir, "vision-lora")) + + save_file(speech_lora, os.path.join(output_dir, "speech-lora", "adapter_model.safetensors")) + save_file(vision_lora, os.path.join(output_dir, "vision-lora", "adapter_model.safetensors")) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "input_dir", + help="Location of the model folder containing the weights and configs.", + ) + parser.add_argument( + "output_dir", + help="Location to write HF model.", + ) + args = parser.parse_args() + + # Convert + convert_and_write_model(args.input_dir, args.output_dir) + convert_and_save_processor(args.input_dir, args.output_dir) + extract_adapters_data(args.input_dir, args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/docs/transformers/src/transformers/models/phi4_multimodal/feature_extraction_phi4_multimodal.py b/docs/transformers/src/transformers/models/phi4_multimodal/feature_extraction_phi4_multimodal.py new file mode 100644 index 0000000000000000000000000000000000000000..5d29af6c8bfb35dd22703518ff41879129e4d6f9 --- /dev/null +++ b/docs/transformers/src/transformers/models/phi4_multimodal/feature_extraction_phi4_multimodal.py @@ -0,0 +1,348 @@ +# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Processor class for Phi4Multimodal +""" + +from typing import Optional, Union + +import numpy as np + +from ...audio_utils import AudioInput +from ...feature_extraction_sequence_utils import SequenceFeatureExtractor +from ...image_processing_utils import BatchFeature +from ...utils import TensorType, is_torch_available, logging + + +if is_torch_available(): + import torch + + +logger = logging.get_logger(__name__) + + +# TODO: @eustlb, remove this once #36603 is merged. +def speechlib_mel(sample_rate, n_fft, n_mels, fmin=None, fmax=None): + """Create a Mel filter-bank the same as SpeechLib FbankFC. + + Args: + sample_rate (int): Sample rate in Hz. number > 0 [scalar] + n_fft (int): FFT size. int > 0 [scalar] + n_mel (int): Mel filter size. int > 0 [scalar] + fmin (float): lowest frequency (in Hz). If None use 0.0. + float >= 0 [scalar] + fmax: highest frequency (in Hz). If None use sample_rate / 2. + float >= 0 [scalar] + + Returns + out (numpy.ndarray): Mel transform matrix + [shape=(n_mels, 1 + n_fft/2)] + """ + + bank_width = int(n_fft // 2 + 1) + if fmax is None: + fmax = sample_rate / 2 + if fmin is None: + fmin = 0 + assert fmin >= 0, "fmin cannot be negtive" + assert fmin < fmax <= sample_rate / 2, "fmax must be between (fmin, samplerate / 2]" + + def mel(f): + return 1127.0 * np.log(1.0 + f / 700.0) + + def bin2mel(fft_bin): + return 1127.0 * np.log(1.0 + fft_bin * sample_rate / (n_fft * 700.0)) + + def f2bin(f): + return int((f * n_fft / sample_rate) + 0.5) + + # Spec 1: FFT bin range [f2bin(fmin) + 1, f2bin(fmax) - 1] + klo = f2bin(fmin) + 1 + khi = f2bin(fmax) + + khi = max(khi, klo) + + # Spec 2: SpeechLib uses trianges in Mel space + mlo = mel(fmin) + mhi = mel(fmax) + m_centers = np.linspace(mlo, mhi, n_mels + 2) + ms = (mhi - mlo) / (n_mels + 1) + + matrix = np.zeros((n_mels, bank_width), dtype=np.float32) + for m in range(0, n_mels): + left = m_centers[m] + center = m_centers[m + 1] + right = m_centers[m + 2] + for fft_bin in range(klo, khi): + mbin = bin2mel(fft_bin) + if left < mbin < right: + matrix[m, fft_bin] = 1.0 - abs(center - mbin) / ms + + return matrix + + +class Phi4MultimodalFeatureExtractor(SequenceFeatureExtractor): + model_input_names = ["audio_input_features", "audio_embed_sizes", "audio_attention_mask"] + + def __init__( + self, + feature_size: int = 80, + sampling_rate: int = 16000, + hop_length: int = 160, + n_fft: int = 512, + win_length: int = 400, + preemphasis: float = 0.97, + padding_value: float = 0.0, + audio_compression_rate: int = 8, + audio_downsample_rate: int = 1, + audio_feat_stride: int = 1, + mel_min_frequency: float = 0, + mel_max_frequency: float = 7690, + **kwargs, + ): + super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs) + + self.hop_length = hop_length + self.n_fft = n_fft + self.win_length = win_length + self.preemphasis = preemphasis + self.padding_value = padding_value + self.audio_compression_rate = audio_compression_rate + self.audio_downsample_rate = audio_downsample_rate + self.audio_feat_stride = audio_feat_stride + + # TODO: @eustlb, uncomment and remove speechlib_mel once #36603 is merged. + # self.mel_filters = mel_filter_bank( + # num_frequency_bins=self.n_fft // 2 + 1, + # num_mel_filters=self.feature_size, + # min_frequency=mel_min_frequency, + # max_frequency=mel_max_frequency, + # sampling_rate=self.sampling_rate, + # triangularize_in_mel_space=True, + # mel_scale="kaldi", + # ) + self.mel_filters = speechlib_mel( + self.sampling_rate, self.n_fft, self.feature_size, mel_min_frequency, mel_max_frequency + ).T + + def __call__( + self, + raw_speech: AudioInput, + sampling_rate: Optional[int] = None, + pad_to_multiple_of: Optional[int] = None, + padding: Optional[str] = "longest", + max_length: Optional[int] = None, + truncation: bool = False, + return_tensors: Optional[Union[str, TensorType]] = None, + return_attention_mask: Optional[bool] = True, + device: Optional[str] = "cpu", + **kwargs, + ) -> BatchFeature: + """ + Main method to featurize and prepare for the model one or several audio sequence(s). Implementation uses PyTorch for + the STFT computation if available, otherwise a slower NumPy based one. + + Args: + raw_speech (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`): + The sequence or batch of sequences to be processed. Each sequence can be a numpy array or PyTorch tensor. + For batched inputs, sequences can be a list of numpy arrays or PyTorch tensors, or a single numpy array or + PyTorch tensor with first dimension being the batch size. + sampling_rate (`int`, *optional*): + The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass + `sampling_rate` at the forward call to prevent silent errors. + pad_to_multiple_of (`int`, *optional*, defaults to None): + If set will pad the sequence to a multiple of the provided value. + padding (`str`, *optional*, defaults to "longest"): + Padding strategy. Can be "longest" to pad to the longest sequence in the batch, or a specific length. + max_length (`int`, *optional*): + Maximum length of the returned list and optionally padding length. + truncation (`bool`, *optional*, defaults to False): + Activates truncation to cut input sequences longer than *max_length* to *max_length*. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors instead of numpy arrays. Acceptable values are: + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return Numpy `np.ndarray` objects. + - `'tf'`: Return TensorFlow `tf.constant` objects. + return_attention_mask (`bool`, *optional*, defaults to `True`): + Whether to return the extracted audio input features' attention mask. + device (`str`, *optional*, defaults to "cpu"): + Specifies the device for computation of the audio features. (e.g., "cpu", "cuda") + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + - **audio_input_features** -- Audio features extracted from the raw audio input, shape (batch_size, max_feature_length, feature_size). + - **audio_lengths** -- Length of each audio sample in the batch, shape (batch_size,). + - **audio_attention_mask** -- Attention mask for the audio input, shape (batch_size, max_feature_length). + If `return_tensors` is not specified, the fields will be PyTorch tensors if PyTorch is available, otherwise NumPy arrays. + """ + if sampling_rate is not None: + if sampling_rate != self.sampling_rate: + raise ValueError( + f"The model corresponding to this feature extractor: {self.__class__.__name__} was trained using a" + f" sampling rate of {self.sampling_rate}. Please make sure that the provided `raw_speech` input" + f" was sampled with {self.sampling_rate} and not {sampling_rate}." + ) + else: + logger.warning( + f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. " + "Failing to do so can result in silent errors that might be hard to debug." + ) + + # Convert to torch tensor + if isinstance(raw_speech, np.ndarray): + raw_speech = torch.tensor(raw_speech) + elif isinstance(raw_speech, (list, tuple)) and isinstance(raw_speech[0], np.ndarray): + raw_speech = [torch.tensor(speech) for speech in raw_speech] + + is_batched_torch = isinstance(raw_speech, torch.Tensor) and len(raw_speech.shape) > 1 + if is_batched_torch and len(raw_speech.shape) > 2: + logger.warning( + f"Only mono-channel audio is supported for input to {self.__class__.__name__}. " + "We will take the mean of the channels to convert to mono." + ) + raw_speech = raw_speech.mean(-1) + + is_batched_sequence = isinstance(raw_speech, (list, tuple)) + if is_batched_sequence: + for speech in raw_speech: + if len(speech.shape) > 1: + logger.warning( + f"Only mono-channel audio is supported for input to {self.__class__.__name__}. " + "We will take the mean of the channels to convert to mono." + ) + speech = speech.mean(-1) + + if is_batched_torch or is_batched_sequence: + raw_speech = [speech[:, None].to(torch.float32) for speech in raw_speech] + else: + raw_speech = [raw_speech[:, None].to(torch.float32)] + + audio_lengths = [len(speech) for speech in raw_speech] + + # convert into correct format for padding + batched_speech = BatchFeature(data={"audio_input_features": raw_speech, "audio_lengths": audio_lengths}) + padded_inputs = self.pad( + batched_speech, + padding=padding, + max_length=max_length, + truncation=truncation, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors="pt", + ) + input_features = padded_inputs.audio_input_features.squeeze(-1) + audio_lengths = padded_inputs.audio_lengths + + input_features = self._torch_extract_fbank_features(input_features, audio_lengths, device) + + feature_lengths = (audio_lengths - self.win_length) // self.hop_length + 1 + feature_lengths = feature_lengths * self.audio_feat_stride + audio_embed_sizes = self._compute_audio_embed_size(feature_lengths) + + feature_attention_mask = ( + torch.arange(0, feature_lengths.max()) if is_torch_available() else np.arange(0, feature_lengths.max()) + ) + feature_attention_mask = ( + feature_attention_mask[None, :] < feature_lengths[:, None] if len(feature_lengths) > 1 else None + ) + + data = { + "audio_input_features": input_features, + "audio_embed_sizes": audio_embed_sizes, + } + if feature_attention_mask is not None and return_attention_mask: + data["audio_attention_mask"] = feature_attention_mask + + return BatchFeature(data=data, tensor_type=return_tensors) + + # TODO; @eustlb, move this to audio_utils in a general spectogram_batch function that handles torch and numpy + def _torch_extract_fbank_features( + self, waveform: "torch.FloatTensor", audio_lengths: "torch.Tensor", device: str = "cpu" + ) -> "torch.FloatTensor": + """ + Compute the log mel-scaled spectrogram of batched waveforms using PyTorch's FFT implementation. + + Args: + waveform (torch.FloatTensor` of shape `(batch_size, max_audio_length)`): + The batched waveforms. + audio_lengths (`torch.Tensor` of shape `(batch_size,)`): + The lengths of the waveforms along the max_audio_length dimension. + device (`str`, *optional*, defaults to "cpu"): + The device to run the computation on. (e.g., "cpu", "cuda") + + Returns: + `torch.FloatTensor` of shape `(batch_size, max_feature_length, feature_size)`: + The log mel-scaled spectrogram of the batched waveforms. + """ + fft_window = torch.hamming_window(self.win_length, periodic=False, device=device, dtype=torch.float64) + + # batched implementation + batch_size = waveform.shape[0] + frames = waveform.unfold(-1, self.win_length, self.hop_length) + + # --- + # the unbatched (and unpaded) original implementation skips last few audio values that can't be included in a frame + # we need to ensure that the corresponding frames for the padded input also mask these values + if batch_size > 1: + frames = frames.clone() + # concerned batch indices + to_mask_batch_idxs = torch.arange(batch_size)[audio_lengths != audio_lengths.max()] + if to_mask_batch_idxs.numel() > 0: + batch_idxs_down = (audio_lengths[to_mask_batch_idxs] - self.win_length) // self.hop_length + 1 + batch_idxs_up = audio_lengths[to_mask_batch_idxs] // self.hop_length + 1 + offset_idx = batch_idxs_down.min() + max_idx = batch_idxs_up.max() + + mask = torch.arange(max_idx - offset_idx, device=device).expand(to_mask_batch_idxs.shape[0], -1) + mask = ((batch_idxs_down - offset_idx).unsqueeze(1) <= mask) & ( + mask < (batch_idxs_up - offset_idx).unsqueeze(1) + ) + mask = mask.unsqueeze(-1).expand(-1, -1, self.win_length) + masked_frames = frames[to_mask_batch_idxs, offset_idx:max_idx].masked_fill_(mask, 0) + frames[to_mask_batch_idxs, offset_idx:max_idx] = masked_frames + # --- + + # apply pre-emphasis first order filter on fft windows + frames_prev = torch.roll(frames, 1, dims=-1) + frames_prev[:, :, 0] = frames_prev[:, :, 1] + frames = (frames - self.preemphasis * frames_prev) * 32768 + + # apply fft + S = torch.fft.rfft(fft_window * frames.view(-1, self.win_length), n=self.n_fft, dim=1) + S = S.view(frames.shape[0], -1, S.shape[-1]) + S = S.to(torch.complex64) + + spec = torch.abs(S) + spec_power = spec**2 + + # apply triangular mel filter bank + mel_filters = torch.from_numpy(self.mel_filters).to(device, torch.float32) + log_spec = torch.clamp(spec_power @ mel_filters, min=1.0) + log_spec = torch.log(log_spec) + + return log_spec + + def _compute_audio_embed_size(self, audio_frames): + integer = audio_frames // self.audio_compression_rate + remainder = audio_frames % self.audio_compression_rate + result = integer + (remainder > 0).to(integer.dtype) + + integer = result // self.audio_downsample_rate + remainder = result % self.audio_downsample_rate + result = integer + (remainder > 0).to(integer.dtype) # qformer compression + + return result + + +__all__ = ["Phi4MultimodalFeatureExtractor"] diff --git a/docs/transformers/src/transformers/models/phi4_multimodal/image_processing_phi4_multimodal_fast.py b/docs/transformers/src/transformers/models/phi4_multimodal/image_processing_phi4_multimodal_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..813f89dbeaa6d29877c15c620d475e76aeb5528c --- /dev/null +++ b/docs/transformers/src/transformers/models/phi4_multimodal/image_processing_phi4_multimodal_fast.py @@ -0,0 +1,263 @@ +# Copyright 2025 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Processor class for Phi4Multimodal +""" + +import math +from typing import List, Optional, Union + +import torch +from torchvision.transforms import functional as F + +from ...image_processing_utils_fast import ( + BaseImageProcessorFast, + BatchFeature, + DefaultFastImageProcessorKwargs, + Unpack, + convert_to_rgb, +) +from ...image_utils import ImageInput, make_flat_list_of_images, valid_images +from ...utils import TensorType, logging + + +logger = logging.get_logger(__name__) + + +class Phi4MultimodalFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): + image_size: Optional[int] + patch_size: Optional[int] + dynamic_hd: Optional[int] + + +class Phi4MultimodalImageProcessorFast(BaseImageProcessorFast): + r""" + Constructs a Phi4Multimodal image processor. + """ + + image_size = 448 + patch_size = 14 + dynamic_hd = 36 + image_mean = [0.5, 0.5, 0.5] + image_std = [0.5, 0.5, 0.5] + valid_init_kwargs = Phi4MultimodalFastImageProcessorKwargs + model_input_names = ["image_pixel_values", "image_sizes", "image_attention_mask"] + + def __init__(self, **kwargs: Unpack[Phi4MultimodalFastImageProcessorKwargs]): + super().__init__(**kwargs) + + def find_closest_aspect_ratio(self, aspect_ratio, target_ratios, width, height): + best_ratio_diff = float("inf") + best_ratio = (1, 1) + area = width * height + for ratio in target_ratios: + target_aspect_ratio = ratio[0] / ratio[1] + ratio_diff = abs(aspect_ratio - target_aspect_ratio) + if ratio_diff < best_ratio_diff: + best_ratio_diff = ratio_diff + best_ratio = ratio + elif ratio_diff == best_ratio_diff: + if area > 0.5 * self.image_size * self.image_size * ratio[0] * ratio[1]: + best_ratio = ratio + return best_ratio + + def dynamic_preprocess(self, image, max_num=36, min_num=1): + image_size = self.image_size + patch_size = self.patch_size + mask_size = image_size // patch_size + orig_width, orig_height = image.size + + w_crop_num = math.ceil(orig_width / float(image_size)) + h_crop_num = math.ceil(orig_height / float(image_size)) + if w_crop_num * h_crop_num > max_num: + aspect_ratio = orig_width / orig_height + + # calculate the existing image aspect ratio + target_ratios = { + (i, j) + for n in range(min_num, max_num + 1) + for i in range(1, n + 1) + for j in range(1, n + 1) + if i * j <= max_num and i * j >= min_num + } + target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) + + # find the closest aspect ratio to the target + target_aspect_ratio = self.find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_width, orig_height) + + # calculate the target width and height + target_width = image_size * target_aspect_ratio[0] + target_height = image_size * target_aspect_ratio[1] + else: + target_width = image_size * w_crop_num + target_height = image_size * h_crop_num + target_aspect_ratio = (w_crop_num, h_crop_num) + + # Calculate the ratio + ratio_width = target_width / orig_width + ratio_height = target_height / orig_height + if ratio_width < ratio_height: + new_size = (target_width, int(orig_height * ratio_width)) + padding_width = 0 + padding_height = target_height - int(orig_height * ratio_width) + else: + new_size = (int(orig_width * ratio_height), target_height) + padding_width = target_width - int(orig_width * ratio_height) + padding_height = 0 + + attention_mask = torch.ones((int(mask_size * target_aspect_ratio[1]), int(mask_size * target_aspect_ratio[0]))) + if padding_width >= patch_size: + attention_mask[:, -math.floor(padding_width / patch_size) :] = 0 + if padding_height >= patch_size: + attention_mask[-math.floor(padding_height / patch_size) :, :] = 0 + + if min(new_size[1], target_height) < 10 or min(new_size[0], target_width) < 10: + raise ValueError(f"the aspect ratio is very extreme {new_size}") + + image = F.resize(image, [new_size[1], new_size[0]]) + resized_img = F.pad(image, [0, 0, padding_width, padding_height], fill=[255, 255, 255]) + + return resized_img, attention_mask + + def pad_to_max_num_crops(self, images, max_crops=5): + """ + images: B x 3 x H x W, B<=max_crops + """ + B, _, H, W = images.shape + if B < max_crops: + pad = torch.zeros(max_crops - B, 3, H, W, dtype=images.dtype, device=images.device) + images = torch.cat([images, pad], dim=0) + return images + + def pad_mask_to_max_num_crops(self, masks, max_crops=5): + B, H, W = masks.shape + if B < max_crops: + pad = torch.ones(max_crops - B, H, W, dtype=masks.dtype, device=masks.device) + masks = torch.cat([masks, pad], dim=0) + return masks + + def preprocess( + self, + images: ImageInput, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + ): + """ + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + """ + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + + images = make_flat_list_of_images(images) + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + images = [convert_to_rgb(image) for image in images] + + image_size = self.image_size + patch_size = self.patch_size + mask_size = image_size // patch_size + imgs_and_masks = [self.dynamic_preprocess(image, max_num=self.dynamic_hd) for image in images] + images, image_attention_masks = [x[0] for x in imgs_and_masks], [x[1] for x in imgs_and_masks] + + images = [F.to_tensor(image) for image in images] + hd_images = [F.normalize(image, image_mean, image_std) for image in images] + global_image = [ + torch.nn.functional.interpolate( + image.unsqueeze(0).float(), + size=(image_size, image_size), + mode="bicubic", + ).to(image.dtype) + for image in hd_images + ] + + shapes = [[image.size(1), image.size(2)] for image in hd_images] + mask_shapes = [[mask.size(0), mask.size(1)] for mask in image_attention_masks] + global_attention_mask = [torch.ones((1, mask_size, mask_size)) for _ in hd_images] + + hd_images_reshape = [] + for im, (h, w) in zip(hd_images, shapes): + im = im.reshape(1, 3, h // image_size, image_size, w // image_size, image_size) + im = im.permute(0, 2, 4, 1, 3, 5) + im = im.reshape(-1, 3, image_size, image_size) + hd_images_reshape.append(im.contiguous()) + + attention_masks_reshape = [] + for mask, (h, w) in zip(image_attention_masks, mask_shapes): + mask = mask.reshape(h // mask_size, mask_size, w // mask_size, mask_size) + mask = mask.transpose(1, 2) + mask = mask.reshape(-1, mask_size, mask_size) + attention_masks_reshape.append(mask.contiguous()) + + downsample_attention_masks = [] + for mask, (h, w) in zip(attention_masks_reshape, mask_shapes): + mask = mask[:, 0::2, 0::2] + mask = mask.reshape( + h // mask_size, w // mask_size, mask_size // 2 + mask_size % 2, mask_size // 2 + mask_size % 2 + ) + mask = mask.transpose(1, 2) + mask = mask.reshape(mask.size(0) * mask.size(1), mask.size(2) * mask.size(3)) + downsample_attention_masks.append(mask) + + num_img_tokens = [ + 256 + 1 + int(mask.sum().item()) + int(mask[:, 0].sum().item()) + 16 for mask in downsample_attention_masks + ] + + hd_images_reshape = [ + torch.cat([_global_image] + [_im], dim=0) for _global_image, _im in zip(global_image, hd_images_reshape) + ] + hd_masks_reshape = [ + torch.cat([_global_mask] + [_mask], dim=0) + for _global_mask, _mask in zip(global_attention_mask, attention_masks_reshape) + ] + max_crops = max([img.size(0) for img in hd_images_reshape]) + image_transformed = [self.pad_to_max_num_crops(im, max_crops) for im in hd_images_reshape] + image_transformed = torch.stack(image_transformed, dim=0) + mask_transformed = [self.pad_mask_to_max_num_crops(mask, max_crops) for mask in hd_masks_reshape] + mask_transformed = torch.stack(mask_transformed, dim=0) + + returned_input_image_embeds = image_transformed + returned_image_sizes = torch.tensor(shapes, dtype=torch.long) + returned_image_attention_mask = mask_transformed + returned_num_img_tokens = num_img_tokens + + data = { + "image_pixel_values": returned_input_image_embeds, + "image_sizes": returned_image_sizes, + "image_attention_mask": returned_image_attention_mask, + "num_img_tokens": returned_num_img_tokens, + } + + return BatchFeature(data=data, tensor_type=return_tensors) + + +__all__ = ["Phi4MultimodalImageProcessorFast"] diff --git a/docs/transformers/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py b/docs/transformers/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py new file mode 100644 index 0000000000000000000000000000000000000000..3b9979edf31c1321f5f942f83a26ecbe31706921 --- /dev/null +++ b/docs/transformers/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py @@ -0,0 +1,2247 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_phi4_multimodal.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2025 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import warnings +from typing import Callable, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn.init import _calculate_fan_in_and_fan_out + +from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPast, + BaseModelOutputWithPooling, + CausalLMOutputWithPast, +) +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + can_return_tuple, + is_torch_flex_attn_available, + logging, + replace_return_docstrings, + torch_int, +) +from .configuration_phi4_multimodal import Phi4MultimodalAudioConfig, Phi4MultimodalConfig, Phi4MultimodalVisionConfig + + +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + +logger = logging.get_logger(__name__) + + +class Phi4MultimodalVisionMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +def simple_eager_attention_forward( + module: nn.Module, + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class Phi4MultimodalVisionAttention(nn.Module): + def __init__(self, config: Phi4MultimodalVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + self.scaling = self.head_dim**-0.5 + self.is_causal = True + self.attention_dropout = config.attention_dropout + + self.k_proj = nn.Linear(config.hidden_size, config.hidden_size) + self.v_proj = nn.Linear(config.hidden_size, config.hidden_size) + self.q_proj = nn.Linear(config.hidden_size, config.hidden_size) + self.out_proj = nn.Linear(config.hidden_size, config.hidden_size) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Input shape: Batch x Time x Channel""" + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + attention_interface: Callable = simple_eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1) + attn_output = self.out_proj(attn_output) + return attn_output, attn_weights + + +class Phi4MultimodalVisionEncoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Phi4MultimodalVisionConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.self_attn = Phi4MultimodalVisionAttention(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = Phi4MultimodalVisionMLP(config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor]: + """ + Args: + hidden_states (`torch.FloatTensor`): + Input to the layer of shape `(batch, seq_len, embed_dim)`. + attention_mask (`torch.FloatTensor`): + Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class Phi4MultimodalVisionEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`Phi4MultimodalVisionEncoderLayer`]. + + Args: + config: Phi4MultimodalVisionConfig + """ + + def __init__(self, config: Phi4MultimodalVisionConfig): + super().__init__() + self.config = config + self.layers = nn.ModuleList( + [Phi4MultimodalVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)] + ) + self.gradient_checkpointing = False + + # Ignore copy + @can_return_tuple + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> BaseModelOutput: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for encoder_layer in self.layers: + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=encoder_states, + attentions=all_attentions, + ) + + +def _trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2, + ) + + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + + +def trunc_normal_tf_( + tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0 +) -> torch.Tensor: + """Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \\leq \text{mean} \\leq b`. + + NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the + bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0 + and the result is subsequently scaled and shifted by the mean and std args. + + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + """ + with torch.no_grad(): + _trunc_normal_(tensor, 0, 1.0, a, b) + tensor.mul_(std).add_(mean) + + +def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + if mode == "fan_in": + denom = fan_in + elif mode == "fan_out": + denom = fan_out + elif mode == "fan_avg": + denom = (fan_in + fan_out) / 2 + + variance = scale / denom + + if distribution == "truncated_normal": + # constant is stddev of standard normal truncated to (-2, 2) + trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978) + elif distribution == "normal": + with torch.no_grad(): + tensor.normal_(std=math.sqrt(variance)) + elif distribution == "uniform": + bound = math.sqrt(3 * variance) + with torch.no_grad(): + tensor.uniform_(-bound, bound) + else: + raise ValueError(f"invalid distribution {distribution}") + + +def lecun_normal_(tensor): + variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal") + + +def default_flax_embed_init(tensor): + variance_scaling_(tensor, mode="fan_in", distribution="normal") + + +class Phi4MultimodalVisionPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = Phi4MultimodalVisionConfig + base_model_prefix = "phi4_vision" + supports_gradient_checkpointing = True + + _no_split_modules = ["Phi4MultimodalVisionEncoderLayer"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, Phi4MultimodalVisionEmbeddings): + width = ( + self.config.hidden_size + if isinstance(self.config, Phi4MultimodalVisionConfig) + else self.config.hidden_size + ) + nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width)) + elif isinstance(module, nn.Embedding): + default_flax_embed_init(module.weight) + elif isinstance(module, Phi4MultimodalVisionAttention): + nn.init.normal_(module.q_proj.weight) + nn.init.normal_(module.k_proj.weight) + nn.init.normal_(module.v_proj.weight) + nn.init.normal_(module.out_proj.weight) + nn.init.zeros_(module.q_proj.bias) + nn.init.zeros_(module.k_proj.bias) + nn.init.zeros_(module.v_proj.bias) + nn.init.zeros_(module.out_proj.bias) + elif isinstance(module, Phi4MultimodalVisionMLP): + nn.init.normal_(module.fc1.weight) + nn.init.normal_(module.fc2.weight) + nn.init.normal_(module.fc1.bias, std=1e-6) + nn.init.normal_(module.fc2.bias, std=1e-6) + elif isinstance(module, Phi4MultimodalVisionMultiheadAttentionPoolingHead): + nn.init.normal_(module.probe.data) + nn.init.normal_(module.attention.in_proj_weight.data) + nn.init.zeros_(module.attention.in_proj_bias.data) + elif isinstance(module, (nn.Linear, nn.Conv2d)): + lecun_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +class Phi4MultimodalVisionEmbeddings(nn.Module): + def __init__(self, config: Phi4MultimodalVisionConfig): + super().__init__() + self.config = config + self.patch_size = config.patch_size + self.num_patches_per_side = config.image_size // self.patch_size + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=config.hidden_size, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + ) + self.position_embedding = nn.Embedding(self.num_patches_per_side**2, config.hidden_size) + + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing and no class embeddings. + + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 + """ + + num_patches = embeddings.shape[1] + num_positions = self.position_embedding.weight.shape[0] + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: + return self.position_embedding(self.position_ids) + + patch_pos_embed = self.position_embedding.weight.unsqueeze(0) + + dim = embeddings.shape[-1] + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(new_height, new_width), + mode="bicubic", + align_corners=False, + ) + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return patch_pos_embed + + def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor) -> torch.Tensor: + batch_size = pixel_values.size(0) + + patch_embeds = self.patch_embedding(pixel_values) + embeddings = patch_embeds.flatten(2).transpose(1, 2) + + max_im_h, max_im_w = pixel_values.size(2), pixel_values.size(3) + max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size + boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side) + position_ids = torch.full((batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0) + + for batch_idx, p_attn_mask in enumerate(patch_attention_mask): + nb_patches_h = p_attn_mask[:, 0].sum() + nb_patches_w = p_attn_mask[0].sum() + + fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h) + fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w) + + bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True) + bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True) + + pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w).flatten() + position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids + + position_ids = position_ids.to(self.position_embedding.weight.device) + + embeddings = embeddings + self.position_embedding(position_ids) + return embeddings + + +class Phi4MultimodalVisionMultiheadAttentionPoolingHead(nn.Module): + """Multihead Attention Pooling.""" + + def __init__(self, config: Phi4MultimodalVisionConfig): + super().__init__() + + self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size)) + self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True) + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.mlp = Phi4MultimodalVisionMLP(config) + + def forward(self, hidden_state, attention_mask): + batch_size = hidden_state.shape[0] + probe = self.probe.repeat(batch_size, 1, 1) + + hidden_state = self.attention( + query=probe, key=hidden_state, value=hidden_state, key_padding_mask=~attention_mask + )[0] + + residual = hidden_state + hidden_state = self.layernorm(hidden_state) + hidden_state = residual + self.mlp(hidden_state) + + return hidden_state[:, 0] + + +class Phi4MultimodalVisionModel(Phi4MultimodalVisionPreTrainedModel): + config_class = Phi4MultimodalVisionConfig + main_input_name = "pixel_values" + + def __init__(self, config: Phi4MultimodalVisionConfig): + super().__init__(config) + self.config = config + + self.embeddings = Phi4MultimodalVisionEmbeddings(config) + self.encoder = Phi4MultimodalVisionEncoder(config) + self.post_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.head = Phi4MultimodalVisionMultiheadAttentionPoolingHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.embeddings.patch_embedding + + def forward( + self, + pixel_values, + patch_attention_mask: Optional[torch.BoolTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> BaseModelOutputWithPooling: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + batch_size = pixel_values.size(0) + if patch_attention_mask is None: + patch_attention_mask = torch.ones( + size=( + batch_size, + pixel_values.size(2) // self.config.patch_size, + pixel_values.size(3) // self.config.patch_size, + ), + dtype=torch.bool, + device=pixel_values.device, + ) + + hidden_states = self.embeddings(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask) + + patch_attention_mask = patch_attention_mask.view(batch_size, -1) + # The call to `_upad_input` in `_flash_attention_forward` is expensive + # So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence), + # avoiding passing the attention_mask, which is equivalent to attending to the full sequence + if not torch.any(~patch_attention_mask): + attention_mask = None + else: + attention_mask = ( + _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype) + if not self.config._attn_implementation == "flash_attention_2" + else patch_attention_mask + ) + + encoder_outputs: BaseModelOutput = self.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + last_hidden_state = encoder_outputs.last_hidden_state + last_hidden_state = self.post_layernorm(last_hidden_state) + + pooled_output = self.head( + hidden_state=last_hidden_state, + attention_mask=patch_attention_mask, + ) + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class Phi4MultimodalImageEmbedding(nn.Module): + """Image embedding.""" + + def __init__(self, config: Phi4MultimodalConfig): + super().__init__() + self.config = config + self.layer_idx = config.vision_config.feature_layer + self.crop_size = config.vision_config.crop_size + self.image_dim_out = config.vision_config.hidden_size + + n_patches = config.vision_config.image_size // config.vision_config.patch_size + if n_patches % 2 != 0: + self.img_processor_padding = nn.ReflectionPad2d((0, 1, 0, 1)) + n_patches += 1 + self.num_img_tokens = (n_patches // 2) ** 2 + + self.drop = nn.Dropout(config.embd_pdrop) + self.img_processor = Phi4MultimodalVisionModel._from_config(config.vision_config) + self.image_token_compression = nn.AvgPool2d(kernel_size=2, stride=2) + self.img_projection_up = nn.Linear(self.image_dim_out, config.hidden_size) + self.img_projection_down = nn.Linear(config.hidden_size, config.hidden_size) + self.global_img_feature_extensor = nn.Parameter(torch.zeros([1, 1, self.image_dim_out])) + self.sub_img_feature_extensor = nn.Parameter(torch.zeros([1, 1, 1, self.image_dim_out])) + + def get_img_features(self, img_embeds: torch.FloatTensor, attention_mask=None) -> torch.FloatTensor: + img_processor_output = self.img_processor( + img_embeds, patch_attention_mask=attention_mask, output_hidden_states=True + ) + img_feature = img_processor_output.hidden_states[self.layer_idx] + + patch_feature = img_feature + # reshape to 2D tensor + width = int(math.sqrt(patch_feature.size(1))) + patch_feature = patch_feature.view(-1, width, width, patch_feature.size(-1)) + # convert to NCHW + patch_feature = patch_feature.permute(0, 3, 1, 2) + if getattr(self, "img_processor_padding", None) is not None: + patch_feature = self.img_processor_padding(patch_feature) + patch_feature = self.image_token_compression(patch_feature) + # convert to NHWC + patch_feature = patch_feature.permute(0, 2, 3, 1) + patch_feature = patch_feature.view(-1, patch_feature.size(1) * patch_feature.size(2), patch_feature.size(-1)) + return patch_feature + + def forward( + self, + input_ids: torch.LongTensor, + inputs_embeds: torch.Tensor, + image_pixel_values: torch.FloatTensor, + image_sizes: Optional[torch.Tensor] = None, + image_attention_mask: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + image_pixel_values = image_pixel_values.to(self.img_processor.embeddings.patch_embedding.weight.dtype) + + target_device = self.img_projection_up.bias.device + target_dtype = self.img_projection_up.bias.dtype + + batch_size = image_pixel_values.shape[0] + + img_features = self.get_img_features( + image_pixel_values.flatten(0, 1), + attention_mask=image_attention_mask.flatten(0, 1).to(dtype=bool, device=target_device), + ) + base_feat_size = int(np.sqrt(img_features.shape[1])) + img_features = img_features.view(batch_size, -1, base_feat_size**2, self.image_dim_out) + image_sizes = image_sizes.view(-1, 2) + + output_imgs = [] + for idx in range(batch_size): + height, width = image_sizes[idx] + height_ratio = height // self.crop_size + width_ratio = width // self.crop_size + area_ratio = height_ratio * width_ratio + + global_img = img_features[idx, :1] + global_img = global_img.reshape(1, base_feat_size, base_feat_size, self.image_dim_out).contiguous() + temporary_extensor = self.sub_img_feature_extensor.repeat(1, base_feat_size, 1, 1) + global_img = torch.cat([global_img, temporary_extensor], dim=2).reshape(1, -1, self.image_dim_out) + + sub_img = img_features[idx, 1:] + sub_img = sub_img[:area_ratio] + sub_img = ( + sub_img.reshape(height_ratio, width_ratio, base_feat_size, base_feat_size, self.image_dim_out) + .transpose(1, 2) + .reshape(1, height_ratio * base_feat_size, width_ratio * base_feat_size, self.image_dim_out) + .contiguous() + ) + + if image_attention_mask is not None: + reshaped_image_attention_mask = ( + image_attention_mask[idx, 1 : area_ratio + 1, 0::2, 0::2] + .reshape(height_ratio, width_ratio, base_feat_size, base_feat_size) + .transpose(1, 2) + .reshape(1, height_ratio * base_feat_size, width_ratio * base_feat_size) + ) + useful_height = int(reshaped_image_attention_mask[0, :, 0].sum().item()) + useful_width = int(reshaped_image_attention_mask[0, 0, :].sum().item()) + sub_img = sub_img[:, :useful_height, :useful_width] + temporary_extensor = self.sub_img_feature_extensor.repeat(1, useful_height, 1, 1) + else: + temporary_extensor = self.sub_img_feature_extensor.repeat(1, height_ratio * base_feat_size, 1, 1) + + sub_img = torch.cat([sub_img, temporary_extensor], dim=2).reshape(1, -1, self.image_dim_out) + + # Merge global and sub + output_imgs.append(torch.cat([sub_img, self.global_img_feature_extensor, global_img], dim=1)) + + img_set_tensor = [] + for output_img in output_imgs: + output_img = output_img.to(device=target_device, dtype=target_dtype) + img_feature_proj = self.img_projection_up(output_img) + img_feature_proj = nn.functional.gelu(img_feature_proj) + img_feature_proj = self.img_projection_down(img_feature_proj) + img_set_tensor.append(img_feature_proj) + + merged_img_set_tensor = torch.cat(img_set_tensor, dim=1).squeeze(0) + merged_img_set_tensor = merged_img_set_tensor.to(dtype=inputs_embeds.dtype, device=inputs_embeds.device) + + with torch.no_grad(): + positions_tuple = torch.nonzero(input_ids == self.config.vision_config.image_token_id, as_tuple=True) + + # Temporarily disable autocast to avoid issue on bf16 tensors + # Ref: https://github.com/pytorch/pytorch/issues/132715 + with torch.autocast(device_type=inputs_embeds.device.type, enabled=False): + image_embeds = inputs_embeds.index_put( + indices=positions_tuple, values=merged_img_set_tensor, accumulate=False + ) + + image_embeds = self.drop(image_embeds) + + return image_embeds + + +########################################################## AUDIO ############################################# + + +class Phi4MultimodalAudioMLP(nn.Module): + def __init__(self, config: Phi4MultimodalAudioConfig): + super().__init__() + self.layer_norm = nn.LayerNorm(config.hidden_size) + self.act_fn = ACT2FN[config.activation] + self.gate_up_proj = nn.Linear(config.hidden_size, config.intermediate_size * 2) + self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward(self, hidden_states): + hidden_states = self.layer_norm(hidden_states) + up_states = self.gate_up_proj(hidden_states) + up_states, gate = up_states.chunk(2, dim=-1) + up_states = up_states * self.act_fn(gate) + up_states = self.dropout(up_states) + hidden_states = self.down_proj(up_states) + out = self.dropout(hidden_states) + + return out + + +class Phi4MultimodalAudioAttention(nn.Module): + def __init__(self, config: Phi4MultimodalAudioConfig): + super().__init__() + self.config = config + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.dropout_rate + self.is_causal = True + + self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=True) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + **kwargs, + ): + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + attention_interface: Callable = simple_eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, _ = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output + + +class Phi4MultimodalAudioDepthWiseSeperableConv1d(nn.Module): + def __init__(self, config: Phi4MultimodalAudioConfig, padding: int = 0): + super().__init__() + self.dw_conv = nn.Conv1d( + config.hidden_size, + config.hidden_size * config.depthwise_multiplier, + config.kernel_size, + 1, + padding=padding, + groups=config.hidden_size, + ) + self.pw_conv = nn.Conv1d( + config.hidden_size * config.depthwise_multiplier, config.depthwise_seperable_out_channel, 1, 1, 0 + ) + + def forward(self, hidden_states): + return self.pw_conv(self.dw_conv(hidden_states)) + + +class Phi4MultimodalAudioGluPointWiseConv(nn.Module): + def __init__(self, config: Phi4MultimodalAudioConfig): + super().__init__() + self.config = config + self.output_dim = config.ext_pw_out_channel + + self.ext_pw_conv_1d = nn.Conv1d(config.hidden_size, config.ext_pw_out_channel * 2, kernel_size=1, stride=1) + self.glu_act = ACT2FN[config.conv_glu_type] + self.b1 = nn.Parameter(torch.zeros(1, config.ext_pw_out_channel, 1)) + self.b2 = nn.Parameter(torch.zeros(1, config.ext_pw_out_channel, 1)) + + def forward(self, hidden_states): + # we assume the input always has the #channel (#dim) in the last dimension of the + # tensor, so need to switch the dimension first for 1D-Conv case + hidden_states = hidden_states.permute([0, 2, 1]) + hidden_states = self.ext_pw_conv_1d(hidden_states) + out = hidden_states[:, 0 : self.output_dim, :] + self.b1 + out = out * self.glu_act(hidden_states[:, self.output_dim : self.output_dim * 2, :] + self.b2) + return out.permute([0, 2, 1]) + + +class Phi4MultimodalAudioConvModule(nn.Module): + def __init__(self, config: Phi4MultimodalAudioConfig): + super().__init__() + self.config = config + self.kernel_size = config.kernel_size + + self.layer_norm = nn.LayerNorm(config.hidden_size) + self.glu = Phi4MultimodalAudioGluPointWiseConv(config) + self.dw_sep_conv_1d = Phi4MultimodalAudioDepthWiseSeperableConv1d(config, padding=config.kernel_size - 1) + self.act = ACT2FN[config.conv_activation] + self.ext_pw_conv_1d = nn.Conv1d(config.hidden_size, config.ext_pw_out_channel, kernel_size=1, stride=1) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward(self, hidden_states: torch.Tensor): + hidden_states = self.glu(self.layer_norm(hidden_states)) + hidden_states = self.dw_sep_conv_1d(hidden_states.permute([0, 2, 1])) + + if self.kernel_size > 1: + hidden_states = hidden_states[:, :, : -(self.kernel_size - 1)] + + hidden_states = self.act(hidden_states) + hidden_states = self.ext_pw_conv_1d(hidden_states) + out = self.dropout(hidden_states.permute([0, 2, 1])) + return out + + +class Phi4MultimodalAudioConformerEncoderLayer(nn.Module): + def __init__(self, config: Phi4MultimodalAudioConfig): + super().__init__() + + self.feed_forward_in = Phi4MultimodalAudioMLP(config) + self.self_attn = Phi4MultimodalAudioAttention(config) + self.conv = Phi4MultimodalAudioConvModule(config) + self.feed_forward_out = Phi4MultimodalAudioMLP(config) + self.layer_norm_att = nn.LayerNorm(config.hidden_size) + self.layer_norm = nn.LayerNorm(config.hidden_size) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + ): + residual = hidden_states + 0.5 * self.feed_forward_in(hidden_states) + hidden_states = self.layer_norm_att(residual) + + hidden_states = residual + self.self_attn(hidden_states, attention_mask) + hidden_states = hidden_states + self.conv(hidden_states) + hidden_states = hidden_states + 0.5 * self.feed_forward_out(hidden_states) + + out = self.layer_norm(hidden_states) + + return out + + +class Phi4MultimodalAudioNemoConvSubsampling(torch.nn.Module): + def __init__(self, config: Phi4MultimodalAudioConfig): + super().__init__() + self.subsampling_factor = config.time_reduction + self.sampling_num = int(math.log(self.subsampling_factor, 2)) + self.act_fn = ACT2FN[config.nemo_activation] + conv_channels = config.nemo_conv_channels + + layers = [ + nn.Conv2d(1, conv_channels, kernel_size=3, stride=2, padding=1), + self.act_fn, + ] + for _ in range(self.sampling_num - 1): + layers.extend( + [ + nn.Conv2d(conv_channels, conv_channels, kernel_size=3, stride=2, padding=1, groups=conv_channels), + nn.Conv2d(conv_channels, conv_channels, kernel_size=1, stride=1, padding=0, groups=1), + self.act_fn, + ] + ) + + # Aggregate the layers + self.conv = torch.nn.Sequential(*layers) + self.out = torch.nn.Linear(conv_channels * config.nemo_final_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor, mask: Optional[torch.Tensor]): + # Unsqueeze Channel Axis + hidden_states = hidden_states.unsqueeze(1) + hidden_states = self.conv(hidden_states) + + # Flatten Channel and Frequency Axes + b, _, t, _ = hidden_states.size() + hidden_states = self.out(hidden_states.transpose(1, 2).reshape(b, t, -1)) + + if mask is None: + return hidden_states, None + + max_audio_length = hidden_states.shape[1] + feature_lens = mask.sum(1) + padding_length = torch.ceil(feature_lens / self.subsampling_factor) + arange_ = torch.arange(0, max_audio_length, device=hidden_states.device) + pad_mask = arange_.expand(padding_length.size(0), -1) < padding_length.unsqueeze(1) + return hidden_states, pad_mask.unsqueeze(1) + + +class Phi4MultimodalAudioRelativeAttentionBias(nn.Module): + def __init__(self, config: Phi4MultimodalAudioConfig): + super().__init__() + + self.max_distance = config.bias_max_distance + self.symmetric = config.bias_symmetric + self.num_buckets = self.max_distance + if not config.bias_symmetric: + self.num_buckets *= 2 + self.bias_values = nn.Embedding(self.num_buckets, config.num_attention_heads) + + def forward(self, x): + # instantiate bias compatible with shape of x + max_pos = x.size(1) + context_position = torch.arange(max_pos, device=x.device, dtype=torch.long)[:, None] + memory_position = torch.arange(max_pos, device=x.device, dtype=torch.long)[None, :] + relative_position = memory_position - context_position + # clipping to a maximum distance using ops that play well with ONNX export + relative_position = relative_position.masked_fill(relative_position < -self.max_distance, -self.max_distance) + relative_position = relative_position.masked_fill( + relative_position > self.max_distance - 1, self.max_distance - 1 + ) + + # mapping from relative position to index in the bias parameter + bias_idx = relative_position + bias_idx = bias_idx.abs() if self.symmetric else bias_idx + self.num_buckets // 2 + + att_bias = self.bias_values(bias_idx) + att_bias = att_bias.permute(2, 0, 1).unsqueeze(0) + + return att_bias + + +class Phi4MultimodalAudioMeanVarianceNormLayer(nn.Module): + def __init__(self, config: Phi4MultimodalAudioConfig): + super().__init__() + self.register_buffer("global_mean", torch.zeros(config.input_size)) + self.register_buffer("global_invstd", torch.ones(config.input_size)) + + def forward(self, x): + return (x - self.global_mean) * self.global_invstd + + +class Phi4MultimodalAudioPreTrainedModel(PreTrainedModel): + config_class = Phi4MultimodalAudioConfig + supports_gradient_checkpointing = True + _no_split_modules = ["Phi4MultimodalAudioConformerEncoderLayer"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, Phi4MultimodalAudioGluPointWiseConv): + module.b1.data.zero_() + module.b2.data.zero_() + + +def unfold_tensor(tensor, max_seq_len): + """ + For a given tensor with shape of (N, T, D), if sequence length T is longer than max_seq_len, + this function unfold it to a (NT', max_seq_len, D) where T' is T // max_seq_len. + Args: + tensor: N, T, D + """ + _, _, D = tensor.shape + tensor = tensor.transpose(-1, -2) + # N x D x 1 x T => N x (D x max_seq_len) x T' + tensor = F.unfold(tensor[..., None, :], kernel_size=(1, max_seq_len), stride=(1, max_seq_len)) + + new_bsz, _, slen = tensor.shape + tensor = tensor.view(new_bsz, -1, max_seq_len, slen) + tensor = tensor.permute(0, 3, 2, 1) + tensor = tensor.view(-1, max_seq_len, D).contiguous() + return tensor + + +def adaptive_enc_mask(x_len, chunk_start_idx, left_window=0, right_window=0): + """ + The function is very important for Transformer Transducer Streaming mode + Args: + xs_len (int): sequence length + chunk_start_idx (list): first idx of each chunk, such as [0,18,36,48]. It also supports adaptive chunk size [0,10,15,45] + left_window (int): how many left chunks can be seen + right_window (int): how many right chunks can be seen. It is used for chunk overlap model. + Returns: + mask (torch.Tensor): a mask tensor for streaming model + """ + chunk_start_idx = torch.Tensor(chunk_start_idx).long() + start_pad = torch.nn.functional.pad( + chunk_start_idx, (1, 0) + ) # append 0 to the beginning, so it becomes [0, 0, 18, 36, 48] + end_pad = torch.nn.functional.pad( + chunk_start_idx, (0, 1), value=x_len + ) # append x_len to the end, so it becomes [0,18,36,48, x_len] + seq_range = torch.arange(0, x_len).unsqueeze(-1) + idx = ((seq_range < end_pad) & (seq_range >= start_pad)).nonzero()[:, 1] + seq_range_expand = torch.arange(0, x_len).unsqueeze(0).expand(x_len, -1) + idx_left = idx - left_window + idx_left[idx_left < 0] = 0 + boundary_left = start_pad[idx_left] + mask_left = seq_range_expand >= boundary_left.unsqueeze(-1) + idx_right = idx + right_window + idx_right[idx_right > len(chunk_start_idx)] = len(chunk_start_idx) + boundary_right = end_pad[idx_right] + mask_right = seq_range_expand < boundary_right.unsqueeze(-1) + return mask_left & mask_right + + +class Phi4MultimodalAudioModel(Phi4MultimodalAudioPreTrainedModel): + def __init__(self, config: Phi4MultimodalAudioConfig): + super().__init__(config) + self.config = config + + self.encoder_embedding = Phi4MultimodalAudioMeanVarianceNormLayer(config) + self.embed = Phi4MultimodalAudioNemoConvSubsampling(config) + self.relative_attention_bias_layer = Phi4MultimodalAudioRelativeAttentionBias(config) + self.encoders = nn.ModuleList( + [Phi4MultimodalAudioConformerEncoderLayer(config) for _ in range(config.num_blocks)] + ) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def _streaming_mask(self, seq_len, batch_size, chunk_size, left_chunk): + # Create mask matrix for streaming + # S stores start index. if chunksize is 18, s is [0,18,36,....] + chunk_start_idx = np.arange(0, seq_len, chunk_size) + # avoid randomness when run evaluation or decoding + if self.training and np.random.rand() > 0.5: + # Either first or last chunk is not complete. + # If only the last one is not complete, EOS is not effective + chunk_start_idx = seq_len - chunk_start_idx + chunk_start_idx = chunk_start_idx[::-1] + chunk_start_idx = chunk_start_idx[:-1] + chunk_start_idx = np.insert(chunk_start_idx, 0, 0) + + enc_streaming_mask = ( + adaptive_enc_mask(seq_len, chunk_start_idx, left_window=left_chunk) + .unsqueeze(0) + .expand([batch_size, -1, -1]) + ) + return enc_streaming_mask + + def forward_embeddings(self, hidden_states, masks): + """Forwarding the inputs through the top embedding layers""" + seq_len = math.ceil(hidden_states.shape[1] / self.config.time_reduction) + if seq_len <= 0: + raise ValueError( + f"The squence length after time reduction is invalid: {seq_len}. Your input feature is too short." + ) + + batch_size = hidden_states.shape[0] + + enc_streaming_mask = self._streaming_mask(seq_len, batch_size, self.config.chunk_size, self.config.left_chunk) + enc_streaming_mask = enc_streaming_mask.to(hidden_states.device) + + hidden_states, masks = self.embed(hidden_states, masks) + + streaming_mask = enc_streaming_mask + if streaming_mask is not None and masks is not None: + hs_mask = masks & streaming_mask + elif masks is not None: + hs_mask = masks + else: + hs_mask = streaming_mask + + return hidden_states, hs_mask, masks + + def calculate_hs_mask(self, hidden_states, device, mask): + max_audio_length = hidden_states.shape[1] + batch_size = hidden_states.shape[0] + enc_streaming_mask = self._streaming_mask( + max_audio_length, batch_size, self.config.chunk_size, self.config.left_chunk + ) + enc_streaming_mask = enc_streaming_mask.to(device) + if mask is None: + return enc_streaming_mask + + feature_lens = mask.sum(1) + padding_length = feature_lens + pad_mask = torch.arange(0, max_audio_length, device=device).expand( + padding_length.size(0), -1 + ) < padding_length.unsqueeze(1) + pad_mask = pad_mask.unsqueeze(1) + pad_mask = pad_mask & enc_streaming_mask + return pad_mask + + def forward(self, hidden_states: torch.Tensor, mask: Optional[torch.Tensor]): + hidden_states = self.encoder_embedding(hidden_states) + hidden_states, hs_mask, mask = self.forward_embeddings(hidden_states, mask) + + unfolded = False + bs, seq_len, _ = hidden_states.shape + max_seq_len = 500 # maxium position for absolute positional encoding + if seq_len > max_seq_len: + # audio sequence is longer than max_seq_len, unfold it into chunks of max_seq_len + unfolded = True + # the unfold op will drop residual frames, pad it to the multiple of max_seq_len + if seq_len % max_seq_len > 0: + chunk_pad_size = max_seq_len - (seq_len % max_seq_len) + else: + chunk_pad_size = 0 + if chunk_pad_size > 0: + hidden_states_pad = F.pad(hidden_states, (0, 0, 0, chunk_pad_size), "constant", 0) + hidden_states = hidden_states_pad.to(hidden_states.device) + + hidden_states = unfold_tensor(hidden_states, max_seq_len) + masks_unfold = None + if mask is not None: + # revise hs_mask here because the previous calculated hs_mask did not consider extra pad + subsampled_pad_mask = mask.squeeze(1) # [bz, subsampled_unmask_seq_len] + extra_padded_subsamlped_pad_mask = F.pad( + subsampled_pad_mask, (0, chunk_pad_size), "constant", False + ) # extra padding to the pad mask + extra_padded_subsamlped_pad_mask = extra_padded_subsamlped_pad_mask.unsqueeze(-1).float() + masks_unfold = unfold_tensor( + extra_padded_subsamlped_pad_mask, max_seq_len + ) # unfold the pad mask like we did to the input tensor + masks_unfold = masks_unfold.squeeze(-1).bool() # unfold op does not support bool tensor + hs_mask = self.calculate_hs_mask( + hidden_states, hidden_states.device, masks_unfold + ) # calculate hs_mask based on the unfolded pad mask + + relative_attention_bias = self.relative_attention_bias_layer(hidden_states) + attention_mask = hs_mask.unsqueeze(1) + relative_attention_bias + + for layer in self.encoders: + hidden_states = layer(hidden_states, attention_mask) + + if unfolded: + embed_dim = hidden_states.shape[-1] + hidden_states = hidden_states.reshape(bs, -1, embed_dim) + # if we ever padded before unfolding, we need to remove the padding + if chunk_pad_size > 0: + hidden_states = hidden_states[:, :-chunk_pad_size, :] + + return hidden_states + + +class Phi4MultimodalAudioEmbedding(nn.Module): + def __init__(self, config: Phi4MultimodalConfig): + super().__init__() + self.config = config + self.layer_idx = config.audio_config.feature_layer + + self.drop = nn.Dropout(config.embd_pdrop) + self.encoder = Phi4MultimodalAudioModel._from_config(config.audio_config) + self.up_proj_for_speech = nn.Linear( + config.audio_config.hidden_size * config.audio_config.downsample_rate, config.hidden_size + ) + self.down_proj_for_speech = nn.Linear(config.hidden_size, config.hidden_size) + self.up_proj_for_vision_speech = nn.Linear( + config.audio_config.hidden_size * config.audio_config.downsample_rate, config.hidden_size + ) + self.down_proj_for_vision_speech = nn.Linear(config.hidden_size, config.hidden_size) + + def forward( + self, + input_ids: torch.LongTensor, + inputs_embeds: torch.Tensor, + audio_input_features: torch.FloatTensor, + audio_embed_sizes=None, + audio_attention_mask=None, + audio_projection_mode="speech", + ) -> torch.FloatTensor: + with torch.no_grad(): + positions_tuple = torch.nonzero(input_ids == self.config.audio_config.audio_token_id, as_tuple=True) + + up_proj = self.up_proj_for_speech if audio_projection_mode == "speech" else self.up_proj_for_vision_speech + down_proj = ( + self.down_proj_for_speech if audio_projection_mode == "speech" else self.down_proj_for_vision_speech + ) + + target_device = up_proj.bias.device + target_dtype = up_proj.bias.dtype + + audio_input_features = audio_input_features.to(device=target_device, dtype=target_dtype) + + audio_encoder_hidden_states = self.encoder(audio_input_features, audio_attention_mask) + audio_encoder_hidden_states = up_proj(audio_encoder_hidden_states) + audio_encoder_hidden_states = nn.functional.gelu(audio_encoder_hidden_states) + audio_embeds = down_proj(audio_encoder_hidden_states) + + merged_audio_embeds = torch.cat( + [audio_embeds[i, : audio_embed_sizes[i], :] for i in range(len(audio_embed_sizes))], dim=0 + ) + merged_audio_embeds = merged_audio_embeds.to(dtype=inputs_embeds.dtype, device=inputs_embeds.device) + # Temporarily disable autocast to avoid issue on bf16 tensors + # Ref: https://github.com/pytorch/pytorch/issues/132715 + with torch.autocast(device_type=inputs_embeds.device.type, enabled=False): + audio_embeds = inputs_embeds.index_put( + indices=positions_tuple, values=merged_audio_embeds, accumulate=False + ) + + audio_embeds = self.drop(audio_embeds) + + return audio_embeds + + +@use_kernel_forward_from_hub("RMSNorm") +class Phi4MultimodalRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Phi4MultimodalRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class Phi4MultimodalMLP(nn.Module): + def __init__(self, config): + super().__init__() + + self.config = config + self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False) + self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) + self.activation_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: + up_states = self.gate_up_proj(hidden_states) + + gate, up_states = up_states.chunk(2, dim=-1) + up_states = up_states * self.activation_fn(gate) + + return self.down_proj(up_states) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + rotary_dim = cos.shape[-1] + q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] + k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] + + q_embed = torch.cat([(q_rot * cos) + (rotate_half(q_rot) * sin), q_pass], dim=-1) + k_embed = torch.cat([(k_rot * cos) + (rotate_half(k_rot) * sin), k_pass], dim=-1) + return q_embed, k_embed + + +class Phi4MultimodalAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Phi4MultimodalConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.num_key_value_heads = config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + op_size = config.num_attention_heads * self.head_dim + 2 * (config.num_key_value_heads * self.head_dim) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + self.qkv_proj = nn.Linear(config.hidden_size, op_size, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + qkv = self.qkv_proj(hidden_states) + query_pos = self.config.num_attention_heads * self.head_dim + query_states = qkv[..., :query_pos] + key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim] + value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :] + + query_states = query_states.view(hidden_shape).transpose(1, 2) + key_states = key_states.view(hidden_shape).transpose(1, 2) + value_states = value_states.view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=getattr(self.config, "sliding_window", None), + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Phi4MultimodalDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Phi4MultimodalConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = Phi4MultimodalAttention(config=config, layer_idx=layer_idx) + self.mlp = Phi4MultimodalMLP(config) + self.input_layernorm = Phi4MultimodalRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Phi4MultimodalRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.config = config + self.resid_attn_dropout = nn.Dropout(config.resid_pdrop) + self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): + input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + position_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range + `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + past_key_value (`Cache`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + self.resid_attn_dropout(hidden_states) # main diff with Llama + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + self.resid_mlp_dropout(hidden_states) # main diff with Llama + + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +class Phi4MultimodalFeatureEmbedding(nn.Module): + """Image-audio embedding.""" + + def __init__(self, config: Phi4MultimodalConfig) -> None: + super().__init__() + self.config = config + self.image_token_id = config.vision_config.image_token_id + self.audio_token_id = config.audio_config.audio_token_id + self.image_embed = Phi4MultimodalImageEmbedding(config) + self.audio_embed = Phi4MultimodalAudioEmbedding(config) + + def forward( + self, + input_ids: torch.LongTensor, + inputs_embeds: torch.Tensor, + image_pixel_values: Optional[torch.FloatTensor] = None, + audio_input_features: Optional[torch.FloatTensor] = None, + image_sizes=None, + image_attention_mask=None, + audio_embed_sizes=None, + audio_attention_mask=None, + ) -> torch.FloatTensor: + with torch.no_grad(): + image_position_mask = (input_ids == self.config.vision_config.image_token_id).unsqueeze(-1) + non_image_position_mask = ~image_position_mask + + image_embeds = None + audio_embeds = None + if image_pixel_values is not None and (input_ids == self.image_token_id).any(): + image_embeds = self.image_embed( + input_ids, + inputs_embeds, + image_pixel_values=image_pixel_values, + image_sizes=image_sizes, + image_attention_mask=image_attention_mask, + ) + if audio_input_features is not None and (input_ids == self.audio_token_id).any(): + audio_projection_mode = "vision" if image_pixel_values is not None else "speech" + audio_embeds = self.audio_embed( + input_ids, + inputs_embeds, + audio_input_features=audio_input_features, + audio_embed_sizes=audio_embed_sizes, + audio_attention_mask=audio_attention_mask, + audio_projection_mode=audio_projection_mode, + ) + + # merge image and audio + if image_embeds is not None and audio_embeds is not None: + inputs_embeds = image_embeds * image_position_mask + audio_embeds * non_image_position_mask + elif image_embeds is not None: + inputs_embeds = image_embeds + elif audio_embeds is not None: + inputs_embeds = audio_embeds + + return inputs_embeds + + +class Phi4MultimodalRotaryEmbedding(nn.Module): + def __init__(self, config: Phi4MultimodalConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +PHI4_MULTIMODAL_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Phi4MultimodalConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Phi4Multimodal Model outputting raw hidden-states without any specific head on top.", + PHI4_MULTIMODAL_START_DOCSTRING, +) +class Phi4MultimodalPreTrainedModel(PreTrainedModel): + config_class = Phi4MultimodalConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Phi4MultimodalDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + _supports_attention_backend = True + _version = "0.0.5" + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, Phi4MultimodalRMSNorm): + module.weight.data.fill_(1.0) + elif isinstance(module, Phi4MultimodalImageEmbedding): + module.global_img_feature_extensor.data.zero_() + module.sub_img_feature_extensor.data.zero_() + + +PHI4_MULTIMODAL_MODEL_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding indices in `input_values`. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache`)`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + See our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + image_pixel_values (`torch.FloatTensor`, *optional*): + If the input contains images, these correspond to the pixel values after transformations (as returned by + the Processor) + image_sizes (`torch.LongTensor`, *optional*): + If the input contains images, these correspond to size of each image. + image_attention_mask (`torch.LongTensor`, *optional*): + Attention mask for the images. + audio_input_features (`torch.FloatTensor`, *optional*): + If the input contains audio samples, these correspond to the values after transformation (as returned by + the Processor). + audio_embed_sizes (`torch.Tensor`, *optional*): + Size of the audio inputs. + audio_attention_mask (`torch.Tensor, *optional*): + Attention mask for the audio inputs. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare Phi4Multimodal Model outputting raw hidden-states without any specific head on top.", + PHI4_MULTIMODAL_START_DOCSTRING, +) +class Phi4MultimodalModel(Phi4MultimodalPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Phi4MultimodalMMDecoderLayer`] + Args: + config: Phi4MultimodalMMConfig + """ + + def __init__(self, config: Phi4MultimodalConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + + self.layers = nn.ModuleList( + [Phi4MultimodalDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Phi4MultimodalRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Phi4MultimodalRotaryEmbedding(config=config) + + self.gradient_checkpointing = False + self.embed_dropout = nn.Dropout(config.embd_pdrop) + + self.embed_tokens_extend = Phi4MultimodalFeatureEmbedding(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @can_return_tuple + @add_start_docstrings_to_model_forward(PHI4_MULTIMODAL_MODEL_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + image_pixel_values: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.LongTensor] = None, + image_attention_mask=None, + audio_input_features: Optional[torch.FloatTensor] = None, + audio_embed_sizes=None, + audio_attention_mask=None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> BaseModelOutputWithPast: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = self.embed_tokens_extend( + input_ids, + inputs_embeds, + image_pixel_values=image_pixel_values, + audio_input_features=audio_input_features, + image_sizes=image_sizes, + image_attention_mask=image_attention_mask, + audio_embed_sizes=audio_embed_sizes, + audio_attention_mask=audio_attention_mask, + ) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, "BlockMask"], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and past_key_values is not None: + is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Phi4Multimodal. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if ( + self.config._attn_implementation == "sdpa" + and not (using_static_cache or using_sliding_window_cache) + and not output_attentions + ): + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + # SlidingWindowCache or StaticCache + if using_sliding_window_cache or using_static_cache: + target_length = past_key_values.get_max_cache_shape() + # DynamicCache or no cache + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + config=self.config, + past_key_values=past_key_values, + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + config: Phi4MultimodalConfig, + past_key_values: Cache, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to place the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + config (`Phi4MultimodalConfig`): + The model's configuration class + past_key_values (`Cache`): + The cache class that is being used currently to generate + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + if config.get_text_config().sliding_window is not None: + # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also + # the check is needed to verify is current checkpoint was trained with sliding window or not + if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: + sliding_attend_mask = torch.arange(target_length, device=device) <= ( + cache_position.reshape(-1, 1) - config.get_text_config().sliding_window + ) + diagonal_attend_mask.bitwise_or_(sliding_attend_mask) + causal_mask *= diagonal_attend_mask + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.shape[-1] > target_length: + attention_mask = attention_mask[:, :target_length] + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + return causal_mask + + +class Phi4MultimodalForCausalLM(Phi4MultimodalPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = Phi4MultimodalModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @can_return_tuple + @add_start_docstrings_to_model_forward(PHI4_MULTIMODAL_MODEL_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=Phi4MultimodalConfig) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + image_pixel_values: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.LongTensor] = None, + image_attention_mask=None, + audio_input_features: Optional[torch.FloatTensor] = None, + audio_embed_sizes=None, + audio_attention_mask=None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs, + ) -> CausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + Returns: + + Example: + ```python + >>> from transformers import AutoTokenizer, Phi4MultimodalForCausalLM + >>> model = Phi4MultimodalForCausalLM.from_pretrained("TBA") + >>> tokenizer = AutoTokenizer.from_pretrained("TBA") + >>> prompt = "This is an example script ." + >>> inputs = tokenizer(prompt, return_tensors="pt") + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + 'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum' + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + image_pixel_values=image_pixel_values, + image_sizes=image_sizes, + image_attention_mask=image_attention_mask, + audio_input_features=audio_input_features, + audio_embed_sizes=audio_embed_sizes, + audio_attention_mask=audio_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + image_pixel_values=None, + image_sizes=None, + image_attention_mask=None, + audio_input_features=None, + audio_embed_sizes=None, + audio_attention_mask=None, + cache_position=None, + position_ids=None, + use_cache=True, + logits_to_keep=0, + **kwargs, + ): + # Overwritten -- this model may need to switch between short and long rope, invalidating the cache in the + # process + + # When the first time input length reached long and short factor switching point, enforce re-compute cache + # It will cause downside of slower at this single token position, however, better than current failure. + if ( + past_key_values + and self.config.rope_scaling + and input_ids.shape[1] >= self.config.original_max_position_embeddings + 1 + ): + past_length = cache_position[0] + if past_length <= self.config.original_max_position_embeddings: + past_key_values = None + + model_inputs = super().prepare_inputs_for_generation( + input_ids=input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + image_pixel_values=image_pixel_values, + image_sizes=image_sizes, + image_attention_mask=image_attention_mask, + audio_input_features=audio_input_features, + audio_embed_sizes=audio_embed_sizes, + audio_attention_mask=audio_attention_mask, + cache_position=cache_position, + position_ids=position_ids, + use_cache=use_cache, + logits_to_keep=logits_to_keep, + **kwargs, + ) + return model_inputs + + +__all__ = [ + "Phi4MultimodalAudioPreTrainedModel", + "Phi4MultimodalAudioModel", + "Phi4MultimodalVisionPreTrainedModel", + "Phi4MultimodalVisionModel", + "Phi4MultimodalPreTrainedModel", + "Phi4MultimodalModel", + "Phi4MultimodalForCausalLM", +] diff --git a/docs/transformers/src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py b/docs/transformers/src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py new file mode 100644 index 0000000000000000000000000000000000000000..b9d6bc0daf9434049629a6a1cf7ba48177180b45 --- /dev/null +++ b/docs/transformers/src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py @@ -0,0 +1,1850 @@ +# Copyright 2025 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Callable, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn + +from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask + +from ...activations import ACT2FN +from ...cache_utils import DynamicCache +from ...configuration_utils import PretrainedConfig +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPast, + BaseModelOutputWithPooling, + CausalLMOutputWithPast, +) +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...utils import ( + add_start_docstrings_to_model_forward, + can_return_tuple, + logging, + replace_return_docstrings, +) +from ..phi3.configuration_phi3 import Phi3Config +from ..phi3.modeling_phi3 import ( + Phi3DecoderLayer, + Phi3ForCausalLM, + Phi3Model, + Phi3PreTrainedModel, + Phi3RMSNorm, + Phi3RotaryEmbedding, +) +from ..siglip.configuration_siglip import SiglipVisionConfig +from ..siglip.modeling_siglip import ( + SiglipEncoder, + SiglipEncoderLayer, + SiglipMLP, + SiglipMultiheadAttentionPoolingHead, + SiglipPreTrainedModel, + SiglipVisionEmbeddings, + default_flax_embed_init, + lecun_normal_, +) + + +logger = logging.get_logger(__name__) + + +class Phi4MultimodalVisionConfig(SiglipVisionConfig): + r""" + This is the configuration class to store the configuration of a [`Phi4MultimodalVisionModel`]. It is used to instantiate a + Phi4Multimodal vision encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the vision encoder of + [microsoft/Phi-4-multimodal-instruct](https://huggingface.co/microsoft/Phi-4-multimodal-instruct) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 1152): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 4304): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 27): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + num_channels (`int`, *optional*, defaults to 3): + Number of channels in the input images. + image_size (`int`, *optional*, defaults to 448): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 14): + The size (resolution) of each patch. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + crop_size (`int`, *optional*, defaults to 448): + Crop size for the input images. + image_token_id (`int`, *optional*, defaults to 200010): + The image token id. + feature_layer (`int`, *optional*, defaults to -2): + The index of the layer of the encoder from which to extract image features. + + Example: + + ```python + >>> from transformers import Phi4MultimodalVisionConfig + + >>> # Initializing a Phi4MultimodalVisionConfig with microsoft/Phi-4-multimodal-instruct style configuration + >>> configuration = Phi4MultimodalVisionConfig() + ```""" + + model_type = "phi4_multimodal_vision" + + def __init__( + self, + hidden_size=1152, + intermediate_size=4304, + num_hidden_layers=27, + num_attention_heads=16, + num_channels=3, + image_size=448, + patch_size=14, + hidden_act="gelu_pytorch_tanh", + layer_norm_eps=1e-6, + attention_dropout=0.0, + crop_size: int = 448, + image_token_id: int = 200010, + feature_layer: int = -2, + **kwargs, + ): + super().__init__( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_hidden_layers=num_hidden_layers, + num_attention_heads=num_attention_heads, + num_channels=num_channels, + image_size=image_size, + patch_size=patch_size, + hidden_act=hidden_act, + layer_norm_eps=layer_norm_eps, + attention_dropout=attention_dropout, + **kwargs, + ) + self.crop_size = crop_size + self.image_token_id = image_token_id + self.feature_layer = feature_layer + + +class Phi4MultimodalAudioConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Phi4MultimodalAudioModel`]. It is used to instantiate a + Phi4Multimodal audio encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the audio encoder of + [microsoft/Phi-4-multimodal-instruct](https://huggingface.co/microsoft/Phi-4-multimodal-instruct) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 1024): + Dimensionality of the encoder layers. + intermediate_size (`int`, *optional*, defaults to 1536): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_blocks (`int`, *optional*, defaults to 24): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + activation (`str`, *optional*, defaults to `"swish"`): + The non-linear activation function in the MLPs. + chunk_size (`int`, *optional*, defaults to -1): + The chunk size to create the masks. + left_chunk (`int`, *optional*, defaults to 18): + The left chunk to create the masks. + dropout_rate (`float`, *optional*, defaults to 0.0): + The dropout ratio. + ext_pw_out_channel (`int`, *optional*, defaults to 1024): + Number of out channels in the point-wise conv modules. + depthwise_seperable_out_channel (`int`, *optional*, defaults to 1024): + Number of out channels in the depth-wise separable conv modules. + depthwise_multiplier (`int`, *optional*, defaults to 1): + Input size multiplier for the depth-wise separable conv modules. + kernel_size (`int`, *optional*, defaults to 3): + Kernel size for the depth-wise separable conv modules. + conv_activation (`str`, *optional*, defaults to `"swish"`): + The non-linear activation function in the conv modules. + input_size (`int`, *optional*, defaults to 80): + Input size for the audio model. + conv_glu_type (`str`, *optional*, defaults to `"swish"`): + The non-linear activation function in the point-wise conv modules. + time_reduction (`int`, *optional*, defaults to 8): + Time reduction (subsampling factor). + bias_max_distance (`int`, *optional*, defaults to 1000): + Max distance for the relative attention bias module. + bias_symmetric (`bool`, *optional*, defaults to `False`): + Whether the relative attention bias should be symmetric or not. + nemo_activation (`str`, *optional*, defaults to `"relu"`): + The non-linear activation function in the nemo conv modules. + nemo_conv_channels (`int`, *optional*, defaults to 1024): + Number of channels in the nemo conv modules. + downsample_rate (`int`, *optional*, defaults to 1): + Downsample rate for the audio feature extractor. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + audio_token_id (`int`, *optional*, defaults to 200011): + The audio token id. + feature_layer (`int`, *optional*, defaults to -2): + The index of the layer of the encoder from which to extract audio features. + + Example: + + ```python + >>> from transformers import Phi4MultimodalAudioConfig + + >>> # Initializing a Phi4MultimodalAudioConfig with microsoft/Phi-4-multimodal-instruct style configuration + >>> configuration = Phi4MultimodalAudioConfig() + ```""" + + model_type = "phi4_multimodal_audio" + + def __init__( + self, + hidden_size: int = 1024, + intermediate_size: int = 1536, + num_blocks: int = 24, + num_attention_heads: int = 16, + activation: str = "swish", + chunk_size: int = -1, + left_chunk: int = 18, + dropout_rate: float = 0.0, + ext_pw_out_channel: int = 1024, + depthwise_seperable_out_channel: int = 1024, + depthwise_multiplier: int = 1, + kernel_size: int = 3, + conv_activation: str = "swish", + input_size: int = 80, + conv_glu_type: str = "swish", + time_reduction: int = 8, + bias_max_distance: int = 1000, + bias_symmetric: bool = False, + nemo_activation: str = "relu", + nemo_conv_channels: int = 1024, + downsample_rate: int = 1, + initializer_range: float = 0.02, + audio_token_id: int = 200011, + feature_layer: int = -2, + **kwargs, + ): + super().__init__(**kwargs) + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.activation = activation + self.chunk_size = chunk_size + self.left_chunk = left_chunk + self.num_blocks = num_blocks + self.dropout_rate = dropout_rate + self.ext_pw_out_channel = ext_pw_out_channel + self.depthwise_seperable_out_channel = depthwise_seperable_out_channel + self.depthwise_multiplier = depthwise_multiplier + self.kernel_size = kernel_size + self.conv_activation = conv_activation + self.input_size = input_size + self.conv_glu_type = conv_glu_type + self.time_reduction = time_reduction + self.bias_max_distance = bias_max_distance + self.bias_symmetric = bias_symmetric + self.nemo_activation = nemo_activation + self.nemo_conv_channels = nemo_conv_channels + self.downsample_rate = downsample_rate + self.audio_token_id = audio_token_id + self.initializer_range = initializer_range + self.feature_layer = feature_layer + + if time_reduction % 2 != 0: + raise ValueError("`time_reduction` should be a multiple of 2!") + length = input_size + for _ in range(int(math.log(time_reduction, 2))): + length = math.floor((length - 1) / 2 + 1) + self.nemo_final_size = length + + +class Phi4MultimodalConfig(Phi3Config): + r""" + This is the configuration class to store the configuration of a [`Phi4MultimodalModel`]. It is used to instantiate a + Phi4Multimodal model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the + [microsoft/Phi-4-multimodal-instruct](https://huggingface.co/microsoft/Phi-4-multimodal-instruct) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 200064): + Vocabulary size of the Phi-3 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Phi3Model`]. + hidden_size (`int`, *optional*, defaults to 3072): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 8192): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + resid_pdrop (`float`, *optional*, defaults to 0.0): + Dropout probability for mlp outputs. + embd_pdrop (`int`, *optional*, defaults to 0.0): + The dropout ratio for the embeddings. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio after computing the attention scores. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 131072): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon value used for the RMSNorm. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. Whether to tie weight embeddings or not. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`dict`, *optional*): + The scaling strategy for the RoPE embeddings. If `None`, no scaling is applied. If a dictionary, it must + contain the following keys: `type`, `short_factor` and `long_factor`. The `type` must be `longrope` and + the `short_factor` and `long_factor` must be lists of numbers with the same length as the hidden size + divided by the number of attention heads divided by 2. + partial_rotary_factor (`float`, *optional*, defaults to `1.0`): + Percentage of the query and keys which will have rotary embedding. Must be between 0.0 and 1.0. + bos_token_id (`int`, *optional*, defaults to 199999): + The id of the "beginning-of-sequence" token. + eos_token_id (`int` or `list[int]`, *optional*, defaults to `[199999, 200020]`): + The id of the "end-of-sequence" token. + pad_token_id (`int`, *optional*, defaults to 199999): + The id of the padding token. + original_max_position_embeddings (`int`, *optional*, defaults to 4096): + The maximum sequence length that this model was trained with. This is used to determine the size of the + original RoPE embeddings when using long scaling. + sliding_window (`int`, *optional*): + Sliding window attention window size. If `None`, no sliding window is applied. + vision_config (`Phi4MultimodalVisionConfig` or `dict`, *optional*): + The vision config for the underlying image embedding model. If not provided, will default to the configuration + used to instantiate a model similar in architecture as + [microsoft/Phi-4-multimodal-instruct](https://huggingface.co/microsoft/Phi-4-multimodal-instruct). + audio_config (`Phi4MultimodalAudioConfig` or `dict`, *optional*): + The audio config for the underlying audio embedding model. If not provided, will default to the configuration + used to instantiate a model similar in architecture as + [microsoft/Phi-4-multimodal-instruct](https://huggingface.co/microsoft/Phi-4-multimodal-instruct). + + Example: + + ```python + >>> from transformers import Phi4MultimodalModel, Phi4MultimodalConfig + + >>> # Initializing a Phi4Multimodal style configuration + >>> configuration = Phi4MultimodalConfig.from_pretrained("microsoft/Phi-4-multimodal-instruct") + + >>> # Initializing a model from the configuration + >>> model = Phi4MultimodalModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + sub_configs = {"audio_config": Phi4MultimodalAudioConfig, "vision_config": Phi4MultimodalVisionConfig} + + def __init__( + self, + vocab_size=200064, + hidden_size=3072, + intermediate_size=8192, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=8, + resid_pdrop=0.0, + embd_pdrop=0.0, + attention_dropout=0.0, + hidden_act="silu", + max_position_embeddings=131072, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + partial_rotary_factor=1, + bos_token_id=199999, + eos_token_id=[199999, 200020], + pad_token_id=199999, + original_max_position_embeddings=4096, + sliding_window=None, + vision_config=None, + audio_config=None, + **kwargs, + ): + super().__init__( + vocab_size=vocab_size, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_hidden_layers=num_hidden_layers, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + resid_pdrop=resid_pdrop, + embd_pdrop=embd_pdrop, + attention_dropout=attention_dropout, + hidden_act=hidden_act, + max_position_embeddings=max_position_embeddings, + initializer_range=initializer_range, + rms_norm_eps=rms_norm_eps, + use_cache=use_cache, + tie_word_embeddings=tie_word_embeddings, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + partial_rotary_factor=partial_rotary_factor, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + pad_token_id=pad_token_id, + original_max_position_embeddings=original_max_position_embeddings, + sliding_window=sliding_window, + **kwargs, + ) + + if isinstance(vision_config, dict): + vision_config = Phi4MultimodalVisionConfig(**vision_config) + elif vision_config is None: + Phi4MultimodalVisionConfig() + self.vision_config = vision_config + + if isinstance(audio_config, dict): + audio_config = Phi4MultimodalAudioConfig(**audio_config) + elif vision_config is None: + audio_config = Phi4MultimodalAudioConfig() + self.audio_config = audio_config + + +class Phi4MultimodalVisionMLP(SiglipMLP): + pass + + +def simple_eager_attention_forward( + module: nn.Module, + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class Phi4MultimodalVisionAttention(nn.Module): + def __init__(self, config: Phi4MultimodalVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + self.scaling = self.head_dim**-0.5 + self.is_causal = True + self.attention_dropout = config.attention_dropout + + self.k_proj = nn.Linear(config.hidden_size, config.hidden_size) + self.v_proj = nn.Linear(config.hidden_size, config.hidden_size) + self.q_proj = nn.Linear(config.hidden_size, config.hidden_size) + self.out_proj = nn.Linear(config.hidden_size, config.hidden_size) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Input shape: Batch x Time x Channel""" + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + attention_interface: Callable = simple_eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1) + attn_output = self.out_proj(attn_output) + return attn_output, attn_weights + + +class Phi4MultimodalVisionEncoderLayer(SiglipEncoderLayer): + def __init__(self, config: Phi4MultimodalVisionConfig): + super().__init__(config) + self.self_attn = Phi4MultimodalVisionAttention(config) + self.mlp = Phi4MultimodalVisionMLP(config) + + +class Phi4MultimodalVisionEncoder(SiglipEncoder): + def __init__(self, config: Phi4MultimodalVisionConfig): + super().__init__() + self.layers = nn.ModuleList( + [Phi4MultimodalVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)] + ) + + +class Phi4MultimodalVisionPreTrainedModel(SiglipPreTrainedModel): + config_class = Phi4MultimodalVisionConfig + base_model_prefix = "phi4_vision" + supports_gradient_checkpointing = True + + _no_split_modules = ["Phi4MultimodalVisionEncoderLayer"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, Phi4MultimodalVisionEmbeddings): + width = ( + self.config.hidden_size + if isinstance(self.config, Phi4MultimodalVisionConfig) + else self.config.hidden_size + ) + nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width)) + elif isinstance(module, nn.Embedding): + default_flax_embed_init(module.weight) + elif isinstance(module, Phi4MultimodalVisionAttention): + nn.init.normal_(module.q_proj.weight) + nn.init.normal_(module.k_proj.weight) + nn.init.normal_(module.v_proj.weight) + nn.init.normal_(module.out_proj.weight) + nn.init.zeros_(module.q_proj.bias) + nn.init.zeros_(module.k_proj.bias) + nn.init.zeros_(module.v_proj.bias) + nn.init.zeros_(module.out_proj.bias) + elif isinstance(module, Phi4MultimodalVisionMLP): + nn.init.normal_(module.fc1.weight) + nn.init.normal_(module.fc2.weight) + nn.init.normal_(module.fc1.bias, std=1e-6) + nn.init.normal_(module.fc2.bias, std=1e-6) + elif isinstance(module, Phi4MultimodalVisionMultiheadAttentionPoolingHead): + nn.init.normal_(module.probe.data) + nn.init.normal_(module.attention.in_proj_weight.data) + nn.init.zeros_(module.attention.in_proj_bias.data) + elif isinstance(module, (nn.Linear, nn.Conv2d)): + lecun_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +class Phi4MultimodalVisionEmbeddings(SiglipVisionEmbeddings, nn.Module): + def __init__(self, config: Phi4MultimodalVisionConfig): + nn.Module.__init__() + self.config = config + self.patch_size = config.patch_size + self.num_patches_per_side = config.image_size // self.patch_size + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=config.hidden_size, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + ) + self.position_embedding = nn.Embedding(self.num_patches_per_side**2, config.hidden_size) + + def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor) -> torch.Tensor: + batch_size = pixel_values.size(0) + + patch_embeds = self.patch_embedding(pixel_values) + embeddings = patch_embeds.flatten(2).transpose(1, 2) + + max_im_h, max_im_w = pixel_values.size(2), pixel_values.size(3) + max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size + boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side) + position_ids = torch.full((batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0) + + for batch_idx, p_attn_mask in enumerate(patch_attention_mask): + nb_patches_h = p_attn_mask[:, 0].sum() + nb_patches_w = p_attn_mask[0].sum() + + fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h) + fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w) + + bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True) + bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True) + + pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w).flatten() + position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids + + position_ids = position_ids.to(self.position_embedding.weight.device) + + embeddings = embeddings + self.position_embedding(position_ids) + return embeddings + + +class Phi4MultimodalVisionMultiheadAttentionPoolingHead(SiglipMultiheadAttentionPoolingHead): + def __init__(self, config: Phi4MultimodalVisionConfig): + super().__init__(config) + self.mlp = Phi4MultimodalVisionMLP(config) + + def forward(self, hidden_state, attention_mask): + batch_size = hidden_state.shape[0] + probe = self.probe.repeat(batch_size, 1, 1) + + hidden_state = self.attention( + query=probe, key=hidden_state, value=hidden_state, key_padding_mask=~attention_mask + )[0] + + residual = hidden_state + hidden_state = self.layernorm(hidden_state) + hidden_state = residual + self.mlp(hidden_state) + + return hidden_state[:, 0] + + +class Phi4MultimodalVisionModel(Phi4MultimodalVisionPreTrainedModel): + config_class = Phi4MultimodalVisionConfig + main_input_name = "pixel_values" + + def __init__(self, config: Phi4MultimodalVisionConfig): + super().__init__(config) + self.config = config + + self.embeddings = Phi4MultimodalVisionEmbeddings(config) + self.encoder = Phi4MultimodalVisionEncoder(config) + self.post_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.head = Phi4MultimodalVisionMultiheadAttentionPoolingHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.embeddings.patch_embedding + + def forward( + self, + pixel_values, + patch_attention_mask: Optional[torch.BoolTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> BaseModelOutputWithPooling: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + batch_size = pixel_values.size(0) + if patch_attention_mask is None: + patch_attention_mask = torch.ones( + size=( + batch_size, + pixel_values.size(2) // self.config.patch_size, + pixel_values.size(3) // self.config.patch_size, + ), + dtype=torch.bool, + device=pixel_values.device, + ) + + hidden_states = self.embeddings(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask) + + patch_attention_mask = patch_attention_mask.view(batch_size, -1) + # The call to `_upad_input` in `_flash_attention_forward` is expensive + # So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence), + # avoiding passing the attention_mask, which is equivalent to attending to the full sequence + if not torch.any(~patch_attention_mask): + attention_mask = None + else: + attention_mask = ( + _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype) + if not self.config._attn_implementation == "flash_attention_2" + else patch_attention_mask + ) + + encoder_outputs: BaseModelOutput = self.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + last_hidden_state = encoder_outputs.last_hidden_state + last_hidden_state = self.post_layernorm(last_hidden_state) + + pooled_output = self.head( + hidden_state=last_hidden_state, + attention_mask=patch_attention_mask, + ) + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class Phi4MultimodalImageEmbedding(nn.Module): + """Image embedding.""" + + def __init__(self, config: Phi4MultimodalConfig): + super().__init__() + self.config = config + self.layer_idx = config.vision_config.feature_layer + self.crop_size = config.vision_config.crop_size + self.image_dim_out = config.vision_config.hidden_size + + n_patches = config.vision_config.image_size // config.vision_config.patch_size + if n_patches % 2 != 0: + self.img_processor_padding = nn.ReflectionPad2d((0, 1, 0, 1)) + n_patches += 1 + self.num_img_tokens = (n_patches // 2) ** 2 + + self.drop = nn.Dropout(config.embd_pdrop) + self.img_processor = Phi4MultimodalVisionModel._from_config(config.vision_config) + self.image_token_compression = nn.AvgPool2d(kernel_size=2, stride=2) + self.img_projection_up = nn.Linear(self.image_dim_out, config.hidden_size) + self.img_projection_down = nn.Linear(config.hidden_size, config.hidden_size) + self.global_img_feature_extensor = nn.Parameter(torch.zeros([1, 1, self.image_dim_out])) + self.sub_img_feature_extensor = nn.Parameter(torch.zeros([1, 1, 1, self.image_dim_out])) + + def get_img_features(self, img_embeds: torch.FloatTensor, attention_mask=None) -> torch.FloatTensor: + img_processor_output = self.img_processor( + img_embeds, patch_attention_mask=attention_mask, output_hidden_states=True + ) + img_feature = img_processor_output.hidden_states[self.layer_idx] + + patch_feature = img_feature + # reshape to 2D tensor + width = int(math.sqrt(patch_feature.size(1))) + patch_feature = patch_feature.view(-1, width, width, patch_feature.size(-1)) + # convert to NCHW + patch_feature = patch_feature.permute(0, 3, 1, 2) + if getattr(self, "img_processor_padding", None) is not None: + patch_feature = self.img_processor_padding(patch_feature) + patch_feature = self.image_token_compression(patch_feature) + # convert to NHWC + patch_feature = patch_feature.permute(0, 2, 3, 1) + patch_feature = patch_feature.view(-1, patch_feature.size(1) * patch_feature.size(2), patch_feature.size(-1)) + return patch_feature + + def forward( + self, + input_ids: torch.LongTensor, + inputs_embeds: torch.Tensor, + image_pixel_values: torch.FloatTensor, + image_sizes: Optional[torch.Tensor] = None, + image_attention_mask: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + image_pixel_values = image_pixel_values.to(self.img_processor.embeddings.patch_embedding.weight.dtype) + + target_device = self.img_projection_up.bias.device + target_dtype = self.img_projection_up.bias.dtype + + batch_size = image_pixel_values.shape[0] + + img_features = self.get_img_features( + image_pixel_values.flatten(0, 1), + attention_mask=image_attention_mask.flatten(0, 1).to(dtype=bool, device=target_device), + ) + base_feat_size = int(np.sqrt(img_features.shape[1])) + img_features = img_features.view(batch_size, -1, base_feat_size**2, self.image_dim_out) + image_sizes = image_sizes.view(-1, 2) + + output_imgs = [] + for idx in range(batch_size): + height, width = image_sizes[idx] + height_ratio = height // self.crop_size + width_ratio = width // self.crop_size + area_ratio = height_ratio * width_ratio + + global_img = img_features[idx, :1] + global_img = global_img.reshape(1, base_feat_size, base_feat_size, self.image_dim_out).contiguous() + temporary_extensor = self.sub_img_feature_extensor.repeat(1, base_feat_size, 1, 1) + global_img = torch.cat([global_img, temporary_extensor], dim=2).reshape(1, -1, self.image_dim_out) + + sub_img = img_features[idx, 1:] + sub_img = sub_img[:area_ratio] + sub_img = ( + sub_img.reshape(height_ratio, width_ratio, base_feat_size, base_feat_size, self.image_dim_out) + .transpose(1, 2) + .reshape(1, height_ratio * base_feat_size, width_ratio * base_feat_size, self.image_dim_out) + .contiguous() + ) + + if image_attention_mask is not None: + reshaped_image_attention_mask = ( + image_attention_mask[idx, 1 : area_ratio + 1, 0::2, 0::2] + .reshape(height_ratio, width_ratio, base_feat_size, base_feat_size) + .transpose(1, 2) + .reshape(1, height_ratio * base_feat_size, width_ratio * base_feat_size) + ) + useful_height = int(reshaped_image_attention_mask[0, :, 0].sum().item()) + useful_width = int(reshaped_image_attention_mask[0, 0, :].sum().item()) + sub_img = sub_img[:, :useful_height, :useful_width] + temporary_extensor = self.sub_img_feature_extensor.repeat(1, useful_height, 1, 1) + else: + temporary_extensor = self.sub_img_feature_extensor.repeat(1, height_ratio * base_feat_size, 1, 1) + + sub_img = torch.cat([sub_img, temporary_extensor], dim=2).reshape(1, -1, self.image_dim_out) + + # Merge global and sub + output_imgs.append(torch.cat([sub_img, self.global_img_feature_extensor, global_img], dim=1)) + + img_set_tensor = [] + for output_img in output_imgs: + output_img = output_img.to(device=target_device, dtype=target_dtype) + img_feature_proj = self.img_projection_up(output_img) + img_feature_proj = nn.functional.gelu(img_feature_proj) + img_feature_proj = self.img_projection_down(img_feature_proj) + img_set_tensor.append(img_feature_proj) + + merged_img_set_tensor = torch.cat(img_set_tensor, dim=1).squeeze(0) + merged_img_set_tensor = merged_img_set_tensor.to(dtype=inputs_embeds.dtype, device=inputs_embeds.device) + + with torch.no_grad(): + positions_tuple = torch.nonzero(input_ids == self.config.vision_config.image_token_id, as_tuple=True) + + # Temporarily disable autocast to avoid issue on bf16 tensors + # Ref: https://github.com/pytorch/pytorch/issues/132715 + with torch.autocast(device_type=inputs_embeds.device.type, enabled=False): + image_embeds = inputs_embeds.index_put( + indices=positions_tuple, values=merged_img_set_tensor, accumulate=False + ) + + image_embeds = self.drop(image_embeds) + + return image_embeds + + +########################################################## AUDIO ############################################# + + +class Phi4MultimodalAudioMLP(nn.Module): + def __init__(self, config: Phi4MultimodalAudioConfig): + super().__init__() + self.layer_norm = nn.LayerNorm(config.hidden_size) + self.act_fn = ACT2FN[config.activation] + self.gate_up_proj = nn.Linear(config.hidden_size, config.intermediate_size * 2) + self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward(self, hidden_states): + hidden_states = self.layer_norm(hidden_states) + up_states = self.gate_up_proj(hidden_states) + up_states, gate = up_states.chunk(2, dim=-1) + up_states = up_states * self.act_fn(gate) + up_states = self.dropout(up_states) + hidden_states = self.down_proj(up_states) + out = self.dropout(hidden_states) + + return out + + +class Phi4MultimodalAudioAttention(nn.Module): + def __init__(self, config: Phi4MultimodalAudioConfig): + super().__init__() + self.config = config + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.dropout_rate + self.is_causal = True + + self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=True) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + **kwargs, + ): + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + attention_interface: Callable = simple_eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, _ = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output + + +class Phi4MultimodalAudioDepthWiseSeperableConv1d(nn.Module): + def __init__(self, config: Phi4MultimodalAudioConfig, padding: int = 0): + super().__init__() + self.dw_conv = nn.Conv1d( + config.hidden_size, + config.hidden_size * config.depthwise_multiplier, + config.kernel_size, + 1, + padding=padding, + groups=config.hidden_size, + ) + self.pw_conv = nn.Conv1d( + config.hidden_size * config.depthwise_multiplier, config.depthwise_seperable_out_channel, 1, 1, 0 + ) + + def forward(self, hidden_states): + return self.pw_conv(self.dw_conv(hidden_states)) + + +class Phi4MultimodalAudioGluPointWiseConv(nn.Module): + def __init__(self, config: Phi4MultimodalAudioConfig): + super().__init__() + self.config = config + self.output_dim = config.ext_pw_out_channel + + self.ext_pw_conv_1d = nn.Conv1d(config.hidden_size, config.ext_pw_out_channel * 2, kernel_size=1, stride=1) + self.glu_act = ACT2FN[config.conv_glu_type] + self.b1 = nn.Parameter(torch.zeros(1, config.ext_pw_out_channel, 1)) + self.b2 = nn.Parameter(torch.zeros(1, config.ext_pw_out_channel, 1)) + + def forward(self, hidden_states): + # we assume the input always has the #channel (#dim) in the last dimension of the + # tensor, so need to switch the dimension first for 1D-Conv case + hidden_states = hidden_states.permute([0, 2, 1]) + hidden_states = self.ext_pw_conv_1d(hidden_states) + out = hidden_states[:, 0 : self.output_dim, :] + self.b1 + out = out * self.glu_act(hidden_states[:, self.output_dim : self.output_dim * 2, :] + self.b2) + return out.permute([0, 2, 1]) + + +class Phi4MultimodalAudioConvModule(nn.Module): + def __init__(self, config: Phi4MultimodalAudioConfig): + super().__init__() + self.config = config + self.kernel_size = config.kernel_size + + self.layer_norm = nn.LayerNorm(config.hidden_size) + self.glu = Phi4MultimodalAudioGluPointWiseConv(config) + self.dw_sep_conv_1d = Phi4MultimodalAudioDepthWiseSeperableConv1d(config, padding=config.kernel_size - 1) + self.act = ACT2FN[config.conv_activation] + self.ext_pw_conv_1d = nn.Conv1d(config.hidden_size, config.ext_pw_out_channel, kernel_size=1, stride=1) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward(self, hidden_states: torch.Tensor): + hidden_states = self.glu(self.layer_norm(hidden_states)) + hidden_states = self.dw_sep_conv_1d(hidden_states.permute([0, 2, 1])) + + if self.kernel_size > 1: + hidden_states = hidden_states[:, :, : -(self.kernel_size - 1)] + + hidden_states = self.act(hidden_states) + hidden_states = self.ext_pw_conv_1d(hidden_states) + out = self.dropout(hidden_states.permute([0, 2, 1])) + return out + + +class Phi4MultimodalAudioConformerEncoderLayer(nn.Module): + def __init__(self, config: Phi4MultimodalAudioConfig): + super().__init__() + + self.feed_forward_in = Phi4MultimodalAudioMLP(config) + self.self_attn = Phi4MultimodalAudioAttention(config) + self.conv = Phi4MultimodalAudioConvModule(config) + self.feed_forward_out = Phi4MultimodalAudioMLP(config) + self.layer_norm_att = nn.LayerNorm(config.hidden_size) + self.layer_norm = nn.LayerNorm(config.hidden_size) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + ): + residual = hidden_states + 0.5 * self.feed_forward_in(hidden_states) + hidden_states = self.layer_norm_att(residual) + + hidden_states = residual + self.self_attn(hidden_states, attention_mask) + hidden_states = hidden_states + self.conv(hidden_states) + hidden_states = hidden_states + 0.5 * self.feed_forward_out(hidden_states) + + out = self.layer_norm(hidden_states) + + return out + + +class Phi4MultimodalAudioNemoConvSubsampling(torch.nn.Module): + def __init__(self, config: Phi4MultimodalAudioConfig): + super().__init__() + self.subsampling_factor = config.time_reduction + self.sampling_num = int(math.log(self.subsampling_factor, 2)) + self.act_fn = ACT2FN[config.nemo_activation] + conv_channels = config.nemo_conv_channels + + layers = [ + nn.Conv2d(1, conv_channels, kernel_size=3, stride=2, padding=1), + self.act_fn, + ] + for _ in range(self.sampling_num - 1): + layers.extend( + [ + nn.Conv2d(conv_channels, conv_channels, kernel_size=3, stride=2, padding=1, groups=conv_channels), + nn.Conv2d(conv_channels, conv_channels, kernel_size=1, stride=1, padding=0, groups=1), + self.act_fn, + ] + ) + + # Aggregate the layers + self.conv = torch.nn.Sequential(*layers) + self.out = torch.nn.Linear(conv_channels * config.nemo_final_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor, mask: Optional[torch.Tensor]): + # Unsqueeze Channel Axis + hidden_states = hidden_states.unsqueeze(1) + hidden_states = self.conv(hidden_states) + + # Flatten Channel and Frequency Axes + b, _, t, _ = hidden_states.size() + hidden_states = self.out(hidden_states.transpose(1, 2).reshape(b, t, -1)) + + if mask is None: + return hidden_states, None + + max_audio_length = hidden_states.shape[1] + feature_lens = mask.sum(1) + padding_length = torch.ceil(feature_lens / self.subsampling_factor) + arange_ = torch.arange(0, max_audio_length, device=hidden_states.device) + pad_mask = arange_.expand(padding_length.size(0), -1) < padding_length.unsqueeze(1) + return hidden_states, pad_mask.unsqueeze(1) + + +class Phi4MultimodalAudioRelativeAttentionBias(nn.Module): + def __init__(self, config: Phi4MultimodalAudioConfig): + super().__init__() + + self.max_distance = config.bias_max_distance + self.symmetric = config.bias_symmetric + self.num_buckets = self.max_distance + if not config.bias_symmetric: + self.num_buckets *= 2 + self.bias_values = nn.Embedding(self.num_buckets, config.num_attention_heads) + + def forward(self, x): + # instantiate bias compatible with shape of x + max_pos = x.size(1) + context_position = torch.arange(max_pos, device=x.device, dtype=torch.long)[:, None] + memory_position = torch.arange(max_pos, device=x.device, dtype=torch.long)[None, :] + relative_position = memory_position - context_position + # clipping to a maximum distance using ops that play well with ONNX export + relative_position = relative_position.masked_fill(relative_position < -self.max_distance, -self.max_distance) + relative_position = relative_position.masked_fill( + relative_position > self.max_distance - 1, self.max_distance - 1 + ) + + # mapping from relative position to index in the bias parameter + bias_idx = relative_position + bias_idx = bias_idx.abs() if self.symmetric else bias_idx + self.num_buckets // 2 + + att_bias = self.bias_values(bias_idx) + att_bias = att_bias.permute(2, 0, 1).unsqueeze(0) + + return att_bias + + +class Phi4MultimodalAudioMeanVarianceNormLayer(nn.Module): + def __init__(self, config: Phi4MultimodalAudioConfig): + super().__init__() + self.register_buffer("global_mean", torch.zeros(config.input_size)) + self.register_buffer("global_invstd", torch.ones(config.input_size)) + + def forward(self, x): + return (x - self.global_mean) * self.global_invstd + + +class Phi4MultimodalAudioPreTrainedModel(PreTrainedModel): + config_class = Phi4MultimodalAudioConfig + supports_gradient_checkpointing = True + _no_split_modules = ["Phi4MultimodalAudioConformerEncoderLayer"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, Phi4MultimodalAudioGluPointWiseConv): + module.b1.data.zero_() + module.b2.data.zero_() + + +class Phi4MultimodalAudioModel(Phi4MultimodalAudioPreTrainedModel): + def __init__(self, config: Phi4MultimodalAudioConfig): + super().__init__(config) + self.config = config + + self.encoder_embedding = Phi4MultimodalAudioMeanVarianceNormLayer(config) + self.embed = Phi4MultimodalAudioNemoConvSubsampling(config) + self.relative_attention_bias_layer = Phi4MultimodalAudioRelativeAttentionBias(config) + self.encoders = nn.ModuleList( + [Phi4MultimodalAudioConformerEncoderLayer(config) for _ in range(config.num_blocks)] + ) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def _streaming_mask(self, seq_len, batch_size, chunk_size, left_chunk): + # Create mask matrix for streaming + # S stores start index. if chunksize is 18, s is [0,18,36,....] + chunk_start_idx = np.arange(0, seq_len, chunk_size) + # avoid randomness when run evaluation or decoding + if self.training and np.random.rand() > 0.5: + # Either first or last chunk is not complete. + # If only the last one is not complete, EOS is not effective + chunk_start_idx = seq_len - chunk_start_idx + chunk_start_idx = chunk_start_idx[::-1] + chunk_start_idx = chunk_start_idx[:-1] + chunk_start_idx = np.insert(chunk_start_idx, 0, 0) + + enc_streaming_mask = ( + adaptive_enc_mask(seq_len, chunk_start_idx, left_window=left_chunk) + .unsqueeze(0) + .expand([batch_size, -1, -1]) + ) + return enc_streaming_mask + + def forward_embeddings(self, hidden_states, masks): + """Forwarding the inputs through the top embedding layers""" + seq_len = math.ceil(hidden_states.shape[1] / self.config.time_reduction) + if seq_len <= 0: + raise ValueError( + f"The squence length after time reduction is invalid: {seq_len}. Your input feature is too short." + ) + + batch_size = hidden_states.shape[0] + + enc_streaming_mask = self._streaming_mask(seq_len, batch_size, self.config.chunk_size, self.config.left_chunk) + enc_streaming_mask = enc_streaming_mask.to(hidden_states.device) + + hidden_states, masks = self.embed(hidden_states, masks) + + streaming_mask = enc_streaming_mask + if streaming_mask is not None and masks is not None: + hs_mask = masks & streaming_mask + elif masks is not None: + hs_mask = masks + else: + hs_mask = streaming_mask + + return hidden_states, hs_mask, masks + + def calculate_hs_mask(self, hidden_states, device, mask): + max_audio_length = hidden_states.shape[1] + batch_size = hidden_states.shape[0] + enc_streaming_mask = self._streaming_mask( + max_audio_length, batch_size, self.config.chunk_size, self.config.left_chunk + ) + enc_streaming_mask = enc_streaming_mask.to(device) + if mask is None: + return enc_streaming_mask + + feature_lens = mask.sum(1) + padding_length = feature_lens + pad_mask = torch.arange(0, max_audio_length, device=device).expand( + padding_length.size(0), -1 + ) < padding_length.unsqueeze(1) + pad_mask = pad_mask.unsqueeze(1) + pad_mask = pad_mask & enc_streaming_mask + return pad_mask + + def forward(self, hidden_states: torch.Tensor, mask: Optional[torch.Tensor]): + hidden_states = self.encoder_embedding(hidden_states) + hidden_states, hs_mask, mask = self.forward_embeddings(hidden_states, mask) + + unfolded = False + bs, seq_len, _ = hidden_states.shape + max_seq_len = 500 # maxium position for absolute positional encoding + if seq_len > max_seq_len: + # audio sequence is longer than max_seq_len, unfold it into chunks of max_seq_len + unfolded = True + # the unfold op will drop residual frames, pad it to the multiple of max_seq_len + if seq_len % max_seq_len > 0: + chunk_pad_size = max_seq_len - (seq_len % max_seq_len) + else: + chunk_pad_size = 0 + if chunk_pad_size > 0: + hidden_states_pad = F.pad(hidden_states, (0, 0, 0, chunk_pad_size), "constant", 0) + hidden_states = hidden_states_pad.to(hidden_states.device) + + hidden_states = unfold_tensor(hidden_states, max_seq_len) + masks_unfold = None + if mask is not None: + # revise hs_mask here because the previous calculated hs_mask did not consider extra pad + subsampled_pad_mask = mask.squeeze(1) # [bz, subsampled_unmask_seq_len] + extra_padded_subsamlped_pad_mask = F.pad( + subsampled_pad_mask, (0, chunk_pad_size), "constant", False + ) # extra padding to the pad mask + extra_padded_subsamlped_pad_mask = extra_padded_subsamlped_pad_mask.unsqueeze(-1).float() + masks_unfold = unfold_tensor( + extra_padded_subsamlped_pad_mask, max_seq_len + ) # unfold the pad mask like we did to the input tensor + masks_unfold = masks_unfold.squeeze(-1).bool() # unfold op does not support bool tensor + hs_mask = self.calculate_hs_mask( + hidden_states, hidden_states.device, masks_unfold + ) # calculate hs_mask based on the unfolded pad mask + + relative_attention_bias = self.relative_attention_bias_layer(hidden_states) + attention_mask = hs_mask.unsqueeze(1) + relative_attention_bias + + for layer in self.encoders: + hidden_states = layer(hidden_states, attention_mask) + + if unfolded: + embed_dim = hidden_states.shape[-1] + hidden_states = hidden_states.reshape(bs, -1, embed_dim) + # if we ever padded before unfolding, we need to remove the padding + if chunk_pad_size > 0: + hidden_states = hidden_states[:, :-chunk_pad_size, :] + + return hidden_states + + +def unfold_tensor(tensor, max_seq_len): + """ + For a given tensor with shape of (N, T, D), if sequence length T is longer than max_seq_len, + this function unfold it to a (NT', max_seq_len, D) where T' is T // max_seq_len. + Args: + tensor: N, T, D + """ + _, _, D = tensor.shape + tensor = tensor.transpose(-1, -2) + # N x D x 1 x T => N x (D x max_seq_len) x T' + tensor = F.unfold(tensor[..., None, :], kernel_size=(1, max_seq_len), stride=(1, max_seq_len)) + + new_bsz, _, slen = tensor.shape + tensor = tensor.view(new_bsz, -1, max_seq_len, slen) + tensor = tensor.permute(0, 3, 2, 1) + tensor = tensor.view(-1, max_seq_len, D).contiguous() + return tensor + + +def adaptive_enc_mask(x_len, chunk_start_idx, left_window=0, right_window=0): + """ + The function is very important for Transformer Transducer Streaming mode + Args: + xs_len (int): sequence length + chunk_start_idx (list): first idx of each chunk, such as [0,18,36,48]. It also supports adaptive chunk size [0,10,15,45] + left_window (int): how many left chunks can be seen + right_window (int): how many right chunks can be seen. It is used for chunk overlap model. + Returns: + mask (torch.Tensor): a mask tensor for streaming model + """ + chunk_start_idx = torch.Tensor(chunk_start_idx).long() + start_pad = torch.nn.functional.pad( + chunk_start_idx, (1, 0) + ) # append 0 to the beginning, so it becomes [0, 0, 18, 36, 48] + end_pad = torch.nn.functional.pad( + chunk_start_idx, (0, 1), value=x_len + ) # append x_len to the end, so it becomes [0,18,36,48, x_len] + seq_range = torch.arange(0, x_len).unsqueeze(-1) + idx = ((seq_range < end_pad) & (seq_range >= start_pad)).nonzero()[:, 1] + seq_range_expand = torch.arange(0, x_len).unsqueeze(0).expand(x_len, -1) + idx_left = idx - left_window + idx_left[idx_left < 0] = 0 + boundary_left = start_pad[idx_left] + mask_left = seq_range_expand >= boundary_left.unsqueeze(-1) + idx_right = idx + right_window + idx_right[idx_right > len(chunk_start_idx)] = len(chunk_start_idx) + boundary_right = end_pad[idx_right] + mask_right = seq_range_expand < boundary_right.unsqueeze(-1) + return mask_left & mask_right + + +class Phi4MultimodalAudioEmbedding(nn.Module): + def __init__(self, config: Phi4MultimodalConfig): + super().__init__() + self.config = config + self.layer_idx = config.audio_config.feature_layer + + self.drop = nn.Dropout(config.embd_pdrop) + self.encoder = Phi4MultimodalAudioModel._from_config(config.audio_config) + self.up_proj_for_speech = nn.Linear( + config.audio_config.hidden_size * config.audio_config.downsample_rate, config.hidden_size + ) + self.down_proj_for_speech = nn.Linear(config.hidden_size, config.hidden_size) + self.up_proj_for_vision_speech = nn.Linear( + config.audio_config.hidden_size * config.audio_config.downsample_rate, config.hidden_size + ) + self.down_proj_for_vision_speech = nn.Linear(config.hidden_size, config.hidden_size) + + def forward( + self, + input_ids: torch.LongTensor, + inputs_embeds: torch.Tensor, + audio_input_features: torch.FloatTensor, + audio_embed_sizes=None, + audio_attention_mask=None, + audio_projection_mode="speech", + ) -> torch.FloatTensor: + with torch.no_grad(): + positions_tuple = torch.nonzero(input_ids == self.config.audio_config.audio_token_id, as_tuple=True) + + up_proj = self.up_proj_for_speech if audio_projection_mode == "speech" else self.up_proj_for_vision_speech + down_proj = ( + self.down_proj_for_speech if audio_projection_mode == "speech" else self.down_proj_for_vision_speech + ) + + target_device = up_proj.bias.device + target_dtype = up_proj.bias.dtype + + audio_input_features = audio_input_features.to(device=target_device, dtype=target_dtype) + + audio_encoder_hidden_states = self.encoder(audio_input_features, audio_attention_mask) + audio_encoder_hidden_states = up_proj(audio_encoder_hidden_states) + audio_encoder_hidden_states = nn.functional.gelu(audio_encoder_hidden_states) + audio_embeds = down_proj(audio_encoder_hidden_states) + + merged_audio_embeds = torch.cat( + [audio_embeds[i, : audio_embed_sizes[i], :] for i in range(len(audio_embed_sizes))], dim=0 + ) + merged_audio_embeds = merged_audio_embeds.to(dtype=inputs_embeds.dtype, device=inputs_embeds.device) + # Temporarily disable autocast to avoid issue on bf16 tensors + # Ref: https://github.com/pytorch/pytorch/issues/132715 + with torch.autocast(device_type=inputs_embeds.device.type, enabled=False): + audio_embeds = inputs_embeds.index_put( + indices=positions_tuple, values=merged_audio_embeds, accumulate=False + ) + + audio_embeds = self.drop(audio_embeds) + + return audio_embeds + + +#################################################### TEXT #################################################### + + +class Phi4MultimodalRMSNorm(Phi3RMSNorm): + pass + + +class Phi4MultimodalDecoderLayer(Phi3DecoderLayer): + pass + + +class Phi4MultimodalFeatureEmbedding(nn.Module): + """Image-audio embedding.""" + + def __init__(self, config: Phi4MultimodalConfig) -> None: + super().__init__() + self.config = config + self.image_token_id = config.vision_config.image_token_id + self.audio_token_id = config.audio_config.audio_token_id + self.image_embed = Phi4MultimodalImageEmbedding(config) + self.audio_embed = Phi4MultimodalAudioEmbedding(config) + + def forward( + self, + input_ids: torch.LongTensor, + inputs_embeds: torch.Tensor, + image_pixel_values: Optional[torch.FloatTensor] = None, + audio_input_features: Optional[torch.FloatTensor] = None, + image_sizes=None, + image_attention_mask=None, + audio_embed_sizes=None, + audio_attention_mask=None, + ) -> torch.FloatTensor: + with torch.no_grad(): + image_position_mask = (input_ids == self.config.vision_config.image_token_id).unsqueeze(-1) + non_image_position_mask = ~image_position_mask + + image_embeds = None + audio_embeds = None + if image_pixel_values is not None and (input_ids == self.image_token_id).any(): + image_embeds = self.image_embed( + input_ids, + inputs_embeds, + image_pixel_values=image_pixel_values, + image_sizes=image_sizes, + image_attention_mask=image_attention_mask, + ) + if audio_input_features is not None and (input_ids == self.audio_token_id).any(): + audio_projection_mode = "vision" if image_pixel_values is not None else "speech" + audio_embeds = self.audio_embed( + input_ids, + inputs_embeds, + audio_input_features=audio_input_features, + audio_embed_sizes=audio_embed_sizes, + audio_attention_mask=audio_attention_mask, + audio_projection_mode=audio_projection_mode, + ) + + # merge image and audio + if image_embeds is not None and audio_embeds is not None: + inputs_embeds = image_embeds * image_position_mask + audio_embeds * non_image_position_mask + elif image_embeds is not None: + inputs_embeds = image_embeds + elif audio_embeds is not None: + inputs_embeds = audio_embeds + + return inputs_embeds + + +PHI4_MULTIMODAL_MODEL_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding indices in `input_values`. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache`)`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + See our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + image_pixel_values (`torch.FloatTensor`, *optional*): + If the input contains images, these correspond to the pixel values after transformations (as returned by + the Processor) + image_sizes (`torch.LongTensor`, *optional*): + If the input contains images, these correspond to size of each image. + image_attention_mask (`torch.LongTensor`, *optional*): + Attention mask for the images. + audio_input_features (`torch.FloatTensor`, *optional*): + If the input contains audio samples, these correspond to the values after transformation (as returned by + the Processor). + audio_embed_sizes (`torch.Tensor`, *optional*): + Size of the audio inputs. + audio_attention_mask (`torch.Tensor, *optional*): + Attention mask for the audio inputs. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +class Phi4MultimodalRotaryEmbedding(Phi3RotaryEmbedding): + pass + + +class Phi4MultimodalPreTrainedModel(Phi3PreTrainedModel): + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, Phi4MultimodalRMSNorm): + module.weight.data.fill_(1.0) + elif isinstance(module, Phi4MultimodalImageEmbedding): + module.global_img_feature_extensor.data.zero_() + module.sub_img_feature_extensor.data.zero_() + + +class Phi4MultimodalModel(Phi3Model, nn.Module): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Phi4MultimodalMMDecoderLayer`] + Args: + config: Phi4MultimodalMMConfig + """ + + def __init__(self, config: Phi4MultimodalConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.embed_dropout = nn.Dropout(config.embd_pdrop) + + self.embed_tokens_extend = Phi4MultimodalFeatureEmbedding(config) + + self.layers = nn.ModuleList( + [Phi4MultimodalDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Phi4MultimodalRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @add_start_docstrings_to_model_forward(PHI4_MULTIMODAL_MODEL_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + image_pixel_values: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.LongTensor] = None, + image_attention_mask=None, + audio_input_features: Optional[torch.FloatTensor] = None, + audio_embed_sizes=None, + audio_attention_mask=None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> BaseModelOutputWithPast: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = self.embed_tokens_extend( + input_ids, + inputs_embeds, + image_pixel_values=image_pixel_values, + audio_input_features=audio_input_features, + image_sizes=image_sizes, + image_attention_mask=image_attention_mask, + audio_embed_sizes=audio_embed_sizes, + audio_attention_mask=audio_attention_mask, + ) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class Phi4MultimodalForCausalLM(Phi3ForCausalLM, nn.Module): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = Phi4MultimodalModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @add_start_docstrings_to_model_forward(PHI4_MULTIMODAL_MODEL_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=Phi4MultimodalConfig) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + image_pixel_values: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.LongTensor] = None, + image_attention_mask=None, + audio_input_features: Optional[torch.FloatTensor] = None, + audio_embed_sizes=None, + audio_attention_mask=None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs, + ) -> CausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + Returns: + + Example: + ```python + >>> from transformers import AutoTokenizer, Phi4MultimodalForCausalLM + >>> model = Phi4MultimodalForCausalLM.from_pretrained("TBA") + >>> tokenizer = AutoTokenizer.from_pretrained("TBA") + >>> prompt = "This is an example script ." + >>> inputs = tokenizer(prompt, return_tensors="pt") + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + 'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum' + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + image_pixel_values=image_pixel_values, + image_sizes=image_sizes, + image_attention_mask=image_attention_mask, + audio_input_features=audio_input_features, + audio_embed_sizes=audio_embed_sizes, + audio_attention_mask=audio_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + image_pixel_values=None, + image_sizes=None, + image_attention_mask=None, + audio_input_features=None, + audio_embed_sizes=None, + audio_attention_mask=None, + cache_position=None, + position_ids=None, + use_cache=True, + logits_to_keep=0, + **kwargs, + ): + # Overwritten -- this model may need to switch between short and long rope, invalidating the cache in the + # process + + # When the first time input length reached long and short factor switching point, enforce re-compute cache + # It will cause downside of slower at this single token position, however, better than current failure. + if ( + past_key_values + and self.config.rope_scaling + and input_ids.shape[1] >= self.config.original_max_position_embeddings + 1 + ): + past_length = cache_position[0] + if past_length <= self.config.original_max_position_embeddings: + past_key_values = None + + model_inputs = super().prepare_inputs_for_generation( + input_ids=input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + image_pixel_values=image_pixel_values, + image_sizes=image_sizes, + image_attention_mask=image_attention_mask, + audio_input_features=audio_input_features, + audio_embed_sizes=audio_embed_sizes, + audio_attention_mask=audio_attention_mask, + cache_position=cache_position, + position_ids=position_ids, + use_cache=use_cache, + logits_to_keep=logits_to_keep, + **kwargs, + ) + return model_inputs + + +__all__ = [ + "Phi4MultimodalAudioPreTrainedModel", + "Phi4MultimodalAudioModel", + "Phi4MultimodalVisionPreTrainedModel", + "Phi4MultimodalVisionModel", + "Phi4MultimodalPreTrainedModel", + "Phi4MultimodalModel", + "Phi4MultimodalForCausalLM", + "Phi4MultimodalVisionConfig", + "Phi4MultimodalAudioConfig", + "Phi4MultimodalConfig", +] diff --git a/docs/transformers/src/transformers/models/phi4_multimodal/processing_phi4_multimodal.py b/docs/transformers/src/transformers/models/phi4_multimodal/processing_phi4_multimodal.py new file mode 100644 index 0000000000000000000000000000000000000000..a0d5a75a59cb6d90122bb508c88e05f30c936a08 --- /dev/null +++ b/docs/transformers/src/transformers/models/phi4_multimodal/processing_phi4_multimodal.py @@ -0,0 +1,195 @@ +# Copyright 2025 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Processor class for Phi4Multimodal +""" + +import re +from typing import List, Optional, Union + +from ...audio_utils import AudioInput +from ...image_processing_utils import BatchFeature +from ...image_utils import ImageInput +from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack +from ...tokenization_utils_base import TextInput +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class Phi4MultimodalProcessorKwargs(ProcessingKwargs, total=False): + _defaults = { + "audio_kwargs": { + "device": "cpu", + }, + } + + +class Phi4MultimodalProcessor(ProcessorMixin): + r""" + Constructs a Phi4Multimodal processor which raps an image processor, a audio processor, and a GPT tokenizer into a single processor. + + [`Phi4MultimodalProcessor`] offers all the functionalities of [`Phi4MultimodalImageProcessorFast`] and [`GPT2Tokenizer`]. See the + [`~Phi4MultimodalProcessor.__call__`] and [`~Phi4MultimodalProcessor.decode`] for more information. + + Args: + image_processor (`Phi4MultimodalImageProcessorFast`): + The image processor to use for images. + audio_processor (`Phi4MultimodalFeatureExtractor`): + The audio processor to use for audio inputs. + tokenizer (`GPT2TokenizerFast`): + The tokenizer to use for text. + fake_image_token_pattern (`str`, *optional*, defaults to `r"<\|image_\d+\|>"`): + The fake image token pattern. + fake_audio_token_pattern (`str`, *optional*, defaults to `r"<\|audio_\d+\|>"`): + The fake audio token pattern. + """ + + attributes = ["image_processor", "audio_processor", "tokenizer"] + tokenizer_class = "GPT2TokenizerFast" + image_processor_class = "Phi4MultimodalImageProcessorFast" + audio_processor_class = "Phi4MultimodalFeatureExtractor" + valid_kwargs = ["chat_template"] + + def __init__( + self, + image_processor, + audio_processor, + tokenizer, + **kwargs, + ): + self.image_token = tokenizer.image_token + self.image_token_id = tokenizer.image_token_id + self.audio_token = tokenizer.audio_token + self.audio_token_id = tokenizer.audio_token_id + super().__init__(image_processor, audio_processor, tokenizer, **kwargs) + + def __call__( + self, + text: Union[TextInput, List[TextInput]], + images: Optional[ImageInput] = None, + audio: Optional[AudioInput] = None, + **kwargs: Unpack[ProcessingKwargs], + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forards the `text` + and `kwargs` arguments to GPT2Tokenizer's [`~GPT2Tokenizer.__call__`] if `text` is not `None` to encode + the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to + Phi4MultimodalImageProcessorFast's [`~Phi4MultimodalImageProcessorFast.__call__`] if `images` is not `None`. Please refer to the doctsring + of the above two methods for more information. + + Args: + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + audio (`List[Union[np.ndarray, torch.Tensor]]`): + List of the audios to be prepared. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model. + - **input_image_embeds** -- Pixel values to be fed to a model. + - **image_sizes** -- List of tuples specifying the size of each image in `input_image_embeds`. + - **image_attention_mask** -- List of attention masks for each image in `input_image_embeds`. + - **input_audio_embeds** -- Audio embeddings to be fed to a model. + - **audio_embed_sizes** -- List of integers specifying the size of each audio in `input_audio_embeds`. + """ + + output_kwargs = self._merge_kwargs(Phi4MultimodalProcessorKwargs, self.tokenizer.init_kwargs, **kwargs) + image_kwargs = output_kwargs["images_kwargs"] + audio_kwargs = output_kwargs["audio_kwargs"] + + image_inputs = self.image_processor(images, **image_kwargs) if images is not None else {} + audio_inputs = self.audio_processor(audio, **audio_kwargs) if audio is not None else {} + + # We pop here for images as we don't need it later + num_img_tokens = image_inputs.pop("num_img_tokens", []) + audio_embed_sizes = audio_inputs.get("audio_embed_sizes", []) + + # Replace certain special tokens for compatibility + if isinstance(text, str): + text = [text] + elif not isinstance(text, list) and not isinstance(text[0], str): + raise ValueError("Invalid input text. Please provide a string, or a list of strings") + + image_token = self.tokenizer.image_token + audio_token = self.tokenizer.audio_token + + # Check that the number of special tokens is sound + concatenated_prompt = "".join(text) + if concatenated_prompt.count(image_token) != len(num_img_tokens): + raise ValueError( + "You should add as much image tokens `<|image|>` in your prompt as you pass `images` to the processor. ", + f"Input contains {concatenated_prompt.count(image_token)} tokens != {len(num_img_tokens)} images", + ) + if concatenated_prompt.count(audio_token) != len(audio_embed_sizes): + raise ValueError( + "You should add as much audio tokens `<|audio|>` in your prompt as you pass `audios` to the processor. " + f"Input contains {concatenated_prompt.count(audio_token)} tokens != {len(audio_embed_sizes)} audios" + ) + + # Add appropriate number of image/audio tokens (note that the count of replacement is dynamic) + image_count_iter = iter(num_img_tokens) + audio_count_iter = iter(audio_embed_sizes) + processed_text = [ + re.sub(re.escape(image_token), lambda _: image_token * next(image_count_iter), t) for t in text + ] + processed_text = [ + re.sub(re.escape(audio_token), lambda _: audio_token * next(audio_count_iter), t) for t in processed_text + ] + + return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) + text_inputs = self.tokenizer(processed_text, **output_kwargs["text_kwargs"]) + self._check_special_mm_tokens(processed_text, text_inputs, modalities=["image"]) + + # prepare batch feature + data = { + **text_inputs, + **image_inputs, + **audio_inputs, + } + + return BatchFeature(data=data, tensor_type=return_tensors) + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to GPT2Tokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to GPT2Tokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + audio_processor_input_names = self.audio_processor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names + audio_processor_input_names)) + + +__all__ = ["Phi4MultimodalProcessor"] diff --git a/docs/transformers/src/transformers/models/phimoe/__init__.py b/docs/transformers/src/transformers/models/phimoe/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e0849f5c5006e59ab19fee37230056d0add2744c --- /dev/null +++ b/docs/transformers/src/transformers/models/phimoe/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2024 Microsoft and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_phimoe import * + from .modeling_phimoe import * + +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/phimoe/configuration_phimoe.py b/docs/transformers/src/transformers/models/phimoe/configuration_phimoe.py new file mode 100644 index 0000000000000000000000000000000000000000..33123ff8ef2dd20aa33d5f8f668f75459170a284 --- /dev/null +++ b/docs/transformers/src/transformers/models/phimoe/configuration_phimoe.py @@ -0,0 +1,203 @@ +# coding=utf-8 +# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""PyTorch Phi-MoE model.""" + +from ...configuration_utils import PretrainedConfig +from ...modeling_rope_utils import rope_config_validation +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class PhimoeConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`PhimoeModel`]. It is used to instantiate a Phi-moe + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the + [microsoft/Phi-3.5-MoE-instruct](https://huggingface.co/microsoft/Phi-3.5-MoE-instruct). + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + vocab_size (`int`, *optional*, defaults to 32064): + Vocabulary size of the Phimoe model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`PhimoeModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 6400): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to `4096*32`): + The maximum sequence length that this model might ever be used with. Mixtral's sliding window attention + allows sequence of up to 4096*32 tokens. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + The id of the padding token. + bos_token_id (`int`, *optional*, defaults to 1): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 2): + The id of the "end-of-sequence" token. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 1000000.0): + The base period of the RoPE embeddings. + rope_scaling (`dict`, *optional*): + The scaling strategy for the RoPE embeddings. If `None`, no scaling is applied. If a dictionary, it must + contain the following keys: `type`, `short_factor`, `long_factor`, `short_mscale`, `long_mscale` and + `original_max_position_embeddings`. The `type` must be `longrope`, the `short_mscale` and `long_scale` must + be numbers, the `short_factor` and `long_factor` must be lists of numbers with the same length as half of + the attention head size and the `original_max_position_embeddings` must be an integer. + sliding_window (`int`, *optional*): + Sliding window attention window size. If not specified, will default to `262144`. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + num_experts_per_tok (`int`, *optional*, defaults to 2): + The number of experts to root per-token, can be also interpreted as the `top-p` routing + parameter + num_local_experts (`int`, *optional*, defaults to 16): + Number of experts per Sparse MLP layer. + output_router_logits (`bool`, *optional*, defaults to `False`): + Whether or not the router logits should be returned by the model. Enabling this will also + allow the model to output the auxiliary loss. See [here]() for more details + router_aux_loss_coef (`float`, *optional*, defaults to 0.001): + The aux loss factor for the total loss. + router_jitter_noise (`float`, *optional*, defaults to 0.01): + Amount of noise to add to the router. + input_jitter_noise (`float`, *optional*, defaults to 0.0): Input jitter noise + attention_bias (`bool`, *optional*, defaults to `False`): Attention bias + lm_head_bias (`bool`, *optional*, defaults to `False`): LM head bias + + Example: + + ```python + >>> from transformers import PhimoeModel, PhimoeConfig + >>> # Initializing a Phi-3 style configuration + >>> configuration = PhimoeConfig.from_pretrained("microsoft/Phi-3.5-MoE-instruct") + >>> # Initializing a model from the configuration + >>> model = PhimoeModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "phimoe" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32064, + hidden_size=4096, + intermediate_size=6400, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=8, + hidden_act="silu", + max_position_embeddings=4096 * 32, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + rope_theta=1e6, + rope_scaling=None, + sliding_window=None, + attention_dropout=0.0, + num_experts_per_tok=2, + num_local_experts=16, + output_router_logits=False, + router_aux_loss_coef=0.001, + router_jitter_noise=0.01, + input_jitter_noise=0.0, + attention_bias=False, + lm_head_bias=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.sliding_window = sliding_window + self.attention_bias = attention_bias + self.lm_head_bias = lm_head_bias + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + + self.num_experts_per_tok = num_experts_per_tok + self.num_local_experts = num_local_experts + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + self.router_jitter_noise = router_jitter_noise + self.input_jitter_noise = input_jitter_noise + + self.rope_scaling = rope_scaling + if isinstance(self.rope_scaling, dict): + if "rope_type" not in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling.get("type", None) + if "original_max_position_embeddings" in self.rope_scaling: + self.original_max_position_embeddings = self.rope_scaling["original_max_position_embeddings"] + rope_scaling_short_mscale = self.rope_scaling.get("short_mscale", None) + rope_scaling_long_mscale = self.rope_scaling.get("long_mscale", None) + if not isinstance(rope_scaling_short_mscale, (int, float)): + raise ValueError( + f"`rope_scaling`'s short_mscale field must be a number, got {rope_scaling_short_mscale}" + ) + if not isinstance(rope_scaling_long_mscale, (int, float)): + raise ValueError( + f"`rope_scaling`'s long_mscale field must be a number, got {rope_scaling_long_mscale}" + ) + + rope_config_validation(self) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +__all__ = ["PhimoeConfig"] diff --git a/docs/transformers/src/transformers/models/phimoe/modeling_phimoe.py b/docs/transformers/src/transformers/models/phimoe/modeling_phimoe.py new file mode 100644 index 0000000000000000000000000000000000000000..683aa431cef3b26b6843b8390a292e4aaf7e09dc --- /dev/null +++ b/docs/transformers/src/transformers/models/phimoe/modeling_phimoe.py @@ -0,0 +1,1627 @@ +# coding=utf-8 +# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""PyTorch Phimoe model.""" + +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask +from ...modeling_flash_attention_utils import is_flash_attn_available +from ...modeling_outputs import ( + MoeCausalLMOutputWithPast, + MoeModelOutputWithPast, + SequenceClassifierOutputWithPast, +) +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS +from ...modeling_utils import PreTrainedModel +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + can_return_tuple, + is_torch_flex_attn_available, + logging, + replace_return_docstrings, +) +from .configuration_phimoe import PhimoeConfig + + +if is_flash_attn_available(): + from ...modeling_flash_attention_utils import _flash_attention_forward + +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + +# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. +# It means that the function will not be traced through and simply appear as a node in the graph. +_prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask) + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "PhimoeConfig" + + +# Copied from transformers.models.mixtral.modeling_mixtral.load_balancing_loss_func +def load_balancing_loss_func( + gate_logits: Union[torch.Tensor, Tuple[torch.Tensor], None], + num_experts: Optional[int] = None, + top_k=2, + attention_mask: Optional[torch.Tensor] = None, +) -> Union[torch.Tensor, int]: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + gate_logits: + Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of + shape [batch_size X sequence_length, num_experts]. + num_experts: + Number of experts + top_k: + The number of experts to route per-token, can be also interpreted as the `top-k` routing + parameter. + attention_mask (`torch.Tensor`, *optional*): + The attention_mask used in forward function + shape [batch_size X sequence_length] if not None. + + Returns: + The auxiliary loss. + """ + if gate_logits is None or not isinstance(gate_logits, tuple): + return 0 + + if isinstance(gate_logits, tuple): + compute_device = gate_logits[0].device + concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) + + routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) + + _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) + + if attention_mask is None: + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.mean(routing_weights, dim=0) + else: + batch_size, sequence_length = attention_mask.shape + num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) + + # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask + expert_attention_mask = ( + attention_mask[None, :, :, None, None] + .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)) + .reshape(-1, top_k, num_experts) + .to(compute_device) + ) + + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( + expert_attention_mask, dim=0 + ) + + # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert + router_per_expert_attention_mask = ( + attention_mask[None, :, :, None] + .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) + .reshape(-1, num_experts) + .to(compute_device) + ) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( + router_per_expert_attention_mask, dim=0 + ) + + overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + return overall_loss * num_experts + + +class PhimoeRotaryEmbedding(nn.Module): + def __init__( + self, + config: Optional[PhimoeConfig] = None, + ): + super().__init__() + + self.config = config + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + self.short_mscale = config.rope_scaling.get("short_mscale") + self.long_mscale = config.rope_scaling.get("long_mscale") + else: + self.rope_type = "default" + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + def forward(self, x, seq_len=None): + mscale = None + if self.config.rope_scaling and seq_len: + mscale = ( + self.long_mscale + if seq_len > self.config.rope_scaling["original_max_position_embeddings"] + else self.short_mscale + ) + inv_freq, attention_scaling = self.rope_init_fn(self.config, x.device, seq_len) + mscale = attention_scaling if mscale is None else mscale + t = torch.arange(seq_len, device=x.device, dtype=torch.float32) + freqs = torch.outer(t, inv_freq) + + emb = torch.cat((freqs, freqs), dim=-1) + return (emb.cos() * mscale).to(x.dtype), (emb.sin() * mscale).to(x.dtype) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class PhimoeAttention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + """ + + def __init__(self, config: PhimoeConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + self.attention_dropout = config.attention_dropout + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=self.config.attention_bias) + self.k_proj = nn.Linear( + self.hidden_size, self.num_key_value_heads * self.head_dim, bias=self.config.attention_bias + ) + self.v_proj = nn.Linear( + self.hidden_size, self.num_key_value_heads * self.head_dim, bias=self.config.attention_bias + ) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=self.config.attention_bias) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class PhimoeFlashAttention2(PhimoeAttention): + """ + Phimoe flash attention module. This module inherits from `PhimoeAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ): + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + dropout_rate = 0.0 if not self.training else self.attention_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + dropout=dropout_rate, + sliding_window=getattr(self.config, "sliding_window", None), + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class PhimoeSdpaAttention(PhimoeAttention): + """ + Phimoe attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `PhimoeAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from PhimoeAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "PhimoeModel is using PhimoeSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + position_embeddings=position_embeddings, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +PHIMOE_ATTENTION_CLASSES = { + "eager": PhimoeAttention, + "flash_attention_2": PhimoeFlashAttention2, + "sdpa": PhimoeSdpaAttention, +} + + +# Copied from transformers.models.mixtral.modeling_mixtral.MixtralBlockSparseTop2MLP with Mixtral->Phimoe +class PhimoeBlockSparseTop2MLP(nn.Module): + def __init__(self, config: PhimoeConfig): + super().__init__() + self.ffn_dim = config.intermediate_size + self.hidden_dim = config.hidden_size + + self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) + self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) + self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) + + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states): + current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states) + current_hidden_states = self.w2(current_hidden_states) + return current_hidden_states + + +class MultiplierProcessor(torch.autograd.Function): + @staticmethod + def forward( + ctx, + scores: torch.Tensor, + multiplier: torch.Tensor, + selected_experts: torch.Tensor, + masked_gates: torch.Tensor, + mask_for_one: torch.Tensor, + ): + """ + Forward pass for the custom autograd function. + + Args: + ctx: Context object to save information for backward computation. + scores (torch.Tensor): Input scores tensor. + multiplier (torch.Tensor): Multiplier tensor. + selected_experts (torch.Tensor): Tensor of selected experts. + masked_gates (torch.Tensor): Masked gates tensor. + mask_for_one (torch.Tensor): Mask for one tensor. + + Returns: + torch.Tensor: Result of the forward pass. + """ + ctx.save_for_backward(multiplier, selected_experts, masked_gates) + return multiplier * mask_for_one + + @staticmethod + def backward( + ctx, + grad_at_output: torch.Tensor, + ): + """ + Backward pass for the custom autograd function. + + Args: + ctx: Context object with saved tensors from the forward pass. + grad_at_output (torch.Tensor): Gradient at the output. + + Returns: + Tuple[torch.Tensor, None, None, None, None]: Gradients for the inputs. + """ + multiplier, selected_experts, masked_gates = ctx.saved_tensors + + grad_at_output = grad_at_output * multiplier + + grad_at_scores_expanded = masked_gates * grad_at_output.mul(-1) + grad_at_scores_expanded.scatter_add_( + dim=-1, + index=selected_experts, + src=grad_at_output, + ) + + return ( + grad_at_scores_expanded, + None, + None, + None, + None, + ) + + +def sparsemixer(scores, jitter_eps, training, top_k=2): + """ + Sparse mixer function to select top-k experts and compute multipliers. + Based on the paper: https://arxiv.org/pdf/2409.12136 + We first replace the TopK(·) function as random sampling of discrete variables + in model training. Then, following Liu et al. (2023a) and Liu et al. (2023b), we apply Heun's + third order method to approximate the expert routing gradient and construct a modified + back-propagation to give a mathematically sound gradient estimation for expert routing. + + Args: + scores (torch.Tensor): Input scores tensor. + jitter_eps (float): Jitter epsilon for numerical stability. + training (bool): Flag indicating if the model is in training mode. + top_k (int): Number of top experts to select. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Multiplier and selected experts tensors. + """ + if top_k != 2: + raise ValueError("top_k must be equal to 2") + + # first expert + + with torch.no_grad(): + # Compute mask for sparsity + mask_logits_threshold, max_ind = scores.max(dim=-1, keepdim=True) + factor = scores.abs().clamp(min=mask_logits_threshold) + mask_logits_threshold = ((mask_logits_threshold - scores) / factor) > (2 * jitter_eps) + + # Apply mask + masked_gates = scores.masked_fill(mask_logits_threshold, float("-inf")) + if training: + selected_experts = ( + ( + masked_gates + - torch.empty_like(masked_gates, memory_format=torch.legacy_contiguous_format).exponential_().log() + ) + .max(dim=-1)[1] + .unsqueeze(-1) + ) # Gumbel sampling, more robust than the multinomial method + else: + selected_experts = max_ind + + # Compute scores for gradients + masked_gates = torch.softmax(masked_gates, dim=-1) + multiplier_o = masked_gates.gather(dim=-1, index=selected_experts) + + if training: + # Compute midpoint mask + max_scores, max_ind = masked_gates.max(dim=-1, keepdim=True) + mask_for_one = torch.logical_or( + selected_experts == max_ind, + torch.rand_like(max_scores) > 0.75, # Heun's third-order method + ) + # 1 -> 1.0 & 0 -> 1./3: lambda x: (x + 0.5) / 1.5 + mask_for_one = torch.add(0.3333, mask_for_one, alpha=0.6667).type_as(masked_gates) + + multiplier = MultiplierProcessor.apply( + scores, + multiplier_o, + selected_experts, + masked_gates, + mask_for_one, + ) + else: + multiplier = multiplier_o + + # Masked out first expert + masked_scores = torch.scatter( + scores, + -1, + selected_experts, + float("-inf"), + ) + with torch.no_grad(): + # Compute mask for sparsity + mask_logits_threshold, max_ind = masked_scores.max(dim=-1, keepdim=True) + factor = scores.abs().clamp(min=mask_logits_threshold) + mask_logits_threshold = ((mask_logits_threshold - scores) / factor) > (2 * jitter_eps) + + # Apply mask + masked_gates_top2 = masked_scores.masked_fill(mask_logits_threshold, float("-inf")) + if training: + selected_experts_top2 = ( + ( + masked_gates_top2 + - torch.empty_like(masked_gates_top2, memory_format=torch.legacy_contiguous_format) + .exponential_() + .log() + ) + .max(dim=-1)[1] + .unsqueeze(-1) + ) # Gumbel sampling, more robust than the multinomial method + else: + selected_experts_top2 = max_ind + # Compute scores for gradients + masked_gates_top2 = torch.softmax(masked_gates_top2, dim=-1) + multiplier_top2_o = masked_gates_top2.gather(dim=-1, index=selected_experts_top2) + + if training: + # Compute midpoint mask + max_scores, max_ind = masked_gates_top2.max(dim=-1, keepdim=True) + mask_for_one_top2 = torch.logical_or( + selected_experts_top2 == max_ind, + torch.rand_like(max_scores).uniform_() > 0.75, # Heun's third-order method + ) + # 1 -> 1.0 & 0 -> 1./3: lambda x: (x + 0.5) / 1.5 + mask_for_one_top2 = torch.add(0.3333, mask_for_one_top2, alpha=0.6667).type_as(masked_gates_top2) + + multiplier_top2 = MultiplierProcessor.apply( + scores, + multiplier_top2_o, + selected_experts_top2, + masked_gates_top2, + mask_for_one_top2, + ) + else: + multiplier_top2 = multiplier_top2_o + + multiplier = torch.concat((multiplier, multiplier_top2), dim=-1) + selected_experts = torch.concat((selected_experts, selected_experts_top2), dim=-1) + + return ( + multiplier, + selected_experts, + ) + + +class PhimoeSparseMoeBlock(nn.Module): + """ + This implementation is + strictly equivalent to standard MoE with full capacity (no + dropped tokens). It's faster since it formulates MoE operations + in terms of block-sparse operations to accommodate imbalanced + assignments of tokens to experts, whereas standard MoE either + (1) drop tokens at the cost of reduced performance or (2) set + capacity factor to number of experts and thus waste computation + and memory on padding. + """ + + def __init__(self, config): + super().__init__() + self.hidden_dim = config.hidden_size + self.ffn_dim = config.intermediate_size + self.num_experts = config.num_local_experts + self.top_k = config.num_experts_per_tok + # gating + self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) + + self.experts = nn.ModuleList([PhimoeBlockSparseTop2MLP(config) for _ in range(self.num_experts)]) + + # Jitter parameters + self.router_jitter_noise = config.router_jitter_noise + self.input_jitter_noise = config.input_jitter_noise + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ """ + batch_size, sequence_length, hidden_dim = hidden_states.shape + if self.training and self.input_jitter_noise > 0: + hidden_states *= torch.empty_like(hidden_states).uniform_( + 1.0 - self.input_jitter_noise, 1.0 + self.input_jitter_noise + ) + hidden_states = hidden_states.view(-1, hidden_dim) + router_logits = self.gate(hidden_states) + + routing_weights, selected_experts = sparsemixer( + router_logits, + jitter_eps=self.router_jitter_noise, + training=self.training, + ) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + # Loop over all available experts in the model and perform the computation on each expert + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) + + if top_x.shape[0] == 0: + continue + + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] + + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states, router_logits + + +class PhimoeDecoderLayer(nn.Module): + def __init__(self, config: PhimoeConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = PHIMOE_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) + + self.block_sparse_moe = PhimoeSparseMoeBlock(config) + self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps, elementwise_affine=True) + self.post_attention_layernorm = nn.LayerNorm( + config.hidden_size, eps=config.rms_norm_eps, elementwise_affine=True + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + output_router_logits: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, and + should not be returned during inference. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states, router_logits = self.block_sparse_moe(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + if output_router_logits: + outputs += (router_logits,) + + return outputs + + +PHIMOE_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + Parameters: + config ([`PhimoeConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Phimoe Model outputting raw hidden-states without any specific head on top.", + PHIMOE_START_DOCSTRING, +) +class PhimoePreTrainedModel(PreTrainedModel): + config_class = PhimoeConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["PhimoeDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +PHIMOE_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, and + should not be returned during inference. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare Phimoe Model outputting raw hidden-states without any specific head on top.", + PHIMOE_START_DOCSTRING, +) +class PhimoeModel(PhimoePreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`PhimoeDecoderLayer`] + Args: + config: PhimoeConfig + """ + + def __init__(self, config: PhimoeConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [PhimoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self._attn_implementation = config._attn_implementation + self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps, elementwise_affine=True) + self.rotary_emb = PhimoeRotaryEmbedding(config=config) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @can_return_tuple + @add_start_docstrings_to_model_forward(PHIMOE_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> MoeModelOutputWithPast: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + if past_key_values is None: + past_key_values = DynamicCache() + else: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " + "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " + "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" + ) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + position_embeddings = self.rotary_emb(hidden_states, seq_len=cache_position[-1] + 1) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_router_logits = () if output_router_logits else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + output_router_logits, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if output_router_logits: + all_router_logits += (layer_outputs[-1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + router_logits=all_router_logits, + ) + + # Copied from transformers.models.phi3.modeling_phi3.Phi3Model._update_causal_mask with Phi3->Phimoe + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, "BlockMask"], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and past_key_values is not None: + is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Phimoe. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if ( + self.config._attn_implementation == "sdpa" + and not (using_static_cache or using_sliding_window_cache) + and not output_attentions + ): + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + # SlidingWindowCache or StaticCache + if using_sliding_window_cache or using_static_cache: + target_length = past_key_values.get_max_cache_shape() + # DynamicCache or no cache + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + config=self.config, + past_key_values=past_key_values, + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.mistral.modeling_mistral.MistralModel._prepare_4d_causal_attention_mask_with_cache_position with Mistral->Phimoe + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + config: PhimoeConfig, + past_key_values: Cache, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to place the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + config (`PhimoeConfig`): + The model's configuration class + past_key_values (`Cache`): + The cache class that is being used currently to generate + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + if config.get_text_config().sliding_window is not None: + # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also + # the check is needed to verify is current checkpoint was trained with sliding window or not + if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: + sliding_attend_mask = torch.arange(target_length, device=device) <= ( + cache_position.reshape(-1, 1) - config.get_text_config().sliding_window + ) + diagonal_attend_mask.bitwise_or_(sliding_attend_mask) + causal_mask *= diagonal_attend_mask + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.shape[-1] > target_length: + attention_mask = attention_mask[:, :target_length] + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + return causal_mask + + +class PhimoeForCausalLM(PhimoePreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = PhimoeModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=self.config.lm_head_bias) + self.router_aux_loss_coef = config.router_aux_loss_coef + self.num_experts = config.num_local_experts + self.num_experts_per_tok = config.num_experts_per_tok + # Initialize weights and apply final processing + self.post_init() + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_input_embeddings + def get_input_embeddings(self): + return self.model.embed_tokens + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_input_embeddings + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_output_embeddings + def get_output_embeddings(self): + return self.lm_head + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_output_embeddings + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_decoder + def set_decoder(self, decoder): + self.model = decoder + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_decoder + def get_decoder(self): + return self.model + + @can_return_tuple + @add_start_docstrings_to_model_forward(PHIMOE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + # Ignore copy + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **loss_kwargs, + ) -> MoeCausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + Returns: + Example: + ```python + >>> from transformers import AutoTokenizer, PhimoeForCausalLM + >>> model = PhimoeForCausalLM.from_pretrained("microsoft/Phi-3.5-MoE-instruct") + >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3.5-MoE-instruct") + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + if ( + use_cache + and self.config.rope_scaling + and cache_position is not None + and cache_position[0] == self.config.original_max_position_embeddings + ): + logger.warning( + f"If you are not using the generate method, you may encounter nonsensical outputs after the {self.config.original_max_position_embeddings}th token, as the KV cache needs to be recomputed." + ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: MoeModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + cache_position=cache_position, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits, + self.num_experts, + self.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) + + # Copied from transformers.models.phi3.modeling_phi3.Phi3ForCausalLM.prepare_inputs_for_generation + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + logits_to_keep=None, + **kwargs, + ): + # Overwritten -- this model may need to switch between short and long rope, invalidating the cache in the + # process + + # When the first time input length reached long and short factor switching point, enforce re-compute cache + # It will cause downside of slower at this single token position, however, better than current failure. + if ( + past_key_values + and self.config.rope_scaling + and input_ids.shape[1] >= self.config.original_max_position_embeddings + 1 + ): + past_length = cache_position[0] + if past_length <= self.config.original_max_position_embeddings: + past_key_values = None + + model_inputs = super().prepare_inputs_for_generation( + input_ids=input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + position_ids=position_ids, + use_cache=use_cache, + logits_to_keep=logits_to_keep, + **kwargs, + ) + return model_inputs + + +@add_start_docstrings( + """ + The Phimoe Model transformer with a sequence classification head on top (linear layer). + [`PhimoeForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + PHIMOE_START_DOCSTRING, +) + +# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Phimoe, LLAMA->PHIMOE, BaseModelOutputWithPast->MoeModelOutputWithPast +class PhimoeForSequenceClassification(PhimoePreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = PhimoeModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @can_return_tuple + @add_start_docstrings_to_model_forward(PHIMOE_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> SequenceClassifierOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + transformer_outputs: MoeModelOutputWithPast = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + hidden_states = transformer_outputs.last_hidden_state + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) + last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) + else: + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +__all__ = [ + "PhimoePreTrainedModel", + "PhimoeModel", + "PhimoeForCausalLM", + "PhimoeForSequenceClassification", +] diff --git a/docs/transformers/src/transformers/models/phobert/__init__.py b/docs/transformers/src/transformers/models/phobert/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5c14d34aff61362159e0b45e5895ceb08850faf5 --- /dev/null +++ b/docs/transformers/src/transformers/models/phobert/__init__.py @@ -0,0 +1,26 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .tokenization_phobert import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/phobert/tokenization_phobert.py b/docs/transformers/src/transformers/models/phobert/tokenization_phobert.py new file mode 100644 index 0000000000000000000000000000000000000000..471b3a89f8536ffa0427760de38ab4ccb93e20c1 --- /dev/null +++ b/docs/transformers/src/transformers/models/phobert/tokenization_phobert.py @@ -0,0 +1,351 @@ +# coding=utf-8 +# Copyright (c) 2020, VinAI Research and the HuggingFace Inc. team. +# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for PhoBERT""" + +import os +import re +from shutil import copyfile +from typing import List, Optional, Tuple + +from ...tokenization_utils import PreTrainedTokenizer +from ...utils import logging + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = { + "vocab_file": "vocab.txt", + "merges_file": "bpe.codes", +} + + +def get_pairs(word): + """ + Return set of symbol pairs in a word. + + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + + pairs = set(pairs) + return pairs + + +class PhobertTokenizer(PreTrainedTokenizer): + """ + Construct a PhoBERT tokenizer. Based on Byte-Pair-Encoding. + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + merges_file (`str`): + Path to the merges file. + bos_token (`st`, *optional*, defaults to `""`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + + + + When building a sequence using special tokens, this is not the token that is used for the beginning of + sequence. The token used is the `cls_token`. + + + + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + sep_token (`str`, *optional*, defaults to `""`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + cls_token (`str`, *optional*, defaults to `""`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + mask_token (`str`, *optional*, defaults to `""`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + """ + + vocab_files_names = VOCAB_FILES_NAMES + + def __init__( + self, + vocab_file, + merges_file, + bos_token="", + eos_token="", + sep_token="", + cls_token="", + unk_token="", + pad_token="", + mask_token="", + **kwargs, + ): + self.vocab_file = vocab_file + self.merges_file = merges_file + + self.encoder = {} + self.encoder[str(bos_token)] = 0 + self.encoder[str(pad_token)] = 1 + self.encoder[str(eos_token)] = 2 + self.encoder[str(unk_token)] = 3 + + self.add_from_file(vocab_file) + + self.decoder = {v: k for k, v in self.encoder.items()} + + with open(merges_file, encoding="utf-8") as merges_handle: + merges = merges_handle.read().split("\n")[:-1] + merges = [tuple(merge.split()[:-1]) for merge in merges] + + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {} + + super().__init__( + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + sep_token=sep_token, + cls_token=cls_token, + pad_token=pad_token, + mask_token=mask_token, + **kwargs, + ) + + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A PhoBERT sequence has the following format: + + - single sequence: ` X ` + - pair of sequences: ` A B ` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + + if token_ids_1 is None: + return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + cls = [self.cls_token_id] + sep = [self.sep_token_id] + return cls + token_ids_0 + sep + sep + token_ids_1 + sep + + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + if token_ids_1 is None: + return [1] + ([0] * len(token_ids_0)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1] + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. PhoBERT does not + make use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of zeros. + """ + + sep = [self.sep_token_id] + cls = [self.cls_token_id] + + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0] + + @property + def vocab_size(self): + return len(self.encoder) + + def get_vocab(self): + return dict(self.encoder, **self.added_tokens_encoder) + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token) + word = tuple(list(word[:-1]) + [word[-1] + ""]) + pairs = get_pairs(word) + + if not pairs: + return token + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + except ValueError: + new_word.extend(word[i:]) + break + else: + new_word.extend(word[i:j]) + i = j + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = "@@ ".join(word) + word = word[:-4] + self.cache[token] = word + return word + + def _tokenize(self, text): + """Tokenize a string.""" + split_tokens = [] + + words = re.findall(r"\S+\n?", text) + + for token in words: + split_tokens.extend(list(self.bpe(token).split(" "))) + return split_tokens + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.encoder.get(token, self.encoder.get(self.unk_token)) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.decoder.get(index, self.unk_token) + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + out_string = " ".join(tokens).replace("@@ ", "").strip() + return out_string + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + out_merge_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): + copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + + if os.path.abspath(self.merges_file) != os.path.abspath(out_merge_file): + copyfile(self.merges_file, out_merge_file) + + return out_vocab_file, out_merge_file + + # def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True): + # filtered_tokens = ' '.join(self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)) + # tokens_generated_so_far = re.sub('(@@ )', '', string=filtered_tokens) + # tokens_generated_so_far = re.sub('(@@ ?$)', '', string=tokens_generated_so_far) + # return ''.join(tokens_generated_so_far) + + def add_from_file(self, f): + """ + Loads a pre-existing dictionary from a text file and adds its symbols to this instance. + """ + if isinstance(f, str): + try: + with open(f, "r", encoding="utf-8") as fd: + self.add_from_file(fd) + except FileNotFoundError as fnfe: + raise fnfe + except UnicodeError: + raise Exception(f"Incorrect encoding detected in {f}, please rebuild the dataset") + return + + lines = f.readlines() + for lineTmp in lines: + line = lineTmp.strip() + idx = line.rfind(" ") + if idx == -1: + raise ValueError("Incorrect dictionary format, expected ' '") + word = line[:idx] + self.encoder[word] = len(self.encoder) + + +__all__ = ["PhobertTokenizer"] diff --git a/docs/transformers/src/transformers/models/pix2struct/__init__.py b/docs/transformers/src/transformers/models/pix2struct/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..aa645dff0494dd08019e5a5a3e0069abc9f63c88 --- /dev/null +++ b/docs/transformers/src/transformers/models/pix2struct/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_pix2struct import * + from .image_processing_pix2struct import * + from .modeling_pix2struct import * + from .processing_pix2struct import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/pix2struct/configuration_pix2struct.py b/docs/transformers/src/transformers/models/pix2struct/configuration_pix2struct.py new file mode 100644 index 0000000000000000000000000000000000000000..db2f0ff7e3ca3187ee8ad00897c7aa5124011ad8 --- /dev/null +++ b/docs/transformers/src/transformers/models/pix2struct/configuration_pix2struct.py @@ -0,0 +1,350 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Pix2Struct model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class Pix2StructTextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Pix2StructTextModel`]. It is used to instantiate + a Pix2Struct text model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the Pix2Struct text decoder used by + the [google/pix2struct-base](https://huggingface.co/google/pix2struct-base) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 50244): + Vocabulary size of the `Pix2Struct` text model. Defines the number of different tokens that can be + represented by the `inputs_ids` passed when calling [`Pix2StructTextModel`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + d_kv (`int`, *optional*, defaults to 64): + Dimensionality of the key, query, value projections in each attention head. + d_ff (`int`, *optional*, defaults to 2048): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + relative_attention_num_buckets (`int`, *optional*, defaults to 32): + The number of buckets to use for each attention layer. + relative_attention_max_distance (`int`, *optional*, defaults to 128): + The maximum distance of the longer sequences for the bucket separation. + dropout_rate (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + layer_norm_epsilon (`float`, *optional*, defaults to 1e-6): + The epsilon used by the layer normalization layers. + initializer_factor (`float`, *optional*, defaults to 1.0): + A factor for initializing all weight matrices (should be kept to 1, used internally for initialization + testing). + dense_act_fn (`Union[Callable, str]`, *optional*, defaults to `"gelu_new"`): + The non-linear activation function (function or string). + decoder_start_token_id (`int`, *optional*, defaults to 0): + The id of the `decoder_start_token_id` token. + use_cache (`bool`, *optional*, defaults to `False`): + Whether or not the model should return the last key/values attentions (not used by all models). + pad_token_id (`int`, *optional*, defaults to 0): + The id of the `padding` token. + eos_token_id (`int`, *optional*, defaults to 1): + The id of the `end-of-sequence` token. + + Example: + + ```python + >>> from transformers import Pix2StructTextConfig, Pix2StructTextModel + + >>> # Initializing a Pix2StructTextConfig with google/pix2struct-base style configuration + >>> configuration = Pix2StructTextConfig() + + >>> # Initializing a Pix2StructTextModel (with random weights) from the google/pix2struct-base style configuration + >>> model = Pix2StructTextModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "pix2struct_text_model" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = { + "hidden_size": "hidden_size", + "num_attention_heads": "num_heads", + "num_hidden_layers": "num_layers", + "decoder_attention_heads": "num_heads", + "encoder_attention_heads": "num_heads", + "encoder_layers": "num_layers", + "decoder_layers": "num_layers", + } + + def __init__( + self, + vocab_size=50244, + hidden_size=768, + d_kv=64, + d_ff=2048, + num_layers=12, + num_heads=12, + relative_attention_num_buckets=32, + relative_attention_max_distance=128, + dropout_rate=0.1, + layer_norm_epsilon=1e-6, + initializer_factor=1.0, + dense_act_fn="gelu_new", + decoder_start_token_id=0, + use_cache=False, + pad_token_id=0, + eos_token_id=1, + tie_word_embeddings=False, + is_decoder=True, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.d_kv = d_kv + self.d_ff = d_ff + self.num_layers = num_layers + self.num_heads = num_heads + self.relative_attention_num_buckets = relative_attention_num_buckets + self.relative_attention_max_distance = relative_attention_max_distance + self.dropout_rate = dropout_rate + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_factor = initializer_factor + self.use_cache = use_cache + + self.eos_token_id = eos_token_id + self.decoder_start_token_id = decoder_start_token_id + + # for backwards compatibility + self.dense_act_fn = dense_act_fn + + super().__init__( + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + decoder_start_token_id=decoder_start_token_id, + tie_word_embeddings=tie_word_embeddings, + is_decoder=is_decoder, + **kwargs, + ) + + +class Pix2StructVisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Pix2StructVisionModel`]. It is used to + instantiate a Pix2Struct vision model according to the specified arguments, defining the model architecture. + Instantiating a configuration defaults will yield a similar configuration to that of the Pix2Struct-base + [google/pix2struct-base](https://huggingface.co/google/pix2struct-base) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + patch_embed_hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the input patch_embedding layer in the Transformer encoder. + d_ff (`int`, *optional*, defaults to 2048): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + d_kv (`int`, *optional*, defaults to 64): + Dimensionality of the key, query, value projections per attention head. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + dense_act_fn (`str` or `function`, *optional*, defaults to `"gelu_new"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` `"gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + dropout_rate (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 1e-10): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + initializer_factor (`float`, *optional*, defaults to 1.0): + A factor for initializing all weight matrices (should be kept to 1, used internally for initialization + testing). + seq_len (`int`, *optional*, defaults to 4096): + Maximum sequence length (here number of patches) supported by the model. + relative_attention_num_buckets (`int`, *optional*, defaults to 32): + The number of buckets to use for each attention layer. + relative_attention_max_distance (`int`, *optional*, defaults to 128): + The maximum distance (in tokens) to use for each attention layer. + + Example: + + ```python + >>> from transformers import Pix2StructVisionConfig, Pix2StructVisionModel + + >>> # Initializing a Pix2StructVisionConfig with google/pix2struct-base style configuration + >>> configuration = Pix2StructVisionConfig() + + >>> # Initializing a Pix2StructVisionModel (with random weights) from the google/pix2struct-base style configuration + >>> model = Pix2StructVisionModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "pix2struct_vision_model" + + def __init__( + self, + hidden_size=768, + patch_embed_hidden_size=768, + d_ff=2048, + d_kv=64, + num_hidden_layers=12, + num_attention_heads=12, + dense_act_fn="gelu_new", + layer_norm_eps=1e-6, + dropout_rate=0.0, + attention_dropout=0.0, + initializer_range=1e-10, + initializer_factor=1.0, + seq_len=4096, + relative_attention_num_buckets=32, + relative_attention_max_distance=128, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.patch_embed_hidden_size = patch_embed_hidden_size + self.d_ff = d_ff + self.dropout_rate = dropout_rate + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.initializer_range = initializer_range + self.initializer_factor = initializer_factor + self.attention_dropout = attention_dropout + self.layer_norm_eps = layer_norm_eps + self.dense_act_fn = dense_act_fn + self.seq_len = seq_len + self.relative_attention_num_buckets = relative_attention_num_buckets + self.relative_attention_max_distance = relative_attention_max_distance + self.d_kv = d_kv + + +class Pix2StructConfig(PretrainedConfig): + r""" + [`Pix2StructConfig`] is the configuration class to store the configuration of a + [`Pix2StructForConditionalGeneration`]. It is used to instantiate a Pix2Struct model according to the specified + arguments, defining the text model and vision model configs. Instantiating a configuration with the defaults will + yield a similar configuration to that of the Pix2Struct-base + [google/pix2struct-base](https://huggingface.co/google/pix2struct-base) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + text_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`Pix2StructTextConfig`]. + vision_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`Pix2StructVisionConfig`]. + initializer_factor (`float`, *optional*, defaults to 1.0): + Factor to multiply the initialization range with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + is_vqa (`bool`, *optional*, defaults to `False`): + Whether the model has been fine-tuned for VQA or not. + kwargs (*optional*): + Dictionary of keyword arguments. + + Example: + + ```python + >>> from transformers import Pix2StructConfig, Pix2StructForConditionalGeneration + + >>> # Initializing a Pix2StructConfig with google/pix2struct-base style configuration + >>> configuration = Pix2StructConfig() + + >>> # Initializing a Pix2StructForConditionalGeneration (with random weights) from the google/pix2struct-base style configuration + >>> model = Pix2StructForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + + >>> # We can also initialize a Pix2StructConfig from a Pix2StructTextConfig and a Pix2StructVisionConfig + + >>> # Initializing a Pix2Struct text and Pix2Struct vision configuration + >>> config_text = Pix2StructTextConfig() + >>> config_vision = Pix2StructVisionConfig() + + >>> config = Pix2StructConfig.from_text_vision_configs(config_text, config_vision) + ```""" + + model_type = "pix2struct" + + def __init__( + self, + text_config=None, + vision_config=None, + initializer_factor=1.0, + initializer_range=0.02, + is_vqa=False, + tie_word_embeddings=False, + is_encoder_decoder=True, + **kwargs, + ): + super().__init__(tie_word_embeddings=tie_word_embeddings, is_encoder_decoder=is_encoder_decoder, **kwargs) + + if text_config is None: + text_config = {} + logger.info("text_config is None. Initializing the Pix2StructTextConfig with default values.") + + if vision_config is None: + vision_config = {} + logger.info("vision_config is None. Initializing the Pix2StructVisionConfig with default values.") + + text_config["is_encoder_decoder"] = is_encoder_decoder + text_config["tie_word_embeddings"] = tie_word_embeddings + self.text_config = Pix2StructTextConfig(**text_config) + self.vision_config = Pix2StructVisionConfig(**vision_config) + + self.decoder_start_token_id = self.text_config.decoder_start_token_id + self.pad_token_id = self.text_config.pad_token_id + self.eos_token_id = self.text_config.eos_token_id + + self.initializer_factor = initializer_factor + self.initializer_range = initializer_range + + self.text_config.initializer_range = self.initializer_range + self.vision_config.initializer_range = self.initializer_range + + self.is_vqa = is_vqa + + @classmethod + def from_text_vision_configs( + cls, text_config: Pix2StructTextConfig, vision_config: Pix2StructVisionConfig, **kwargs + ): + r""" + Instantiate a [`Pix2StructConfig`] (or a derived class) from pix2struct text model configuration and pix2struct + vision model configuration. + + Returns: + [`Pix2StructConfig`]: An instance of a configuration object + """ + + return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs) + + +__all__ = ["Pix2StructConfig", "Pix2StructTextConfig", "Pix2StructVisionConfig"] diff --git a/docs/transformers/src/transformers/models/pix2struct/convert_pix2struct_original_pytorch_to_hf.py b/docs/transformers/src/transformers/models/pix2struct/convert_pix2struct_original_pytorch_to_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..457c2236694ad1367fada658a10905400e537da1 --- /dev/null +++ b/docs/transformers/src/transformers/models/pix2struct/convert_pix2struct_original_pytorch_to_hf.py @@ -0,0 +1,155 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import os +import re + +import torch +from flax.traverse_util import flatten_dict +from t5x import checkpoints + +from transformers import ( + AutoTokenizer, + Pix2StructConfig, + Pix2StructForConditionalGeneration, + Pix2StructImageProcessor, + Pix2StructProcessor, + Pix2StructTextConfig, + Pix2StructVisionConfig, +) + + +def get_flax_param(t5x_checkpoint_path): + flax_params = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path) + flax_params = flatten_dict(flax_params) + return flax_params + + +def rename_and_convert_flax_params(flax_dict): + converted_dict = {} + + CONVERSION_MAPPING = { + "token_embedder": "embeddings", + "encoder_norm": "layernorm", + "kernel": "weight", + ".out": ".output", + "scale": "weight", + "embedders_0.pos_embedding": "row_embedder.weight", + "embedders_1.pos_embedding": "column_embedder.weight", + } + + DECODER_CONVERSION_MAPPING = { + "query": "attention.query", + "key": "attention.key", + "value": "attention.value", + "output.dense": "output", + "encoder_decoder_attention.o": "encoder_decoder_attention.attention.o", + "pre_self_attention_layer_norm": "self_attention.layer_norm", + "pre_cross_attention_layer_norm": "encoder_decoder_attention.layer_norm", + "mlp.": "mlp.DenseReluDense.", + "pre_mlp_layer_norm": "mlp.layer_norm", + "self_attention.o": "self_attention.attention.o", + "decoder.embeddings.embedding": "decoder.embed_tokens.weight", + "decoder.relpos_bias.rel_embedding": "decoder.layer.0.self_attention.attention.relative_attention_bias.weight", + "decoder.decoder_norm.weight": "decoder.final_layer_norm.weight", + "decoder.logits_dense.weight": "decoder.lm_head.weight", + } + + for key in flax_dict.keys(): + if "target" in key: + # remove the first prefix from the key + new_key = ".".join(key[1:]) + + # rename the key + for old, new in CONVERSION_MAPPING.items(): + new_key = new_key.replace(old, new) + + if "decoder" in new_key: + for old, new in DECODER_CONVERSION_MAPPING.items(): + new_key = new_key.replace(old, new) + + if "layers" in new_key and "decoder" not in new_key: + # use regex to replace the layer number + new_key = re.sub(r"layers_(\d+)", r"layer.\1", new_key) + new_key = new_key.replace("encoder", "encoder.encoder") + + elif "layers" in new_key and "decoder" in new_key: + # use regex to replace the layer number + new_key = re.sub(r"layers_(\d+)", r"layer.\1", new_key) + + converted_dict[new_key] = flax_dict[key] + + converted_torch_dict = {} + # convert converted_dict into torch format + for key in converted_dict.keys(): + if ("embed_tokens" not in key) and ("embedder" not in key): + converted_torch_dict[key] = torch.from_numpy(converted_dict[key].T) + else: + converted_torch_dict[key] = torch.from_numpy(converted_dict[key]) + + return converted_torch_dict + + +def convert_pix2struct_original_pytorch_checkpoint_to_hf( + t5x_checkpoint_path, pytorch_dump_folder_path, use_large=False, is_vqa=False +): + flax_params = get_flax_param(t5x_checkpoint_path) + + if not use_large: + encoder_config = Pix2StructVisionConfig() + decoder_config = Pix2StructTextConfig() + else: + encoder_config = Pix2StructVisionConfig( + hidden_size=1536, d_ff=3968, num_attention_heads=24, num_hidden_layers=18 + ) + decoder_config = Pix2StructTextConfig(hidden_size=1536, d_ff=3968, num_heads=24, num_layers=18) + config = Pix2StructConfig( + vision_config=encoder_config.to_dict(), text_config=decoder_config.to_dict(), is_vqa=is_vqa + ) + + model = Pix2StructForConditionalGeneration(config) + + torch_params = rename_and_convert_flax_params(flax_params) + model.load_state_dict(torch_params) + + tok = AutoTokenizer.from_pretrained("ybelkada/test-pix2struct-tokenizer") + image_processor = Pix2StructImageProcessor() + processor = Pix2StructProcessor(image_processor=image_processor, tokenizer=tok) + + if use_large: + processor.image_processor.max_patches = 4096 + + processor.image_processor.is_vqa = True + + # mkdir if needed + os.makedirs(pytorch_dump_folder_path, exist_ok=True) + + model.save_pretrained(pytorch_dump_folder_path) + processor.save_pretrained(pytorch_dump_folder_path) + + print("Model saved in {}".format(pytorch_dump_folder_path)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--t5x_checkpoint_path", default=None, type=str, help="Path to the original T5x checkpoint.") + parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.") + parser.add_argument("--use_large", action="store_true", help="Use large model.") + parser.add_argument("--is_vqa", action="store_true", help="Use large model.") + args = parser.parse_args() + + convert_pix2struct_original_pytorch_checkpoint_to_hf( + args.t5x_checkpoint_path, args.pytorch_dump_folder_path, args.use_large + ) diff --git a/docs/transformers/src/transformers/models/pix2struct/image_processing_pix2struct.py b/docs/transformers/src/transformers/models/pix2struct/image_processing_pix2struct.py new file mode 100644 index 0000000000000000000000000000000000000000..e9db5175b2ebd3a11b48790a8714677c30743756 --- /dev/null +++ b/docs/transformers/src/transformers/models/pix2struct/image_processing_pix2struct.py @@ -0,0 +1,464 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for Pix2Struct.""" + +import io +import math +from typing import Dict, Optional, Union + +import numpy as np +from huggingface_hub import hf_hub_download + +from ...image_processing_utils import BaseImageProcessor, BatchFeature +from ...image_transforms import convert_to_rgb, normalize, to_channel_dimension_format, to_pil_image +from ...image_utils import ( + ChannelDimension, + ImageInput, + get_image_size, + infer_channel_dimension_format, + make_list_of_images, + to_numpy_array, + valid_images, +) +from ...utils import TensorType, is_torch_available, is_vision_available, logging +from ...utils.import_utils import requires_backends + + +if is_vision_available(): + import textwrap + + from PIL import Image, ImageDraw, ImageFont + +if is_torch_available(): + import torch + +logger = logging.get_logger(__name__) +DEFAULT_FONT_PATH = "ybelkada/fonts" + + +# adapted from: https://discuss.pytorch.org/t/tf-image-extract-patches-in-pytorch/171409/2 +def torch_extract_patches(image_tensor, patch_height, patch_width): + """ + Utiliy function to extract patches from a given image tensor. Returns a tensor of shape (1, `patch_height`, + `patch_width`, `num_channels`x `patch_height` x `patch_width`) + + Args: + image_tensor (torch.Tensor): + The image tensor to extract patches from. + patch_height (int): + The height of the patches to extract. + patch_width (int): + The width of the patches to extract. + """ + requires_backends(torch_extract_patches, ["torch"]) + + image_tensor = image_tensor.unsqueeze(0) + patches = torch.nn.functional.unfold(image_tensor, (patch_height, patch_width), stride=(patch_height, patch_width)) + patches = patches.reshape(image_tensor.size(0), image_tensor.size(1), patch_height, patch_width, -1) + patches = patches.permute(0, 4, 2, 3, 1).reshape( + image_tensor.size(2) // patch_height, + image_tensor.size(3) // patch_width, + image_tensor.size(1) * patch_height * patch_width, + ) + return patches.unsqueeze(0) + + +# Adapted from https://github.com/google-research/pix2struct/blob/0e1779af0f4db4b652c1d92b3bbd2550a7399123/pix2struct/preprocessing/preprocessing_utils.py#L106 +def render_text( + text: str, + text_size: int = 36, + text_color: str = "black", + background_color: str = "white", + left_padding: int = 5, + right_padding: int = 5, + top_padding: int = 5, + bottom_padding: int = 5, + font_bytes: Optional[bytes] = None, + font_path: Optional[str] = None, +) -> Image.Image: + """ + Render text. This script is entirely adapted from the original script that can be found here: + https://github.com/google-research/pix2struct/blob/main/pix2struct/preprocessing/preprocessing_utils.py + + Args: + text (`str`, *optional*, defaults to ): + Text to render. + text_size (`int`, *optional*, defaults to 36): + Size of the text. + text_color (`str`, *optional*, defaults to `"black"`): + Color of the text. + background_color (`str`, *optional*, defaults to `"white"`): + Color of the background. + left_padding (`int`, *optional*, defaults to 5): + Padding on the left. + right_padding (`int`, *optional*, defaults to 5): + Padding on the right. + top_padding (`int`, *optional*, defaults to 5): + Padding on the top. + bottom_padding (`int`, *optional*, defaults to 5): + Padding on the bottom. + font_bytes (`bytes`, *optional*): + Bytes of the font to use. If `None`, the default font will be used. + font_path (`str`, *optional*): + Path to the font to use. If `None`, the default font will be used. + """ + requires_backends(render_text, "vision") + # Add new lines so that each line is no more than 80 characters. + + wrapper = textwrap.TextWrapper(width=80) + lines = wrapper.wrap(text=text) + wrapped_text = "\n".join(lines) + + if font_bytes is not None and font_path is None: + font = io.BytesIO(font_bytes) + elif font_path is not None: + font = font_path + else: + font = hf_hub_download(DEFAULT_FONT_PATH, "Arial.TTF") + font = ImageFont.truetype(font, encoding="UTF-8", size=text_size) + + # Use a temporary canvas to determine the width and height in pixels when + # rendering the text. + temp_draw = ImageDraw.Draw(Image.new("RGB", (1, 1), background_color)) + _, _, text_width, text_height = temp_draw.textbbox((0, 0), wrapped_text, font) + + # Create the actual image with a bit of padding around the text. + image_width = text_width + left_padding + right_padding + image_height = text_height + top_padding + bottom_padding + image = Image.new("RGB", (image_width, image_height), background_color) + draw = ImageDraw.Draw(image) + draw.text(xy=(left_padding, top_padding), text=wrapped_text, fill=text_color, font=font) + return image + + +# Adapted from https://github.com/google-research/pix2struct/blob/0e1779af0f4db4b652c1d92b3bbd2550a7399123/pix2struct/preprocessing/preprocessing_utils.py#L87 +def render_header( + image: np.ndarray, header: str, input_data_format: Optional[Union[str, ChildProcessError]] = None, **kwargs +): + """ + Renders the input text as a header on the input image. + + Args: + image (`np.ndarray`): + The image to render the header on. + header (`str`): + The header text. + data_format (`Union[ChannelDimension, str]`, *optional*): + The data format of the image. Can be either "ChannelDimension.channels_first" or + "ChannelDimension.channels_last". + + Returns: + `np.ndarray`: The image with the header rendered. + """ + requires_backends(render_header, "vision") + + # Convert to PIL image if necessary + image = to_pil_image(image, input_data_format=input_data_format) + + header_image = render_text(header, **kwargs) + new_width = max(header_image.width, image.width) + + new_height = int(image.height * (new_width / image.width)) + new_header_height = int(header_image.height * (new_width / header_image.width)) + + new_image = Image.new("RGB", (new_width, new_height + new_header_height), "white") + new_image.paste(header_image.resize((new_width, new_header_height)), (0, 0)) + new_image.paste(image.resize((new_width, new_height)), (0, new_header_height)) + + # Convert back to the original framework if necessary + new_image = to_numpy_array(new_image) + + if infer_channel_dimension_format(new_image) == ChannelDimension.LAST: + new_image = to_channel_dimension_format(new_image, ChannelDimension.LAST) + + return new_image + + +class Pix2StructImageProcessor(BaseImageProcessor): + r""" + Constructs a Pix2Struct image processor. + + Args: + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess` + method. According to Pix2Struct paper and code, the image is normalized with its own mean and standard + deviation. + patch_size (`Dict[str, int]`, *optional*, defaults to `{"height": 16, "width": 16}`): + The patch size to use for the image. According to Pix2Struct paper and code, the patch size is 16x16. + max_patches (`int`, *optional*, defaults to 2048): + The maximum number of patches to extract from the image as per the [Pix2Struct + paper](https://arxiv.org/pdf/2210.03347.pdf). + is_vqa (`bool`, *optional*, defaults to `False`): + Whether or not the image processor is for the VQA task. If `True` and `header_text` is passed in, text is + rendered onto the input images. + """ + + model_input_names = ["flattened_patches"] + + def __init__( + self, + do_convert_rgb: bool = True, + do_normalize: bool = True, + patch_size: Dict[str, int] = None, + max_patches: int = 2048, + is_vqa: bool = False, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.patch_size = patch_size if patch_size is not None else {"height": 16, "width": 16} + self.do_normalize = do_normalize + self.do_convert_rgb = do_convert_rgb + self.max_patches = max_patches + self.is_vqa = is_vqa + + def extract_flattened_patches( + self, + image: np.ndarray, + max_patches: int, + patch_size: dict, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Extract flattened patches from an image. + + Args: + image (`np.ndarray`): + Image to extract flattened patches from. + max_patches (`int`): + Maximum number of patches to extract. + patch_size (`dict`): + Dictionary containing the patch height and width. + + Returns: + result (`np.ndarray`): + A sequence of `max_patches` flattened patches. + """ + requires_backends(self.extract_flattened_patches, "torch") + + # convert to torch + image = to_channel_dimension_format(image, ChannelDimension.FIRST, input_data_format) + image = torch.from_numpy(image) + + patch_height, patch_width = patch_size["height"], patch_size["width"] + image_height, image_width = get_image_size(image, ChannelDimension.FIRST) + + # maximize scale s.t. + scale = math.sqrt(max_patches * (patch_height / image_height) * (patch_width / image_width)) + num_feasible_rows = max(min(math.floor(scale * image_height / patch_height), max_patches), 1) + num_feasible_cols = max(min(math.floor(scale * image_width / patch_width), max_patches), 1) + resized_height = max(num_feasible_rows * patch_height, 1) + resized_width = max(num_feasible_cols * patch_width, 1) + + image = torch.nn.functional.interpolate( + image.unsqueeze(0), + size=(resized_height, resized_width), + mode="bilinear", + align_corners=False, + antialias=True, + ).squeeze(0) + + # [1, rows, columns, patch_height * patch_width * image_channels] + patches = torch_extract_patches(image, patch_height, patch_width) + + patches_shape = patches.shape + rows = patches_shape[1] + columns = patches_shape[2] + depth = patches_shape[3] + + # [rows * columns, patch_height * patch_width * image_channels] + patches = patches.reshape([rows * columns, depth]) + + # [rows * columns, 1] + row_ids = torch.arange(rows).reshape([rows, 1]).repeat(1, columns).reshape([rows * columns, 1]) + col_ids = torch.arange(columns).reshape([1, columns]).repeat(rows, 1).reshape([rows * columns, 1]) + + # Offset by 1 so the ids do not contain zeros, which represent padding. + row_ids += 1 + col_ids += 1 + + # Prepare additional patch features. + # [rows * columns, 1] + row_ids = row_ids.to(torch.float32) + col_ids = col_ids.to(torch.float32) + + # [rows * columns, 2 + patch_height * patch_width * image_channels] + result = torch.cat([row_ids, col_ids, patches], -1) + + # [max_patches, 2 + patch_height * patch_width * image_channels] + result = torch.nn.functional.pad(result, [0, 0, 0, max_patches - (rows * columns)]).float() + + result = to_numpy_array(result) + + return result + + def normalize( + self, + image: np.ndarray, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Normalize an image. image = (image - image_mean) / image_std. + + The image std is to mimic the tensorflow implementation of the `per_image_standardization`: + https://www.tensorflow.org/api_docs/python/tf/image/per_image_standardization + + Args: + image (`np.ndarray`): + Image to normalize. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the output image. If unset, the channel dimension format of the input + image is used. + input_data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + if image.dtype == np.uint8: + image = image.astype(np.float32) + + # take mean across the whole `image` + mean = np.mean(image) + std = np.std(image) + adjusted_stddev = max(std, 1.0 / math.sqrt(np.prod(image.shape))) + + return normalize( + image, + mean=mean, + std=adjusted_stddev, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + def preprocess( + self, + images: ImageInput, + header_text: Optional[str] = None, + do_convert_rgb: Optional[bool] = None, + do_normalize: Optional[bool] = None, + max_patches: Optional[int] = None, + patch_size: Optional[Dict[str, int]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: ChannelDimension = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> ImageInput: + """ + Preprocess an image or batch of images. The processor first computes the maximum possible number of + aspect-ratio preserving patches of size `patch_size` that can be extracted from the image. It then pads the + image with zeros to make the image respect the constraint of `max_patches`. Before extracting the patches the + images are standardized following the tensorflow implementation of `per_image_standardization` + (https://www.tensorflow.org/api_docs/python/tf/image/per_image_standardization). + + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images. + header_text (`Union[List[str], str]`, *optional*): + Text to render as a header. Only has an effect if `image_processor.is_vqa` is `True`. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + max_patches (`int`, *optional*, defaults to `self.max_patches`): + Maximum number of patches to extract. + patch_size (`dict`, *optional*, defaults to `self.patch_size`): + Dictionary containing the patch height and width. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + patch_size = patch_size if patch_size is not None else self.patch_size + max_patches = max_patches if max_patches is not None else self.max_patches + is_vqa = self.is_vqa + + if kwargs.get("data_format", None) is not None: + raise ValueError("data_format is not an accepted input as the outputs are ") + + images = make_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + # PIL RGBA images are converted to RGB + if do_convert_rgb: + images = [convert_to_rgb(image) for image in images] + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + if is_vqa: + if header_text is None: + raise ValueError("A header text must be provided for VQA models.") + font_bytes = kwargs.pop("font_bytes", None) + font_path = kwargs.pop("font_path", None) + + if isinstance(header_text, str): + header_text = [header_text] * len(images) + + images = [ + render_header(image, header_text[i], font_bytes=font_bytes, font_path=font_path) + for i, image in enumerate(images) + ] + + if do_normalize: + images = [self.normalize(image=image, input_data_format=input_data_format) for image in images] + + # convert to torch tensor and permute + images = [ + self.extract_flattened_patches( + image=image, max_patches=max_patches, patch_size=patch_size, input_data_format=input_data_format + ) + for image in images + ] + + # create attention mask in numpy + attention_masks = [(image.sum(axis=-1) != 0).astype(np.float32) for image in images] + + encoded_outputs = BatchFeature( + data={"flattened_patches": images, "attention_mask": attention_masks}, tensor_type=return_tensors + ) + + return encoded_outputs + + +__all__ = ["Pix2StructImageProcessor"] diff --git a/docs/transformers/src/transformers/models/pix2struct/modeling_pix2struct.py b/docs/transformers/src/transformers/models/pix2struct/modeling_pix2struct.py new file mode 100644 index 0000000000000000000000000000000000000000..192f4a10b12ff81ec3223563b2fb0f58c202ff92 --- /dev/null +++ b/docs/transformers/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -0,0 +1,1907 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. & Google team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Pix2Struct modeling file""" + +import math +from typing import Dict, List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPooling, + CausalLMOutputWithCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, +) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import ALL_LAYERNORM_LAYERS +from ...utils import ( + DUMMY_INPUTS, + DUMMY_MASK, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_torch_flex_attn_available, + is_torch_fx_proxy, + is_torchdynamo_compiling, + logging, + replace_return_docstrings, +) +from .configuration_pix2struct import Pix2StructConfig, Pix2StructTextConfig, Pix2StructVisionConfig + + +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + +logger = logging.get_logger(__name__) + +# General docstring +_CONFIG_FOR_DOC = "Pix2StructConfig" + + +# Adapted from transformers.models.t5.modeling_t5.T5LayerNorm with T5->Pix2Struct +class Pix2StructLayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Construct a layernorm module in the T5 style. No bias and no subtraction of mean. + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean + # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated + # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for + # half-precision inputs is done in fp32 + + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states + + +try: + from apex.normalization import FusedRMSNorm + + Pix2StructLayerNorm = FusedRMSNorm # noqa + + logger.info("Discovered apex.normalization.FusedRMSNorm - will use it instead of Pix2StructLayerNorm") +except ImportError: + # using the normal Pix2StructLayerNorm + pass +except Exception: + logger.warning("Discovered apex but it failed to load, falling back to Pix2StructLayerNorm") + pass + +ALL_LAYERNORM_LAYERS.append(Pix2StructLayerNorm) + + +class Pix2StructVisionEmbeddings(nn.Module): + r""" + Construct the embeddings from patch. In `Pix2Struct` the input is different from classic Vision-transformer models. + Here the input is a sequence of `seq_len` flattened patches that also combines padding patches (tokens). Each patch + is represented by a vector of `hidden_size` values. + """ + + def __init__(self, config: Pix2StructConfig) -> None: + super().__init__() + self.patch_projection = nn.Linear(config.patch_embed_hidden_size, config.hidden_size) + + self.row_embedder = nn.Embedding(config.seq_len, config.hidden_size) + self.column_embedder = nn.Embedding(config.seq_len, config.hidden_size) + + self.dropout = nn.Dropout(config.dropout_rate) + + def forward(self, flattened_patches: torch.Tensor) -> torch.Tensor: + # the row and column indices are stored in the first and second position of the flattened_patches + # flattened_patches: `batch_size`, `seq_len`, `hidden_size` + 2 + row_indices = flattened_patches[:, :, 0].long() + col_indices = flattened_patches[:, :, 1].long() + + flattened_patches = flattened_patches[:, :, 2:] + + embeddings = self.patch_projection(flattened_patches) + row_embeddings = self.row_embedder(row_indices) + col_embeddings = self.column_embedder(col_indices) + + # sum all embeddings together + embeddings = embeddings + row_embeddings + col_embeddings + + embeddings = self.dropout(embeddings) + + return embeddings + + +class Pix2StructVisionAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + self.key_value_proj_dim = config.d_kv + self.n_heads = config.num_attention_heads + self.dropout = config.attention_dropout + self.inner_dim = self.n_heads * self.key_value_proj_dim + + # Mesh TensorFlow initialization to avoid scaling before softmax + self.query = nn.Linear(self.hidden_size, self.inner_dim, bias=False) + self.key = nn.Linear(self.hidden_size, self.inner_dim, bias=False) + self.value = nn.Linear(self.hidden_size, self.inner_dim, bias=False) + self.output = nn.Linear(self.inner_dim, self.hidden_size, bias=False) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + output_attentions=False, + ): + """ + Self-attention block + """ + # Input is (batch_size, seq_length, dim) + # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) + # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) + batch_size, seq_length = hidden_states.shape[:2] + + def to_projection_shape(states): + """projection""" + return states.contiguous().view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + + # get query states + # (batch_size, n_heads, seq_length, dim_per_head) + query_states = to_projection_shape(self.query(hidden_states)) + + # get key/value states + key_states = to_projection_shape(self.key(hidden_states)) + value_states = to_projection_shape(self.value(hidden_states)) + + # compute scores + # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + scores = torch.matmul(query_states, key_states.transpose(3, 2)) + + if position_bias is None: + position_bias = torch.zeros( + (1, self.n_heads, seq_length, seq_length), device=scores.device, dtype=scores.dtype + ) + if self.gradient_checkpointing and self.training: + position_bias.requires_grad = True + + if attention_mask.dim() == 2: + position_bias = position_bias + attention_mask[:, None, None, :].to(position_bias.device) + elif attention_mask is not None: + # (batch_size, n_heads, seq_length, key_length) + position_bias = position_bias + attention_mask.to(position_bias.device) + elif not is_torchdynamo_compiling(): + attention_mask = torch.ones( + (batch_size, seq_length), device=position_bias.device, dtype=position_bias.dtype + ) + position_bias = position_bias + attention_mask.to(position_bias.device) + + position_bias = 1 - position_bias + + position_bias_masked = position_bias.masked_fill(position_bias == 1, torch.finfo(scores.dtype).min) + scores += position_bias_masked + scores = torch.max(scores, torch.tensor(torch.finfo(scores.dtype).min)) + + # (batch_size, n_heads, seq_length, key_length) + attn_weights = nn.functional.softmax(scores, dim=-1, dtype=torch.float32).type_as(scores) + + # (batch_size, n_heads, seq_length, key_length) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + # Mask heads if we want to + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask + + attn_output = torch.matmul(attn_weights, value_states) + + # (batch_size, seq_length, dim) + attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) + + attn_output = self.output(attn_output) + + outputs = (attn_output,) + (position_bias,) + + if output_attentions: + outputs = outputs + (attn_weights,) + return outputs + + +# Copied from transformers.models.t5.modeling_t5.T5DenseGatedActDense with T5DenseGatedActDense->Pix2StructVisionMlp,T5Config->Pix2StructVisionConfig,config.d_model->config.hidden_size,dropout_rate->dropout_rate +class Pix2StructVisionMlp(nn.Module): + def __init__(self, config: Pix2StructVisionConfig): + super().__init__() + self.wi_0 = nn.Linear(config.hidden_size, config.d_ff, bias=False) + self.wi_1 = nn.Linear(config.hidden_size, config.d_ff, bias=False) + self.wo = nn.Linear(config.d_ff, config.hidden_size, bias=False) + self.dropout = nn.Dropout(config.dropout_rate) + self.act = ACT2FN[config.dense_act_fn] + + def forward(self, hidden_states): + hidden_gelu = self.act(self.wi_0(hidden_states)) + hidden_linear = self.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = self.dropout(hidden_states) + + # To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32. + # See https://github.com/huggingface/transformers/issues/20287 + # we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None`` + if ( + isinstance(self.wo.weight, torch.Tensor) + and hidden_states.dtype != self.wo.weight.dtype + and self.wo.weight.dtype != torch.int8 + ): + hidden_states = hidden_states.to(self.wo.weight.dtype) + + hidden_states = self.wo(hidden_states) + return hidden_states + + +class Pix2StructVisionLayer(nn.Module): + def __init__(self, config: Pix2StructConfig) -> None: + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = Pix2StructVisionAttention(config) + self.mlp = Pix2StructVisionMlp(config) + self.pre_mlp_layer_norm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.pre_attention_layer_norm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + residual = hidden_states + + # in Pix2StructVision, layernorm is applied before self-attention + hidden_states = self.pre_attention_layer_norm(hidden_states) + + self_attention_outputs = self.attention( + hidden_states, + attention_mask=attention_mask, + layer_head_mask=head_mask, + output_attentions=output_attentions, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + # first residual connection + hidden_states = attention_output + residual + + # in Pix2StructVision, layernorm is also applied after self-attention + layer_output = self.pre_mlp_layer_norm(hidden_states) + layer_output = self.mlp(layer_output) + hidden_states # second residual connection + + outputs = (layer_output,) + outputs + + return outputs + + +class Pix2StructVisionEncoder(nn.Module): + def __init__(self, config: Pix2StructConfig) -> None: + super().__init__() + self.config = config + self.layer = nn.ModuleList([Pix2StructVisionLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ) -> Union[tuple, BaseModelOutput]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.__call__, + hidden_states, + attention_mask, + layer_head_mask, + output_attentions, + ) + else: + layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class Pix2StructPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = Pix2StructConfig + _supports_cache_class = True + _supports_static_cache = False + + @property + def dummy_inputs(self): + input_ids = torch.tensor(DUMMY_INPUTS) + input_mask = torch.tensor(DUMMY_MASK) + dummy_inputs = { + "decoder_input_ids": input_ids, + "input_ids": input_ids, + "decoder_attention_mask": input_mask, + } + return dummy_inputs + + def _init_weights(self, module): + """Initialize the weights""" + factor = self.config.initializer_factor # Used for testing weights initialization + if isinstance(module, Pix2StructLayerNorm): + module.weight.data.fill_(factor * 1.0) + elif isinstance(module, Pix2StructTextDenseGatedActDense): + hidden_size = ( + self.config.text_config.hidden_size + if isinstance(self.config, Pix2StructConfig) + else self.config.hidden_size + ) + d_ff = self.config.text_config.d_ff if isinstance(self.config, Pix2StructConfig) else self.config.d_ff + + module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((hidden_size) ** -0.5)) + if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None: + module.wi_0.bias.data.zero_() + module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((hidden_size) ** -0.5)) + if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None: + module.wi_1.bias.data.zero_() + module.wo.weight.data.normal_(mean=0.0, std=factor * ((d_ff) ** -0.5)) + if hasattr(module.wo, "bias") and module.wo.bias is not None: + module.wo.bias.data.zero_() + elif isinstance(module, Pix2StructTextAttention): + # Mesh TensorFlow attention initialization to avoid scaling before softmax + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136 + hidden_size = ( + self.config.text_config.hidden_size + if isinstance(self.config, Pix2StructConfig) + else self.config.hidden_size + ) + key_value_proj_dim = ( + self.config.text_config.d_kv if isinstance(self.config, Pix2StructConfig) else self.config.hidden_size + ) + n_heads = ( + self.config.text_config.num_heads + if isinstance(self.config, Pix2StructConfig) + else self.config.num_heads + ) + + module.query.weight.data.normal_(mean=0.0, std=factor * ((hidden_size * key_value_proj_dim) ** -0.5)) + module.key.weight.data.normal_(mean=0.0, std=factor * (hidden_size**-0.5)) + module.value.weight.data.normal_(mean=0.0, std=factor * (hidden_size**-0.5)) + module.output.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) + if module.has_relative_attention_bias: + module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((hidden_size) ** -0.5)) + elif isinstance(module, nn.Embedding): + hidden_size = ( + self.config.text_config.hidden_size + if isinstance(self.config, Pix2StructConfig) + else self.config.hidden_size + ) + + module.weight.data.normal_(mean=0.0, std=factor * ((hidden_size) ** -0.5)) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, Pix2StructTextModel): + hidden_size = ( + self.config.text_config.hidden_size + if isinstance(self.config, Pix2StructConfig) + else self.config.hidden_size + ) + + module.lm_head.weight.data.normal_(mean=0.0, std=factor * ((hidden_size) ** -0.5)) + elif isinstance(module, (nn.Linear, nn.Conv2d)): + # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid + # `trunc_normal_cpu` not implemented in `half` issues + module.weight.data = nn.init.trunc_normal_( + module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range + ).to(module.weight.dtype) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, Pix2StructLayerNorm): + if module.weight is not None: + module.weight.data.fill_(1.0) + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel._shift_right with T5->Pix2Struct + def _shift_right(self, input_ids): + decoder_start_token_id = self.config.decoder_start_token_id + pad_token_id = self.config.pad_token_id + + if decoder_start_token_id is None: + raise ValueError( + "self.model.config.decoder_start_token_id has to be defined. In Pix2Struct it is usually set to the pad_token_id. " + "See Pix2Struct docs for more information." + ) + + # shift inputs to the right + if is_torch_fx_proxy(input_ids): + # Item assignment is not supported natively for proxies. + shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id) + shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1) + else: + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() + shifted_input_ids[..., 0] = decoder_start_token_id + + if pad_token_id is None: + raise ValueError("self.model.config.pad_token_id has to be defined.") + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +PIX2STRUCT_VISION_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it + as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`Pix2StructConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +PIX2STRUCT_VISION_INPUTS_DOCSTRING = r""" + Args: + flattened_patches (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_channels x patch_height x patch_width)`): + Flattened and padded pixel values. These values can be obtained using [`AutoImageProcessor`]. See + [`Pix2StructVisionImageProcessor.__call__`] for details. Check the [original + paper](https://arxiv.org/abs/2210.03347) (figure 5) for more details. + + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding pixel values. Mask values selected in `[0, 1]`: + + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare Pix2StructVision Model transformer outputting raw hidden-states without any specific head on top.", + PIX2STRUCT_VISION_START_DOCSTRING, +) +class Pix2StructVisionModel(Pix2StructPreTrainedModel): + config_class = Pix2StructVisionConfig + main_input_name = "flattened_patches" + supports_gradient_checkpointing = True + _no_split_modules = ["Pix2StructVisionLayer"] + + def __init__(self, config: Pix2StructConfig): + super().__init__(config) + self.config = config + + self.embeddings = Pix2StructVisionEmbeddings(config) + self.encoder = Pix2StructVisionEncoder(config) + + self.layernorm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.patch_projection + + def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None: + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(PIX2STRUCT_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC) + def forward( + self, + flattened_patches: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + Example: + + ```python + >>> import requests + >>> from PIL import Image + >>> from transformers import AutoProcessor, Pix2StructVisionModel + + >>> image_processor = AutoProcessor.from_pretrained("google/pix2struct-textcaps-base") + >>> model = Pix2StructVisionModel.from_pretrained("google/pix2struct-textcaps-base") + + >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = image_processor(images=image, return_tensors="pt") + >>> with torch.no_grad(): + ... outputs = model(**inputs) + + >>> last_hidden_states = outputs.last_hidden_state + >>> list(last_hidden_states.shape) + [1, 2048, 768] + ``` + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if flattened_patches is None: + raise ValueError("You have to specify flattened_patches") + + if attention_mask is None: + # check where `flattened_patches` is not 0 + attention_mask = (flattened_patches.sum(dim=-1) != 0).float() + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings(flattened_patches) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + + if not return_dict: + head_outputs = (sequence_output,) + return head_outputs + encoder_outputs[1:] + + return BaseModelOutput( + last_hidden_state=sequence_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +# Copied from transformers.models.t5.modeling_t5.T5DenseGatedActDense with T5->Pix2StructText,d_model->hidden_size +class Pix2StructTextDenseGatedActDense(nn.Module): + def __init__(self, config: Pix2StructTextConfig): + super().__init__() + self.wi_0 = nn.Linear(config.hidden_size, config.d_ff, bias=False) + self.wi_1 = nn.Linear(config.hidden_size, config.d_ff, bias=False) + self.wo = nn.Linear(config.d_ff, config.hidden_size, bias=False) + self.dropout = nn.Dropout(config.dropout_rate) + self.act = ACT2FN[config.dense_act_fn] + + def forward(self, hidden_states): + hidden_gelu = self.act(self.wi_0(hidden_states)) + hidden_linear = self.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = self.dropout(hidden_states) + + # To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32. + # See https://github.com/huggingface/transformers/issues/20287 + # we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None`` + if ( + isinstance(self.wo.weight, torch.Tensor) + and hidden_states.dtype != self.wo.weight.dtype + and self.wo.weight.dtype != torch.int8 + ): + hidden_states = hidden_states.to(self.wo.weight.dtype) + + hidden_states = self.wo(hidden_states) + return hidden_states + + +class Pix2StructTextLayerFF(nn.Module): + def __init__(self, config: Pix2StructTextConfig): + super().__init__() + self.DenseReluDense = Pix2StructTextDenseGatedActDense(config) + + self.layer_norm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + # Copied from transformers.models.t5.modeling_t5.T5LayerFF.forward + def forward(self, hidden_states): + forwarded_states = self.layer_norm(hidden_states) + forwarded_states = self.DenseReluDense(forwarded_states) + hidden_states = hidden_states + self.dropout(forwarded_states) + return hidden_states + + +class Pix2StructTextAttention(nn.Module): + def __init__( + self, config: Pix2StructTextConfig, has_relative_attention_bias=False, layer_idx: Optional[int] = None + ): + super().__init__() + self.has_relative_attention_bias = has_relative_attention_bias + self.relative_attention_num_buckets = config.relative_attention_num_buckets + self.relative_attention_max_distance = config.relative_attention_max_distance + self.hidden_size = config.hidden_size + self.key_value_proj_dim = config.d_kv + self.n_heads = config.num_heads + self.dropout = config.dropout_rate + self.inner_dim = self.n_heads * self.key_value_proj_dim + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + # Mesh TensorFlow initialization to avoid scaling before softmax + self.query = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + self.key = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + self.value = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + self.output = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads) + self.pruned_heads = set() + self.gradient_checkpointing = False + + @staticmethod + # Copied from transformers.models.t5.modeling_t5.T5Attention._relative_position_bucket + def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_position_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_position_if_large = torch.min( + relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) + return relative_buckets + + # Adapted from transformers.models.t5.modeling_t5.T5Attention.compute_bias + def compute_bias(self, query_length, key_length, device=None, cache_position=None): + """Compute binned relative position bias""" + if device is None: + device = self.relative_attention_bias.weight.device + if cache_position is None: + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + else: + context_position = cache_position[:, None].to(device) + memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] + relative_position = memory_position - context_position # shape (query_length, key_length) + relative_position_bucket = self._relative_position_bucket( + relative_position, # shape (query_length, key_length) + bidirectional=False, + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads) + values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) + return values + + # Adapted from transformers.models.t5.modeling_t5.T5Attention.forward + def forward( + self, + hidden_states, + mask=None, + key_value_states=None, + position_bias=None, + past_key_value=None, + layer_head_mask=None, + query_length=None, + use_cache=False, + output_attentions=False, + cache_position=None, + ): + """ + Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). + """ + # Input is (batch_size, seq_length, dim) + # Mask is (batch_size, 1, 1, key_length) (non-causal) or (batch_size, 1, seq_length, key_length) (causal decoder) + batch_size, seq_length = hidden_states.shape[:2] + + # if key_value_states are provided this layer is used as a cross-attention layer for the decoder + is_cross_attention = key_value_states is not None + + query_states = self.query(hidden_states) + query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + + if past_key_value is not None: + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value and is_updated: + # reuse k,v, cross_attentions + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] + else: + key_states = self.key(current_states) + value_states = self.value(current_states) + key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True + + # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + scores = torch.matmul(query_states, key_states.transpose(3, 2)) + + if position_bias is None: + key_length = key_states.shape[-2] + # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past) + real_seq_length = query_length if query_length is not None else cache_position[-1] + 1 + if not self.has_relative_attention_bias: + position_bias = torch.zeros( + (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype + ) + if self.gradient_checkpointing and self.training: + position_bias.requires_grad = True + else: + position_bias = self.compute_bias( + real_seq_length, key_length, device=scores.device, cache_position=cache_position + ) + position_bias = position_bias[:, :, -seq_length:, :] + + if mask is not None: + causal_mask = mask[:, :, :, : key_states.shape[-2]] + position_bias = position_bias + causal_mask + + if self.pruned_heads: + mask = torch.ones(position_bias.shape[1]) + mask[list(self.pruned_heads)] = 0 + position_bias_masked = position_bias[:, mask.bool()] + else: + position_bias_masked = position_bias + + scores += position_bias_masked + + # (batch_size, n_heads, seq_length, key_length) + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + # Mask heads if we want to + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask + + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(batch_size, -1, self.inner_dim) + attn_output = self.output(attn_output) + + outputs = (attn_output, past_key_value, position_bias) + + if output_attentions: + outputs = outputs + (attn_weights,) + return outputs + + +# Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5LayerNorm->Pix2StructLayerNorm,T5Attention->Pix2StructTextAttention,T5LayerSelfAttention->Pix2StructTextLayerSelfAttention,self.SelfAttention->self.attention,config.d_model->config.hidden_size +class Pix2StructTextLayerSelfAttention(nn.Module): + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): + super().__init__() + self.attention = Pix2StructTextAttention( + config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx + ) + self.layer_norm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + use_cache=False, + output_attentions=False, + cache_position=None, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.attention( + normed_hidden_states, + mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + ) + hidden_states = hidden_states + self.dropout(attention_output[0]) + outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5LayerNorm->Pix2StructLayerNorm,T5Attention->Pix2StructTextAttention,T5LayerCrossAttention->Pix2StructTextLayerCrossAttention,self.EncDecAttention->self.attention,config.d_model->config.hidden_size +class Pix2StructTextLayerCrossAttention(nn.Module): + def __init__(self, config, layer_idx: Optional[int] = None): + super().__init__() + self.attention = Pix2StructTextAttention(config, has_relative_attention_bias=False, layer_idx=layer_idx) + self.layer_norm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + hidden_states, + key_value_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + use_cache=False, + query_length=None, + output_attentions=False, + cache_position=None, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.attention( + normed_hidden_states, + mask=attention_mask, + key_value_states=key_value_states, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + query_length=query_length, + output_attentions=output_attentions, + cache_position=cache_position, + ) + layer_output = hidden_states + self.dropout(attention_output[0]) + outputs = (layer_output,) + attention_output[1:] # add attentions if we output them + return outputs + + +class Pix2StructTextBlock(nn.Module): + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): + super().__init__() + + self.self_attention = Pix2StructTextLayerSelfAttention( + config, + has_relative_attention_bias=has_relative_attention_bias, + layer_idx=layer_idx, + ) + + self.encoder_decoder_attention = Pix2StructTextLayerCrossAttention( + config, + layer_idx=layer_idx, + ) + + self.mlp = Pix2StructTextLayerFF(config) + + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + encoder_decoder_position_bias=None, + layer_head_mask=None, + cross_attn_layer_head_mask=None, + past_key_value=None, + use_cache=False, + output_attentions=False, + return_dict=True, + cache_position=None, + ): + self_attention_outputs = self.self_attention( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + ) + hidden_states, past_key_value = self_attention_outputs[:2] + attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + do_cross_attention = encoder_hidden_states is not None + if do_cross_attention: + cross_attention_outputs = self.encoder_decoder_attention( + hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + position_bias=encoder_decoder_position_bias, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=past_key_value, + query_length=cache_position[-1] + 1, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states, past_key_value = cross_attention_outputs[:2] + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + # Keep cross-attention outputs and relative position weights + attention_outputs = attention_outputs + cross_attention_outputs[2:] + + # Apply Feed Forward layer + hidden_states = self.mlp(hidden_states) + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if use_cache: + outputs = outputs + (past_key_value,) + attention_outputs + else: + outputs = outputs + attention_outputs + + return outputs + + +PIX2STRUCT_START_DOCSTRING = r""" + + The Pix2Struct model was proposed in [Pix2Struct: Screenshot Parsing as Pretraining for Visual Language + Understanding](https://arxiv.org/abs/2210.03347) by Kenton Lee, Mandar Joshi, Iulia Turc, Hexiang Hu, Fangyu Liu, + Julian Eisenschlos, Urvashi Khandelwal, Peter Shaw, Ming-Wei Chang, Kristina Toutanova. It's an encoder decoder + transformer pre-trained in a image-to-text setting. + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config (Union[`Pix2StructConfig`, `Pix2StructTextConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +PIX2STRUCT_TEXT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Pix2StructText is a model with relative position + embeddings so you should be able to pad the inputs on both the right and the left. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for detail. + + [What are input IDs?](../glossary#input-ids) + + To know more on how to prepare `input_ids` for pretraining take a look a [Pix2StructText + Training](./t5#training). + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + Pix2StructText uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If + `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + To know more on how to prepare `decoder_input_ids` for pretraining take a look at [Pix2StructText + Training](./t5#training). + decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in `[0, + 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0, + 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in + `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): + Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, `optional`: *attentions*) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a sequence of hidden states at + the output of the last layer of the encoder. Used in the cross-attention of the decoder. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention layers. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded + representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be + input (see `past_key_values`). This is useful if you want more control over how to convert + `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. + + If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value + of `inputs_embeds`. + + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. +""" + +PIX2STRUCT_INPUTS_DOCSTRING = r""" + Args: + flattened_patches (`torch.FloatTensor` of shape `(batch_size, seq_length, hidden_size)`): + Flattened pixel patches. the `hidden_size` is obtained by the following formula: `hidden_size` = + `num_channels` * `patch_size` * `patch_size` + + The process of flattening the pixel patches is done by `Pix2StructProcessor`. + + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + Pix2StructText uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If + `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + To know more on how to prepare `decoder_input_ids` for pretraining take a look at [Pix2StructText + Training](./t5#training). + decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in `[0, + 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0, + 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in + `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): + Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, `optional`: *attentions*) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a sequence of hidden states at + the output of the last layer of the encoder. Used in the cross-attention of the decoder. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention layers. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded + representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be + input (see `past_key_values`). This is useful if you want more control over how to convert + `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. + + If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value + of `inputs_embeds`. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss for the decoder. + + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The standalone text decoder of Pix2Struct", + PIX2STRUCT_START_DOCSTRING, +) +class Pix2StructTextModel(Pix2StructPreTrainedModel): + config_class = Pix2StructTextConfig + _no_split_modules = ["Pix2StructTextBlock"] + _tied_weights_keys = ["lm_head.weight"] + supports_gradient_checkpointing = True + + def __init__(self, config): + super().__init__(config) + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) + + self.layer = nn.ModuleList( + [ + Pix2StructTextBlock(config, has_relative_attention_bias=bool(i == 0), layer_idx=i) + for i in range(config.num_layers) + ] + ) + self.final_layer_norm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + self.gradient_checkpointing = False + + # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel._reorder_cache + def _reorder_cache(self, past_key_values, beam_idx): + # if decoder past is not included in output + # speedy decoding is disabled and no need to reorder + if past_key_values is None: + logger.warning("You might want to consider setting `use_cache=True` to speed up decoding") + return past_key_values + + reordered_decoder_past = () + for layer_past_states in past_key_values: + # get the correct batch idx from layer past batch dim + # batch dim of `past` is at 2nd position + reordered_layer_past_states = () + for layer_past_state in layer_past_states: + # need to set correct `past` for each of the four key / value states + reordered_layer_past_states = reordered_layer_past_states + ( + layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)), + ) + + if reordered_layer_past_states[0].shape != layer_past_states[0].shape: + raise ValueError( + f"reordered_layer_past_states[0] shape {reordered_layer_past_states[0].shape} and layer_past_states[0] shape {layer_past_states[0].shape} mismatched" + ) + if len(reordered_layer_past_states) != len(layer_past_states): + raise ValueError( + f"length of reordered_layer_past_states {len(reordered_layer_past_states)} and length of layer_past_states {len(layer_past_states)} mismatched" + ) + + reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,) + return reordered_decoder_past + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, new_embeddings): + self.embed_tokens = new_embeddings + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + @add_start_docstrings_to_model_forward(PIX2STRUCT_TEXT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + labels: Optional[torch.LongTensor] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Union[Tuple[torch.FloatTensor, ...], CausalLMOutputWithCrossAttentions]: + r""" + Returns: + + Example: + + ```python + >>> from transformers import AutoProcessor, Pix2StructTextModel + + >>> processor = AutoProcessor.from_pretrained("google/pix2struct-textcaps-base") + >>> model = Pix2StructTextModel.from_pretrained("google/pix2struct-textcaps-base") + + >>> inputs = processor(text="Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + >>> loss = outputs.loss + ``` + """ + use_cache = use_cache if use_cache is not None else self.config.use_cache + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + if inputs_embeds is None: + assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings" + inputs_embeds = self.embed_tokens(input_ids) + + batch_size, seq_length = input_shape + + # initialize past_key_values + return_legacy_cache = False + return_self_attention_cache = False + if use_cache or past_key_values is not None: + if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache): + return_self_attention_cache = True + past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) + elif not isinstance(past_key_values, EncoderDecoderCache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + elif past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + + past_key_values_length = 0 + if cache_position is not None: + past_key_values_length = cache_position[0] + elif past_key_values is not None: + past_key_values_length = past_key_values.get_seq_length() + + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device + ) + + if attention_mask is None: + # required mask seq length can be calculated via length of past + mask_seq_length = ( + past_key_values.get_seq_length() + seq_length if past_key_values is not None else seq_length + ) + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + + if self.config.is_decoder: + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + past_key_values.self_attention_cache if past_key_values is not None else None, + output_attentions, + ) + else: + causal_mask = attention_mask[:, None, None, :] + causal_mask = causal_mask.to(dtype=inputs_embeds.dtype) + causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + head_mask = self.get_head_mask(head_mask, self.config.num_layers) + cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_cross_attentions = () if (output_attentions) else None + position_bias = None + encoder_decoder_position_bias = None + + hidden_states = self.dropout(inputs_embeds) + + for i, layer_module in enumerate(self.layer): + layer_head_mask = head_mask[i] + cross_attn_layer_head_mask = cross_attn_head_mask[i] + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + layer_outputs = self._gradient_checkpointing_func( + layer_module.forward, + hidden_states, + causal_mask, + position_bias, + encoder_hidden_states, + encoder_extended_attention_mask, + encoder_decoder_position_bias, + layer_head_mask, + cross_attn_layer_head_mask, + None, # past_key_value is always None with gradient checkpointing + use_cache, + output_attentions, + cache_position, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask=causal_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + layer_head_mask=layer_head_mask, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, + past_key_value=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + ) + + # layer_outputs is a tuple with: + # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + if use_cache is False: + layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] + + hidden_states, next_decoder_cache = layer_outputs[:2] + + # We share the position biases between the layers - the first layer store them + # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), + # (cross-attention position bias), (cross-attention weights) + position_bias = layer_outputs[2] + if encoder_hidden_states is not None: + encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[3],) + if encoder_hidden_states is not None: + all_cross_attentions = all_cross_attentions + (layer_outputs[5],) + + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states) + + logits = self.lm_head(hidden_states) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + loss_fct = nn.CrossEntropyLoss(ignore_index=-100, reduction="mean") + + loss = loss_fct(logits.contiguous().view(-1, logits.size(-1)), labels.contiguous().view(-1)) + + next_cache = next_decoder_cache if use_cache else None + if return_self_attention_cache: + next_cache = past_key_values.self_attention_cache + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + + if not return_dict: + return tuple( + v + for v in [ + loss, + logits, + next_cache, + all_hidden_states, + all_attentions, + all_cross_attentions, + ] + if v is not None + ) + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=logits, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ) + + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, "BlockMask"], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to place the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +@add_start_docstrings( + "A conditional generation model with a language modeling head. Can be used for sequence generation tasks.", + PIX2STRUCT_START_DOCSTRING, +) +class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel, GenerationMixin): + config_class = Pix2StructConfig + main_input_name = "flattened_patches" + _tied_weights_keys = ["decoder.lm_head.weight"] + + def __init__(self, config: Pix2StructConfig): + super().__init__(config) + + self.encoder = Pix2StructVisionModel(config.vision_config) + self.decoder = Pix2StructTextModel(config.text_config) + + self.is_vqa = config.is_vqa + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.decoder.get_input_embeddings() + + def set_input_embeddings(self, new_embeddings): + self.decoder.set_input_embeddings(new_embeddings) + + def get_output_embeddings(self) -> nn.Module: + return self.decoder.get_output_embeddings() + + def set_output_embeddings(self, new_embeddings): + self.decoder.set_output_embeddings(new_embeddings) + + def get_decoder(self): + return self.decoder + + def get_encoder(self): + return self.encoder + + @add_start_docstrings_to_model_forward(PIX2STRUCT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + flattened_patches: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + labels: Optional[torch.LongTensor] = None, + decoder_inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]: + r""" + Returns: + + Example: + + Inference: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Pix2StructForConditionalGeneration + + >>> processor = AutoProcessor.from_pretrained("google/pix2struct-textcaps-base") + >>> model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base") + + >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> # autoregressive generation + >>> generated_ids = model.generate(**inputs, max_new_tokens=50) + >>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + >>> print(generated_text) + A stop sign is on a street corner. + + >>> # conditional generation + >>> text = "A picture of" + >>> inputs = processor(text=text, images=image, return_tensors="pt", add_special_tokens=False) + + >>> generated_ids = model.generate(**inputs, max_new_tokens=50) + >>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + >>> print(generated_text) + A picture of a stop sign with a red stop sign + ``` + + Training: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Pix2StructForConditionalGeneration + + >>> processor = AutoProcessor.from_pretrained("google/pix2struct-base") + >>> model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-base") + + >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> text = "A stop sign is on the street corner." + + >>> inputs = processor(images=image, return_tensors="pt") + >>> labels = processor(text=text, return_tensors="pt").input_ids + + >>> # forward pass + >>> outputs = model(**inputs, labels=labels) + >>> loss = outputs.loss + >>> print(f"{loss.item():.5f}") + 5.94282 + ```""" + use_cache = use_cache if use_cache is not None else self.config.text_config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + encoder_outputs = self.encoder( + flattened_patches=flattened_patches, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + hidden_states = encoder_outputs[0] + + if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: + # get decoder inputs from shifting lm labels to the right + decoder_input_ids = self._shift_right(labels) + decoder_attention_mask = ( + decoder_attention_mask + if decoder_attention_mask is not None + else decoder_input_ids.ne(self.config.pad_token_id).float() + ) + # Always attend to the first token + decoder_attention_mask[:, 0] = 1 + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + labels=labels, + return_dict=return_dict, + cache_position=cache_position, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqLMOutput( + loss=decoder_outputs.loss, + logits=decoder_outputs.logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +__all__ = [ + "Pix2StructPreTrainedModel", + "Pix2StructForConditionalGeneration", + "Pix2StructVisionModel", + "Pix2StructTextModel", +] diff --git a/docs/transformers/src/transformers/models/pix2struct/processing_pix2struct.py b/docs/transformers/src/transformers/models/pix2struct/processing_pix2struct.py new file mode 100644 index 0000000000000000000000000000000000000000..f9b1fcc44087d6452c86dcd7e910391d617fb3a4 --- /dev/null +++ b/docs/transformers/src/transformers/models/pix2struct/processing_pix2struct.py @@ -0,0 +1,157 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Processor class for Pix2Struct. +""" + +from typing import List, Optional, Union + +from ...feature_extraction_utils import BatchFeature +from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack +from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput +from ...utils import logging + + +class Pix2StructImagesKwargs(ImagesKwargs, total=False): + max_patches: Optional[int] + header_text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] + + +class Pix2StructProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs: Pix2StructImagesKwargs + _defaults = { + "text_kwargs": { + "add_special_tokens": True, + "padding": False, + "stride": 0, + "return_overflowing_tokens": False, + "return_special_tokens_mask": False, + "return_offsets_mapping": False, + "return_token_type_ids": False, + "return_length": False, + "verbose": True, + }, + "images_kwargs": { + "max_patches": 2048, + }, + } + + +logger = logging.get_logger(__name__) + + +class Pix2StructProcessor(ProcessorMixin): + r""" + Constructs a PIX2STRUCT processor which wraps a BERT tokenizer and PIX2STRUCT image processor into a single + processor. + + [`Pix2StructProcessor`] offers all the functionalities of [`Pix2StructImageProcessor`] and [`T5TokenizerFast`]. See + the docstring of [`~Pix2StructProcessor.__call__`] and [`~Pix2StructProcessor.decode`] for more information. + + Args: + image_processor (`Pix2StructImageProcessor`): + An instance of [`Pix2StructImageProcessor`]. The image processor is a required input. + tokenizer (Union[`T5TokenizerFast`, `T5Tokenizer`]): + An instance of ['T5TokenizerFast`] or ['T5Tokenizer`]. The tokenizer is a required input. + """ + + attributes = ["image_processor", "tokenizer"] + image_processor_class = "Pix2StructImageProcessor" + tokenizer_class = ("T5Tokenizer", "T5TokenizerFast") + + def __init__(self, image_processor, tokenizer): + tokenizer.return_token_type_ids = False + super().__init__(image_processor, tokenizer) + + def __call__( + self, + images=None, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + audio=None, + videos=None, + **kwargs: Unpack[Pix2StructProcessorKwargs], + ) -> Union[BatchEncoding, BatchFeature]: + """ + This method uses [`Pix2StructImageProcessor.preprocess`] method to prepare image(s) for the model, and + [`T5TokenizerFast.__call__`] to prepare text for the model. + + Please refer to the docstring of the above two methods for more information. + """ + if images is None and text is None: + raise ValueError("You have to specify either images or text.") + + output_kwargs = self._merge_kwargs( + Pix2StructProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + add_special_tokens = output_kwargs["text_kwargs"].pop("add_special_tokens", None) + # Get only text + if images is None and not self.image_processor.is_vqa: + output_kwargs["text_kwargs"]["add_special_tokens"] = ( + add_special_tokens if add_special_tokens is not None else True + ) + self.current_processor = self.tokenizer + text_encoding = self.tokenizer(text=text, **output_kwargs["text_kwargs"]) + return text_encoding + + if not self.image_processor.is_vqa: + # add pixel_values + encoding_image_processor = self.image_processor(images, **output_kwargs["images_kwargs"]) + else: + # add pixel_values and bbox + output_kwargs["images_kwargs"].setdefault("header_text", text) + encoding_image_processor = self.image_processor(images, **output_kwargs["images_kwargs"]) + + if text is not None and not self.image_processor.is_vqa: + output_kwargs["text_kwargs"]["add_special_tokens"] = ( + add_special_tokens if add_special_tokens is not None else False + ) + text_encoding = self.tokenizer(text=text, **output_kwargs["text_kwargs"]) + + if "attention_mask" in text_encoding: + text_encoding["decoder_attention_mask"] = text_encoding.pop("attention_mask") + if "input_ids" in text_encoding: + text_encoding["decoder_input_ids"] = text_encoding.pop("input_ids") + else: + text_encoding = None + + if text_encoding is not None: + encoding_image_processor.update(text_encoding) + + return encoding_image_processor + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to Pix2StructTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. + Please refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to Pix2StructTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) + + +__all__ = ["Pix2StructProcessor"] diff --git a/docs/transformers/src/transformers/models/pixtral/__init__.py b/docs/transformers/src/transformers/models/pixtral/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8afbbfb1f47f71d478a9ed475f36719d405b4124 --- /dev/null +++ b/docs/transformers/src/transformers/models/pixtral/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_pixtral import * + from .image_processing_pixtral import * + from .image_processing_pixtral_fast import * + from .modeling_pixtral import * + from .processing_pixtral import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/pixtral/configuration_pixtral.py b/docs/transformers/src/transformers/models/pixtral/configuration_pixtral.py new file mode 100644 index 0000000000000000000000000000000000000000..d4710e00e421662541e8f2c0417e2719de847ee9 --- /dev/null +++ b/docs/transformers/src/transformers/models/pixtral/configuration_pixtral.py @@ -0,0 +1,106 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. team. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Pixtral model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class PixtralVisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`PixtralVisionModel`]. It is used to instantiate an + Pixtral vision encoder according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to the vision encoder used by Pixtral-12B. + + e.g. [pixtral-hf/pixtral-9b](https://huggingface.co/pixtral-hf/pixtral-9b) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 1024): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 4096): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 24): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads in the Transformer encoder. + num_channels (`int`, *optional*, defaults to 3): + Number of input channels in the input images. + image_size (`int`, *optional*, defaults to 1024): + Max dimension of the input images. + patch_size (`int`, *optional*, defaults to 16): + Size of the image patches. + hidden_act (`str`, *optional*, defaults to `"gelu"`): + Activation function used in the hidden layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + Dropout probability for the attention layers. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + + Example: + + ```python + >>> from transformers import PixtralVisionModel, PixtralVisionConfig + + >>> # Initializing a Pixtral-12B style configuration + >>> config = PixtralVisionConfig() + + >>> # Initializing a model (with randomly initialized weights) from the configuration + >>> model = PixtralVisionModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "pixtral" + + def __init__( + self, + hidden_size=1024, + intermediate_size=4096, + num_hidden_layers=24, + num_attention_heads=16, + num_channels=3, + image_size=1024, + patch_size=16, + hidden_act="gelu", + attention_dropout=0.0, + rope_theta=10000.0, + initializer_range=0.02, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.image_size = image_size + self.attention_dropout = attention_dropout + self.hidden_act = hidden_act + self.rope_theta = rope_theta + self.head_dim = hidden_size // num_attention_heads + self.initializer_range = initializer_range + + +__all__ = ["PixtralVisionConfig"] diff --git a/docs/transformers/src/transformers/models/pixtral/convert_pixtral_weights_to_hf.py b/docs/transformers/src/transformers/models/pixtral/convert_pixtral_weights_to_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..373aa6cb6e45fcb60e668fc0490a7fed16d7bf6f --- /dev/null +++ b/docs/transformers/src/transformers/models/pixtral/convert_pixtral_weights_to_hf.py @@ -0,0 +1,245 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. team. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import json +import os + +import regex as re +import torch +from mistral_common.tokens.tokenizers.mistral import MistralTokenizer +from safetensors.torch import load_file as safe_load_file + +from transformers import ( + LlavaConfig, + LlavaForConditionalGeneration, + MistralConfig, + PixtralImageProcessor, + PixtralProcessor, + PixtralVisionConfig, +) + + +""" +# Here is how to get the original tokens! +model_name = "mistralai/Pixtral-12B-2409" +tok = MistralTokenizer.from_model(model_name) + +from mistral_common.protocol.instruct.request import ChatCompletionRequest, UserMessage, ImageChunk, TextChunk + +EXPECTED_TOKENS = tok.encode_chat_completion( + ChatCompletionRequest( + messages=[ + UserMessage( + content=[ + TextChunk(text="Describe the images"), + ] + [ImageChunk(image=img) for img in IMG_URLS] + ) + ], + model="pixtral", + ) +) +assert tokenizer.decode(inputs["input_ids"][0]) == EXPECTED_TOKENS +""" + +OLD_KEY_TO_NEW_KEY_MAPPING = { + # Layer Normalization Weights + r"vision_encoder.transformer.layers.(\d+).input_layernorm.weight": r"vision_tower.transformer.layers.\1.attention_norm.weight", + r"vision_encoder.transformer.layers.(\d+).ffn_norm.weight": r"vision_tower.transformer.layers.\1.ffn_norm.weight", + # Self Attention Projections + r"vision_encoder.transformer.layers.(\d+).attention.wq.weight": r"vision_tower.transformer.layers.\1.attention.q_proj.weight", + r"vision_encoder.transformer.layers.(\d+).attention.wk.weight": r"vision_tower.transformer.layers.\1.attention.k_proj.weight", + r"vision_encoder.transformer.layers.(\d+).attention.wv.weight": r"vision_tower.transformer.layers.\1.attention.v_proj.weight", + r"vision_encoder.transformer.layers.(\d+).attention.wo.weight": r"vision_tower.transformer.layers.\1.attention.o_proj.weight", + # MLP Projections + r"vision_encoder.transformer.layers.(\d+).feed_forward.w1.weight": r"vision_tower.transformer.layers.\1.feed_forward.gate_proj.weight", + r"vision_encoder.transformer.layers.(\d+).feed_forward.w2.weight": r"vision_tower.transformer.layers.\1.feed_forward.down_proj.weight", + r"vision_encoder.transformer.layers.(\d+).feed_forward.w3.weight": r"vision_tower.transformer.layers.\1.feed_forward.up_proj.weight", + # Additional mappings + r"vision_encoder": r"vision_tower", + r"vision_language_adapter.w_in": r"multi_modal_projector.linear_1", + r"vision_language_adapter.w_out": r"multi_modal_projector.linear_2", + r"layers.(\d+).attention.wq.weight": r"language_model.model.layers.\1.self_attn.q_proj.weight", + r"layers.(\d+).attention.wk.weight": r"language_model.model.layers.\1.self_attn.k_proj.weight", + r"layers.(\d+).attention.wv.weight": r"language_model.model.layers.\1.self_attn.v_proj.weight", + r"layers.(\d+).attention.wo.weight": r"language_model.model.layers.\1.self_attn.o_proj.weight", + r"layers.(\d+).feed_forward.w1.weight": r"language_model.model.layers.\1.mlp.gate_proj.weight", + r"layers.(\d+).feed_forward.w2.weight": r"language_model.model.layers.\1.mlp.down_proj.weight", + r"layers.(\d+).feed_forward.w3.weight": r"language_model.model.layers.\1.mlp.up_proj.weight", + r"layers.(\d+).ffn_norm.weight": r"language_model.model.layers.\1.post_attention_layernorm.weight", + r"layers.(\d+).attention_norm.weight": r"language_model.model.layers.\1.input_layernorm.weight", + r"tok_embeddings.weight": r"language_model.model.embed_tokens.weight", + r"output.weight": r"language_model.lm_head.weight", + r"norm.weight": r"language_model.model.norm.weight", +} + + +def convert_mistral_tokenizer(model_file): + from transformers import LlamaTokenizer + + mistral_tokenizer = MistralTokenizer.from_file(model_file) + vocab = mistral_tokenizer.instruct_tokenizer.tokenizer.vocab() + control_token_ids = mistral_tokenizer.instruct_tokenizer.tokenizer._control_tokens + all_special = [vocab[id] for id in control_token_ids] + hf_tokenizer = LlamaTokenizer(model_file) + # Do I need to exclude tokens that are already special? + hf_tokenizer.add_special_tokens({"additional_special_tokens": all_special}) + hf_tokenizer.model_input_names = ["input_ids", "attention_mask"] + return hf_tokenizer + + +def permute_for_rope(value, n_heads, config): + dim1 = value.shape[0] + dim2 = config.hidden_size + return value.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2) + + +def convert_dictionary(original_state_dict, vision_config, text_config): + new_dict = {} + + all_keys = "\n" + "\n".join(original_state_dict.keys()) + old_keys = all_keys + for old, new in OLD_KEY_TO_NEW_KEY_MAPPING.items(): + all_keys = re.sub(r"\n" + old, r"\n" + new, all_keys) + + OLD_TO_NEW = dict(zip(old_keys.split("\n"), all_keys.split("\n"))) + + for key, value in original_state_dict.items(): + new_key = OLD_TO_NEW[key] + if "vision_encoder" in key: + _config = vision_config + num_attention_heads = _config.num_attention_heads + else: + _config = text_config + if "q_proj" in new_key: + num_attention_heads = _config.num_attention_heads + if "k_proj" in new_key: + num_attention_heads = _config.num_key_value_heads + + if "q_proj" in new_key or "k_proj" in new_key: + value = permute_for_rope(value, num_attention_heads, _config) + + new_dict[new_key] = value + return new_dict + + +MISTRAL_CONFIG_MAPPING = { + "dim": "hidden_size", + "hidden_dim": "intermediate_size", + "n_kv_heads": "num_key_value_heads", + "n_heads": "num_attention_heads", + "n_layers": "num_hidden_layers", +} + + +def convert_mistral_model(input_dir, output_dir): + vision_config = {} + if os.path.isfile(f"{input_dir}/params.json"): + with open(f"{input_dir}/params.json") as f: + param_json = json.load(f) + vision_config = param_json.pop("vision_encoder") + for k, v in MISTRAL_CONFIG_MAPPING.items(): + value = param_json.pop(k) + param_json[v] = value + if "hidden_act" not in vision_config: + vision_config["hidden_act"] = "silu" + text_config = MistralConfig( + **param_json, + hidden_act="silu", + sliding_window=None, + tie_word_embeddings=False, + rms_norm_eps=1e-5, + ) + else: + text_config = MistralConfig( + attention_dropout=0.0, + bos_token_id=1, + eos_token_id=2, + head_dim=128, + hidden_act="silu", + hidden_size=5120, + initializer_range=0.02, + intermediate_size=14336, + max_position_embeddings=1024000, + model_type="mistral", + num_attention_heads=32, + num_hidden_layers=40, + num_key_value_heads=8, + rms_norm_eps=1e-05, + rope_theta=1000000000.0, + sliding_window=None, + tie_word_embeddings=False, + vocab_size=131072, + ) + adapter_bias = vision_config.pop("adapter_bias", True) + vision_config = PixtralVisionConfig(**vision_config) + config = LlavaConfig( + vision_config, + text_config, + vision_feature_layer=-1, + image_token_id=10, + vision_feature_select_strategy="full", + image_seq_length=1, + multimodal_projector_bias=adapter_bias, + ) + config.architectures = ["LlavaForConditionalGeneration"] + config.save_pretrained(output_dir) + full_original_state_dict = {} + safetensors_files = sorted([file for file in os.listdir(input_dir) if file.endswith(".safetensors")]) + if len(safetensors_files) == 1: + full_original_state_dict = safe_load_file(f"{input_dir}/consolidated.safetensors") + else: + for file in safetensors_files: + loaded_dict = safe_load_file(f"{input_dir}/{file}") + full_original_state_dict.update(loaded_dict) + + new_dict = convert_dictionary(full_original_state_dict, vision_config, text_config) + with torch.device("meta"): + model = LlavaForConditionalGeneration(config) + model.load_state_dict(new_dict, strict=True, assign=True) + model.save_pretrained(output_dir) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--input_dir", + help="Location of LLaMA weights, which contains tokenizer.model and model folders", + required=True, + ) + parser.add_argument( + "--output_dir", + help="Location to write HF model and tokenizer", + required=True, + ) + parser.add_argument( + "--tokenizer_file", help="Location of the specific tokenizer model file to use.", required=True + ) + parser.add_argument( + "--chat_template_file", + help="Optional file containing a raw chat template. Will be set as the processor's chat template.", + required=False, + ) + + args = parser.parse_args() + convert_mistral_model(args.input_dir, args.output_dir) + tokenizer = convert_mistral_tokenizer(args.tokenizer_file) + image_processor = PixtralImageProcessor() + processor = PixtralProcessor(tokenizer=tokenizer, image_processor=image_processor, image_token="[IMG]") + if args.chat_template_file: + processor.chat_template = open(args.chat_template_file).read() + processor.save_pretrained(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/docs/transformers/src/transformers/models/pixtral/image_processing_pixtral.py b/docs/transformers/src/transformers/models/pixtral/image_processing_pixtral.py new file mode 100644 index 0000000000000000000000000000000000000000..074a8d1076a5112796be5eece74c824f6c95811c --- /dev/null +++ b/docs/transformers/src/transformers/models/pixtral/image_processing_pixtral.py @@ -0,0 +1,472 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for Pixtral.""" + +import math +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import ( + pad, + resize, + to_channel_dimension_format, +) +from ...image_utils import ( + ChannelDimension, + ImageInput, + PILImageResampling, + get_image_size, + infer_channel_dimension_format, + is_scaled_image, + make_list_of_images, + to_numpy_array, + valid_images, + validate_kwargs, + validate_preprocess_arguments, +) +from ...utils import TensorType, is_vision_available, logging +from ...utils.import_utils import requires_backends + + +logger = logging.get_logger(__name__) + + +if is_vision_available(): + import PIL + + +# Adapted from function in image_transforms.py to ensure any transparent pixels are converted to white. +def convert_to_rgb(image: ImageInput) -> ImageInput: + """ + Converts an image to RGB format. Only converts if the image is of type PIL.Image.Image, otherwise returns the image + as is. + Args: + image (Image): + The image to convert. + """ + requires_backends(convert_to_rgb, ["vision"]) + + if not isinstance(image, PIL.Image.Image): + return image + + if image.mode == "RGB": + return image + + # First we convert to RGBA to set background to white. + image = image.convert("RGBA") + + # Create a new image with a white background. + new_image = PIL.Image.new("RGBA", image.size, "WHITE") + new_image.paste(image, (0, 0), image) + new_image = new_image.convert("RGB") + return new_image + + +def _num_image_tokens(image_size: Tuple[int, int], patch_size: Tuple[int, int]) -> int: + """ + Calculate the number of image tokens given the image size and patch size. + + Args: + image_size (`Tuple[int, int]`): + The size of the image as `(height, width)`. + patch_size (`Tuple[int, int]`): + The patch size as `(height, width)`. + + Returns: + `int`: The number of image tokens. + """ + height, width = image_size + patch_height, patch_width = patch_size if isinstance(patch_size, (tuple, list)) else (patch_size, patch_size) + num_width_tokens = (width - 1) // patch_width + 1 + num_height_tokens = (height - 1) // patch_height + 1 + return num_height_tokens, num_width_tokens + + +def get_resize_output_image_size( + input_image: ImageInput, + size: Union[int, Tuple[int, int], List[int], Tuple[int]], + patch_size: Union[int, Tuple[int, int], List[int], Tuple[int]], + input_data_format: Optional[Union[str, ChannelDimension]] = None, +) -> tuple: + """ + Find the target (height, width) dimension of the output image after resizing given the input image and the desired + size. + + Args: + input_image (`ImageInput`): + The image to resize. + size (`int` or `Tuple[int, int]`): + Max image size an input image can be. Must be a dictionary with the key "longest_edge". + patch_size (`int` or `Tuple[int, int]`): + The patch_size as `(height, width)` to use for resizing the image. If patch_size is an integer, `(patch_size, patch_size)` + will be used + input_data_format (`ChannelDimension`, *optional*): + The channel dimension format of the input image. If unset, will use the inferred format from the input. + + Returns: + `tuple`: The target (height, width) dimension of the output image after resizing. + """ + max_height, max_width = size if isinstance(size, (tuple, list)) else (size, size) + patch_height, patch_width = patch_size if isinstance(patch_size, (tuple, list)) else (patch_size, patch_size) + height, width = get_image_size(input_image, input_data_format) + + ratio = max(height / max_height, width / max_width) + + if ratio > 1: + # Orgiginal implementation uses `round` which utilises bankers rounding, which can lead to surprising results + # Here we use floor to ensure the image is always smaller than the given "longest_edge" + height = int(math.floor(height / ratio)) + width = int(math.floor(width / ratio)) + + num_height_tokens, num_width_tokens = _num_image_tokens((height, width), (patch_height, patch_width)) + return num_height_tokens * patch_height, num_width_tokens * patch_width + + +class PixtralImageProcessor(BaseImageProcessor): + r""" + Constructs a Pixtral image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by + `do_resize` in the `preprocess` method. + size (`Dict[str, int]` *optional*, defaults to `{"longest_edge": 1024}`): + Size of the maximum dimension of either the height or width dimension of the image. Used to control how + images are resized. If either the height or width are greater than `size["longest_edge"]` then both the height and width are rescaled by `height / ratio`, `width /ratio` where `ratio = max(height / longest_edge, width / longest_edge)` + patch_size (`Dict[str, int]` *optional*, defaults to `{"height": 16, "width": 16}`): + Size of the patches in the model, used to calculate the output image size. Can be overridden by `patch_size` in the `preprocess` method. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`): + Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in + the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess` + method. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method. + image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + Can be overridden by the `image_std` parameter in the `preprocess` method. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize: bool = True, + size: Dict[str, int] = None, + patch_size: Dict[str, int] = None, + resample: PILImageResampling = PILImageResampling.BICUBIC, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_convert_rgb: bool = True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"longest_edge": 1024} + patch_size = patch_size if patch_size is not None else {"height": 16, "width": 16} + patch_size = get_size_dict(patch_size, default_to_square=True) + + self.do_resize = do_resize + self.size = size + self.patch_size = patch_size + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else [0.48145466, 0.4578275, 0.40821073] + self.image_std = image_std if image_std is not None else [0.26862954, 0.26130258, 0.27577711] + self.do_convert_rgb = do_convert_rgb + self._valid_processor_keys = [ + "images", + "do_resize", + "size", + "patch_size", + "resample", + "do_rescale", + "rescale_factor", + "do_normalize", + "image_mean", + "image_std", + "do_convert_rgb", + "return_tensors", + "data_format", + "input_data_format", + ] + + def resize( + self, + image: np.ndarray, + size: Dict[str, int], + patch_size: Dict[str, int], + resample: PILImageResampling = PILImageResampling.BICUBIC, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge + resized to keep the input aspect ratio. + + Args: + image (`np.ndarray`): + Image to resize. + size (`Dict[str, int]`): + Dict containing the longest possible edge of the image. + patch_size (`Dict[str, int]`): + Patch size used to calculate the size of the output image. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + Resampling filter to use when resiizing the image. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be the same as the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + if "longest_edge" in size: + size = (size["longest_edge"], size["longest_edge"]) + elif "height" in size and "width" in size: + size = (size["height"], size["width"]) + else: + raise ValueError("size must contain either 'longest_edge' or 'height' and 'width'.") + + if "height" in patch_size and "width" in patch_size: + patch_size = (patch_size["height"], patch_size["width"]) + else: + raise ValueError("patch_size must contain either 'shortest_edge' or 'height' and 'width'.") + + output_size = get_resize_output_image_size( + image, + size=size, + patch_size=patch_size, + input_data_format=input_data_format, + ) + return resize( + image, + size=output_size, + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + def _pad_for_batching( + self, + pixel_values: List[np.ndarray], + image_sizes: List[List[int]], + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + """ + Pads images on the `num_of_patches` dimension with zeros to form a batch of same number of patches. + Args: + pixel_values (`List[np.ndarray]`): + An array of pixel values of each images of shape (`batch_size`, `height`, `width`, `channels`) + image_sizes (`List[List[int]]`): + A list of sizes for each image in `pixel_values` in (height, width) format. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + If unset, will use same as the input image. + input_data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + If unset, will use the inferred format of the input image. + Returns: + List[`np.ndarray`]: The padded images. + """ + + max_shape = ( + max([size[0] for size in image_sizes]), + max([size[1] for size in image_sizes]), + ) + pixel_values = [ + pad( + image, + padding=((0, max_shape[0] - size[0]), (0, max_shape[1] - size[1])), + data_format=data_format, + input_data_format=input_data_format, + ) + for image, size in zip(pixel_values, image_sizes) + ] + return pixel_values + + def preprocess( + self, + images: ImageInput, + do_resize: Optional[bool] = None, + size: Dict[str, int] = None, + patch_size: Dict[str, int] = None, + resample: PILImageResampling = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_convert_rgb: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> PIL.Image.Image: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Describes the maximum input dimensions to the model. + patch_size (`Dict[str, int]`, *optional*, defaults to `self.patch_size`): + Patch size in the model. Used to calculate the image after resizing. + resample (`int`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only + has an effect if `do_resize` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to + `True`. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + patch_size = patch_size if patch_size is not None else self.patch_size + patch_size = get_size_dict(patch_size, default_to_square=True) + + do_resize = do_resize if do_resize is not None else self.do_resize + size = size if size is not None else self.size + resample = resample if resample is not None else self.resample + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + + validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys) + + images = make_list_of_images(images) + + if not valid_images(images[0]): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_resize=do_resize, + size=size, + resample=resample, + ) + + if do_convert_rgb: + images = [convert_to_rgb(image) for image in images] + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if do_rescale and is_scaled_image(images[0]): + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + batch_images = [] + batch_image_sizes = [] + for image in images: + if do_resize: + image = self.resize( + image=image, + size=size, + patch_size=patch_size, + resample=resample, + input_data_format=input_data_format, + ) + + if do_rescale: + image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + + if do_normalize: + image = self.normalize( + image=image, mean=image_mean, std=image_std, input_data_format=input_data_format + ) + + image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + + batch_images.append(image) + batch_image_sizes.append(get_image_size(image, data_format)) + + pixel_values = self._pad_for_batching( + pixel_values=batch_images, + image_sizes=batch_image_sizes, + input_data_format=data_format, + data_format=data_format, + ) + + return BatchFeature( + data={"pixel_values": pixel_values, "image_sizes": batch_image_sizes}, tensor_type=return_tensors + ) + + +__all__ = ["PixtralImageProcessor"] diff --git a/docs/transformers/src/transformers/models/pixtral/image_processing_pixtral_fast.py b/docs/transformers/src/transformers/models/pixtral/image_processing_pixtral_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..0cb4673038536180b062fa2cfb252c32a959c785 --- /dev/null +++ b/docs/transformers/src/transformers/models/pixtral/image_processing_pixtral_fast.py @@ -0,0 +1,217 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for Pixtral.""" + +from typing import Dict, List, Optional, Union + +from ...image_processing_utils import BatchFeature, get_size_dict +from ...image_processing_utils_fast import ( + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS, + BaseImageProcessorFast, + DefaultFastImageProcessorKwargs, + group_images_by_shape, + reorder_images, +) +from ...image_utils import ( + ImageInput, + PILImageResampling, + SizeDict, +) +from ...processing_utils import Unpack +from ...utils import ( + TensorType, + add_start_docstrings, + is_torch_available, + is_torchvision_available, + is_torchvision_v2_available, + is_vision_available, + logging, +) +from .image_processing_pixtral import ( + get_resize_output_image_size, +) + + +logger = logging.get_logger(__name__) + +if is_torch_available(): + import torch + +if is_torchvision_available(): + if is_vision_available(): + pass + + if is_torchvision_v2_available(): + from torchvision.transforms.v2 import functional as F + else: + from torchvision.transforms import functional as F + + +class PixtralFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): + patch_size: Optional[Dict[str, int]] + + +@add_start_docstrings( + r"Constructs a fast ConvNeXT image processor.", + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, + """ + patch_size (`Dict[str, int]` *optional*, defaults to `{"height": 16, "width": 16}`): + Size of the patches in the model, used to calculate the output image size. Can be overridden by `patch_size` in the `preprocess` method. + """, +) +class PixtralImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.BICUBIC + image_mean = [0.48145466, 0.4578275, 0.40821073] + image_std = [0.26862954, 0.26130258, 0.27577711] + patch_size = {"height": 16, "width": 16} + size = {"longest_edge": 1024} + default_to_square = True + do_resize = True + do_rescale = True + do_normalize = True + do_convert_rgb = True + valid_kwargs = PixtralFastImageProcessorKwargs + + def __init__(self, **kwargs: Unpack[PixtralFastImageProcessorKwargs]): + super().__init__(**kwargs) + + @add_start_docstrings( + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS, + """ + patch_size (`Dict[str, int]` *optional*, defaults to `{"height": 16, "width": 16}`): + Size of the patches in the model, used to calculate the output image size. Can be overridden by `patch_size` in the `preprocess` method. + """, + ) + def preprocess(self, images: ImageInput, **kwargs: Unpack[PixtralFastImageProcessorKwargs]) -> BatchFeature: + return super().preprocess(images, **kwargs) + + def resize( + self, + image: torch.Tensor, + size: SizeDict, + patch_size: SizeDict, + interpolation: "F.InterpolationMode" = None, + **kwargs, + ) -> torch.Tensor: + """ + Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge + resized to keep the input aspect ratio. + + Args: + image (`torch.Tensor`): + Image to resize. + size (`SizeDict`): + Dict containing the longest possible edge of the image. + patch_size (`SizeDict`): + Patch size used to calculate the size of the output image. + interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`): + Resampling filter to use when resiizing the image. + """ + interpolation = interpolation if interpolation is not None else F.InterpolationMode.BILINEAR + if size.longest_edge: + size = (size.longest_edge, size.longest_edge) + elif size.height and size.width: + size = (size.height, size.width) + else: + raise ValueError("size must contain either 'longest_edge' or 'height' and 'width'.") + + if patch_size.height and patch_size.width: + patch_size = (patch_size.height, patch_size.width) + else: + raise ValueError("patch_size must contain either 'shortest_edge' or 'height' and 'width'.") + + output_size = get_resize_output_image_size(image, size=size, patch_size=patch_size) + return F.resize(image, size=output_size, interpolation=interpolation, **kwargs) + + # Adapted from transformers.models.pixtral.image_processing_pixtral.PixtralImageProcessor._pad_for_batching + def _pad_for_batching( + self, + pixel_values: List[torch.Tensor], + image_sizes: List[List[int]], + ): + """ + Pads images on the `num_of_patches` dimension with zeros to form a batch of same number of patches. + Args: + pixel_values (`List[torch.Tensor]`): + An array of pixel values of each images of shape (`batch_size`, `channels`, `height`, `width`) + image_sizes (`List[List[int]]`): + A list of sizes for each image in `pixel_values` in (height, width) format. + Returns: + List[`torch.Tensor`]: The padded images. + """ + + max_shape = (max([size[0] for size in image_sizes]), max([size[1] for size in image_sizes])) + pixel_values = [ + torch.nn.functional.pad(image, pad=(0, max_shape[1] - size[1], 0, max_shape[0] - size[0])) + for image, size in zip(pixel_values, image_sizes) + ] + return torch.stack(pixel_values) + + def _preprocess( + self, + images: List["torch.Tensor"], + do_resize: bool, + size: SizeDict, + patch_size: Dict[str, int], + interpolation: Optional["F.InterpolationMode"], + do_center_crop: bool, + crop_size: Dict[str, int], + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + image_mean: Optional[Union[float, List[float]]], + image_std: Optional[Union[float, List[float]]], + return_tensors: Optional[Union[str, TensorType]], + ) -> BatchFeature: + patch_size = get_size_dict(patch_size, default_to_square=True) + patch_size = SizeDict(**patch_size) + # Group images by size for batched resizing + grouped_images, grouped_images_index = group_images_by_shape(images) + resized_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_resize: + stacked_images = self.resize( + image=stacked_images, size=size, patch_size=patch_size, interpolation=interpolation + ) + resized_images_grouped[shape] = stacked_images + resized_images = reorder_images(resized_images_grouped, grouped_images_index) + + # Group images by size for further processing + # Needed in case do_resize is False, or resize returns images with different sizes + grouped_images, grouped_images_index = group_images_by_shape(resized_images) + batch_image_sizes = [grouped_images_index[i][0] for i in range(len(grouped_images_index))] + processed_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_center_crop: + stacked_images = self.center_crop(stacked_images, crop_size) + # Fused rescale and normalize + stacked_images = self.rescale_and_normalize( + stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std + ) + processed_images_grouped[shape] = stacked_images + + processed_images = reorder_images(processed_images_grouped, grouped_images_index) + padded_images = self._pad_for_batching( + pixel_values=processed_images, + image_sizes=batch_image_sizes, + ) + + return BatchFeature( + data={"pixel_values": padded_images, "image_sizes": batch_image_sizes}, tensor_type=return_tensors + ) + + +__all__ = ["PixtralImageProcessorFast"] diff --git a/docs/transformers/src/transformers/models/pixtral/modeling_pixtral.py b/docs/transformers/src/transformers/models/pixtral/modeling_pixtral.py new file mode 100644 index 0000000000000000000000000000000000000000..bc0ec918e997adaf3252c9b6fdc70eff1dff7ef6 --- /dev/null +++ b/docs/transformers/src/transformers/models/pixtral/modeling_pixtral.py @@ -0,0 +1,505 @@ +# coding=utf-8 +# Copyright 2024 Mistral and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Pixtral model.""" + +from typing import Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn + +from ... import PreTrainedModel +from ...activations import ACT2FN +from ...modeling_outputs import BaseModelOutput +from ...modeling_rope_utils import dynamic_rope_update +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, +) +from .configuration_pixtral import PixtralVisionConfig + + +logger = logging.get_logger(__name__) + + +def position_ids_in_meshgrid(patch_embeds_list, max_width): + positions = [] + for patch in patch_embeds_list: + height, width = patch.shape[-2:] + mesh = torch.meshgrid(torch.arange(height), torch.arange(width), indexing="ij") + h_grid, v_grid = torch.stack(mesh, dim=-1).reshape(-1, 2).chunk(2, -1) + ids = h_grid * max_width + v_grid + positions.append(ids[:, 0]) + return torch.cat(positions) + + +class PixtralRotaryEmbedding(nn.Module): + """ + The key with pixtral embedding is just that you have a frequency for each pixel positions. + If you have height x width pixels (or embedding pixels), then the frequency used for ROPE + is given by indexing the pre_computed frequency on the width and height. + + What you output is of dimension (batch, height * width, dim) with dim the embed dim. + + This simply means that for each image hidden state, you are going to add + a corresponding positional embedding, based on its index in the grid. + """ + + def __init__(self, config, device=None): + super().__init__() + self.rope_type = "default" + self.dim = config.head_dim + self.base = config.rope_theta + max_patches_per_side = config.image_size // config.patch_size + freqs = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)) + + h = torch.arange(max_patches_per_side, device=freqs.device) + w = torch.arange(max_patches_per_side, device=freqs.device) + + freqs_h = torch.outer(h, freqs[::2]).float() + freqs_w = torch.outer(w, freqs[1::2]).float() + inv_freq = torch.cat( + [ + freqs_h[:, None, :].repeat(1, max_patches_per_side, 1), + freqs_w[None, :, :].repeat(max_patches_per_side, 1, 1), + ], + dim=-1, + ).reshape(-1, self.dim // 2) # we reshape to only index on the position indexes, not tuple of indexes + # Different from paper, but it uses a different permutation in order to obtain the same calculation + + # TODO maybe make it torch compatible later on. We can also just slice + self.register_buffer("inv_freq", torch.cat((inv_freq, inv_freq), dim=-1), persistent=False) + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + freqs = self.inv_freq[position_ids] + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + emb = freqs + cos = emb.cos() + sin = emb.sin() + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class PixtralAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) + self.o_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Input shape: Batch x Time x Channel""" + + batch_size, patches, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, unsqueeze_dim=0) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale + + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(batch_size, patches, -1) + + attn_output = self.o_proj(attn_output) + + return attn_output, attn_weights + + +# Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Pixtral +class PixtralMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Pixtral +class PixtralRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + PixtralRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class PixtralAttentionLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.attention_norm = PixtralRMSNorm(config.hidden_size, eps=1e-5) + self.feed_forward = PixtralMLP(config) + self.attention = PixtralAttention(config) + self.ffn_norm = PixtralRMSNorm(config.hidden_size, eps=1e-5) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + output_attentions: Optional[bool] = None, + ) -> Tuple[torch.FloatTensor]: + """ + Args: + hidden_states (`torch.FloatTensor`): + Input to the layer of shape `(batch, seq_len, embed_dim)`. + attention_mask (`torch.FloatTensor`): + Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + hidden_states = self.attention_norm(hidden_states) + hidden_states, attn_weights = self.attention( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_embeddings=position_embeddings, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.ffn_norm(hidden_states) + hidden_states = self.feed_forward(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + return outputs + + +class PixtralTransformer(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layers = torch.nn.ModuleList() + for _ in range(config.num_hidden_layers): + self.layers.append(PixtralAttentionLayer(config)) + self.gradient_checkpointing = False + + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Embeddings which serve as input to the Transformer. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for encoder_layer in self.layers: + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + attention_mask, + position_embeddings, + output_attentions, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + position_embeddings=position_embeddings, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +PIXTRAL_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`PixtralVisionConfig`]): + Model configuration class with all the parameters of the vision encoder. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +class PixtralPreTrainedModel(PreTrainedModel): + config_class = PixtralVisionConfig + base_model_prefix = "model" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + _no_split_modules = ["PixtralAttentionLayer"] + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, PixtralRMSNorm): + module.weight.data.fill_(1.0) + + +PIXTRAL_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`AutoImageProcessor.__call__`] + for details. + image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`, *optional*): + The sizes of the images in the batch, being (height, width) for each image. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +def generate_block_attention_mask(patch_embeds_list, tensor): + dtype = tensor.dtype + device = tensor.device + seq_len = tensor.shape[1] + d_min = torch.finfo(dtype).min + causal_mask = torch.full((seq_len, seq_len), fill_value=d_min, dtype=dtype, device=device) + + block_end_idx = torch.tensor(patch_embeds_list).cumsum(-1) + block_start_idx = torch.tensor([0] + patch_embeds_list[:-1]).cumsum(-1) + for start, end in zip(block_start_idx, block_end_idx): + causal_mask[start:end, start:end] = 0 + + causal_mask = causal_mask[None, None, :, :].expand(tensor.shape[0], 1, -1, -1) + return causal_mask + + +@add_start_docstrings( + "The bare Pixtral vision encoder outputting raw hidden-states without any specific head on top.", + PIXTRAL_START_DOCSTRING, +) +class PixtralVisionModel(PixtralPreTrainedModel): + base_model_prefix = "vision_encoder" + + def __init__(self, config): + super().__init__(config) + self.config = config + self.patch_conv = nn.Conv2d( + in_channels=config.num_channels, + out_channels=config.hidden_size, + kernel_size=config.patch_size, + stride=config.patch_size, + bias=False, + ) + self.patch_size = config.patch_size + self.ln_pre = PixtralRMSNorm(config.hidden_size, eps=1e-5) + self.transformer = PixtralTransformer(config) + self.patch_positional_embedding = PixtralRotaryEmbedding(config) + + self.post_init() + + def get_input_embeddings(self): + return self.patch_conv + + @add_start_docstrings_to_model_forward(PIXTRAL_INPUTS_DOCSTRING) + def forward( + self, + pixel_values: torch.Tensor, + image_sizes: torch.Tensor, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + *args, + **kwargs, + ) -> Union[Tuple, BaseModelOutput]: + """ + Returns: + pixel_values: tensor of token features for + all tokens of all images of shape (N_toks, D) + """ + # pass images through initial convolution independently + patch_embeds = self.patch_conv(pixel_values) + patch_embeds_list = [ + embed[..., : (size[0] // self.patch_size), : (size[1] // self.patch_size)] + for embed, size in zip(patch_embeds, image_sizes) + ] + + # flatten to a single sequence + patch_embeds = torch.cat([p.flatten(1).T for p in patch_embeds_list], dim=0).unsqueeze(0) + patch_embeds = self.ln_pre(patch_embeds) + + # positional embeddings + position_ids = position_ids_in_meshgrid( + patch_embeds_list, max_width=self.config.image_size // self.config.patch_size + ) + position_embeddings = self.patch_positional_embedding(patch_embeds, position_ids) + + attention_mask = generate_block_attention_mask( + [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], patch_embeds + ) + + out = self.transformer( + patch_embeds, + attention_mask=attention_mask, + position_embeddings=position_embeddings, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + return_dict=return_dict, + ) + + return out + + +__all__ = ["PixtralVisionModel", "PixtralPreTrainedModel"] diff --git a/docs/transformers/src/transformers/models/pixtral/processing_pixtral.py b/docs/transformers/src/transformers/models/pixtral/processing_pixtral.py new file mode 100644 index 0000000000000000000000000000000000000000..4531c56b5aaf5d7ac28a2049e00219adf6971bcb --- /dev/null +++ b/docs/transformers/src/transformers/models/pixtral/processing_pixtral.py @@ -0,0 +1,244 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Processor class for Pixtral. +""" + +from typing import List, Union + +from ...feature_extraction_utils import BatchFeature +from ...image_utils import ImageInput, is_valid_image, load_image +from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order +from ...tokenization_utils_base import PreTokenizedInput, TextInput +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class PixtralProcessorKwargs(ProcessingKwargs, total=False): + _defaults = { + "text_kwargs": { + "padding": False, + }, + "images_kwargs": {}, + "common_kwargs": { + "return_tensors": "pt", + }, + } + + +# Copied from transformers.models.idefics2.processing_idefics2.is_url +def is_url(val) -> bool: + return isinstance(val, str) and val.startswith("http") + + +# Copied from transformers.models.idefics2.processing_idefics2.is_image_or_image_url +def is_image_or_image_url(elem): + return is_url(elem) or is_valid_image(elem) + + +class PixtralProcessor(ProcessorMixin): + r""" + Constructs a Pixtral processor which wraps a Pixtral image processor and a Pixtral tokenizer into a single processor. + + [`PixtralProcessor`] offers all the functionalities of [`CLIPImageProcessor`] and [`LlamaTokenizerFast`]. See the + [`~PixtralProcessor.__call__`] and [`~PixtralProcessor.decode`] for more information. + + Args: + image_processor ([`PixtralImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`LlamaTokenizerFast`], *optional*): + The tokenizer is a required input. + patch_size (`int`, *optional*, defaults to 16): + Patch size from the vision tower. + spatial_merge_size (`int`, *optional*, defaults to 1): + The downsampling factor for the spatial merge operation. + chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages + in a chat into a tokenizable string. + image_token (`str`, *optional*, defaults to `"[IMG]"`): + Special token used to denote image location. + image_break_token (`str`, *optional*, defaults to `"[IMG_BREAK]"`): + Special token used to denote the end of a line of pixels in an image. + image_end_token (`str`, *optional*, defaults to `"[IMG_END]"`): + Special token used to denote the end of an image input. + """ + + attributes = ["image_processor", "tokenizer"] + valid_kwargs = [ + "chat_template", + "patch_size", + "spatial_merge_size", + "image_token", + "image_break_token", + "image_end_token", + ] + image_processor_class = "AutoImageProcessor" + tokenizer_class = "AutoTokenizer" + + def __init__( + self, + image_processor=None, + tokenizer=None, + patch_size: int = 16, + spatial_merge_size: int = 1, + chat_template=None, + image_token="[IMG]", # set the default and let users change if they have peculiar special tokens in rare cases + image_break_token="[IMG_BREAK]", + image_end_token="[IMG_END]", + **kwargs, + ): + self.patch_size = patch_size + self.spatial_merge_size = spatial_merge_size + self.image_token = image_token + self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token) + self.image_break_token = image_break_token + self.image_end_token = image_end_token + super().__init__(image_processor, tokenizer, chat_template=chat_template) + + def __call__( + self, + images: ImageInput = None, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + audio=None, + videos=None, + **kwargs: Unpack[PixtralProcessorKwargs], + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` + and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode + the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to + CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the docstring + of the above two methods for more information. + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + """ + # check if images and text inputs are reversed for BC + images, text = _validate_images_text_input_order(images, text) + + output_kwargs = self._merge_kwargs( + PixtralProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + + patch_size = self.patch_size * self.spatial_merge_size + + if images is not None: + if is_image_or_image_url(images): + images = [images] + elif isinstance(images, (list, tuple)) and is_image_or_image_url(images[0]): + pass + elif ( + isinstance(images, (list, tuple)) + and isinstance(images[0], (list, tuple)) + and is_image_or_image_url(images[0][0]) + ): + images = [image for sublist in images for image in sublist] + else: + raise ValueError( + "Invalid input images. Please provide a single image, a list of images, or a list of lists of images." + ) + images = [load_image(im) if isinstance(im, str) else im for im in images] + image_inputs = self.image_processor(images, patch_size=patch_size, **output_kwargs["images_kwargs"]) + else: + image_inputs = {} + + if isinstance(text, str): + text = [text] + elif not isinstance(text, list) and not isinstance(text[0], str): + raise ValueError("Invalid input text. Please provide a string, or a list of strings") + + # try to expand inputs in processing if we have the necessary parts + prompt_strings = text + if image_inputs.get("pixel_values") is not None: + # Replace the image token with the expanded image token sequence + image_sizes = iter(image_inputs["image_sizes"]) + prompt_strings = [] + replace_strings = [] + + for sample in text: + while self.image_token in sample: + height, width = next(image_sizes) + num_height_tokens = height // patch_size + num_width_tokens = width // patch_size + replace_tokens = [ + [self.image_token] * num_width_tokens + [self.image_break_token] + ] * num_height_tokens + # Flatten list + replace_tokens = [item for sublist in replace_tokens for item in sublist] + replace_tokens[-1] = self.image_end_token + replace_str = "".join(replace_tokens) + replace_strings.append(replace_str) + sample = sample.replace(self.image_token, "", 1) + + while "" in sample: + replace_str = replace_strings.pop(0) + sample = sample.replace("", replace_str, 1) + prompt_strings.append(sample) + + return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) + text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"]) + self._check_special_mm_tokens(prompt_strings, text_inputs, modalities=["image"]) + return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors) + + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + @property + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) + + +__all__ = ["PixtralProcessor"] diff --git a/docs/transformers/src/transformers/models/plbart/__init__.py b/docs/transformers/src/transformers/models/plbart/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..33e5618ca7e1727c0ce28458d952ce0bfbbfee59 --- /dev/null +++ b/docs/transformers/src/transformers/models/plbart/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_plbart import * + from .modeling_plbart import * + from .tokenization_plbart import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/plbart/configuration_plbart.py b/docs/transformers/src/transformers/models/plbart/configuration_plbart.py new file mode 100644 index 0000000000000000000000000000000000000000..30871c4b7259a889124f859842b51e6e067bcdfd --- /dev/null +++ b/docs/transformers/src/transformers/models/plbart/configuration_plbart.py @@ -0,0 +1,193 @@ +# coding=utf-8 +# Copyright 2022, UCLA NLP, The Facebook AI Research Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PLBART model configuration""" + +from collections import OrderedDict +from typing import Mapping + +from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfigWithPast +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class PLBartConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`PLBartModel`]. It is used to instantiate an + PLBART model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the PLBART + [uclanlp/plbart-base](https://huggingface.co/uclanlp/plbart-base) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50005): + Vocabulary size of the PLBART model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`PLBartModel`]. + d_model (`int`, *optional*, defaults to 768): + Dimensionality of the layers and the pooler layer. + encoder_layers (`int`, *optional*, defaults to 6): + Number of encoder layers. + decoder_layers (`int`, *optional*, defaults to 6): + Number of decoder layers. + encoder_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + decoder_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer decoder. + decoder_ffn_dim (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (often named feed-forward) layer in decoder. + encoder_ffn_dim (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (often named feed-forward) layer in decoder. + activation_function (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_dropout (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + activation_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for activations inside the fully connected layer. + classifier_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for classifier. + max_position_embeddings (`int`, *optional*, defaults to 1024): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + init_std (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + encoder_layerdrop (`float`, *optional*, defaults to 0.0): + The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) + for more details. + decoder_layerdrop (`float`, *optional*, defaults to 0.0): + The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) + for more details. + scale_embedding (`bool`, *optional*, defaults to `True`): + Scale embeddings by diving by sqrt(d_model). + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models) + forced_eos_token_id (`int`, *optional*, defaults to 2): + The id of the token to force as the last generated token when `max_length` is reached. Usually set to + `eos_token_id`. + + Example: + + ```python + >>> from transformers import PLBartConfig, PLBartModel + + >>> # Initializing a PLBART uclanlp/plbart-base style configuration + >>> configuration = PLBartConfig() + + >>> # Initializing a model (with random weights) from the uclanlp/plbart-base style configuration + >>> model = PLBartModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "plbart" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"} + + def __init__( + self, + vocab_size=50005, + max_position_embeddings=1024, + encoder_layers=6, + encoder_ffn_dim=3072, + encoder_attention_heads=12, + decoder_layers=6, + decoder_ffn_dim=3072, + decoder_attention_heads=12, + encoder_layerdrop=0.0, + decoder_layerdrop=0.0, + use_cache=True, + is_encoder_decoder=True, + activation_function="gelu", + d_model=768, + dropout=0.1, + attention_dropout=0.1, + activation_dropout=0.0, + init_std=0.02, + classifier_dropout=0.0, + scale_embedding=True, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + forced_eos_token_id=2, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.d_model = d_model + self.encoder_ffn_dim = encoder_ffn_dim + self.encoder_layers = encoder_layers + self.encoder_attention_heads = encoder_attention_heads + self.decoder_ffn_dim = decoder_ffn_dim + self.decoder_layers = decoder_layers + self.decoder_attention_heads = decoder_attention_heads + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.activation_function = activation_function + self.init_std = init_std + self.encoder_layerdrop = encoder_layerdrop + self.decoder_layerdrop = decoder_layerdrop + self.classifier_dropout = classifier_dropout + self.use_cache = use_cache + self.num_hidden_layers = encoder_layers + self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + is_encoder_decoder=is_encoder_decoder, + forced_eos_token_id=forced_eos_token_id, + **kwargs, + ) + + +class PLBartOnnxConfig(OnnxConfigWithPast): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict( + [ + ("input_ids", {0: "batch", 1: "sequence"}), + ("attention_mask", {0: "batch", 1: "sequence"}), + ] + ) + + @property + def outputs(self) -> Mapping[str, Mapping[int, str]]: + if self.use_past: + return OrderedDict( + [ + ("last_hidden_state", {0: "batch", 1: "sequence"}), + ("past_keys", {0: "batch", 2: "sequence"}), + ("encoder_last_hidden_state", {0: "batch", 1: "sequence"}), + ] + ) + else: + return OrderedDict( + [ + ("last_hidden_state", {0: "batch", 1: "sequence"}), + ("encoder_last_hidden_state", {0: "batch", 1: "sequence"}), + ] + ) + + +__all__ = ["PLBartConfig"] diff --git a/docs/transformers/src/transformers/models/plbart/convert_plbart_original_checkpoint_to_torch.py b/docs/transformers/src/transformers/models/plbart/convert_plbart_original_checkpoint_to_torch.py new file mode 100644 index 0000000000000000000000000000000000000000..0a2bb9553e0b97f8894f77900835b31bf6d38b33 --- /dev/null +++ b/docs/transformers/src/transformers/models/plbart/convert_plbart_original_checkpoint_to_torch.py @@ -0,0 +1,94 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +import torch +from torch import nn + +from transformers import PLBartConfig, PLBartForConditionalGeneration, PLBartForSequenceClassification + + +def remove_ignore_keys_(state_dict): + ignore_keys = [ + "encoder.version", + "decoder.version", + "model.encoder.version", + "model.decoder.version", + "_float_tensor", + "decoder.output_projection.weight", + ] + for k in ignore_keys: + state_dict.pop(k, None) + + +def make_linear_from_emb(emb): + vocab_size, emb_size = emb.weight.shape + lin_layer = nn.Linear(vocab_size, emb_size, bias=False) + lin_layer.weight.data = emb.weight.data + return lin_layer + + +def convert_fairseq_plbart_checkpoint_from_disk( + checkpoint_path, hf_config_path="uclanlp/plbart-base", finetuned=False, classification=False +): + state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True)["model"] + remove_ignore_keys_(state_dict) + vocab_size = state_dict["encoder.embed_tokens.weight"].shape[0] + + plbart_config = PLBartConfig.from_pretrained(hf_config_path, vocab_size=vocab_size) + + state_dict["shared.weight"] = state_dict["decoder.embed_tokens.weight"] + if not classification: + model = PLBartForConditionalGeneration(plbart_config) + model.model.load_state_dict(state_dict) + if finetuned: + model.lm_head = make_linear_from_emb(model.model.shared) + + else: + classification_head = {} + for key, value in state_dict.copy().items(): + if key.startswith("classification_heads.sentence_classification_head"): + classification_head[key.replace("classification_heads.sentence_classification_head.", "")] = value + state_dict.pop(key) + model = PLBartForSequenceClassification(plbart_config) + model.model.load_state_dict(state_dict) + model.classification_head.load_state_dict(classification_head) + + return model + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument("fairseq_path", type=str, help="model.pt on local filesystem.") + parser.add_argument("pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.") + parser.add_argument( + "--hf_config", + default="uclanlp/plbart-base", + type=str, + help="Which huggingface architecture to use: plbart-base", + ) + parser.add_argument("--finetuned", action="store_true", help="whether the model is a fine-tuned checkpoint") + parser.add_argument( + "--classification", action="store_true", help="whether the model is a classification checkpoint" + ) + args = parser.parse_args() + model = convert_fairseq_plbart_checkpoint_from_disk( + args.fairseq_path, + hf_config_path=args.hf_config, + finetuned=args.finetuned, + classification=args.classification, + ) + model.save_pretrained(args.pytorch_dump_folder_path) diff --git a/docs/transformers/src/transformers/models/plbart/modeling_plbart.py b/docs/transformers/src/transformers/models/plbart/modeling_plbart.py new file mode 100644 index 0000000000000000000000000000000000000000..e625c20251e87d6817b2879966bbec00bf027632 --- /dev/null +++ b/docs/transformers/src/transformers/models/plbart/modeling_plbart.py @@ -0,0 +1,1731 @@ +# coding=utf-8 +# Copyright 2022, UCLA NLP, The Facebook AI Research Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch PLBART model.""" + +import copy +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import ( + _prepare_4d_attention_mask, + _prepare_4d_attention_mask_for_sdpa, + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, + Seq2SeqSequenceClassifierOutput, +) +from ...modeling_utils import PreTrainedModel +from ...utils import ( + add_code_sample_docstrings, + add_end_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from .configuration_plbart import PLBartConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "uclanlp/plbart-base" +_CONFIG_FOR_DOC = "PLBartConfig" + + +# Copied from transformers.models.mbart.modeling_mbart.shift_tokens_right +def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int): + """ + Shift input ids one token to the right, and wrap the last non pad token (the token) Note that MBart does not + have a single `decoder_start_token_id` in contrast to other Bart-like models. + """ + prev_output_tokens = input_ids.clone() + + if pad_token_id is None: + raise ValueError("self.model.config.pad_token_id has to be defined.") + # replace possible -100 values in labels by `pad_token_id` + prev_output_tokens.masked_fill_(prev_output_tokens == -100, pad_token_id) + + index_of_eos = (prev_output_tokens.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1) + decoder_start_tokens = prev_output_tokens.gather(1, index_of_eos).squeeze() + prev_output_tokens[:, 1:] = prev_output_tokens[:, :-1].clone() + prev_output_tokens[:, 0] = decoder_start_tokens + + return prev_output_tokens + + +# Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->PLBart +class PLBartLearnedPositionalEmbedding(nn.Embedding): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int): + # PLBart is set up so that if padding_idx is specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. Other models don't have this hack + self.offset = 2 + super().__init__(num_embeddings + self.offset, embedding_dim) + + def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0): + """`input_ids' shape is expected to be [bsz x seqlen].""" + + bsz, seq_len = input_ids.shape[:2] + positions = torch.arange( + past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device + ).expand(bsz, -1) + + return super().forward(positions + self.offset) + + +# Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->PLBart +class PLBartScaledWordEmbedding(nn.Embedding): + """ + This module overrides nn.Embeddings' forward by multiplying with embeddings scale. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0): + super().__init__(num_embeddings, embedding_dim, padding_idx) + self.embed_scale = embed_scale + + def forward(self, input_ids: torch.Tensor): + return super().forward(input_ids) * self.embed_scale + + +# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->PLBart +class PLBartAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + is_causal: bool = False, + config: Optional[PLBartConfig] = None, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + self.config = config + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + self.is_causal = is_causal + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.reshape(*proj_shape) + value_states = value_states.reshape(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value + + +# Copied from transformers.models.bart.modeling_bart.BartEncoderLayer with Bart->PLBart, BART->PLBART +class PLBartEncoderLayer(nn.Module): + def __init__(self, config: PLBartConfig): + super().__init__() + self.embed_dim = config.d_model + + self.self_attn = PLBART_ATTENTION_CLASSES[config._attn_implementation]( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + config=config, + ) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: torch.FloatTensor, + layer_head_mask: torch.FloatTensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states, attn_weights, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + if hidden_states.dtype == torch.float16 and ( + torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() + ): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +# TODO: Implement attention with SDPA for PLBart. +PLBART_ATTENTION_CLASSES = {"eager": PLBartAttention} + + +# Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->PLBart, BART->PLBART +class PLBartDecoderLayer(nn.Module): + def __init__(self, config: PLBartConfig): + super().__init__() + self.embed_dim = config.d_model + + self.self_attn = PLBART_ATTENTION_CLASSES[config._attn_implementation]( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + is_causal=True, + config=config, + ) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.encoder_attn = PLBART_ATTENTION_CLASSES[config._attn_implementation]( + self.embed_dim, + config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + config=config, + ) + self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) + self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + cross_attn_layer_head_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = True, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + encoder_hidden_states (`torch.FloatTensor`): + cross attention input to the layer of shape `(batch, seq_len, embed_dim)` + encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of + size `(decoder_attention_heads,)`. + past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + # Self Attention + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + # add present self-attn cache to positions 1,2 of present_key_value tuple + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=self_attn_past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Cross-Attention Block + cross_attn_present_key_value = None + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + + # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=cross_attn_past_key_value, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + # add cross-attn to positions 3,4 of present_key_value tuple + present_key_value = present_key_value + cross_attn_present_key_value + + # Fully Connected + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +# Copied from transformers.models.bart.modeling_bart.BartClassificationHead with Bart->PLBart +class PLBartClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__( + self, + input_dim: int, + inner_dim: int, + num_classes: int, + pooler_dropout: float, + ): + super().__init__() + self.dense = nn.Linear(input_dim, inner_dim) + self.dropout = nn.Dropout(p=pooler_dropout) + self.out_proj = nn.Linear(inner_dim, num_classes) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dropout(hidden_states) + hidden_states = self.dense(hidden_states) + hidden_states = torch.tanh(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.out_proj(hidden_states) + return hidden_states + + +class PLBartPreTrainedModel(PreTrainedModel): + config_class = PLBartConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["PLBartDecoderLayer", "PLBartEncoderLayer"] + + def _init_weights(self, module): + std = self.config.init_std + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +PLBART_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`PLBartConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +PLBART_GENERATION_EXAMPLE = r""" + Mask-filling example: + + ```python + >>> from transformers import AutoTokenizer, PLBartForConditionalGeneration + + >>> model = PLBartForConditionalGeneration.from_pretrained("uclanlp/plbart-base") + >>> tokenizer = AutoTokenizer.from_pretrained("uclanlp/plbart-base") + + >>> # en_XX is the language symbol id for English + >>> TXT = " Is 0 the Fibonacci number ? en_XX" + >>> input_ids = tokenizer([TXT], add_special_tokens=False, return_tensors="pt").input_ids + + >>> logits = model(input_ids).logits + >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item() + >>> probs = logits[0, masked_index].softmax(dim=0) + >>> values, predictions = probs.topk(5) + + >>> tokenizer.decode(predictions).split() + ['first', 'same', 'highest', 'result', 'number'] + ``` +""" + +PLBART_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`] or [`PLBartMultiTokenizer`] depending on the checkpoint. + See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`] or [`PLBartMultiTokenizer`] depending on the checkpoint. + See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + PLBart uses a specific language id token as the starting token for `decoder_input_ids` generation that + varies according to source and target language, *e.g.* 50003 for *en_XX*, and 50001 for *java*. If + `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + For translation and summarization training, `decoder_input_ids` should be provided. If no + `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right + for denoising pre-training following the paper. + decoder_attention_mask (: + obj:*torch.LongTensor* of shape `(batch_size, target_sequence_length)`, *optional*): Default behavior: + generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also be used by default. + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + decoder_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (: + obj:*torch.Tensor* of shape `(decoder_layers, decoder_attention_heads)`, *optional*): Mask to nullify + selected heads of the cross-attention modules in the decoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): + Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of + hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. + past_key_values (: + obj:*tuple(tuple(torch.FloatTensor))*, *optional*, returned when `use_cache=True` is passed or when + `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple + having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional + tensors of shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (: + obj:*torch.FloatTensor* of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, + instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful + if you want more control over how to convert `input_ids` indices into associated vectors than the model's + internal embedding lookup matrix. + decoder_inputs_embeds (: + obj:*torch.FloatTensor* of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded + representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be + input (see `past_key_values`). This is useful if you want more control over how to convert + `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. + + If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value + of `inputs_embeds`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +# Copied from transformers.models.bart.modeling_bart.BartEncoder with Bart->PLBart +class PLBartEncoder(PLBartPreTrainedModel): + """ + Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a + [`PLBartEncoderLayer`]. + + Args: + config: PLBartConfig + embed_tokens (nn.Embedding): output embedding + """ + + def __init__(self, config: PLBartConfig, embed_tokens: Optional[nn.Embedding] = None): + super().__init__(config) + + self.dropout = config.dropout + self.layerdrop = config.encoder_layerdrop + + embed_dim = config.d_model + self.padding_idx = config.pad_token_id + self.max_source_positions = config.max_position_embeddings + embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + + self.embed_tokens = PLBartScaledWordEmbedding( + config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale + ) + + if embed_tokens is not None: + self.embed_tokens.weight = embed_tokens.weight + + self.embed_positions = PLBartLearnedPositionalEmbedding( + config.max_position_embeddings, + embed_dim, + ) + self.layers = nn.ModuleList([PLBartEncoderLayer(config) for _ in range(config.encoder_layers)]) + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_sdpa = config._attn_implementation == "sdpa" + self.layernorm_embedding = nn.LayerNorm(embed_dim) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input = input_ids + input_ids = input_ids.view(-1, input_ids.shape[-1]) + elif inputs_embeds is not None: + input = inputs_embeds[:, :, -1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + embed_pos = self.embed_positions(input) + embed_pos = embed_pos.to(inputs_embeds.device) + + hidden_states = inputs_embeds + embed_pos + hidden_states = self.layernorm_embedding(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + # expand attention_mask + if attention_mask is not None: + if self._use_flash_attention_2: + attention_mask = attention_mask if 0 in attention_mask else None + elif self._use_sdpa and head_mask is None and not output_attentions: + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + if head_mask.size()[0] != (len(self.layers)): + raise ValueError( + f"The head_mask should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + to_drop = False + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: # skip the layer + to_drop = True + + if to_drop: + layer_outputs = (None, None) + else: + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + attention_mask, + (head_mask[idx] if head_mask is not None else None), + output_attentions, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +# Copied from transformers.models.bart.modeling_bart.BartDecoder with Bart->PLBart +class PLBartDecoder(PLBartPreTrainedModel): + """ + Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`PLBartDecoderLayer`] + + Args: + config: PLBartConfig + embed_tokens (nn.Embedding): output embedding + """ + + def __init__(self, config: PLBartConfig, embed_tokens: Optional[nn.Embedding] = None): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.decoder_layerdrop + self.padding_idx = config.pad_token_id + self.max_target_positions = config.max_position_embeddings + embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + + self.embed_tokens = PLBartScaledWordEmbedding( + config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale + ) + + if embed_tokens is not None: + self.embed_tokens.weight = embed_tokens.weight + + self.embed_positions = PLBartLearnedPositionalEmbedding( + config.max_position_embeddings, + config.d_model, + ) + self.layers = nn.ModuleList([PLBartDecoderLayer(config) for _ in range(config.decoder_layers)]) + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_sdpa = config._attn_implementation == "sdpa" + + self.layernorm_embedding = nn.LayerNorm(config.d_model) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): + Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values + selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing + cross-attention on hidden heads. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input = input_ids + input_shape = input.shape + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + input = inputs_embeds[:, :, -1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input) + + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self._use_sdpa and not output_attentions and cross_attn_head_mask is None: + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self._use_flash_attention_2: + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self._use_sdpa and cross_attn_head_mask is None and not output_attentions: + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + # embed positions + positions = self.embed_positions(input, past_key_values_length) + positions = positions.to(inputs_embeds.device) + + hidden_states = inputs_embeds + positions + hidden_states = self.layernorm_embedding(hidden_states) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + next_decoder_cache = () if use_cache else None + + # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): + if attn_mask is not None: + if attn_mask.size()[0] != (len(self.layers)): + raise ValueError( + f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: + continue + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + head_mask[idx] if head_mask is not None else None, + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, + None, + output_attentions, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=( + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None + ), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + +@add_start_docstrings( + "The bare PLBART Model outputting raw hidden-states without any specific head on top.", + PLBART_START_DOCSTRING, +) +class PLBartModel(PLBartPreTrainedModel): + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + + def __init__(self, config: PLBartConfig): + super().__init__(config) + + padding_idx, vocab_size = config.pad_token_id, config.vocab_size + embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + self.shared = PLBartScaledWordEmbedding(vocab_size, config.d_model, padding_idx, embed_scale=embed_scale) + + self.encoder = PLBartEncoder(config, self.shared) + self.decoder = PLBartDecoder(config, self.shared) + + self.init_weights() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, value): + self.shared = value + self.encoder.embed_tokens = self.shared + self.decoder.embed_tokens = self.shared + + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + @add_start_docstrings_to_model_forward(PLBART_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=Seq2SeqModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.LongTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # different to other models, PLBart automatically creates decoder_input_ids from + # input_ids if no decoder_input_ids are provided + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right(input_ids, self.config.pad_token_id) + + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + "The PLBART Model with a language modeling head. Can be used for code-to-text, text-to-code and code-to-code.", + PLBART_START_DOCSTRING, +) +class PLBartForConditionalGeneration(PLBartPreTrainedModel, GenerationMixin): + base_model_prefix = "model" + _keys_to_ignore_on_load_missing = ["final_logits_bias"] + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] + + def __init__(self, config: PLBartConfig): + super().__init__(config) + self.model = PLBartModel(config) + self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings))) + self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False) + + self.init_weights() + + def get_encoder(self): + return self.model.get_encoder() + + def get_decoder(self): + return self.model.get_decoder() + + def resize_token_embeddings( + self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True + ) -> nn.Embedding: + new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing) + self._resize_final_logits_bias(new_embeddings.weight.shape[0]) + return new_embeddings + + def _resize_final_logits_bias(self, new_num_tokens: int) -> None: + old_num_tokens = self.final_logits_bias.shape[-1] + if new_num_tokens <= old_num_tokens: + new_bias = self.final_logits_bias[:, :new_num_tokens] + else: + extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device) + new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1) + self.register_buffer("final_logits_bias", new_bias) + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + @add_start_docstrings_to_model_forward(PLBART_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + @add_end_docstrings(PLBART_GENERATION_EXAMPLE) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.LongTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None: + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id) + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + encoder_outputs=encoder_outputs, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + lm_logits = self.lm_head(outputs[0]) + lm_logits = lm_logits + self.final_logits_bias.to(lm_logits.device) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (lm_logits,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return Seq2SeqLMOutput( + loss=masked_lm_loss, + logits=lm_logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return shift_tokens_right(labels, self.config.pad_token_id) + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + # cached cross_attention states don't have to be reordered -> they are always the same + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) + + layer_past[2:], + ) + return reordered_past + + +@add_start_docstrings( + """ + PLBart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for code + classification. + """, + PLBART_START_DOCSTRING, +) +class PLBartForSequenceClassification(PLBartPreTrainedModel): + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + + def __init__(self, config: PLBartConfig, **kwargs): + super().__init__(config, **kwargs) + self.model = PLBartModel(config) + self.classification_head = PLBartClassificationHead( + config.d_model, + config.d_model, + config.num_labels, + config.classifier_dropout, + ) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(PLBART_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=Seq2SeqSequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + # Copied from transformers.models.bart.modeling_bart.BartForSequenceClassification.forward + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if labels is not None: + use_cache = False + + if input_ids is None and inputs_embeds is not None: + raise NotImplementedError( + f"Passing input embeddings is currently not supported for {self.__class__.__name__}" + ) + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = outputs[0] # last hidden state + + eos_mask = input_ids.eq(self.config.eos_token_id).to(hidden_states.device) + + if len(torch.unique_consecutive(eos_mask.sum(1))) > 1: + raise ValueError("All examples must have the same number of tokens.") + sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[ + :, -1, : + ] + logits = self.classification_head(sentence_representation) + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.config.num_labels == 1: + self.config.problem_type = "regression" + elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.config.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return Seq2SeqSequenceClassifierOutput( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + +# Copied from transformers.models.bart.modeling_bart.BartDecoderWrapper with Bart->PLBart +class PLBartDecoderWrapper(PLBartPreTrainedModel): + """ + This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is + used in combination with the [`EncoderDecoderModel`] framework. + """ + + def __init__(self, config): + super().__init__(config) + self.decoder = PLBartDecoder(config) + + def forward(self, *args, **kwargs): + return self.decoder(*args, **kwargs) + + +# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->PLBart, facebook/bart-base->uclanlp/plbart-base +class PLBartForCausalLM(PLBartPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + config = copy.deepcopy(config) + config.is_decoder = True + config.is_encoder_decoder = False + super().__init__(config) + self.model = PLBartDecoderWrapper(config) + + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.model.decoder.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model.decoder = decoder + + def get_decoder(self): + return self.model.decoder + + @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + if the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used + in the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional + tensors are only required when the model is used as a decoder in a Sequence to Sequence model. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, PLBartForCausalLM + + >>> tokenizer = AutoTokenizer.from_pretrained("uclanlp/plbart-base") + >>> model = PLBartForCausalLM.from_pretrained("uclanlp/plbart-base", add_cross_attention=False) + >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder." + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> logits = outputs.logits + >>> expected_shape = [1, inputs.input_ids.shape[-1], model.config.vocab_size] + >>> list(logits.shape) == expected_shape + True + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + logits = self.lm_head(outputs[0]) + + loss = None + if labels is not None: + labels = labels.to(logits.device) + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + + +__all__ = [ + "PLBartForCausalLM", + "PLBartForConditionalGeneration", + "PLBartForSequenceClassification", + "PLBartModel", + "PLBartPreTrainedModel", +] diff --git a/docs/transformers/src/transformers/models/plbart/tokenization_plbart.py b/docs/transformers/src/transformers/models/plbart/tokenization_plbart.py new file mode 100644 index 0000000000000000000000000000000000000000..b9b73b6f47414361bee9b85152ec8c5e6258f28e --- /dev/null +++ b/docs/transformers/src/transformers/models/plbart/tokenization_plbart.py @@ -0,0 +1,432 @@ +# coding=utf-8 +# Copyright 2022, UCLA NLP, The Facebook AI Research Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple + +import sentencepiece as spm + +from ...tokenization_utils import AddedToken, BatchEncoding, PreTrainedTokenizer +from ...utils import logging +from ...utils.import_utils import requires + + +logger = logging.get_logger(__name__) + +SPIECE_UNDERLINE = "▁" + +VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model", "tokenizer_file": "tokenizer.json"} + + +FAIRSEQ_LANGUAGE_CODES = { + "base": ["__java__", "__python__", "__en_XX__"], + "multi": ["__java__", "__python__", "__en_XX__", "__javascript__", "__php__", "__ruby__", "__go__"], +} + +FAIRSEQ_LANGUAGE_CODES_MAP = { + "java": "__java__", + "python": "__python__", + "en_XX": "__en_XX__", + "javascript": "__javascript__", + "php": "__php__", + "ruby": "__ruby__", + "go": "__go__", +} + + +@requires(backends=("sentencepiece",)) +class PLBartTokenizer(PreTrainedTokenizer): + """ + Construct an PLBART tokenizer. + + Adapted from [`RobertaTokenizer`] and [`XLNetTokenizer`]. Based on + [SentencePiece](https://github.com/google/sentencepiece). + + The tokenization method is ` ` for source language documents, and ` + ` for target language documents. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + src_lang (`str`, *optional*): + A string representing the source language. + tgt_lang (`str`, *optional*): + A string representing the target language. + bos_token (`str`, *optional*, defaults to `""`): + The start of sequence token. + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + sep_token (`str`, *optional*, defaults to `""`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + cls_token (`str`, *optional*, defaults to `""`): + The cls token, which is a special token used as the first token for all tasks. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + mask_token(`str`, *optional*, defaults to `""`): + The token used for masking values. This is the token used when training this model with masking tasks. This + is only used in the `"base"` tokenizer type. For `"multi"` tokenizer, masking is never done for the + downstream tasks. + language_codes (`str`, *optional*, defaults to `"base"`): + What language codes to use. Should be one of `"base"` or `"multi"`. + sp_model_kwargs (`dict`, *optional*): + Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for + SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things, + to set: + - `enable_sampling`: Enable subword regularization. + - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout. + - `nbest_size = {0,1}`: No sampling is performed. + - `nbest_size > 1`: samples from the nbest_size results. + - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice) + using forward-filtering-and-backward-sampling algorithm. + - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for + BPE-dropout. + + Examples: + + ```python + >>> from transformers import PLBartTokenizer + + >>> tokenizer = PLBartTokenizer.from_pretrained("uclanlp/plbart-python-en_XX", src_lang="python", tgt_lang="en_XX") + >>> example_python_phrase = "def maximum(a,b,c):NEW_LINE_INDENTreturn max([a,b,c])" + >>> expected_translation_english = "Returns the maximum value of a b c." + >>> inputs = tokenizer(example_python_phrase, text_target=expected_translation_english, return_tensors="pt") + ```""" + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + + prefix_tokens: List[int] = [] + suffix_tokens: List[int] = [] + + def __init__( + self, + vocab_file, + bos_token="", + eos_token="", + sep_token="", + cls_token="", + unk_token="", + pad_token="", + mask_token="", + language_codes="base", + tokenizer_file=None, + src_lang=None, + tgt_lang=None, + sp_model_kwargs: Optional[Dict[str, Any]] = None, + additional_special_tokens=None, + clean_up_tokenization_spaces=True, + **kwargs, + ): + # Mask token behave like a normal word, i.e. include the space before it + mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token + + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + src_lang = self._convert_lang_code_special_format(src_lang) + tgt_lang = self._convert_lang_code_special_format(tgt_lang) + + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(str(vocab_file)) + self.vocab_file = vocab_file + self.language_codes = language_codes + + fairseq_language_codes = FAIRSEQ_LANGUAGE_CODES[self.language_codes] + + # Original fairseq vocab and spm vocab must be "aligned": + # Vocab | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 + # -------- | ------- | ------- | ------ | ------- | --- | --- | --- | ----- | ----- | ---- + # fairseq | '' | '' | '' | '' | ',' | '.' | '▁' | 's' | '▁de' | '-' + # spm | '' | '' | '' | ',' | '.' | '▁' | 's' | '▁de' | '-' | '▁a' + + # Mimic fairseq token-to-id alignment for the first 4 token + self.fairseq_tokens_to_ids = {"": 0, "": 1, "": 2, "": 3} + + # The first "real" token "," has position 4 in the original fairseq vocab and position 3 in the spm vocab + self.fairseq_offset = 1 + + self.sp_model_size = len(self.sp_model) + self.lang_code_to_id = { + code: self.sp_model_size + i + self.fairseq_offset for i, code in enumerate(fairseq_language_codes) + } + self.id_to_lang_code = {v: k for k, v in self.lang_code_to_id.items()} + + if self.language_codes == "base": + self.fairseq_tokens_to_ids[""] = len(self.sp_model) + len(self.lang_code_to_id) + self.fairseq_offset + + self.fairseq_tokens_to_ids.update(self.lang_code_to_id) + self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()} + _additional_special_tokens = list(self.lang_code_to_id.keys()) + + if additional_special_tokens is not None: + # Only add those special tokens if they are not already there. + _additional_special_tokens.extend( + [t for t in additional_special_tokens if t not in _additional_special_tokens] + ) + + if self.language_codes == "base": + self._src_lang = src_lang + self.cur_lang_code_id = ( + self.lang_code_to_id[self._src_lang] if self._src_lang is not None else self._src_lang + ) + else: + self._src_lang = src_lang if src_lang is not None else "__en_XX__" + self.cur_lang_code_id = self.lang_code_to_id[self._src_lang] + + super().__init__( + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + sep_token=sep_token, + cls_token=cls_token, + pad_token=pad_token, + mask_token=mask_token, + language_codes=language_codes, + tokenizer_file=tokenizer_file, + src_lang=src_lang, + tgt_lang=tgt_lang, + additional_special_tokens=_additional_special_tokens, + sp_model_kwargs=self.sp_model_kwargs, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + **kwargs, + ) + + self.tgt_lang = tgt_lang + self.set_src_lang_special_tokens(self._src_lang) + + def __getstate__(self): + state = self.__dict__.copy() + state["sp_model"] = None + state["sp_model_proto"] = self.sp_model.serialized_model_proto() + return state + + def __setstate__(self, d): + self.__dict__ = d + + # for backward compatibility + if not hasattr(self, "sp_model_kwargs"): + self.sp_model_kwargs = {} + + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.LoadFromSerializedProto(self.sp_model_proto) + + @property + def vocab_size(self): + if self.language_codes == "base": + return ( + len(self.sp_model) + len(self.lang_code_to_id) + self.fairseq_offset + 1 + ) # Plus 1 for the mask token + else: + return len(self.sp_model) + len(self.lang_code_to_id) + self.fairseq_offset + + @property + def src_lang(self) -> str: + return self._src_lang + + @src_lang.setter + def src_lang(self, new_src_lang: str) -> None: + new_src_lang = self._convert_lang_code_special_format(new_src_lang) + self._src_lang = new_src_lang + self.set_src_lang_special_tokens(self._src_lang) + + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + prefix_ones = [1] * len(self.prefix_tokens) + suffix_ones = [1] * len(self.suffix_tokens) + if token_ids_1 is None: + return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones + return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones + + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. An PLBART sequence has the following format, where `X` represents the sequence: + + - `input_ids` (for encoder) `X [eos, src_lang_code]` + - `decoder_input_ids`: (for decoder) `X [eos, tgt_lang_code]` + + BOS is never used. Pairs of sequences are not the expected use case, but they will be handled without a + separator. + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + if token_ids_1 is None: + return self.prefix_tokens + token_ids_0 + self.suffix_tokens + # We don't expect to process pairs, but leave the pair logic for API consistency + return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. PLBart does not + make use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of zeros. + """ + + sep = [self.sep_token_id] + cls = [self.cls_token_id] + + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0] + + def _build_translation_inputs( + self, raw_inputs, return_tensors: str, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs + ): + """Used by translation pipeline, to prepare inputs for the generate function""" + if src_lang is None or tgt_lang is None: + raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model") + self.src_lang = self._convert_lang_code_special_format(src_lang) + self.tgt_lang = self._convert_lang_code_special_format(tgt_lang) + inputs = self(raw_inputs, add_special_tokens=True, return_tensors=return_tensors, **extra_kwargs) + tgt_lang_id = self.convert_tokens_to_ids(self.tgt_lang) + inputs["forced_bos_token_id"] = tgt_lang_id + return inputs + + def get_vocab(self): + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + def _tokenize(self, text: str) -> List[str]: + return self.sp_model.encode(text, out_type=str) + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + if token in self.fairseq_tokens_to_ids: + return self.fairseq_tokens_to_ids[token] + spm_id = self.sp_model.PieceToId(token) + + # Need to return unknown token if the SP model returned 0 + return spm_id + self.fairseq_offset if spm_id else self.unk_token_id + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + if index in self.fairseq_ids_to_tokens: + return self.fairseq_ids_to_tokens[index] + return self.sp_model.IdToPiece(index - self.fairseq_offset) + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (strings for sub-words) in a single string.""" + out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip() + return out_string + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): + copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + + return (out_vocab_file,) + + def prepare_seq2seq_batch( + self, + src_texts: List[str], + src_lang: str = "en_XX", + tgt_texts: Optional[List[str]] = None, + tgt_lang: str = "python", + **kwargs, + ) -> BatchEncoding: + self.src_lang = self._convert_lang_code_special_format(src_lang) + self.tgt_lang = self._convert_lang_code_special_format(tgt_lang) + return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs) + + def _switch_to_input_mode(self): + return self.set_src_lang_special_tokens(self.src_lang) + + def _switch_to_target_mode(self): + return self.set_tgt_lang_special_tokens(self.tgt_lang) + + def set_src_lang_special_tokens(self, src_lang) -> None: + """Reset the special tokens to the source lang setting. No prefix and suffix=[eos, src_lang_code].""" + src_lang = self._convert_lang_code_special_format(src_lang) + self.cur_lang_code = self.lang_code_to_id[src_lang] if src_lang is not None else None + self.prefix_tokens = [] + if self.cur_lang_code is not None: + self.suffix_tokens = [self.eos_token_id, self.cur_lang_code] + else: + self.suffix_tokens = [self.eos_token_id] + + def set_tgt_lang_special_tokens(self, lang: str) -> None: + """Reset the special tokens to the target language setting. No prefix and suffix=[eos, tgt_lang_code].""" + lang = self._convert_lang_code_special_format(lang) + + self.cur_lang_code = self.lang_code_to_id[lang] if lang is not None else None + self.prefix_tokens = [] + if self.cur_lang_code is not None: + self.suffix_tokens = [self.eos_token_id, self.cur_lang_code] + else: + self.suffix_tokens = [self.eos_token_id] + + def _convert_lang_code_special_format(self, lang: str) -> str: + """Convert Language Codes to format tokenizer uses if required""" + lang = FAIRSEQ_LANGUAGE_CODES_MAP[lang] if lang in FAIRSEQ_LANGUAGE_CODES_MAP.keys() else lang + return lang + + +__all__ = ["PLBartTokenizer"] diff --git a/docs/transformers/src/transformers/models/poolformer/__init__.py b/docs/transformers/src/transformers/models/poolformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2e80bb47c302c88ffe33450d48a3b6ac6cf2ffc0 --- /dev/null +++ b/docs/transformers/src/transformers/models/poolformer/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_poolformer import * + from .feature_extraction_poolformer import * + from .image_processing_poolformer import * + from .image_processing_poolformer_fast import * + from .modeling_poolformer import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/poolformer/configuration_poolformer.py b/docs/transformers/src/transformers/models/poolformer/configuration_poolformer.py new file mode 100644 index 0000000000000000000000000000000000000000..cdaf1306313830982fd799d9bdc24dd4f1a7412a --- /dev/null +++ b/docs/transformers/src/transformers/models/poolformer/configuration_poolformer.py @@ -0,0 +1,148 @@ +# coding=utf-8 +# Copyright 2022 Sea AI Labs and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PoolFormer model configuration""" + +from collections import OrderedDict +from typing import Mapping + +from packaging import version + +from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class PoolFormerConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of [`PoolFormerModel`]. It is used to instantiate a + PoolFormer model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the PoolFormer + [sail/poolformer_s12](https://huggingface.co/sail/poolformer_s12) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + num_channels (`int`, *optional*, defaults to 3): + The number of channels in the input image. + patch_size (`int`, *optional*, defaults to 16): + The size of the input patch. + stride (`int`, *optional*, defaults to 16): + The stride of the input patch. + pool_size (`int`, *optional*, defaults to 3): + The size of the pooling window. + mlp_ratio (`float`, *optional*, defaults to 4.0): + The ratio of the number of channels in the output of the MLP to the number of channels in the input. + depths (`list`, *optional*, defaults to `[2, 2, 6, 2]`): + The depth of each encoder block. + hidden_sizes (`list`, *optional*, defaults to `[64, 128, 320, 512]`): + The hidden sizes of each encoder block. + patch_sizes (`list`, *optional*, defaults to `[7, 3, 3, 3]`): + The size of the input patch for each encoder block. + strides (`list`, *optional*, defaults to `[4, 2, 2, 2]`): + The stride of the input patch for each encoder block. + padding (`list`, *optional*, defaults to `[2, 1, 1, 1]`): + The padding of the input patch for each encoder block. + num_encoder_blocks (`int`, *optional*, defaults to 4): + The number of encoder blocks. + drop_path_rate (`float`, *optional*, defaults to 0.0): + The dropout rate for the dropout layers. + hidden_act (`str`, *optional*, defaults to `"gelu"`): + The activation function for the hidden layers. + use_layer_scale (`bool`, *optional*, defaults to `True`): + Whether to use layer scale. + layer_scale_init_value (`float`, *optional*, defaults to 1e-05): + The initial value for the layer scale. + initializer_range (`float`, *optional*, defaults to 0.02): + The initializer range for the weights. + + Example: + + ```python + >>> from transformers import PoolFormerConfig, PoolFormerModel + + >>> # Initializing a PoolFormer sail/poolformer_s12 style configuration + >>> configuration = PoolFormerConfig() + + >>> # Initializing a model (with random weights) from the sail/poolformer_s12 style configuration + >>> model = PoolFormerModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "poolformer" + + def __init__( + self, + num_channels=3, + patch_size=16, + stride=16, + pool_size=3, + mlp_ratio=4.0, + depths=[2, 2, 6, 2], + hidden_sizes=[64, 128, 320, 512], + patch_sizes=[7, 3, 3, 3], + strides=[4, 2, 2, 2], + padding=[2, 1, 1, 1], + num_encoder_blocks=4, + drop_path_rate=0.0, + hidden_act="gelu", + use_layer_scale=True, + layer_scale_init_value=1e-5, + initializer_range=0.02, + **kwargs, + ): + self.num_channels = num_channels + self.patch_size = patch_size + self.stride = stride + self.padding = padding + self.pool_size = pool_size + self.hidden_sizes = hidden_sizes + self.mlp_ratio = mlp_ratio + self.depths = depths + self.patch_sizes = patch_sizes + self.strides = strides + self.num_encoder_blocks = num_encoder_blocks + self.drop_path_rate = drop_path_rate + self.hidden_act = hidden_act + self.use_layer_scale = use_layer_scale + self.layer_scale_init_value = layer_scale_init_value + self.initializer_range = initializer_range + super().__init__(**kwargs) + + +class PoolFormerOnnxConfig(OnnxConfig): + torch_onnx_minimum_version = version.parse("1.11") + + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict( + [ + ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}), + ] + ) + + @property + def atol_for_validation(self) -> float: + return 2e-3 + + +__all__ = ["PoolFormerConfig", "PoolFormerOnnxConfig"] diff --git a/docs/transformers/src/transformers/models/poolformer/convert_poolformer_original_to_pytorch.py b/docs/transformers/src/transformers/models/poolformer/convert_poolformer_original_to_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..ddcfb9cd24193714929c9be4759dcea198457bb3 --- /dev/null +++ b/docs/transformers/src/transformers/models/poolformer/convert_poolformer_original_to_pytorch.py @@ -0,0 +1,214 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert PoolFormer checkpoints from the original repository. URL: https://github.com/sail-sg/poolformer""" + +import argparse +import json +from collections import OrderedDict +from pathlib import Path + +import requests +import torch +from huggingface_hub import hf_hub_download +from PIL import Image + +from transformers import PoolFormerConfig, PoolFormerForImageClassification, PoolFormerImageProcessor +from transformers.utils import logging + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + + +def replace_key_with_offset(key, offset, original_name, new_name): + """ + Replaces the key by subtracting the offset from the original layer number + """ + to_find = original_name.split(".")[0] + key_list = key.split(".") + orig_block_num = int(key_list[key_list.index(to_find) - 2]) + layer_num = int(key_list[key_list.index(to_find) - 1]) + new_block_num = orig_block_num - offset + + key = key.replace(f"{orig_block_num}.{layer_num}.{original_name}", f"block.{new_block_num}.{layer_num}.{new_name}") + return key + + +def rename_keys(state_dict): + new_state_dict = OrderedDict() + total_embed_found, patch_emb_offset = 0, 0 + for key, value in state_dict.items(): + if key.startswith("network"): + key = key.replace("network", "poolformer.encoder") + if "proj" in key: + # Works for the first embedding as well as the internal embedding layers + if key.endswith("bias") and "patch_embed" not in key: + patch_emb_offset += 1 + to_replace = key[: key.find("proj")] + key = key.replace(to_replace, f"patch_embeddings.{total_embed_found}.") + key = key.replace("proj", "projection") + if key.endswith("bias"): + total_embed_found += 1 + if "patch_embeddings" in key: + key = "poolformer.encoder." + key + if "mlp.fc1" in key: + key = replace_key_with_offset(key, patch_emb_offset, "mlp.fc1", "output.conv1") + if "mlp.fc2" in key: + key = replace_key_with_offset(key, patch_emb_offset, "mlp.fc2", "output.conv2") + if "norm1" in key: + key = replace_key_with_offset(key, patch_emb_offset, "norm1", "before_norm") + if "norm2" in key: + key = replace_key_with_offset(key, patch_emb_offset, "norm2", "after_norm") + if "layer_scale_1" in key: + key = replace_key_with_offset(key, patch_emb_offset, "layer_scale_1", "layer_scale_1") + if "layer_scale_2" in key: + key = replace_key_with_offset(key, patch_emb_offset, "layer_scale_2", "layer_scale_2") + if "head" in key: + key = key.replace("head", "classifier") + new_state_dict[key] = value + return new_state_dict + + +# We will verify our results on a COCO image +def prepare_img(): + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + image = Image.open(requests.get(url, stream=True).raw) + + return image + + +@torch.no_grad() +def convert_poolformer_checkpoint(model_name, checkpoint_path, pytorch_dump_folder_path): + """ + Copy/paste/tweak model's weights to our PoolFormer structure. + """ + + # load default PoolFormer configuration + config = PoolFormerConfig() + + # set attributes based on model_name + repo_id = "huggingface/label-files" + size = model_name[-3:] + config.num_labels = 1000 + filename = "imagenet-1k-id2label.json" + expected_shape = (1, 1000) + + # set config attributes + id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r")) + id2label = {int(k): v for k, v in id2label.items()} + config.id2label = id2label + config.label2id = {v: k for k, v in id2label.items()} + if size == "s12": + config.depths = [2, 2, 6, 2] + config.hidden_sizes = [64, 128, 320, 512] + config.mlp_ratio = 4.0 + crop_pct = 0.9 + elif size == "s24": + config.depths = [4, 4, 12, 4] + config.hidden_sizes = [64, 128, 320, 512] + config.mlp_ratio = 4.0 + crop_pct = 0.9 + elif size == "s36": + config.depths = [6, 6, 18, 6] + config.hidden_sizes = [64, 128, 320, 512] + config.mlp_ratio = 4.0 + config.layer_scale_init_value = 1e-6 + crop_pct = 0.9 + elif size == "m36": + config.depths = [6, 6, 18, 6] + config.hidden_sizes = [96, 192, 384, 768] + config.mlp_ratio = 4.0 + config.layer_scale_init_value = 1e-6 + crop_pct = 0.95 + elif size == "m48": + config.depths = [8, 8, 24, 8] + config.hidden_sizes = [96, 192, 384, 768] + config.mlp_ratio = 4.0 + config.layer_scale_init_value = 1e-6 + crop_pct = 0.95 + else: + raise ValueError(f"Size {size} not supported") + + # load image processor + image_processor = PoolFormerImageProcessor(crop_pct=crop_pct) + + # Prepare image + image = prepare_img() + pixel_values = image_processor(images=image, return_tensors="pt").pixel_values + + logger.info(f"Converting model {model_name}...") + + # load original state dict + state_dict = torch.load(checkpoint_path, map_location=torch.device("cpu"), weights_only=True) + + # rename keys + state_dict = rename_keys(state_dict) + + # create HuggingFace model and load state dict + model = PoolFormerForImageClassification(config) + model.load_state_dict(state_dict) + model.eval() + + # Define image processor + image_processor = PoolFormerImageProcessor(crop_pct=crop_pct) + pixel_values = image_processor(images=prepare_img(), return_tensors="pt").pixel_values + + # forward pass + outputs = model(pixel_values) + logits = outputs.logits + + # define expected logit slices for different models + if size == "s12": + expected_slice = torch.tensor([-0.3045, -0.6758, -0.4869]) + elif size == "s24": + expected_slice = torch.tensor([0.4402, -0.1374, -0.8045]) + elif size == "s36": + expected_slice = torch.tensor([-0.6080, -0.5133, -0.5898]) + elif size == "m36": + expected_slice = torch.tensor([0.3952, 0.2263, -1.2668]) + elif size == "m48": + expected_slice = torch.tensor([0.1167, -0.0656, -0.3423]) + else: + raise ValueError(f"Size {size} not supported") + + # verify logits + assert logits.shape == expected_shape + assert torch.allclose(logits[0, :3], expected_slice, atol=1e-2) + + # finally, save model and image processor + logger.info(f"Saving PyTorch model and image processor to {pytorch_dump_folder_path}...") + Path(pytorch_dump_folder_path).mkdir(exist_ok=True) + model.save_pretrained(pytorch_dump_folder_path) + print(f"Saving image processor to {pytorch_dump_folder_path}") + image_processor.save_pretrained(pytorch_dump_folder_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--model_name", + default="poolformer_s12", + type=str, + help="Name of the model you'd like to convert.", + ) + parser.add_argument( + "--checkpoint_path", default=None, type=str, help="Path to the original PyTorch checkpoint (.pth file)." + ) + parser.add_argument( + "--pytorch_dump_folder_path", default=None, type=str, help="Path to the folder to output PyTorch model." + ) + args = parser.parse_args() + convert_poolformer_checkpoint(args.model_name, args.checkpoint_path, args.pytorch_dump_folder_path) diff --git a/docs/transformers/src/transformers/models/poolformer/feature_extraction_poolformer.py b/docs/transformers/src/transformers/models/poolformer/feature_extraction_poolformer.py new file mode 100644 index 0000000000000000000000000000000000000000..bde18b3ec0960f4425a1788b24cf838ff6459922 --- /dev/null +++ b/docs/transformers/src/transformers/models/poolformer/feature_extraction_poolformer.py @@ -0,0 +1,38 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Feature extractor class for PoolFormer.""" + +import warnings + +from ...utils import logging +from ...utils.import_utils import requires +from .image_processing_poolformer import PoolFormerImageProcessor + + +logger = logging.get_logger(__name__) + + +@requires(backends=("vision",)) +class PoolFormerFeatureExtractor(PoolFormerImageProcessor): + def __init__(self, *args, **kwargs) -> None: + warnings.warn( + "The class PoolFormerFeatureExtractor is deprecated and will be removed in version 5 of Transformers." + " Please use PoolFormerImageProcessor instead.", + FutureWarning, + ) + super().__init__(*args, **kwargs) + + +__all__ = ["PoolFormerFeatureExtractor"] diff --git a/docs/transformers/src/transformers/models/poolformer/image_processing_poolformer.py b/docs/transformers/src/transformers/models/poolformer/image_processing_poolformer.py new file mode 100644 index 0000000000000000000000000000000000000000..cd4c4bb77086a53ef7583d2569c24ff62c79cb20 --- /dev/null +++ b/docs/transformers/src/transformers/models/poolformer/image_processing_poolformer.py @@ -0,0 +1,360 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for PoolFormer.""" + +from typing import Dict, List, Optional, Union + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import ( + get_resize_output_image_size, + resize, + to_channel_dimension_format, +) +from ...image_utils import ( + IMAGENET_DEFAULT_MEAN, + IMAGENET_DEFAULT_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + infer_channel_dimension_format, + is_scaled_image, + make_list_of_images, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging + + +if is_vision_available(): + import PIL + + +logger = logging.get_logger(__name__) + + +class PoolFormerImageProcessor(BaseImageProcessor): + r""" + Constructs a PoolFormer image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by + `do_resize` in the `preprocess` method. + size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 224}`): + Size of the image after resizing. Can be overridden by `size` in the `preprocess` method. If crop_pct is + unset: + - size is `{"height": h, "width": w}`: the image is resized to `(h, w)`. + - size is `{"shortest_edge": s}`: the shortest edge of the image is resized to s whilst maintaining the + aspect ratio. + + If crop_pct is set: + - size is `{"height": h, "width": w}`: the image is resized to `(int(floor(h/crop_pct)), + int(floor(w/crop_pct)))` + - size is `{"height": c, "width": c}`: the shortest edge of the image is resized to `int(floor(c/crop_pct)` + whilst maintaining the aspect ratio. + - size is `{"shortest_edge": c}`: the shortest edge of the image is resized to `int(floor(c/crop_pct)` + whilst maintaining the aspect ratio. + crop_pct (`float`, *optional*, defaults to 0.9): + Percentage of the image to crop from the center. Can be overridden by `crop_pct` in the `preprocess` + method. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`): + Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method. + do_center_crop (`bool`, *optional*, defaults to `True`): + Whether to center crop the image. If the input size is smaller than `crop_size` along any edge, the image + is padded with 0's and then center cropped. Can be overridden by `do_center_crop` in the `preprocess` + method. + crop_size (`Dict[str, int]`, *optional*, defaults to `{"height": 224, "width": 224}`): + Size of the image after applying center crop. Only has an effect if `do_center_crop` is set to `True`. Can + be overridden by the `crop_size` parameter in the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the + `preprocess` method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale` + parameter in the `preprocess` method. + do_normalize (`bool`, *optional*, defaults to `True`): + Controls whether to normalize the image. Can be overridden by the `do_normalize` parameter in the + `preprocess` method. + image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize: bool = True, + size: Dict[str, int] = None, + crop_pct: int = 0.9, + resample: PILImageResampling = PILImageResampling.BICUBIC, + do_center_crop: bool = True, + crop_size: Dict[str, int] = None, + rescale_factor: Union[int, float] = 1 / 255, + do_rescale: bool = True, + do_normalize: bool = True, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"shortest_edge": 224} + size = get_size_dict(size, default_to_square=False) + crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224} + crop_size = get_size_dict(crop_size, param_name="crop_size") + + self.do_resize = do_resize + self.size = size + self.crop_pct = crop_pct + self.resample = resample + self.do_center_crop = do_center_crop + self.crop_size = crop_size + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN + self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD + + def resize( + self, + image: np.ndarray, + size: Dict[str, int], + crop_pct: Optional[float] = None, + resample: PILImageResampling = PILImageResampling.BICUBIC, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize an image. + + If crop_pct is unset: + - size is `{"height": h, "width": w}`: the image is resized to `(h, w)`. + - size is `{"shortest_edge": s}`: the shortest edge of the image is resized to s whilst maintaining the + aspect ratio. + + if crop_pct is set: + - size is `{"height": h, "width": w}`: the image is resized to `(int(floor(h/crop_pct)), + int(floor(w/crop_pct)))` + - size is `{"height": c, "width": c}`: the shortest edge of the image is resized to `int(floor(c/crop_pct)` + whilst maintaining the aspect ratio. + - size is `{"shortest_edge": c}`: the shortest edge of the image is resized to `int(floor(c/crop_pct)` + whilst maintaining the aspect ratio. + + Args: + image (`np.ndarray`): + Image to resize. + size (`Dict[str, int]`): + Size of the output image. + crop_pct (`float`, *optional*): + Percentage of the image that will be cropped from the center. If set, the image is resized + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + Resampling filter to use when resizing the image. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be the same as the input image. + input_data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + size = get_size_dict(size, default_to_square=False) + if "shortest_edge" not in size and ("height" not in size or "width" not in size): + raise ValueError(f"size must contain 'height' and 'width' or 'shortest_edge' as keys. Got {size.keys()}") + if crop_pct is not None: + if "shortest_edge" in size: + scale_size = int(size["shortest_edge"] / crop_pct) + elif "height" in size and "width" in size: + if size["height"] == size["width"]: + scale_size = int(size["height"] / crop_pct) + else: + scale_size = (int(size["height"] / crop_pct), int(size["width"] / crop_pct)) + else: + raise ValueError("Invalid size for resize: {}".format(size)) + + output_size = get_resize_output_image_size( + image, size=scale_size, default_to_square=False, input_data_format=input_data_format + ) + else: + if "shortest_edge" in size: + output_size = get_resize_output_image_size( + image, size=size["shortest_edge"], default_to_square=False, input_data_format=input_data_format + ) + elif "height" in size and "width" in size: + output_size = (size["height"], size["width"]) + else: + raise ValueError("Invalid size for resize: {}".format(size)) + + return resize( + image, + size=output_size, + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + @filter_out_non_signature_kwargs() + def preprocess( + self, + images: ImageInput, + do_resize: Optional[bool] = None, + size: Dict[str, int] = None, + crop_pct: Optional[int] = None, + resample: PILImageResampling = None, + do_center_crop: Optional[bool] = None, + crop_size: Dict[str, int] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: ChannelDimension = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> PIL.Image.Image: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Size of the image after applying resize. + crop_pct (`float`, *optional*, defaults to `self.crop_pct`): + Percentage of the image to crop. Only has an effect if `do_resize` is set to `True`. + resample (`int`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only + has an effect if `do_resize` is set to `True`. + do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`): + Whether to center crop the image. + crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`): + Size of the image after applying center crop. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image values between [0 - 1]. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `ChannelDimension.LAST`: image in (height, width, num_channels) format. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + do_resize = do_resize if do_resize is not None else self.do_resize + crop_pct = crop_pct if crop_pct is not None else self.crop_pct + resample = resample if resample is not None else self.resample + do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + + size = size if size is not None else self.size + size = get_size_dict(size, default_to_square=False) + crop_size = crop_size if crop_size is not None else self.crop_size + crop_size = get_size_dict(crop_size, param_name="crop_size") + + images = make_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_resize=do_resize, + size=size, + resample=resample, + ) + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if do_rescale and is_scaled_image(images[0]): + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + if do_resize: + images = [ + self.resize( + image=image, size=size, crop_pct=crop_pct, resample=resample, input_data_format=input_data_format + ) + for image in images + ] + + if do_center_crop: + images = [ + self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) for image in images + ] + + if do_rescale: + images = [ + self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + for image in images + ] + + if do_normalize: + images = [ + self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) + for image in images + ] + + images = [ + to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images + ] + + data = {"pixel_values": images} + return BatchFeature(data=data, tensor_type=return_tensors) + + +__all__ = ["PoolFormerImageProcessor"] diff --git a/docs/transformers/src/transformers/models/poolformer/image_processing_poolformer_fast.py b/docs/transformers/src/transformers/models/poolformer/image_processing_poolformer_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..37e81a5f5f9e6c1351bf790182763814946fa76c --- /dev/null +++ b/docs/transformers/src/transformers/models/poolformer/image_processing_poolformer_fast.py @@ -0,0 +1,270 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for PoolFormer.""" + +from typing import Optional, Union + +from ...image_processing_utils_fast import ( + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS, + BaseImageProcessorFast, + BatchFeature, + DefaultFastImageProcessorKwargs, +) +from ...image_transforms import ( + ChannelDimension, + get_resize_output_image_size, + get_size_with_aspect_ratio, + group_images_by_shape, + reorder_images, +) +from ...image_utils import ( + IMAGENET_DEFAULT_MEAN, + IMAGENET_DEFAULT_STD, + ImageInput, + PILImageResampling, + SizeDict, + get_image_size_for_max_height_width, +) +from ...processing_utils import Unpack +from ...utils import ( + TensorType, + add_start_docstrings, + is_torch_available, + is_torchvision_available, + is_torchvision_v2_available, +) + + +if is_torch_available(): + import torch + +if is_torchvision_available(): + if is_torchvision_v2_available(): + from torchvision.transforms.v2 import functional as F + else: + from torchvision.transforms import functional as F + + +class PoolFormerFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): + crop_pct: Optional[float] + + +@add_start_docstrings( + "Constructs a fast PoolFormer image processor.", + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, + """ + crop_pct (`float`, *optional*, defaults to `self.crop_pct`): + Percentage of the image to crop. Only has an effect if `do_resize` is set to `True`. + """, +) +class PoolFormerImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.BICUBIC + image_mean = IMAGENET_DEFAULT_MEAN + image_std = IMAGENET_DEFAULT_STD + size = {"shortest_edge": 224} + default_to_square = False + crop_size = {"height": 224, "width": 224} + crop_pct = 0.9 + do_resize = True + do_center_crop = True + do_rescale = True + do_normalize = True + valid_kwargs = PoolFormerFastImageProcessorKwargs + + def __init__(self, **kwargs: Unpack[PoolFormerFastImageProcessorKwargs]): + super().__init__(**kwargs) + + @add_start_docstrings( + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS, + """ + crop_pct (`float`, *optional*, defaults to `self.crop_pct`): + Percentage of the image to crop. Only has an effect if `do_resize` is set to `True`. + """, + ) + def preprocess(self, images: ImageInput, **kwargs: Unpack[PoolFormerFastImageProcessorKwargs]) -> BatchFeature: + return super().preprocess(images, **kwargs) + + def resize( + self, + image: "torch.Tensor", + size: SizeDict, + crop_pct: Optional[float] = None, + interpolation: "F.InterpolationMode" = None, + antialias: bool = True, + **kwargs, + ) -> "torch.Tensor": + """ + Resize an image. + + If crop_pct is unset: + - size is `{"height": h, "width": w}`: the image is resized to `(h, w)`. + - size is `{"shortest_edge": s}`: the shortest edge of the image is resized to s whilst maintaining the + aspect ratio. + + if crop_pct is set: + - size is `{"height": h, "width": w}`: the image is resized to `(int(floor(h/crop_pct)), + int(floor(w/crop_pct)))` + - size is `{"height": c, "width": c}`: the shortest edge of the image is resized to `int(floor(c/crop_pct)` + whilst maintaining the aspect ratio. + - size is `{"shortest_edge": c}`: the shortest edge of the image is resized to `int(floor(c/crop_pct)` + whilst maintaining the aspect ratio. + + Args: + image (`torch.Tensor`): + Image to resize. + size (`SizeDict`): + Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image. + crop_pct (`float`, *optional*): + Percentage of the image that will be cropped from the center. If set, the image is resized + resample (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`): + `InterpolationMode` filter to use when resizing the image e.g. `InterpolationMode.BICUBIC`. + + Returns: + `torch.Tensor`: The resized image. + """ + interpolation = interpolation if interpolation is not None else F.InterpolationMode.BILINEAR + if crop_pct is not None: + if size.shortest_edge: + scale_size = int(size.shortest_edge / crop_pct) + elif size.height and size.width: + if size.height == size.width: + scale_size = int(size.height / crop_pct) + else: + scale_size = (int(size.height / crop_pct), int(size.width / crop_pct)) + else: + raise ValueError("Invalid size for resize: {}".format(size)) + + new_size = get_resize_output_image_size( + image, + size=scale_size, + default_to_square=False, + input_data_format=ChannelDimension.FIRST, + ) + else: + if size.shortest_edge and size.longest_edge: + # Resize the image so that the shortest edge or the longest edge is of the given size + # while maintaining the aspect ratio of the original image. + new_size = get_size_with_aspect_ratio( + image.size()[-2:], + size.shortest_edge, + size.longest_edge, + ) + elif size.shortest_edge: + new_size = get_resize_output_image_size( + image, + size=size.shortest_edge, + default_to_square=False, + input_data_format=ChannelDimension.FIRST, + ) + elif size.max_height and size.max_width: + new_size = get_image_size_for_max_height_width(image.size()[-2:], size.max_height, size.max_width) + elif size.height and size.width: + new_size = (size.height, size.width) + else: + raise ValueError( + "Size must contain 'height' and 'width' keys, or 'max_height' and 'max_width', or 'shortest_edge' key. Got" + f" {size}." + ) + return F.resize(image, new_size, interpolation=interpolation, antialias=antialias) + + def center_crop( + self, + image: "torch.Tensor", + size: SizeDict, + **kwargs, + ) -> "torch.Tensor": + """ + Center crop an image to `(size["height"], size["width"])`. If the input size is smaller than `crop_size` along + any edge, the image is padded with 0's and then center cropped. + + Args: + image (`"torch.Tensor"`): + Image to center crop. + size (`Dict[str, int]`): + Size of the output image. + + Returns: + `torch.Tensor`: The center cropped image. + """ + if size.height is None or size.width is None: + raise ValueError(f"The size dictionary must have keys 'height' and 'width'. Got {size.keys()}") + image_height, image_width = image.shape[-2:] + crop_height, crop_width = size.height, size.width + + if crop_width > image_width or crop_height > image_height: + padding_ltrb = [ + (crop_width - image_width) // 2 if crop_width > image_width else 0, + (crop_height - image_height) // 2 if crop_height > image_height else 0, + (crop_width - image_width + 1) // 2 if crop_width > image_width else 0, + (crop_height - image_height + 1) // 2 if crop_height > image_height else 0, + ] + image = F.pad(image, padding_ltrb, fill=0) # PIL uses fill value 0 + image_height, image_width = image.shape[-2:] + if crop_width == image_width and crop_height == image_height: + return image + + crop_top = int((image_height - crop_height) / 2.0) + crop_left = int((image_width - crop_width) / 2.0) + return F.crop(image, crop_top, crop_left, crop_height, crop_width) + + def _preprocess( + self, + images: list["torch.Tensor"], + do_resize: bool, + size: SizeDict, + crop_pct: float, + interpolation: Optional["F.InterpolationMode"], + do_center_crop: bool, + crop_size: SizeDict, + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + image_mean: Optional[Union[float, list[float]]], + image_std: Optional[Union[float, list[float]]], + return_tensors: Optional[Union[str, TensorType]], + **kwargs, + ) -> BatchFeature: + # Group images by size for batched resizing + grouped_images, grouped_images_index = group_images_by_shape(images) + resized_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_resize: + stacked_images = self.resize( + image=stacked_images, size=size, crop_pct=crop_pct, interpolation=interpolation + ) + resized_images_grouped[shape] = stacked_images + resized_images = reorder_images(resized_images_grouped, grouped_images_index) + + # Group images by size for further processing + # Needed in case do_resize is False, or resize returns images with different sizes + grouped_images, grouped_images_index = group_images_by_shape(resized_images) + processed_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_center_crop: + stacked_images = self.center_crop(stacked_images, crop_size) + # Fused rescale and normalize + stacked_images = self.rescale_and_normalize( + stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std + ) + processed_images_grouped[shape] = stacked_images + + processed_images = reorder_images(processed_images_grouped, grouped_images_index) + processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images + + return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) + + +__all__ = ["PoolFormerImageProcessorFast"] diff --git a/docs/transformers/src/transformers/models/poolformer/modeling_poolformer.py b/docs/transformers/src/transformers/models/poolformer/modeling_poolformer.py new file mode 100644 index 0000000000000000000000000000000000000000..5b2ed8868b19ee5eab5971636be40dd256672572 --- /dev/null +++ b/docs/transformers/src/transformers/models/poolformer/modeling_poolformer.py @@ -0,0 +1,452 @@ +# coding=utf-8 +# Copyright 2022 Sea AI Lab and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch PoolFormer model.""" + +import collections.abc +from typing import Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...modeling_outputs import BaseModelOutputWithNoAttention, ImageClassifierOutputWithNoAttention +from ...modeling_utils import PreTrainedModel +from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging +from .configuration_poolformer import PoolFormerConfig + + +logger = logging.get_logger(__name__) + +# General docstring +_CONFIG_FOR_DOC = "PoolFormerConfig" + +# Base docstring +_CHECKPOINT_FOR_DOC = "sail/poolformer_s12" +_EXPECTED_OUTPUT_SHAPE = [1, 512, 7, 7] + +# Image classification docstring +_IMAGE_CLASS_CHECKPOINT = "sail/poolformer_s12" +_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat" + + +# Copied from transformers.models.beit.modeling_beit.drop_path +def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, + however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the + layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the + argument. + """ + if drop_prob == 0.0 or not training: + return input + keep_prob = 1 - drop_prob + shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) + random_tensor.floor_() # binarize + output = input.div(keep_prob) * random_tensor + return output + + +# Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->PoolFormer +class PoolFormerDropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: Optional[float] = None) -> None: + super().__init__() + self.drop_prob = drop_prob + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return drop_path(hidden_states, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return "p={}".format(self.drop_prob) + + +class PoolFormerEmbeddings(nn.Module): + """ + Construct Patch Embeddings. + """ + + def __init__(self, hidden_size, num_channels, patch_size, stride, padding, norm_layer=None): + super().__init__() + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + stride = stride if isinstance(stride, collections.abc.Iterable) else (stride, stride) + padding = padding if isinstance(padding, collections.abc.Iterable) else (padding, padding) + + self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=stride, padding=padding) + self.norm = norm_layer(hidden_size) if norm_layer else nn.Identity() + + def forward(self, pixel_values): + embeddings = self.projection(pixel_values) + embeddings = self.norm(embeddings) + return embeddings + + +class PoolFormerGroupNorm(nn.GroupNorm): + """ + Group Normalization with 1 group. Input: tensor in shape [B, C, H, W] + """ + + def __init__(self, num_channels, **kwargs): + super().__init__(1, num_channels, **kwargs) + + +class PoolFormerPooling(nn.Module): + def __init__(self, pool_size): + super().__init__() + self.pool = nn.AvgPool2d(pool_size, stride=1, padding=pool_size // 2, count_include_pad=False) + + def forward(self, hidden_states): + return self.pool(hidden_states) - hidden_states + + +class PoolFormerOutput(nn.Module): + def __init__(self, config, dropout_prob, hidden_size, intermediate_size): + super().__init__() + self.conv1 = nn.Conv2d(hidden_size, intermediate_size, 1) + self.conv2 = nn.Conv2d(intermediate_size, hidden_size, 1) + self.drop = PoolFormerDropPath(dropout_prob) + if isinstance(config.hidden_act, str): + self.act_fn = ACT2FN[config.hidden_act] + else: + self.act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.conv1(hidden_states) + hidden_states = self.act_fn(hidden_states) + hidden_states = self.drop(hidden_states) + hidden_states = self.conv2(hidden_states) + hidden_states = self.drop(hidden_states) + + return hidden_states + + +class PoolFormerLayer(nn.Module): + """This corresponds to the 'PoolFormerBlock' class in the original implementation.""" + + def __init__(self, config, num_channels, pool_size, hidden_size, intermediate_size, drop_path): + super().__init__() + self.pooling = PoolFormerPooling(pool_size) + self.output = PoolFormerOutput(config, drop_path, hidden_size, intermediate_size) + self.before_norm = PoolFormerGroupNorm(num_channels) + self.after_norm = PoolFormerGroupNorm(num_channels) + + # Useful for training neural nets + self.drop_path = PoolFormerDropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.use_layer_scale = config.use_layer_scale + if config.use_layer_scale: + self.layer_scale_1 = nn.Parameter( + config.layer_scale_init_value * torch.ones((num_channels)), requires_grad=True + ) + self.layer_scale_2 = nn.Parameter( + config.layer_scale_init_value * torch.ones((num_channels)), requires_grad=True + ) + + def forward(self, hidden_states): + if self.use_layer_scale: + pooling_output = self.pooling(self.before_norm(hidden_states)) + scaled_op = self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * pooling_output + # First residual connection + hidden_states = hidden_states + self.drop_path(scaled_op) + outputs = () + + layer_output = self.output(self.after_norm(hidden_states)) + scaled_op = self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * layer_output + # Second residual connection + output = hidden_states + self.drop_path(scaled_op) + + outputs = (output,) + outputs + return outputs + + else: + pooling_output = self.drop_path(self.pooling(self.before_norm(hidden_states))) + # First residual connection + hidden_states = pooling_output + hidden_states + outputs = () + + # Second residual connection inside the PoolFormerOutput block + layer_output = self.drop_path(self.output(self.after_norm(hidden_states))) + output = hidden_states + layer_output + + outputs = (output,) + outputs + return outputs + + +class PoolFormerEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + # stochastic depth decay rule + dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")] + + # patch embeddings + embeddings = [] + for i in range(config.num_encoder_blocks): + embeddings.append( + PoolFormerEmbeddings( + patch_size=config.patch_sizes[i], + stride=config.strides[i], + padding=config.padding[i], + num_channels=config.num_channels if i == 0 else config.hidden_sizes[i - 1], + hidden_size=config.hidden_sizes[i], + ) + ) + self.patch_embeddings = nn.ModuleList(embeddings) + + # Transformer blocks + blocks = [] + cur = 0 + for i in range(config.num_encoder_blocks): + # each block consists of layers + layers = [] + if i != 0: + cur += config.depths[i - 1] + for j in range(config.depths[i]): + layers.append( + PoolFormerLayer( + config, + num_channels=config.hidden_sizes[i], + pool_size=config.pool_size, + hidden_size=config.hidden_sizes[i], + intermediate_size=int(config.hidden_sizes[i] * config.mlp_ratio), + drop_path=dpr[cur + j], + ) + ) + blocks.append(nn.ModuleList(layers)) + + self.block = nn.ModuleList(blocks) + + def forward(self, pixel_values, output_hidden_states=False, return_dict=True): + all_hidden_states = () if output_hidden_states else None + + hidden_states = pixel_values + for idx, layers in enumerate(zip(self.patch_embeddings, self.block)): + embedding_layer, block_layer = layers + # Get patch embeddings from hidden_states + hidden_states = embedding_layer(hidden_states) + # Send the embeddings through the blocks + for _, blk in enumerate(block_layer): + layer_outputs = blk(hidden_states) + hidden_states = layer_outputs[0] + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states] if v is not None) + + return BaseModelOutputWithNoAttention(last_hidden_state=hidden_states, hidden_states=all_hidden_states) + + +class PoolFormerPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = PoolFormerConfig + base_model_prefix = "poolformer" + main_input_name = "pixel_values" + _no_split_modules = ["PoolFormerLayer"] + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.GroupNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, PoolFormerLayer): + if hasattr(module, "layer_scale_1"): + module.layer_scale_1.data.fill_(self.config.layer_scale_init_value) + module.layer_scale_2.data.fill_(self.config.layer_scale_init_value) + + +POOLFORMER_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use + it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`PoolFormerConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +POOLFORMER_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See + [`PoolFormerImageProcessor.__call__`] for details. +""" + + +@add_start_docstrings( + "The bare PoolFormer Model transformer outputting raw hidden-states without any specific head on top.", + POOLFORMER_START_DOCSTRING, +) +class PoolFormerModel(PoolFormerPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.config = config + + self.encoder = PoolFormerEncoder(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.patch_embeddings + + @add_start_docstrings_to_model_forward(POOLFORMER_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithNoAttention, + config_class=_CONFIG_FOR_DOC, + modality="vision", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithNoAttention]: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + encoder_outputs = self.encoder( + pixel_values, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + + if not return_dict: + return (sequence_output, None) + encoder_outputs[1:] + + return BaseModelOutputWithNoAttention( + last_hidden_state=sequence_output, + hidden_states=encoder_outputs.hidden_states, + ) + + +class PoolFormerFinalPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + + def forward(self, hidden_states): + output = self.dense(hidden_states) + return output + + +@add_start_docstrings( + """ + PoolFormer Model transformer with an image classification head on top + """, + POOLFORMER_START_DOCSTRING, +) +class PoolFormerForImageClassification(PoolFormerPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.poolformer = PoolFormerModel(config) + + # Final norm + self.norm = PoolFormerGroupNorm(config.hidden_sizes[-1]) + # Classifier head + self.classifier = ( + nn.Linear(config.hidden_sizes[-1], config.num_labels) if config.num_labels > 0 else nn.Identity() + ) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(POOLFORMER_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_IMAGE_CLASS_CHECKPOINT, + output_type=ImageClassifierOutputWithNoAttention, + config_class=_CONFIG_FOR_DOC, + expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, + ) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, ImageClassifierOutputWithNoAttention]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.poolformer( + pixel_values, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.classifier(self.norm(sequence_output).mean([-2, -1])) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states) + + +__all__ = ["PoolFormerForImageClassification", "PoolFormerModel", "PoolFormerPreTrainedModel"] diff --git a/docs/transformers/src/transformers/models/pop2piano/__init__.py b/docs/transformers/src/transformers/models/pop2piano/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cbbfcbf157b58aa582e2a8b2a99c7d561372904b --- /dev/null +++ b/docs/transformers/src/transformers/models/pop2piano/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_pop2piano import * + from .modeling_pop2piano import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/pop2piano/configuration_pop2piano.py b/docs/transformers/src/transformers/models/pop2piano/configuration_pop2piano.py new file mode 100644 index 0000000000000000000000000000000000000000..484e1a4f933e9626919df8c522432e07184ec4c1 --- /dev/null +++ b/docs/transformers/src/transformers/models/pop2piano/configuration_pop2piano.py @@ -0,0 +1,127 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Pop2Piano model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class Pop2PianoConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Pop2PianoForConditionalGeneration`]. It is used + to instantiate a Pop2PianoForConditionalGeneration model according to the specified arguments, defining the model + architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the + Pop2Piano [sweetcocoa/pop2piano](https://huggingface.co/sweetcocoa/pop2piano) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Arguments: + vocab_size (`int`, *optional*, defaults to 2400): + Vocabulary size of the `Pop2PianoForConditionalGeneration` model. Defines the number of different tokens + that can be represented by the `inputs_ids` passed when calling [`Pop2PianoForConditionalGeneration`]. + composer_vocab_size (`int`, *optional*, defaults to 21): + Denotes the number of composers. + d_model (`int`, *optional*, defaults to 512): + Size of the encoder layers and the pooler layer. + d_kv (`int`, *optional*, defaults to 64): + Size of the key, query, value projections per attention head. The `inner_dim` of the projection layer will + be defined as `num_heads * d_kv`. + d_ff (`int`, *optional*, defaults to 2048): + Size of the intermediate feed forward layer in each `Pop2PianoBlock`. + num_layers (`int`, *optional*, defaults to 6): + Number of hidden layers in the Transformer encoder. + num_decoder_layers (`int`, *optional*): + Number of hidden layers in the Transformer decoder. Will use the same value as `num_layers` if not set. + num_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer encoder. + relative_attention_num_buckets (`int`, *optional*, defaults to 32): + The number of buckets to use for each attention layer. + relative_attention_max_distance (`int`, *optional*, defaults to 128): + The maximum distance of the longer sequences for the bucket separation. + dropout_rate (`float`, *optional*, defaults to 0.1): + The ratio for all dropout layers. + layer_norm_epsilon (`float`, *optional*, defaults to 1e-6): + The epsilon used by the layer normalization layers. + initializer_factor (`float`, *optional*, defaults to 1.0): + A factor for initializing all weight matrices (should be kept to 1.0, used internally for initialization + testing). + feed_forward_proj (`string`, *optional*, defaults to `"gated-gelu"`): + Type of feed forward layer to be used. Should be one of `"relu"` or `"gated-gelu"`. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + dense_act_fn (`string`, *optional*, defaults to `"relu"`): + Type of Activation Function to be used in `Pop2PianoDenseActDense` and in `Pop2PianoDenseGatedActDense`. + """ + + model_type = "pop2piano" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=2400, + composer_vocab_size=21, + d_model=512, + d_kv=64, + d_ff=2048, + num_layers=6, + num_decoder_layers=None, + num_heads=8, + relative_attention_num_buckets=32, + relative_attention_max_distance=128, + dropout_rate=0.1, + layer_norm_epsilon=1e-6, + initializer_factor=1.0, + feed_forward_proj="gated-gelu", # noqa + is_encoder_decoder=True, + use_cache=True, + pad_token_id=0, + eos_token_id=1, + dense_act_fn="relu", + **kwargs, + ): + self.vocab_size = vocab_size + self.composer_vocab_size = composer_vocab_size + self.d_model = d_model + self.d_kv = d_kv + self.d_ff = d_ff + self.num_layers = num_layers + self.num_decoder_layers = num_decoder_layers if num_decoder_layers is not None else self.num_layers + self.num_heads = num_heads + self.relative_attention_num_buckets = relative_attention_num_buckets + self.relative_attention_max_distance = relative_attention_max_distance + self.dropout_rate = dropout_rate + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_factor = initializer_factor + self.feed_forward_proj = feed_forward_proj + self.use_cache = use_cache + self.dense_act_fn = dense_act_fn + self.is_gated_act = self.feed_forward_proj.split("-")[0] == "gated" + self.hidden_size = self.d_model + self.num_attention_heads = num_heads + self.num_hidden_layers = num_layers + + super().__init__( + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + is_encoder_decoder=is_encoder_decoder, + **kwargs, + ) + + +__all__ = ["Pop2PianoConfig"] diff --git a/docs/transformers/src/transformers/models/pop2piano/convert_pop2piano_weights_to_hf.py b/docs/transformers/src/transformers/models/pop2piano/convert_pop2piano_weights_to_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..84788ac6aecf63d58becc522502ec99c720dce5d --- /dev/null +++ b/docs/transformers/src/transformers/models/pop2piano/convert_pop2piano_weights_to_hf.py @@ -0,0 +1,190 @@ +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""File for loading the Pop2Piano model weights from the official repository and to show how tokenizer vocab was +constructed""" + +import json + +import torch + +from transformers import Pop2PianoConfig, Pop2PianoForConditionalGeneration + + +########################## MODEL WEIGHTS ########################## + +# This weights were downloaded from the official pop2piano repository +# https://huggingface.co/sweetcocoa/pop2piano/blob/main/model-1999-val_0.67311615.ckpt +official_weights = torch.load("./model-1999-val_0.67311615.ckpt", weights_only=True) +state_dict = {} + + +# load the config and init the model +cfg = Pop2PianoConfig.from_pretrained("sweetcocoa/pop2piano") +model = Pop2PianoForConditionalGeneration(cfg) + + +# load relative attention bias +state_dict["encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"] = official_weights["state_dict"][ + "transformer.encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight" +] +state_dict["decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"] = official_weights["state_dict"][ + "transformer.decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight" +] + +# load embed tokens and final layer norm for both encoder and decoder +state_dict["encoder.embed_tokens.weight"] = official_weights["state_dict"]["transformer.encoder.embed_tokens.weight"] +state_dict["decoder.embed_tokens.weight"] = official_weights["state_dict"]["transformer.decoder.embed_tokens.weight"] + +state_dict["encoder.final_layer_norm.weight"] = official_weights["state_dict"][ + "transformer.encoder.final_layer_norm.weight" +] +state_dict["decoder.final_layer_norm.weight"] = official_weights["state_dict"][ + "transformer.decoder.final_layer_norm.weight" +] + +# load lm_head, mel_conditioner.emb and shared +state_dict["lm_head.weight"] = official_weights["state_dict"]["transformer.lm_head.weight"] +state_dict["mel_conditioner.embedding.weight"] = official_weights["state_dict"]["mel_conditioner.embedding.weight"] +state_dict["shared.weight"] = official_weights["state_dict"]["transformer.shared.weight"] + +# load each encoder blocks +for i in range(cfg.num_layers): + # layer 0 + state_dict[f"encoder.block.{i}.layer.0.SelfAttention.q.weight"] = official_weights["state_dict"][ + f"transformer.encoder.block.{i}.layer.0.SelfAttention.q.weight" + ] + state_dict[f"encoder.block.{i}.layer.0.SelfAttention.k.weight"] = official_weights["state_dict"][ + f"transformer.encoder.block.{i}.layer.0.SelfAttention.k.weight" + ] + state_dict[f"encoder.block.{i}.layer.0.SelfAttention.v.weight"] = official_weights["state_dict"][ + f"transformer.encoder.block.{i}.layer.0.SelfAttention.v.weight" + ] + state_dict[f"encoder.block.{i}.layer.0.SelfAttention.o.weight"] = official_weights["state_dict"][ + f"transformer.encoder.block.{i}.layer.0.SelfAttention.o.weight" + ] + state_dict[f"encoder.block.{i}.layer.0.layer_norm.weight"] = official_weights["state_dict"][ + f"transformer.encoder.block.{i}.layer.0.layer_norm.weight" + ] + + # layer 1 + state_dict[f"encoder.block.{i}.layer.1.DenseReluDense.wi_0.weight"] = official_weights["state_dict"][ + f"transformer.encoder.block.{i}.layer.1.DenseReluDense.wi_0.weight" + ] + state_dict[f"encoder.block.{i}.layer.1.DenseReluDense.wi_1.weight"] = official_weights["state_dict"][ + f"transformer.encoder.block.{i}.layer.1.DenseReluDense.wi_1.weight" + ] + state_dict[f"encoder.block.{i}.layer.1.DenseReluDense.wo.weight"] = official_weights["state_dict"][ + f"transformer.encoder.block.{i}.layer.1.DenseReluDense.wo.weight" + ] + state_dict[f"encoder.block.{i}.layer.1.layer_norm.weight"] = official_weights["state_dict"][ + f"transformer.encoder.block.{i}.layer.1.layer_norm.weight" + ] + +# load each decoder blocks +for i in range(6): + # layer 0 + state_dict[f"decoder.block.{i}.layer.0.SelfAttention.q.weight"] = official_weights["state_dict"][ + f"transformer.decoder.block.{i}.layer.0.SelfAttention.q.weight" + ] + state_dict[f"decoder.block.{i}.layer.0.SelfAttention.k.weight"] = official_weights["state_dict"][ + f"transformer.decoder.block.{i}.layer.0.SelfAttention.k.weight" + ] + state_dict[f"decoder.block.{i}.layer.0.SelfAttention.v.weight"] = official_weights["state_dict"][ + f"transformer.decoder.block.{i}.layer.0.SelfAttention.v.weight" + ] + state_dict[f"decoder.block.{i}.layer.0.SelfAttention.o.weight"] = official_weights["state_dict"][ + f"transformer.decoder.block.{i}.layer.0.SelfAttention.o.weight" + ] + state_dict[f"decoder.block.{i}.layer.0.layer_norm.weight"] = official_weights["state_dict"][ + f"transformer.decoder.block.{i}.layer.0.layer_norm.weight" + ] + + # layer 1 + state_dict[f"decoder.block.{i}.layer.1.EncDecAttention.q.weight"] = official_weights["state_dict"][ + f"transformer.decoder.block.{i}.layer.1.EncDecAttention.q.weight" + ] + state_dict[f"decoder.block.{i}.layer.1.EncDecAttention.k.weight"] = official_weights["state_dict"][ + f"transformer.decoder.block.{i}.layer.1.EncDecAttention.k.weight" + ] + state_dict[f"decoder.block.{i}.layer.1.EncDecAttention.v.weight"] = official_weights["state_dict"][ + f"transformer.decoder.block.{i}.layer.1.EncDecAttention.v.weight" + ] + state_dict[f"decoder.block.{i}.layer.1.EncDecAttention.o.weight"] = official_weights["state_dict"][ + f"transformer.decoder.block.{i}.layer.1.EncDecAttention.o.weight" + ] + state_dict[f"decoder.block.{i}.layer.1.layer_norm.weight"] = official_weights["state_dict"][ + f"transformer.decoder.block.{i}.layer.1.layer_norm.weight" + ] + + # layer 2 + state_dict[f"decoder.block.{i}.layer.2.DenseReluDense.wi_0.weight"] = official_weights["state_dict"][ + f"transformer.decoder.block.{i}.layer.2.DenseReluDense.wi_0.weight" + ] + state_dict[f"decoder.block.{i}.layer.2.DenseReluDense.wi_1.weight"] = official_weights["state_dict"][ + f"transformer.decoder.block.{i}.layer.2.DenseReluDense.wi_1.weight" + ] + state_dict[f"decoder.block.{i}.layer.2.DenseReluDense.wo.weight"] = official_weights["state_dict"][ + f"transformer.decoder.block.{i}.layer.2.DenseReluDense.wo.weight" + ] + state_dict[f"decoder.block.{i}.layer.2.layer_norm.weight"] = official_weights["state_dict"][ + f"transformer.decoder.block.{i}.layer.2.layer_norm.weight" + ] + +model.load_state_dict(state_dict, strict=True) + +# save the weights +torch.save(state_dict, "./pytorch_model.bin") + +########################## TOKENIZER ########################## + +# the tokenize and detokenize methods are taken from the official implementation + + +# link : https://github.com/sweetcocoa/pop2piano/blob/fac11e8dcfc73487513f4588e8d0c22a22f2fdc5/midi_tokenizer.py#L34 +def tokenize(idx, token_type, n_special=4, n_note=128, n_velocity=2): + if token_type == "TOKEN_TIME": + return n_special + n_note + n_velocity + idx + elif token_type == "TOKEN_VELOCITY": + return n_special + n_note + idx + elif token_type == "TOKEN_NOTE": + return n_special + idx + elif token_type == "TOKEN_SPECIAL": + return idx + else: + return -1 + + +# link : https://github.com/sweetcocoa/pop2piano/blob/fac11e8dcfc73487513f4588e8d0c22a22f2fdc5/midi_tokenizer.py#L48 +def detokenize(idx, n_special=4, n_note=128, n_velocity=2, time_idx_offset=0): + if idx >= n_special + n_note + n_velocity: + return "TOKEN_TIME", (idx - (n_special + n_note + n_velocity)) + time_idx_offset + elif idx >= n_special + n_note: + return "TOKEN_VELOCITY", idx - (n_special + n_note) + elif idx >= n_special: + return "TOKEN_NOTE", idx - n_special + else: + return "TOKEN_SPECIAL", idx + + +# crate the decoder and then the encoder of the tokenizer +decoder = {} +for i in range(cfg.vocab_size): + decoder.update({i: f"{detokenize(i)[1]}_{detokenize(i)[0]}"}) + +encoder = {v: k for k, v in decoder.items()} + +# save the vocab +with open("./vocab.json", "w") as file: + file.write(json.dumps(encoder)) diff --git a/docs/transformers/src/transformers/models/pop2piano/feature_extraction_pop2piano.py b/docs/transformers/src/transformers/models/pop2piano/feature_extraction_pop2piano.py new file mode 100644 index 0000000000000000000000000000000000000000..0b191ab106b100d9f4f8d57a92b56f715abd8fad --- /dev/null +++ b/docs/transformers/src/transformers/models/pop2piano/feature_extraction_pop2piano.py @@ -0,0 +1,455 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Feature extractor class for Pop2Piano""" + +import warnings +from typing import List, Optional, Union + +import numpy +import numpy as np + +from ...audio_utils import mel_filter_bank, spectrogram +from ...feature_extraction_sequence_utils import SequenceFeatureExtractor +from ...feature_extraction_utils import BatchFeature +from ...utils import ( + TensorType, + is_essentia_available, + is_librosa_available, + is_scipy_available, + logging, + requires_backends, +) +from ...utils.import_utils import requires + + +if is_essentia_available(): + import essentia + import essentia.standard + +if is_librosa_available(): + import librosa + +if is_scipy_available(): + import scipy + + +logger = logging.get_logger(__name__) + + +@requires(backends=("essentia", "librosa", "scipy", "torch")) +class Pop2PianoFeatureExtractor(SequenceFeatureExtractor): + r""" + Constructs a Pop2Piano feature extractor. + + This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains + most of the main methods. Users should refer to this superclass for more information regarding those methods. + + This class extracts rhythm and preprocesses the audio before it is passed to the model. First the audio is passed + to `RhythmExtractor2013` algorithm which extracts the beat_times, beat positions and estimates their confidence as + well as tempo in bpm, then beat_times is interpolated and to get beatsteps. Later we calculate + extrapolated_beatsteps from it to be used in tokenizer. On the other hand audio is resampled to self.sampling_rate + and preprocessed and then log mel spectogram is computed from that to be used in our transformer model. + + Args: + sampling_rate (`int`, *optional*, defaults to 22050): + Target Sampling rate of audio signal. It's the sampling rate that we forward to the model. + padding_value (`int`, *optional*, defaults to 0): + Padding value used to pad the audio. Should correspond to silences. + window_size (`int`, *optional*, defaults to 4096): + Length of the window in samples to which the Fourier transform is applied. + hop_length (`int`, *optional*, defaults to 1024): + Step size between each window of the waveform, in samples. + min_frequency (`float`, *optional*, defaults to 10.0): + Lowest frequency that will be used in the log-mel spectrogram. + feature_size (`int`, *optional*, defaults to 512): + The feature dimension of the extracted features. + num_bars (`int`, *optional*, defaults to 2): + Determines interval between each sequence. + """ + + model_input_names = ["input_features", "beatsteps", "extrapolated_beatstep"] + + def __init__( + self, + sampling_rate: int = 22050, + padding_value: int = 0, + window_size: int = 4096, + hop_length: int = 1024, + min_frequency: float = 10.0, + feature_size: int = 512, + num_bars: int = 2, + **kwargs, + ): + super().__init__( + feature_size=feature_size, + sampling_rate=sampling_rate, + padding_value=padding_value, + **kwargs, + ) + self.sampling_rate = sampling_rate + self.padding_value = padding_value + self.window_size = window_size + self.hop_length = hop_length + self.min_frequency = min_frequency + self.feature_size = feature_size + self.num_bars = num_bars + self.mel_filters = mel_filter_bank( + num_frequency_bins=(self.window_size // 2) + 1, + num_mel_filters=self.feature_size, + min_frequency=self.min_frequency, + max_frequency=float(self.sampling_rate // 2), + sampling_rate=self.sampling_rate, + norm=None, + mel_scale="htk", + ) + + def mel_spectrogram(self, sequence: np.ndarray): + """ + Generates MelSpectrogram. + + Args: + sequence (`numpy.ndarray`): + The sequence of which the mel-spectrogram will be computed. + """ + mel_specs = [] + for seq in sequence: + window = np.hanning(self.window_size + 1)[:-1] + mel_specs.append( + spectrogram( + waveform=seq, + window=window, + frame_length=self.window_size, + hop_length=self.hop_length, + power=2.0, + mel_filters=self.mel_filters, + ) + ) + mel_specs = np.array(mel_specs) + + return mel_specs + + def extract_rhythm(self, audio: np.ndarray): + """ + This algorithm(`RhythmExtractor2013`) extracts the beat positions and estimates their confidence as well as + tempo in bpm for an audio signal. For more information please visit + https://essentia.upf.edu/reference/std_RhythmExtractor2013.html . + + Args: + audio(`numpy.ndarray`): + raw audio waveform which is passed to the Rhythm Extractor. + """ + requires_backends(self, ["essentia"]) + essentia_tracker = essentia.standard.RhythmExtractor2013(method="multifeature") + bpm, beat_times, confidence, estimates, essentia_beat_intervals = essentia_tracker(audio) + + return bpm, beat_times, confidence, estimates, essentia_beat_intervals + + def interpolate_beat_times( + self, beat_times: numpy.ndarray, steps_per_beat: numpy.ndarray, n_extend: numpy.ndarray + ): + """ + This method takes beat_times and then interpolates that using `scipy.interpolate.interp1d` and the output is + then used to convert raw audio to log-mel-spectrogram. + + Args: + beat_times (`numpy.ndarray`): + beat_times is passed into `scipy.interpolate.interp1d` for processing. + steps_per_beat (`int`): + used as an parameter to control the interpolation. + n_extend (`int`): + used as an parameter to control the interpolation. + """ + + requires_backends(self, ["scipy"]) + beat_times_function = scipy.interpolate.interp1d( + np.arange(beat_times.size), + beat_times, + bounds_error=False, + fill_value="extrapolate", + ) + + ext_beats = beat_times_function( + np.linspace(0, beat_times.size + n_extend - 1, beat_times.size * steps_per_beat + n_extend) + ) + + return ext_beats + + def preprocess_mel(self, audio: np.ndarray, beatstep: np.ndarray): + """ + Preprocessing for log-mel-spectrogram + + Args: + audio (`numpy.ndarray` of shape `(audio_length, )` ): + Raw audio waveform to be processed. + beatstep (`numpy.ndarray`): + Interpolated values of the raw audio. If beatstep[0] is greater than 0.0, then it will be shifted by + the value at beatstep[0]. + """ + + if audio is not None and len(audio.shape) != 1: + raise ValueError( + f"Expected `audio` to be a single channel audio input of shape `(n, )` but found shape {audio.shape}." + ) + if beatstep[0] > 0.0: + beatstep = beatstep - beatstep[0] + + num_steps = self.num_bars * 4 + num_target_steps = len(beatstep) + extrapolated_beatstep = self.interpolate_beat_times( + beat_times=beatstep, steps_per_beat=1, n_extend=(self.num_bars + 1) * 4 + 1 + ) + + sample_indices = [] + max_feature_length = 0 + for i in range(0, num_target_steps, num_steps): + start_idx = i + end_idx = min(i + num_steps, num_target_steps) + start_sample = int(extrapolated_beatstep[start_idx] * self.sampling_rate) + end_sample = int(extrapolated_beatstep[end_idx] * self.sampling_rate) + sample_indices.append((start_sample, end_sample)) + max_feature_length = max(max_feature_length, end_sample - start_sample) + padded_batch = [] + for start_sample, end_sample in sample_indices: + feature = audio[start_sample:end_sample] + padded_feature = np.pad( + feature, + ((0, max_feature_length - feature.shape[0]),), + "constant", + constant_values=0, + ) + padded_batch.append(padded_feature) + + padded_batch = np.asarray(padded_batch) + return padded_batch, extrapolated_beatstep + + def _pad(self, features: np.ndarray, add_zero_line=True): + features_shapes = [each_feature.shape for each_feature in features] + attention_masks, padded_features = [], [] + for i, each_feature in enumerate(features): + # To pad "input_features". + if len(each_feature.shape) == 3: + features_pad_value = max([*zip(*features_shapes)][1]) - features_shapes[i][1] + attention_mask = np.ones(features_shapes[i][:2], dtype=np.int64) + feature_padding = ((0, 0), (0, features_pad_value), (0, 0)) + attention_mask_padding = (feature_padding[0], feature_padding[1]) + + # To pad "beatsteps" and "extrapolated_beatstep". + else: + each_feature = each_feature.reshape(1, -1) + features_pad_value = max([*zip(*features_shapes)][0]) - features_shapes[i][0] + attention_mask = np.ones(features_shapes[i], dtype=np.int64).reshape(1, -1) + feature_padding = attention_mask_padding = ((0, 0), (0, features_pad_value)) + + each_padded_feature = np.pad(each_feature, feature_padding, "constant", constant_values=self.padding_value) + attention_mask = np.pad( + attention_mask, attention_mask_padding, "constant", constant_values=self.padding_value + ) + + if add_zero_line: + # if it is batched then we seperate each examples using zero array + zero_array_len = max([*zip(*features_shapes)][1]) + + # we concatenate the zero array line here + each_padded_feature = np.concatenate( + [each_padded_feature, np.zeros([1, zero_array_len, self.feature_size])], axis=0 + ) + attention_mask = np.concatenate( + [attention_mask, np.zeros([1, zero_array_len], dtype=attention_mask.dtype)], axis=0 + ) + + padded_features.append(each_padded_feature) + attention_masks.append(attention_mask) + + padded_features = np.concatenate(padded_features, axis=0).astype(np.float32) + attention_masks = np.concatenate(attention_masks, axis=0).astype(np.int64) + + return padded_features, attention_masks + + def pad( + self, + inputs: BatchFeature, + is_batched: bool, + return_attention_mask: bool, + return_tensors: Optional[Union[str, TensorType]] = None, + ): + """ + Pads the inputs to same length and returns attention_mask. + + Args: + inputs (`BatchFeature`): + Processed audio features. + is_batched (`bool`): + Whether inputs are batched or not. + return_attention_mask (`bool`): + Whether to return attention mask or not. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors instead of list of python integers. Acceptable values are: + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return Numpy `np.ndarray` objects. + If nothing is specified, it will return list of `np.ndarray` arrays. + Return: + `BatchFeature` with attention_mask, attention_mask_beatsteps and attention_mask_extrapolated_beatstep added + to it: + - **attention_mask** numpy.ndarray of shape `(batch_size, max_input_features_seq_length)` -- + Example : + 1, 1, 1, 0, 0 (audio 1, also here it is padded to max length of 5 thats why there are 2 zeros at + the end indicating they are padded) + + 0, 0, 0, 0, 0 (zero pad to seperate audio 1 and 2) + + 1, 1, 1, 1, 1 (audio 2) + + 0, 0, 0, 0, 0 (zero pad to seperate audio 2 and 3) + + 1, 1, 1, 1, 1 (audio 3) + - **attention_mask_beatsteps** numpy.ndarray of shape `(batch_size, max_beatsteps_seq_length)` + - **attention_mask_extrapolated_beatstep** numpy.ndarray of shape `(batch_size, + max_extrapolated_beatstep_seq_length)` + """ + + processed_features_dict = {} + for feature_name, feature_value in inputs.items(): + if feature_name == "input_features": + padded_feature_values, attention_mask = self._pad(feature_value, add_zero_line=True) + processed_features_dict[feature_name] = padded_feature_values + if return_attention_mask: + processed_features_dict["attention_mask"] = attention_mask + else: + padded_feature_values, attention_mask = self._pad(feature_value, add_zero_line=False) + processed_features_dict[feature_name] = padded_feature_values + if return_attention_mask: + processed_features_dict[f"attention_mask_{feature_name}"] = attention_mask + + # If we are processing only one example, we should remove the zero array line since we don't need it to + # seperate examples from each other. + if not is_batched and not return_attention_mask: + processed_features_dict["input_features"] = processed_features_dict["input_features"][:-1, ...] + + outputs = BatchFeature(processed_features_dict, tensor_type=return_tensors) + + return outputs + + def __call__( + self, + audio: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]], + sampling_rate: Union[int, List[int]], + steps_per_beat: int = 2, + resample: Optional[bool] = True, + return_attention_mask: Optional[bool] = False, + return_tensors: Optional[Union[str, TensorType]] = None, + **kwargs, + ) -> BatchFeature: + """ + Main method to featurize and prepare for the model. + + Args: + audio (`np.ndarray`, `List`): + The audio or batch of audio to be processed. Each audio can be a numpy array, a list of float values, a + list of numpy arrays or a list of list of float values. + sampling_rate (`int`): + The sampling rate at which the `audio` input was sampled. It is strongly recommended to pass + `sampling_rate` at the forward call to prevent silent errors. + steps_per_beat (`int`, *optional*, defaults to 2): + This is used in interpolating `beat_times`. + resample (`bool`, *optional*, defaults to `True`): + Determines whether to resample the audio to `sampling_rate` or not before processing. Must be True + during inference. + return_attention_mask (`bool` *optional*, defaults to `False`): + Denotes if attention_mask for input_features, beatsteps and extrapolated_beatstep will be given as + output or not. Automatically set to True for batched inputs. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors instead of list of python integers. Acceptable values are: + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return Numpy `np.ndarray` objects. + If nothing is specified, it will return list of `np.ndarray` arrays. + """ + + requires_backends(self, ["librosa"]) + is_batched = bool(isinstance(audio, (list, tuple)) and isinstance(audio[0], (np.ndarray, tuple, list))) + if is_batched: + # This enables the user to process files of different sampling_rate at same time + if not isinstance(sampling_rate, list): + raise ValueError( + "Please give sampling_rate of each audio separately when you are passing multiple raw_audios at the same time. " + f"Received {sampling_rate}, expected [audio_1_sr, ..., audio_n_sr]." + ) + return_attention_mask = True if return_attention_mask is None else return_attention_mask + else: + audio = [audio] + sampling_rate = [sampling_rate] + return_attention_mask = False if return_attention_mask is None else return_attention_mask + + batch_input_features, batch_beatsteps, batch_ext_beatstep = [], [], [] + for single_raw_audio, single_sampling_rate in zip(audio, sampling_rate): + bpm, beat_times, confidence, estimates, essentia_beat_intervals = self.extract_rhythm( + audio=single_raw_audio + ) + beatsteps = self.interpolate_beat_times(beat_times=beat_times, steps_per_beat=steps_per_beat, n_extend=1) + + if self.sampling_rate != single_sampling_rate and self.sampling_rate is not None: + if resample: + # Change sampling_rate to self.sampling_rate + single_raw_audio = librosa.core.resample( + single_raw_audio, + orig_sr=single_sampling_rate, + target_sr=self.sampling_rate, + res_type="kaiser_best", + ) + else: + warnings.warn( + f"The sampling_rate of the provided audio is different from the target sampling_rate " + f"of the Feature Extractor, {self.sampling_rate} vs {single_sampling_rate}. " + f"In these cases it is recommended to use `resample=True` in the `__call__` method to " + f"get the optimal behaviour." + ) + + single_sampling_rate = self.sampling_rate + start_sample = int(beatsteps[0] * single_sampling_rate) + end_sample = int(beatsteps[-1] * single_sampling_rate) + + input_features, extrapolated_beatstep = self.preprocess_mel( + single_raw_audio[start_sample:end_sample], beatsteps - beatsteps[0] + ) + + mel_specs = self.mel_spectrogram(input_features.astype(np.float32)) + + # apply np.log to get log mel-spectrograms + log_mel_specs = np.log(np.clip(mel_specs, a_min=1e-6, a_max=None)) + + input_features = np.transpose(log_mel_specs, (0, -1, -2)) + + batch_input_features.append(input_features) + batch_beatsteps.append(beatsteps) + batch_ext_beatstep.append(extrapolated_beatstep) + + output = BatchFeature( + { + "input_features": batch_input_features, + "beatsteps": batch_beatsteps, + "extrapolated_beatstep": batch_ext_beatstep, + } + ) + + output = self.pad( + output, + is_batched=is_batched, + return_attention_mask=return_attention_mask, + return_tensors=return_tensors, + ) + + return output + + +__all__ = ["Pop2PianoFeatureExtractor"] diff --git a/docs/transformers/src/transformers/models/pop2piano/modeling_pop2piano.py b/docs/transformers/src/transformers/models/pop2piano/modeling_pop2piano.py new file mode 100644 index 0000000000000000000000000000000000000000..0a8b479554374f44105f76c060d7f58e99e14eb0 --- /dev/null +++ b/docs/transformers/src/transformers/models/pop2piano/modeling_pop2piano.py @@ -0,0 +1,1494 @@ +# coding=utf-8 +# Copyright 2023 The Pop2Piano Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Pop2Piano model.""" + +import copy +import math +from typing import Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss + +from transformers.generation import GenerationConfig + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqLMOutput, +) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import ALL_LAYERNORM_LAYERS, find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_torch_flex_attn_available, + is_torch_fx_proxy, + is_torchdynamo_compiling, + logging, + replace_return_docstrings, +) +from .configuration_pop2piano import Pop2PianoConfig + + +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + +logger = logging.get_logger(__name__) + +_load_pop2piano_layer_norm = True + +try: + from apex.normalization import FusedRMSNorm + + _load_pop2piano_layer_norm = False + + logger.info("Discovered apex.normalization.FusedRMSNorm - will use it instead of Pop2PianoLayerNorm") +except ImportError: + # using the normal Pop2PianoLayerNorm + pass +except Exception: + logger.warning("Discovered apex but it failed to load, falling back to Pop2PianoLayerNorm") + pass + + +_CONFIG_FOR_DOC = "Pop2PianoConfig" +_CHECKPOINT_FOR_DOC = "sweetcocoa/pop2piano" + + +POP2PIANO_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Pop2Piano is a model with relative position embeddings + so you should be able to pad the inputs on both the right and the left. Indices can be obtained using + [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for detail. + [What are input IDs?](../glossary#input-ids) To know more on how to prepare `input_ids` for pretraining + take a look a [Pop2Piano Training](./Pop2Piano#training). + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. Indices can be obtained using + [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. + [What are decoder input IDs?](../glossary#decoder-input-ids) Pop2Piano uses the `pad_token_id` as the + starting token for `decoder_input_ids` generation. If `past_key_values` is used, optionally only the last + `decoder_input_ids` have to be input (see `past_key_values`). To know more on how to prepare + decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in `[0, + 1]`: + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0, + 1]`: + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in + `[0, 1]`: + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): + Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, `optional`: *attentions*) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a sequence of hidden states at + the output of the last layer of the encoder. Used in the cross-attention of the decoder. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Does the same task as `inputs_embeds`. If `inputs_embeds` is not present but `input_features` is present + then `input_features` will be considered as `inputs_embeds`. + decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded + representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be + input (see `past_key_values`). This is useful if you want more control over how to convert + `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. If + `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value of + `inputs_embeds`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. +""" + + +# Copied from transformers.models.t5.modeling_t5.T5LayerNorm with T5->Pop2Piano +class Pop2PianoLayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Construct a layernorm module in the Pop2Piano style. No bias and no subtraction of mean. + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + # Pop2Piano uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean + # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated + # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for + # half-precision inputs is done in fp32 + + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states + + +if not _load_pop2piano_layer_norm: + Pop2PianoLayerNorm = FusedRMSNorm # noqa + +ALL_LAYERNORM_LAYERS.append(Pop2PianoLayerNorm) + + +# Copied from transformers.models.t5.modeling_t5.T5DenseActDense with T5->Pop2Piano,t5->pop2piano +class Pop2PianoDenseActDense(nn.Module): + def __init__(self, config: Pop2PianoConfig): + super().__init__() + self.wi = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) + self.dropout = nn.Dropout(config.dropout_rate) + self.act = ACT2FN[config.dense_act_fn] + + def forward(self, hidden_states): + hidden_states = self.wi(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.dropout(hidden_states) + if ( + isinstance(self.wo.weight, torch.Tensor) + and hidden_states.dtype != self.wo.weight.dtype + and self.wo.weight.dtype != torch.int8 + ): + hidden_states = hidden_states.to(self.wo.weight.dtype) + hidden_states = self.wo(hidden_states) + return hidden_states + + +# Copied from transformers.models.t5.modeling_t5.T5DenseGatedActDense with T5->Pop2Piano +class Pop2PianoDenseGatedActDense(nn.Module): + def __init__(self, config: Pop2PianoConfig): + super().__init__() + self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) + self.dropout = nn.Dropout(config.dropout_rate) + self.act = ACT2FN[config.dense_act_fn] + + def forward(self, hidden_states): + hidden_gelu = self.act(self.wi_0(hidden_states)) + hidden_linear = self.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = self.dropout(hidden_states) + + # To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32. + # See https://github.com/huggingface/transformers/issues/20287 + # we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None`` + if ( + isinstance(self.wo.weight, torch.Tensor) + and hidden_states.dtype != self.wo.weight.dtype + and self.wo.weight.dtype != torch.int8 + ): + hidden_states = hidden_states.to(self.wo.weight.dtype) + + hidden_states = self.wo(hidden_states) + return hidden_states + + +# Copied from transformers.models.t5.modeling_t5.T5LayerFF with T5->Pop2Piano +class Pop2PianoLayerFF(nn.Module): + def __init__(self, config: Pop2PianoConfig): + super().__init__() + if config.is_gated_act: + self.DenseReluDense = Pop2PianoDenseGatedActDense(config) + else: + self.DenseReluDense = Pop2PianoDenseActDense(config) + + self.layer_norm = Pop2PianoLayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward(self, hidden_states): + forwarded_states = self.layer_norm(hidden_states) + forwarded_states = self.DenseReluDense(forwarded_states) + hidden_states = hidden_states + self.dropout(forwarded_states) + return hidden_states + + +# Copied from transformers.models.t5.modeling_t5.T5Attention with T5->Pop2Piano,t5->pop2piano +class Pop2PianoAttention(nn.Module): + def __init__( + self, + config: Pop2PianoConfig, + has_relative_attention_bias=False, + layer_idx: Optional[int] = None, + ): + super().__init__() + self.is_decoder = config.is_decoder + self.has_relative_attention_bias = has_relative_attention_bias + self.relative_attention_num_buckets = config.relative_attention_num_buckets + self.relative_attention_max_distance = config.relative_attention_max_distance + self.d_model = config.d_model + self.key_value_proj_dim = config.d_kv + self.n_heads = config.num_heads + self.dropout = config.dropout_rate + self.inner_dim = self.n_heads * self.key_value_proj_dim + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + # Mesh TensorFlow initialization to avoid scaling before softmax + self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.k = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.v = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.o = nn.Linear(self.inner_dim, self.d_model, bias=False) + + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads) + self.pruned_heads = set() + self.gradient_checkpointing = False + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads + ) + # Prune linear layers + self.q = prune_linear_layer(self.q, index) + self.k = prune_linear_layer(self.k, index) + self.v = prune_linear_layer(self.v, index) + self.o = prune_linear_layer(self.o, index, dim=1) + # Update hyper params + self.n_heads = self.n_heads - len(heads) + self.inner_dim = self.key_value_proj_dim * self.n_heads + self.pruned_heads = self.pruned_heads.union(heads) + + @staticmethod + def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_position_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_position_if_large = torch.min( + relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) + return relative_buckets + + def compute_bias(self, query_length, key_length, device=None, cache_position=None): + """Compute binned relative position bias""" + if device is None: + device = self.relative_attention_bias.weight.device + if cache_position is None: + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + else: + context_position = cache_position[:, None].to(device) + memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] + relative_position = memory_position - context_position # shape (query_length, key_length) + relative_position_bucket = self._relative_position_bucket( + relative_position, # shape (query_length, key_length) + bidirectional=(not self.is_decoder), + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads) + values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) + return values + + def forward( + self, + hidden_states, + mask=None, + key_value_states=None, + position_bias=None, + past_key_value=None, + layer_head_mask=None, + query_length=None, + use_cache=False, + output_attentions=False, + cache_position=None, + ): + """ + Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). + """ + # Input is (batch_size, seq_length, dim) + # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder) + batch_size, seq_length = hidden_states.shape[:2] + + # if key_value_states are provided this layer is used as a cross-attention layer for the decoder + is_cross_attention = key_value_states is not None + + query_states = self.q(hidden_states) + query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + + if past_key_value is not None: + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: + # reuse k,v, cross_attentions + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] + else: + key_states = self.k(current_states) + value_states = self.v(current_states) + key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True + + # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + scores = torch.matmul(query_states, key_states.transpose(3, 2)) + + if position_bias is None: + key_length = key_states.shape[-2] + # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past) + real_seq_length = query_length if query_length is not None else cache_position[-1] + 1 + if not self.has_relative_attention_bias: + position_bias = torch.zeros( + (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype + ) + if self.gradient_checkpointing and self.training: + position_bias.requires_grad = True + else: + position_bias = self.compute_bias( + real_seq_length, key_length, device=scores.device, cache_position=cache_position + ) + position_bias = position_bias[:, :, -seq_length:, :] + + if mask is not None: + causal_mask = mask[:, :, :, : key_states.shape[-2]] + position_bias = position_bias + causal_mask + + if self.pruned_heads: + mask = torch.ones(position_bias.shape[1]) + mask[list(self.pruned_heads)] = 0 + position_bias_masked = position_bias[:, mask.bool()] + else: + position_bias_masked = position_bias + + scores += position_bias_masked + + # (batch_size, n_heads, seq_length, key_length) + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + # Mask heads if we want to + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask + + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(batch_size, -1, self.inner_dim) + attn_output = self.o(attn_output) + + outputs = (attn_output, past_key_value, position_bias) + + if output_attentions: + outputs = outputs + (attn_weights,) + return outputs + + +# Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5->Pop2Piano,t5->pop2piano +class Pop2PianoLayerSelfAttention(nn.Module): + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): + super().__init__() + self.SelfAttention = Pop2PianoAttention( + config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx + ) + self.layer_norm = Pop2PianoLayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + use_cache=False, + output_attentions=False, + cache_position=None, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.SelfAttention( + normed_hidden_states, + mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + ) + hidden_states = hidden_states + self.dropout(attention_output[0]) + outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5->Pop2Piano,t5->pop2piano +class Pop2PianoLayerCrossAttention(nn.Module): + def __init__(self, config, layer_idx: Optional[int] = None): + super().__init__() + self.EncDecAttention = Pop2PianoAttention(config, has_relative_attention_bias=False, layer_idx=layer_idx) + self.layer_norm = Pop2PianoLayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + hidden_states, + key_value_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + use_cache=False, + query_length=None, + output_attentions=False, + cache_position=None, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.EncDecAttention( + normed_hidden_states, + mask=attention_mask, + key_value_states=key_value_states, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + query_length=query_length, + output_attentions=output_attentions, + cache_position=cache_position, + ) + layer_output = hidden_states + self.dropout(attention_output[0]) + outputs = (layer_output,) + attention_output[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.t5.modeling_t5.T5Block with T5->Pop2Piano,t5->pop2piano +class Pop2PianoBlock(nn.Module): + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): + super().__init__() + self.is_decoder = config.is_decoder + self.layer = nn.ModuleList() + self.layer.append( + Pop2PianoLayerSelfAttention( + config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx + ) + ) + if self.is_decoder: + self.layer.append(Pop2PianoLayerCrossAttention(config, layer_idx=layer_idx)) + + self.layer.append(Pop2PianoLayerFF(config)) + + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + encoder_decoder_position_bias=None, + layer_head_mask=None, + cross_attn_layer_head_mask=None, + past_key_value=None, + use_cache=False, + output_attentions=False, + return_dict=True, + cache_position=None, + ): + self_attention_outputs = self.layer[0]( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + ) + hidden_states, past_key_value = self_attention_outputs[:2] + attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(hidden_states).any(), + torch.finfo(hidden_states.dtype).max - 1000, + torch.finfo(hidden_states.dtype).max, + ) + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + do_cross_attention = self.is_decoder and encoder_hidden_states is not None + if do_cross_attention: + cross_attention_outputs = self.layer[1]( + hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + position_bias=encoder_decoder_position_bias, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=past_key_value, + query_length=cache_position[-1] + 1, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states, past_key_value = cross_attention_outputs[:2] + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(hidden_states).any(), + torch.finfo(hidden_states.dtype).max - 1000, + torch.finfo(hidden_states.dtype).max, + ) + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + # Keep cross-attention outputs and relative position weights + attention_outputs = attention_outputs + cross_attention_outputs[2:] + + # Apply Feed Forward layer + hidden_states = self.layer[-1](hidden_states) + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(hidden_states).any(), + torch.finfo(hidden_states.dtype).max - 1000, + torch.finfo(hidden_states.dtype).max, + ) + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if use_cache: + outputs = outputs + (past_key_value,) + attention_outputs + else: + outputs = outputs + attention_outputs + + return outputs # hidden-states, past_key_value, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + + +class Pop2PianoPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = Pop2PianoConfig + base_model_prefix = "transformer" + is_parallelizable = False + supports_gradient_checkpointing = True + _supports_cache_class = True + _supports_static_cache = False + _no_split_modules = ["Pop2PianoBlock"] + _keep_in_fp32_modules = ["wo"] + + def _init_weights(self, module): + """Initialize the weights""" + factor = self.config.initializer_factor # Used for testing weights initialization + if isinstance(module, Pop2PianoLayerNorm): + module.weight.data.fill_(factor * 1.0) + elif isinstance(module, Pop2PianoConcatEmbeddingToMel): + module.embedding.weight.data.normal_(mean=0.0, std=factor * 1.0) + elif isinstance(module, Pop2PianoForConditionalGeneration): + # Mesh TensorFlow embeddings initialization + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624 + module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0) + if hasattr(module, "lm_head") and not self.config.tie_word_embeddings: + module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0) + elif isinstance(module, Pop2PianoDenseActDense): + # Mesh TensorFlow FF initialization + # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56 + # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89 + module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.wi, "bias") and module.wi.bias is not None: + module.wi.bias.data.zero_() + module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + if hasattr(module.wo, "bias") and module.wo.bias is not None: + module.wo.bias.data.zero_() + elif isinstance(module, Pop2PianoDenseGatedActDense): + module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None: + module.wi_0.bias.data.zero_() + module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None: + module.wi_1.bias.data.zero_() + module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + if hasattr(module.wo, "bias") and module.wo.bias is not None: + module.wo.bias.data.zero_() + elif isinstance(module, Pop2PianoAttention): + # Mesh TensorFlow attention initialization to avoid scaling before softmax + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136 + d_model = self.config.d_model + key_value_proj_dim = self.config.d_kv + n_heads = self.config.num_heads + module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) + module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) + module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) + module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) + if module.has_relative_attention_bias: + module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) + + def _shift_right(self, input_ids): + decoder_start_token_id = self.config.decoder_start_token_id + pad_token_id = self.config.pad_token_id + + if decoder_start_token_id is None: + raise ValueError( + "self.model.config.decoder_start_token_id has to be defined. In Pop2Piano it is usually set to the pad_token_id." + ) + + # shift inputs to the right + if is_torch_fx_proxy(input_ids): + # Item assignment is not supported natively for proxies. + shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id) + shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1) + else: + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() + shifted_input_ids[..., 0] = decoder_start_token_id + + if pad_token_id is None: + raise ValueError("self.model.config.pad_token_id has to be defined.") + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +class Pop2PianoStack(Pop2PianoPreTrainedModel): + # Copied from transformers.models.t5.modeling_t5.T5Stack.__init__ with T5->Pop2Piano,t5->pop2piano + def __init__(self, config, embed_tokens=None): + super().__init__(config) + + self.embed_tokens = embed_tokens + self.is_decoder = config.is_decoder + + self.block = nn.ModuleList( + [ + Pop2PianoBlock(config, has_relative_attention_bias=bool(i == 0), layer_idx=i) + for i in range(config.num_layers) + ] + ) + self.final_layer_norm = Pop2PianoLayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + # Initialize weights and apply final processing + self.post_init() + # Model parallel + self.model_parallel = False + self.device_map = None + self.gradient_checkpointing = False + + # Copied from transformers.models.t5.modeling_t5.T5Stack.get_input_embeddings + def get_input_embeddings(self): + return self.embed_tokens + + # Copied from transformers.models.t5.modeling_t5.T5Stack.set_input_embeddings + def set_input_embeddings(self, new_embeddings): + self.embed_tokens = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + inputs_embeds=None, + head_mask=None, + cross_attn_head_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + cache_position=None, + ): + use_cache = use_cache if use_cache is not None else self.config.use_cache + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError( + f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time" + ) + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if inputs_embeds is None: + if self.embed_tokens is None: + raise ValueError("You have to initialize the model with valid token embeddings") + inputs_embeds = self.embed_tokens(input_ids) + + batch_size, seq_length = input_shape + + if use_cache is True: + if not self.is_decoder: + raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder") + + # initialize past_key_values + return_legacy_cache = False + return_self_attention_cache = False + if self.is_decoder and (use_cache or past_key_values is not None): + if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache): + return_self_attention_cache = True + past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) + elif not isinstance(past_key_values, EncoderDecoderCache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + elif past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + elif not self.is_decoder: + # do not pass cache object down the line for encoder stack + # it messes indexing later in decoder-stack because cache object is modified in-place + past_key_values = None + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device + ) + + if attention_mask is None and not is_torchdynamo_compiling(): + # required mask seq length can be calculated via length of past cache + mask_seq_length = past_key_values_length + seq_length + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + + if self.config.is_decoder: + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + past_key_values.self_attention_cache if past_key_values is not None else None, + output_attentions, + ) + else: + causal_mask = attention_mask[:, None, None, :] + causal_mask = causal_mask.to(dtype=inputs_embeds.dtype) + causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + head_mask = self.get_head_mask(head_mask, self.config.num_layers) + cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_cross_attentions = () if (output_attentions and self.is_decoder) else None + position_bias = None + encoder_decoder_position_bias = None + + hidden_states = self.dropout(inputs_embeds) + + for i, layer_module in enumerate(self.block): + layer_head_mask = head_mask[i] + cross_attn_layer_head_mask = cross_attn_head_mask[i] + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.forward, + hidden_states, + causal_mask, + position_bias, + encoder_hidden_states, + encoder_extended_attention_mask, + encoder_decoder_position_bias, + layer_head_mask, + cross_attn_layer_head_mask, + None, # past_key_value is always None with gradient checkpointing + use_cache, + output_attentions, + cache_position, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask=causal_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + layer_head_mask=layer_head_mask, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, + past_key_value=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + ) + + # layer_outputs is a tuple with: + # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + if use_cache is False: + layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] + + hidden_states, next_decoder_cache = layer_outputs[:2] + + # We share the position biases between the layers - the first layer store them + # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), + # (cross-attention position bias), (cross-attention weights) + position_bias = layer_outputs[2] + if self.is_decoder and encoder_hidden_states is not None: + encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[3],) + if self.is_decoder: + all_cross_attentions = all_cross_attentions + (layer_outputs[5],) + + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_self_attention_cache: + next_cache = past_key_values.self_attention_cache + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_cache, + all_hidden_states, + all_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ) + + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, "BlockMask"], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to place the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +class Pop2PianoConcatEmbeddingToMel(nn.Module): + """Embedding Matrix for `composer` tokens.""" + + def __init__(self, config): + super().__init__() + self.embedding = nn.Embedding(num_embeddings=config.composer_vocab_size, embedding_dim=config.d_model) + + def forward(self, feature, index_value, embedding_offset): + index_shifted = index_value - embedding_offset + composer_embedding = self.embedding(index_shifted).unsqueeze(1) + inputs_embeds = torch.cat([composer_embedding, feature], dim=1) + return inputs_embeds + + +Pop2Piano_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Pop2PianoConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings("""Pop2Piano Model with a `language modeling` head on top.""", Pop2Piano_START_DOCSTRING) +class Pop2PianoForConditionalGeneration(Pop2PianoPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] + + def __init__(self, config: Pop2PianoConfig): + super().__init__(config) + self.config = config + self.model_dim = config.d_model + + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + self.mel_conditioner = Pop2PianoConcatEmbeddingToMel(config) + + encoder_config = copy.deepcopy(config) + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + + self.encoder = Pop2PianoStack(encoder_config, self.shared) + + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.is_encoder_decoder = False + decoder_config.num_layers = config.num_decoder_layers + self.decoder = Pop2PianoStack(decoder_config, self.shared) + + self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + self.decoder.set_input_embeddings(new_embeddings) + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_output_embeddings(self): + return self.lm_head + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + def get_mel_conditioner_outputs( + self, + input_features: torch.FloatTensor, + composer: str, + generation_config: GenerationConfig, + attention_mask: Optional[torch.FloatTensor] = None, + ): + """ + This method is used to concatenate mel conditioner tokens at the front of the input_features in order to + control the type of MIDI token generated by the model. + + Args: + input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + input features extracted from the feature extractor. + composer (`str`): + composer token which determines the type of MIDI tokens to be generated. + generation_config (`~generation.GenerationConfig`): + The generation is used to get the composer-feature_token pair. + attention_mask (``, *optional*): + For batched generation `input_features` are padded to have the same shape across all examples. + `attention_mask` helps to determine which areas were padded and which were not. + - 1 for tokens that are **not padded**, + - 0 for tokens that are **padded**. + """ + composer_to_feature_token = generation_config.composer_to_feature_token + if composer not in composer_to_feature_token.keys(): + raise ValueError( + f"Please choose a composer from {list(composer_to_feature_token.keys())}. Composer received - {composer}" + ) + composer_value = composer_to_feature_token[composer] + composer_value = torch.tensor(composer_value, device=self.device) + composer_value = composer_value.repeat(input_features.shape[0]) + + embedding_offset = min(composer_to_feature_token.values()) + + input_features = self.mel_conditioner( + feature=input_features, + index_value=composer_value, + embedding_offset=embedding_offset, + ) + if attention_mask is not None: + input_features[~attention_mask[:, 0].bool()] = 0.0 + + # since self.mel_conditioner adds a new array at the front of inputs_embeds we need to do the same for attention_mask to keep the shapes same + attention_mask = torch.concatenate([attention_mask[:, 0].view(-1, 1), attention_mask], axis=1) + return input_features, attention_mask + + return input_features, None + + @add_start_docstrings_to_model_forward(POP2PIANO_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + input_features: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ..., + config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for + labels in `[0, ..., config.vocab_size]` + Returns: + """ + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if inputs_embeds is not None and input_features is not None: + raise ValueError("Both `inputs_embeds` and `input_features` received! Please provide only one of them") + elif input_features is not None and inputs_embeds is None: + inputs_embeds = input_features + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + # Convert encoder inputs in embeddings if needed + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + hidden_states = encoder_outputs[0] + + if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: + # get decoder inputs from shifting lm labels to the right + decoder_input_ids = self._shift_right(labels) + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + sequence_output = decoder_outputs[0] + + if self.config.tie_word_embeddings: + # Rescale output before projecting on vocab + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + sequence_output = sequence_output * (self.model_dim**-0.5) + + lm_logits = self.lm_head(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss(ignore_index=-100) + loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) + + if not return_dict: + output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs + return ((loss,) + output) if loss is not None else output + + return Seq2SeqLMOutput( + loss=loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + @torch.no_grad() + def generate( + self, + input_features, + attention_mask=None, + composer="composer1", + generation_config=None, + **kwargs, + ): + """ + Generates token ids for midi outputs. + + + + Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the + model's default generation configuration. You can override any `generation_config` by passing the corresponding + parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`. For an overview of generation + strategies and code examples, check out the [following guide](./generation_strategies). + + + + Parameters: + input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + This is the featurized version of audio generated by `Pop2PianoFeatureExtractor`. + attention_mask: + For batched generation `input_features` are padded to have the same shape across all examples. + `attention_mask` helps to determine which areas were padded and which were not. + - 1 for tokens that are **not padded**, + - 0 for tokens that are **padded**. + composer (`str`, *optional*, defaults to `"composer1"`): + This value is passed to `Pop2PianoConcatEmbeddingToMel` to generate different embeddings for each + `"composer"`. Please make sure that the composet value is present in `composer_to_feature_token` in + `generation_config`. For an example please see + https://huggingface.co/sweetcocoa/pop2piano/blob/main/generation_config.json . + generation_config (`~generation.GenerationConfig`, *optional*): + The generation configuration to be used as base parametrization for the generation call. `**kwargs` + passed to generate matching the attributes of `generation_config` will override them. If + `generation_config` is not provided, the default will be used, which had the following loading + priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model + configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s + default values, whose documentation should be checked to parameterize generation. + kwargs: + Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be + forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder + specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*. + Return: + [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` + or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. + Since Pop2Piano is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible + [`~utils.ModelOutput`] types are: + - [`~generation.GenerateEncoderDecoderOutput`], + - [`~generation.GenerateBeamEncoderDecoderOutput`] + """ + + if generation_config is None: + generation_config = self.generation_config + generation_config.update(**kwargs) + + # check for composer_to_feature_token + if not hasattr(generation_config, "composer_to_feature_token"): + raise ValueError( + "`composer_to_feature_token` was not found! Please refer to " + "https://huggingface.co/sweetcocoa/pop2piano/blob/main/generation_config.json" + "and parse a dict like that." + ) + + if len(generation_config.composer_to_feature_token) != self.config.composer_vocab_size: + raise ValueError( + "config.composer_vocab_size must be same as the number of keys in " + f"generation_config.composer_to_feature_token! " + f"Found {self.config.composer_vocab_size} vs {len(generation_config.composer_to_feature_token)}." + ) + + # to control the variation of generated MIDI tokens we concatenate mel-conditioner tokens(which depends on composer_token) + # at the front of input_features. + input_features, attention_mask = self.get_mel_conditioner_outputs( + input_features=input_features, + attention_mask=attention_mask, + composer=composer, + generation_config=generation_config, + ) + + return super().generate( + inputs=None, + inputs_embeds=input_features, + attention_mask=attention_mask, + generation_config=generation_config, + **kwargs, + ) + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return self._shift_right(labels) + + def _reorder_cache(self, past_key_values, beam_idx): + # if decoder past is not included in output + # speedy decoding is disabled and no need to reorder + if past_key_values is None: + logger.warning("You might want to consider setting `use_cache=True` to speed up decoding") + return past_key_values + + reordered_decoder_past = () + for layer_past_states in past_key_values: + # get the correct batch idx from layer past batch dim + # batch dim of `past` is at 2nd position + reordered_layer_past_states = () + for layer_past_state in layer_past_states: + # need to set correct `past` for each of the four key / value states + reordered_layer_past_states = reordered_layer_past_states + ( + layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)), + ) + + if reordered_layer_past_states[0].shape != layer_past_states[0].shape: + raise ValueError( + f"reordered_layer_past_states[0] shape {reordered_layer_past_states[0].shape} and layer_past_states[0] shape {layer_past_states[0].shape} mismatched" + ) + if len(reordered_layer_past_states) != len(layer_past_states): + raise ValueError( + f"length of reordered_layer_past_states {len(reordered_layer_past_states)} and length of layer_past_states {len(layer_past_states)} mismatched" + ) + + reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,) + return reordered_decoder_past + + +__all__ = ["Pop2PianoForConditionalGeneration", "Pop2PianoPreTrainedModel"] diff --git a/docs/transformers/src/transformers/models/pop2piano/processing_pop2piano.py b/docs/transformers/src/transformers/models/pop2piano/processing_pop2piano.py new file mode 100644 index 0000000000000000000000000000000000000000..3b839f8b1fd548a821f9358475fa0b9f9ea7fd09 --- /dev/null +++ b/docs/transformers/src/transformers/models/pop2piano/processing_pop2piano.py @@ -0,0 +1,144 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Processor class for Pop2Piano.""" + +import os +from typing import List, Optional, Union + +import numpy as np + +from ...feature_extraction_utils import BatchFeature +from ...processing_utils import ProcessorMixin +from ...tokenization_utils import BatchEncoding, PaddingStrategy, TruncationStrategy +from ...utils import TensorType +from ...utils.import_utils import requires + + +@requires(backends=("essentia", "librosa", "pretty_midi", "scipy", "torch")) +class Pop2PianoProcessor(ProcessorMixin): + r""" + Constructs an Pop2Piano processor which wraps a Pop2Piano Feature Extractor and Pop2Piano Tokenizer into a single + processor. + + [`Pop2PianoProcessor`] offers all the functionalities of [`Pop2PianoFeatureExtractor`] and [`Pop2PianoTokenizer`]. + See the docstring of [`~Pop2PianoProcessor.__call__`] and [`~Pop2PianoProcessor.decode`] for more information. + + Args: + feature_extractor (`Pop2PianoFeatureExtractor`): + An instance of [`Pop2PianoFeatureExtractor`]. The feature extractor is a required input. + tokenizer (`Pop2PianoTokenizer`): + An instance of ['Pop2PianoTokenizer`]. The tokenizer is a required input. + """ + + attributes = ["feature_extractor", "tokenizer"] + feature_extractor_class = "Pop2PianoFeatureExtractor" + tokenizer_class = "Pop2PianoTokenizer" + + def __init__(self, feature_extractor, tokenizer): + super().__init__(feature_extractor, tokenizer) + + def __call__( + self, + audio: Union[np.ndarray, List[float], List[np.ndarray]] = None, + sampling_rate: Union[int, List[int]] = None, + steps_per_beat: int = 2, + resample: Optional[bool] = True, + notes: Union[List, TensorType] = None, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length: Optional[int] = None, + pad_to_multiple_of: Optional[int] = None, + verbose: bool = True, + **kwargs, + ) -> Union[BatchFeature, BatchEncoding]: + """ + This method uses [`Pop2PianoFeatureExtractor.__call__`] method to prepare log-mel-spectrograms for the model, + and [`Pop2PianoTokenizer.__call__`] to prepare token_ids from notes. + + Please refer to the docstring of the above two methods for more information. + """ + + # Since Feature Extractor needs both audio and sampling_rate and tokenizer needs both token_ids and + # feature_extractor_output, we must check for both. + if (audio is None and sampling_rate is None) and (notes is None): + raise ValueError( + "You have to specify at least audios and sampling_rate in order to use feature extractor or " + "notes to use the tokenizer part." + ) + + if audio is not None and sampling_rate is not None: + inputs = self.feature_extractor( + audio=audio, + sampling_rate=sampling_rate, + steps_per_beat=steps_per_beat, + resample=resample, + **kwargs, + ) + if notes is not None: + encoded_token_ids = self.tokenizer( + notes=notes, + padding=padding, + truncation=truncation, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + verbose=verbose, + **kwargs, + ) + + if notes is None: + return inputs + + elif audio is None or sampling_rate is None: + return encoded_token_ids + + else: + inputs["token_ids"] = encoded_token_ids["token_ids"] + return inputs + + def batch_decode( + self, + token_ids, + feature_extractor_output: BatchFeature, + return_midi: bool = True, + ) -> BatchEncoding: + """ + This method uses [`Pop2PianoTokenizer.batch_decode`] method to convert model generated token_ids to midi_notes. + + Please refer to the docstring of the above two methods for more information. + """ + + return self.tokenizer.batch_decode( + token_ids=token_ids, feature_extractor_output=feature_extractor_output, return_midi=return_midi + ) + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + feature_extractor_input_names = self.feature_extractor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + feature_extractor_input_names)) + + def save_pretrained(self, save_directory, **kwargs): + if os.path.isfile(save_directory): + raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file") + os.makedirs(save_directory, exist_ok=True) + return super().save_pretrained(save_directory, **kwargs) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + args = cls._get_arguments_from_pretrained(pretrained_model_name_or_path, **kwargs) + return cls(*args) + + +__all__ = ["Pop2PianoProcessor"] diff --git a/docs/transformers/src/transformers/models/pop2piano/tokenization_pop2piano.py b/docs/transformers/src/transformers/models/pop2piano/tokenization_pop2piano.py new file mode 100644 index 0000000000000000000000000000000000000000..0a7a8a558669186541281d2168fe3314fbf356f2 --- /dev/null +++ b/docs/transformers/src/transformers/models/pop2piano/tokenization_pop2piano.py @@ -0,0 +1,723 @@ +# coding=utf-8 +# Copyright 2023 The Pop2Piano Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization class for Pop2Piano.""" + +import json +import os +from typing import List, Optional, Tuple, Union + +import numpy as np + +from ...feature_extraction_utils import BatchFeature +from ...tokenization_utils import AddedToken, BatchEncoding, PaddingStrategy, PreTrainedTokenizer, TruncationStrategy +from ...utils import TensorType, is_pretty_midi_available, logging, requires_backends, to_numpy +from ...utils.import_utils import requires + + +if is_pretty_midi_available(): + import pretty_midi + +logger = logging.get_logger(__name__) + + +VOCAB_FILES_NAMES = { + "vocab": "vocab.json", +} + + +def token_time_to_note(number, cutoff_time_idx, current_idx): + current_idx += number + if cutoff_time_idx is not None: + current_idx = min(current_idx, cutoff_time_idx) + + return current_idx + + +def token_note_to_note(number, current_velocity, default_velocity, note_onsets_ready, current_idx, notes): + if note_onsets_ready[number] is not None: + # offset with onset + onset_idx = note_onsets_ready[number] + if onset_idx < current_idx: + # Time shift after previous note_on + offset_idx = current_idx + notes.append([onset_idx, offset_idx, number, default_velocity]) + onsets_ready = None if current_velocity == 0 else current_idx + note_onsets_ready[number] = onsets_ready + else: + note_onsets_ready[number] = current_idx + return notes + + +@requires(backends=("pretty_midi", "torch")) +class Pop2PianoTokenizer(PreTrainedTokenizer): + """ + Constructs a Pop2Piano tokenizer. This tokenizer does not require training. + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab (`str`): + Path to the vocab file which contains the vocabulary. + default_velocity (`int`, *optional*, defaults to 77): + Determines the default velocity to be used while creating midi Notes. + num_bars (`int`, *optional*, defaults to 2): + Determines cutoff_time_idx in for each token. + unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"-1"`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to 1): + The end of sequence token. + pad_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to 0): + A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by + attention mechanisms or loss computation. + bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to 2): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + """ + + model_input_names = ["token_ids", "attention_mask"] + vocab_files_names = VOCAB_FILES_NAMES + + def __init__( + self, + vocab, + default_velocity=77, + num_bars=2, + unk_token="-1", + eos_token="1", + pad_token="0", + bos_token="2", + **kwargs, + ): + unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token + eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token + pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token + bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token + + self.default_velocity = default_velocity + self.num_bars = num_bars + + # Load the vocab + with open(vocab, "rb") as file: + self.encoder = json.load(file) + + # create mappings for encoder + self.decoder = {v: k for k, v in self.encoder.items()} + + super().__init__( + unk_token=unk_token, + eos_token=eos_token, + pad_token=pad_token, + bos_token=bos_token, + **kwargs, + ) + + @property + def vocab_size(self): + """Returns the vocabulary size of the tokenizer.""" + return len(self.encoder) + + def get_vocab(self): + """Returns the vocabulary of the tokenizer.""" + return dict(self.encoder, **self.added_tokens_encoder) + + def _convert_id_to_token(self, token_id: int) -> list: + """ + Decodes the token ids generated by the transformer into notes. + + Args: + token_id (`int`): + This denotes the ids generated by the transformers to be converted to Midi tokens. + + Returns: + `List`: A list consists of token_type (`str`) and value (`int`). + """ + + token_type_value = self.decoder.get(token_id, f"{self.unk_token}_TOKEN_TIME") + token_type_value = token_type_value.split("_") + token_type, value = "_".join(token_type_value[1:]), int(token_type_value[0]) + + return [token_type, value] + + def _convert_token_to_id(self, token, token_type="TOKEN_TIME") -> int: + """ + Encodes the Midi tokens to transformer generated token ids. + + Args: + token (`int`): + This denotes the token value. + token_type (`str`): + This denotes the type of the token. There are four types of midi tokens such as "TOKEN_TIME", + "TOKEN_VELOCITY", "TOKEN_NOTE" and "TOKEN_SPECIAL". + + Returns: + `int`: returns the id of the token. + """ + return self.encoder.get(f"{token}_{token_type}", int(self.unk_token)) + + def relative_batch_tokens_ids_to_notes( + self, + tokens: np.ndarray, + beat_offset_idx: int, + bars_per_batch: int, + cutoff_time_idx: int, + ): + """ + Converts relative tokens to notes which are then used to generate pretty midi object. + + Args: + tokens (`numpy.ndarray`): + Tokens to be converted to notes. + beat_offset_idx (`int`): + Denotes beat offset index for each note in generated Midi. + bars_per_batch (`int`): + A parameter to control the Midi output generation. + cutoff_time_idx (`int`): + Denotes the cutoff time index for each note in generated Midi. + """ + + notes = None + + for index in range(len(tokens)): + _tokens = tokens[index] + _start_idx = beat_offset_idx + index * bars_per_batch * 4 + _cutoff_time_idx = cutoff_time_idx + _start_idx + _notes = self.relative_tokens_ids_to_notes( + _tokens, + start_idx=_start_idx, + cutoff_time_idx=_cutoff_time_idx, + ) + + if len(_notes) == 0: + pass + elif notes is None: + notes = _notes + else: + notes = np.concatenate((notes, _notes), axis=0) + + if notes is None: + return [] + return notes + + def relative_batch_tokens_ids_to_midi( + self, + tokens: np.ndarray, + beatstep: np.ndarray, + beat_offset_idx: int = 0, + bars_per_batch: int = 2, + cutoff_time_idx: int = 12, + ): + """ + Converts tokens to Midi. This method calls `relative_batch_tokens_ids_to_notes` method to convert batch tokens + to notes then uses `notes_to_midi` method to convert them to Midi. + + Args: + tokens (`numpy.ndarray`): + Denotes tokens which alongside beatstep will be converted to Midi. + beatstep (`np.ndarray`): + We get beatstep from feature extractor which is also used to get Midi. + beat_offset_idx (`int`, *optional*, defaults to 0): + Denotes beat offset index for each note in generated Midi. + bars_per_batch (`int`, *optional*, defaults to 2): + A parameter to control the Midi output generation. + cutoff_time_idx (`int`, *optional*, defaults to 12): + Denotes the cutoff time index for each note in generated Midi. + """ + beat_offset_idx = 0 if beat_offset_idx is None else beat_offset_idx + notes = self.relative_batch_tokens_ids_to_notes( + tokens=tokens, + beat_offset_idx=beat_offset_idx, + bars_per_batch=bars_per_batch, + cutoff_time_idx=cutoff_time_idx, + ) + midi = self.notes_to_midi(notes, beatstep, offset_sec=beatstep[beat_offset_idx]) + return midi + + # Taken from the original code + # Please see https://github.com/sweetcocoa/pop2piano/blob/fac11e8dcfc73487513f4588e8d0c22a22f2fdc5/midi_tokenizer.py#L257 + def relative_tokens_ids_to_notes( + self, tokens: np.ndarray, start_idx: float, cutoff_time_idx: Optional[float] = None + ): + """ + Converts relative tokens to notes which will then be used to create Pretty Midi objects. + + Args: + tokens (`numpy.ndarray`): + Relative Tokens which will be converted to notes. + start_idx (`float`): + A parameter which denotes the starting index. + cutoff_time_idx (`float`, *optional*): + A parameter used while converting tokens to notes. + """ + words = [self._convert_id_to_token(token) for token in tokens] + + current_idx = start_idx + current_velocity = 0 + note_onsets_ready = [None for i in range(sum([k.endswith("NOTE") for k in self.encoder.keys()]) + 1)] + notes = [] + for token_type, number in words: + if token_type == "TOKEN_SPECIAL": + if number == 1: + break + elif token_type == "TOKEN_TIME": + current_idx = token_time_to_note( + number=number, cutoff_time_idx=cutoff_time_idx, current_idx=current_idx + ) + elif token_type == "TOKEN_VELOCITY": + current_velocity = number + + elif token_type == "TOKEN_NOTE": + notes = token_note_to_note( + number=number, + current_velocity=current_velocity, + default_velocity=self.default_velocity, + note_onsets_ready=note_onsets_ready, + current_idx=current_idx, + notes=notes, + ) + else: + raise ValueError("Token type not understood!") + + for pitch, note_onset in enumerate(note_onsets_ready): + # force offset if no offset for each pitch + if note_onset is not None: + if cutoff_time_idx is None: + cutoff = note_onset + 1 + else: + cutoff = max(cutoff_time_idx, note_onset + 1) + + offset_idx = max(current_idx, cutoff) + notes.append([note_onset, offset_idx, pitch, self.default_velocity]) + + if len(notes) == 0: + return [] + else: + notes = np.array(notes) + note_order = notes[:, 0] * 128 + notes[:, 1] + notes = notes[note_order.argsort()] + return notes + + def notes_to_midi(self, notes: np.ndarray, beatstep: np.ndarray, offset_sec: int = 0.0): + """ + Converts notes to Midi. + + Args: + notes (`numpy.ndarray`): + This is used to create Pretty Midi objects. + beatstep (`numpy.ndarray`): + This is the extrapolated beatstep that we get from feature extractor. + offset_sec (`int`, *optional*, defaults to 0.0): + This represents the offset seconds which is used while creating each Pretty Midi Note. + """ + + requires_backends(self, ["pretty_midi"]) + + new_pm = pretty_midi.PrettyMIDI(resolution=384, initial_tempo=120.0) + new_inst = pretty_midi.Instrument(program=0) + new_notes = [] + + for onset_idx, offset_idx, pitch, velocity in notes: + new_note = pretty_midi.Note( + velocity=velocity, + pitch=pitch, + start=beatstep[onset_idx] - offset_sec, + end=beatstep[offset_idx] - offset_sec, + ) + new_notes.append(new_note) + new_inst.notes = new_notes + new_pm.instruments.append(new_inst) + new_pm.remove_invalid_notes() + return new_pm + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + """ + Saves the tokenizer's vocabulary dictionary to the provided save_directory. + + Args: + save_directory (`str`): + A path to the directory where to saved. It will be created if it doesn't exist. + filename_prefix (`Optional[str]`, *optional*): + A prefix to add to the names of the files saved by the tokenizer. + """ + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + + # Save the encoder. + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab"] + ) + with open(out_vocab_file, "w") as file: + file.write(json.dumps(self.encoder)) + + return (out_vocab_file,) + + def encode_plus( + self, + notes: Union[np.ndarray, List[pretty_midi.Note]], + truncation_strategy: Optional[TruncationStrategy] = None, + max_length: Optional[int] = None, + **kwargs, + ) -> BatchEncoding: + r""" + This is the `encode_plus` method for `Pop2PianoTokenizer`. It converts the midi notes to the transformer + generated token ids. It only works on a single batch, to process multiple batches please use + `batch_encode_plus` or `__call__` method. + + Args: + notes (`numpy.ndarray` of shape `[sequence_length, 4]` or `list` of `pretty_midi.Note` objects): + This represents the midi notes. If `notes` is a `numpy.ndarray`: + - Each sequence must have 4 values, they are `onset idx`, `offset idx`, `pitch` and `velocity`. + If `notes` is a `list` containing `pretty_midi.Note` objects: + - Each sequence must have 4 attributes, they are `start`, `end`, `pitch` and `velocity`. + truncation_strategy ([`~tokenization_utils_base.TruncationStrategy`], *optional*): + Indicates the truncation strategy that is going to be used during truncation. + max_length (`int`, *optional*): + Maximum length of the returned list and optionally padding length (see above). + + Returns: + `BatchEncoding` containing the tokens ids. + """ + + requires_backends(self, ["pretty_midi"]) + + # check if notes is a pretty_midi object or not, if yes then extract the attributes and put them into a numpy + # array. + if isinstance(notes[0], pretty_midi.Note): + notes = np.array( + [[each_note.start, each_note.end, each_note.pitch, each_note.velocity] for each_note in notes] + ).reshape(-1, 4) + + # to round up all the values to the closest int values. + notes = np.round(notes).astype(np.int32) + max_time_idx = notes[:, :2].max() + + times = [[] for i in range((max_time_idx + 1))] + for onset, offset, pitch, velocity in notes: + times[onset].append([pitch, velocity]) + times[offset].append([pitch, 0]) + + tokens = [] + current_velocity = 0 + for i, time in enumerate(times): + if len(time) == 0: + continue + tokens.append(self._convert_token_to_id(i, "TOKEN_TIME")) + for pitch, velocity in time: + velocity = int(velocity > 0) + if current_velocity != velocity: + current_velocity = velocity + tokens.append(self._convert_token_to_id(velocity, "TOKEN_VELOCITY")) + tokens.append(self._convert_token_to_id(pitch, "TOKEN_NOTE")) + + total_len = len(tokens) + + # truncation + if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and max_length and total_len > max_length: + tokens, _, _ = self.truncate_sequences( + ids=tokens, + num_tokens_to_remove=total_len - max_length, + truncation_strategy=truncation_strategy, + **kwargs, + ) + + return BatchEncoding({"token_ids": tokens}) + + def batch_encode_plus( + self, + notes: Union[np.ndarray, List[pretty_midi.Note]], + truncation_strategy: Optional[TruncationStrategy] = None, + max_length: Optional[int] = None, + **kwargs, + ) -> BatchEncoding: + r""" + This is the `batch_encode_plus` method for `Pop2PianoTokenizer`. It converts the midi notes to the transformer + generated token ids. It works on multiple batches by calling `encode_plus` multiple times in a loop. + + Args: + notes (`numpy.ndarray` of shape `[batch_size, sequence_length, 4]` or `list` of `pretty_midi.Note` objects): + This represents the midi notes. If `notes` is a `numpy.ndarray`: + - Each sequence must have 4 values, they are `onset idx`, `offset idx`, `pitch` and `velocity`. + If `notes` is a `list` containing `pretty_midi.Note` objects: + - Each sequence must have 4 attributes, they are `start`, `end`, `pitch` and `velocity`. + truncation_strategy ([`~tokenization_utils_base.TruncationStrategy`], *optional*): + Indicates the truncation strategy that is going to be used during truncation. + max_length (`int`, *optional*): + Maximum length of the returned list and optionally padding length (see above). + + Returns: + `BatchEncoding` containing the tokens ids. + """ + + encoded_batch_token_ids = [] + for i in range(len(notes)): + encoded_batch_token_ids.append( + self.encode_plus( + notes[i], + truncation_strategy=truncation_strategy, + max_length=max_length, + **kwargs, + )["token_ids"] + ) + + return BatchEncoding({"token_ids": encoded_batch_token_ids}) + + def __call__( + self, + notes: Union[ + np.ndarray, + List[pretty_midi.Note], + List[List[pretty_midi.Note]], + ], + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length: Optional[int] = None, + pad_to_multiple_of: Optional[int] = None, + return_attention_mask: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + r""" + This is the `__call__` method for `Pop2PianoTokenizer`. It converts the midi notes to the transformer generated + token ids. + + Args: + notes (`numpy.ndarray` of shape `[batch_size, max_sequence_length, 4]` or `list` of `pretty_midi.Note` objects): + This represents the midi notes. + + If `notes` is a `numpy.ndarray`: + - Each sequence must have 4 values, they are `onset idx`, `offset idx`, `pitch` and `velocity`. + If `notes` is a `list` containing `pretty_midi.Note` objects: + - Each sequence must have 4 attributes, they are `start`, `end`, `pitch` and `velocity`. + padding (`bool`, `str` or [`~file_utils.PaddingStrategy`], *optional*, defaults to `False`): + Activates and controls padding. Accepts the following values: + + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + truncation (`bool`, `str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*, defaults to `False`): + Activates and controls truncation. Accepts the following values: + + - `True` or `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or + to the maximum acceptable input length for the model if that argument is not provided. This will + truncate token by token, removing a token from the longest sequence in the pair if a pair of + sequences (or a batch of pairs) is provided. + - `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the + maximum acceptable input length for the model if that argument is not provided. This will only + truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + - `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the + maximum acceptable input length for the model if that argument is not provided. This will only + truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + - `False` or `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths + greater than the model maximum admissible input size). + max_length (`int`, *optional*): + Controls the maximum length to use by one of the truncation/padding parameters. If left unset or set to + `None`, this will use the predefined model maximum length if a maximum length is required by one of the + truncation/padding parameters. If the model has no specific maximum input length (like XLNet) + truncation/padding to a maximum length will be deactivated. + pad_to_multiple_of (`int`, *optional*): + If set will pad the sequence to a multiple of the provided value. This is especially useful to enable + the use of Tensor Cores on NVIDIA hardware with compute capability `>= 7.5` (Volta). + return_attention_mask (`bool`, *optional*): + Whether to return the attention mask. If left to the default, will return the attention mask according + to the specific tokenizer's default, defined by the `return_outputs` attribute. + + [What are attention masks?](../glossary#attention-mask) + return_tensors (`str` or [`~file_utils.TensorType`], *optional*): + If set, will return tensors instead of list of python integers. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return Numpy `np.ndarray` objects. + verbose (`bool`, *optional*, defaults to `True`): + Whether or not to print more information and warnings. + + Returns: + `BatchEncoding` containing the token_ids. + """ + + # check if it is batched or not + # it is batched if its a list containing a list of `pretty_midi.Notes` where the outer list contains all the + # batches and the inner list contains all Notes for a single batch. Otherwise if np.ndarray is passed it will be + # considered batched if it has shape of `[batch_size, seqence_length, 4]` or ndim=3. + is_batched = notes.ndim == 3 if isinstance(notes, np.ndarray) else isinstance(notes[0], list) + + # get the truncation and padding strategy + padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies( + padding=padding, + truncation=truncation, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + verbose=verbose, + **kwargs, + ) + + if is_batched: + # If the user has not explicitly mentioned `return_attention_mask` as False, we change it to True + return_attention_mask = True if return_attention_mask is None else return_attention_mask + token_ids = self.batch_encode_plus( + notes=notes, + truncation_strategy=truncation_strategy, + max_length=max_length, + **kwargs, + ) + else: + token_ids = self.encode_plus( + notes=notes, + truncation_strategy=truncation_strategy, + max_length=max_length, + **kwargs, + ) + + # since we already have truncated sequnences we are just left to do padding + token_ids = self.pad( + token_ids, + padding=padding_strategy, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + return_attention_mask=return_attention_mask, + return_tensors=return_tensors, + verbose=verbose, + ) + + return token_ids + + def batch_decode( + self, + token_ids, + feature_extractor_output: BatchFeature, + return_midi: bool = True, + ): + r""" + This is the `batch_decode` method for `Pop2PianoTokenizer`. It converts the token_ids generated by the + transformer to midi_notes and returns them. + + Args: + token_ids (`Union[np.ndarray, torch.Tensor, tf.Tensor]`): + Output token_ids of `Pop2PianoConditionalGeneration` model. + feature_extractor_output (`BatchFeature`): + Denotes the output of `Pop2PianoFeatureExtractor.__call__`. It must contain `"beatstep"` and + `"extrapolated_beatstep"`. Also `"attention_mask_beatsteps"` and + `"attention_mask_extrapolated_beatstep"` + should be present if they were returned by the feature extractor. + return_midi (`bool`, *optional*, defaults to `True`): + Whether to return midi object or not. + Returns: + If `return_midi` is True: + - `BatchEncoding` containing both `notes` and `pretty_midi.pretty_midi.PrettyMIDI` objects. + If `return_midi` is False: + - `BatchEncoding` containing `notes`. + """ + + # check if they have attention_masks(attention_mask, attention_mask_beatsteps, attention_mask_extrapolated_beatstep) or not + attention_masks_present = bool( + hasattr(feature_extractor_output, "attention_mask") + and hasattr(feature_extractor_output, "attention_mask_beatsteps") + and hasattr(feature_extractor_output, "attention_mask_extrapolated_beatstep") + ) + + # if we are processing batched inputs then we must need attention_masks + if not attention_masks_present and feature_extractor_output["beatsteps"].shape[0] > 1: + raise ValueError( + "attention_mask, attention_mask_beatsteps and attention_mask_extrapolated_beatstep must be present " + "for batched inputs! But one of them were not present." + ) + + # check for length mismatch between inputs_embeds, beatsteps and extrapolated_beatstep + if attention_masks_present: + # since we know about the number of examples in token_ids from attention_mask + if ( + sum(feature_extractor_output["attention_mask"][:, 0] == 0) + != feature_extractor_output["beatsteps"].shape[0] + or feature_extractor_output["beatsteps"].shape[0] + != feature_extractor_output["extrapolated_beatstep"].shape[0] + ): + raise ValueError( + "Length mistamtch between token_ids, beatsteps and extrapolated_beatstep! Found " + f"token_ids length - {token_ids.shape[0]}, beatsteps shape - {feature_extractor_output['beatsteps'].shape[0]} " + f"and extrapolated_beatsteps shape - {feature_extractor_output['extrapolated_beatstep'].shape[0]}" + ) + if feature_extractor_output["attention_mask"].shape[0] != token_ids.shape[0]: + raise ValueError( + f"Found attention_mask of length - {feature_extractor_output['attention_mask'].shape[0]} but token_ids of length - {token_ids.shape[0]}" + ) + else: + # if there is no attention mask present then it's surely a single example + if ( + feature_extractor_output["beatsteps"].shape[0] != 1 + or feature_extractor_output["extrapolated_beatstep"].shape[0] != 1 + ): + raise ValueError( + "Length mistamtch of beatsteps and extrapolated_beatstep! Since attention_mask is not present the number of examples must be 1, " + f"But found beatsteps length - {feature_extractor_output['beatsteps'].shape[0]}, extrapolated_beatsteps length - {feature_extractor_output['extrapolated_beatstep'].shape[0]}." + ) + + if attention_masks_present: + # check for zeros(since token_ids are seperated by zero arrays) + batch_idx = np.where(feature_extractor_output["attention_mask"][:, 0] == 0)[0] + else: + batch_idx = [token_ids.shape[0]] + + notes_list = [] + pretty_midi_objects_list = [] + start_idx = 0 + for index, end_idx in enumerate(batch_idx): + each_tokens_ids = token_ids[start_idx:end_idx] + # check where the whole example ended by searching for eos_token_id and getting the upper bound + each_tokens_ids = each_tokens_ids[:, : np.max(np.where(each_tokens_ids == int(self.eos_token))[1]) + 1] + beatsteps = feature_extractor_output["beatsteps"][index] + extrapolated_beatstep = feature_extractor_output["extrapolated_beatstep"][index] + + # if attention mask is present then mask out real array/tensor + if attention_masks_present: + attention_mask_beatsteps = feature_extractor_output["attention_mask_beatsteps"][index] + attention_mask_extrapolated_beatstep = feature_extractor_output[ + "attention_mask_extrapolated_beatstep" + ][index] + beatsteps = beatsteps[: np.max(np.where(attention_mask_beatsteps == 1)[0]) + 1] + extrapolated_beatstep = extrapolated_beatstep[ + : np.max(np.where(attention_mask_extrapolated_beatstep == 1)[0]) + 1 + ] + + each_tokens_ids = to_numpy(each_tokens_ids) + beatsteps = to_numpy(beatsteps) + extrapolated_beatstep = to_numpy(extrapolated_beatstep) + + pretty_midi_object = self.relative_batch_tokens_ids_to_midi( + tokens=each_tokens_ids, + beatstep=extrapolated_beatstep, + bars_per_batch=self.num_bars, + cutoff_time_idx=(self.num_bars + 1) * 4, + ) + + for note in pretty_midi_object.instruments[0].notes: + note.start += beatsteps[0] + note.end += beatsteps[0] + notes_list.append(note) + + pretty_midi_objects_list.append(pretty_midi_object) + start_idx += end_idx + 1 # 1 represents the zero array + + if return_midi: + return BatchEncoding({"notes": notes_list, "pretty_midi_objects": pretty_midi_objects_list}) + + return BatchEncoding({"notes": notes_list}) + + +__all__ = ["Pop2PianoTokenizer"] diff --git a/docs/transformers/src/transformers/models/prompt_depth_anything/__init__.py b/docs/transformers/src/transformers/models/prompt_depth_anything/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3cb05f8e378874693b225823050dee69c206f946 --- /dev/null +++ b/docs/transformers/src/transformers/models/prompt_depth_anything/__init__.py @@ -0,0 +1,31 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_prompt_depth_anything import PromptDepthAnythingConfig + from .image_processing_prompt_depth_anything import PromptDepthAnythingImageProcessor + from .modeling_prompt_depth_anything import ( + PromptDepthAnythingForDepthEstimation, + PromptDepthAnythingPreTrainedModel, + ) +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py b/docs/transformers/src/transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py new file mode 100644 index 0000000000000000000000000000000000000000..cf213133c1424d185b7b38ca5f84d1f7e1526522 --- /dev/null +++ b/docs/transformers/src/transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py @@ -0,0 +1,171 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_prompt_depth_anything.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2025 The HuggingFace Team. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ...utils.backbone_utils import verify_backbone_config_arguments +from ..auto.configuration_auto import CONFIG_MAPPING + + +logger = logging.get_logger(__name__) + + +class PromptDepthAnythingConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`PromptDepthAnythingModel`]. It is used to instantiate a PromptDepthAnything + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the PromptDepthAnything + [LiheYoung/depth-anything-small-hf](https://huggingface.co/LiheYoung/depth-anything-small-hf) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + backbone_config (`Union[Dict[str, Any], PretrainedConfig]`, *optional*): + The configuration of the backbone model. Only used in case `is_hybrid` is `True` or in case you want to + leverage the [`AutoBackbone`] API. + backbone (`str`, *optional*): + Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this + will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone` + is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights. + use_pretrained_backbone (`bool`, *optional*, defaults to `False`): + Whether to use pretrained weights for the backbone. + use_timm_backbone (`bool`, *optional*, defaults to `False`): + Whether or not to use the `timm` library for the backbone. If set to `False`, will use the [`AutoBackbone`] + API. + backbone_kwargs (`dict`, *optional*): + Keyword arguments to be passed to AutoBackbone when loading from a checkpoint + e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set. + patch_size (`int`, *optional*, defaults to 14): + The size of the patches to extract from the backbone features. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + reassemble_hidden_size (`int`, *optional*, defaults to 384): + The number of input channels of the reassemble layers. + reassemble_factors (`List[int]`, *optional*, defaults to `[4, 2, 1, 0.5]`): + The up/downsampling factors of the reassemble layers. + neck_hidden_sizes (`List[str]`, *optional*, defaults to `[48, 96, 192, 384]`): + The hidden sizes to project to for the feature maps of the backbone. + fusion_hidden_size (`int`, *optional*, defaults to 64): + The number of channels before fusion. + head_in_index (`int`, *optional*, defaults to -1): + The index of the features to use in the depth estimation head. + head_hidden_size (`int`, *optional*, defaults to 32): + The number of output channels in the second convolution of the depth estimation head. + depth_estimation_type (`str`, *optional*, defaults to `"relative"`): + The type of depth estimation to use. Can be one of `["relative", "metric"]`. + max_depth (`float`, *optional*): + The maximum depth to use for the "metric" depth estimation head. 20 should be used for indoor models + and 80 for outdoor models. For "relative" depth estimation, this value is ignored. + + Example: + + ```python + >>> from transformers import PromptDepthAnythingConfig, PromptDepthAnythingForDepthEstimation + + >>> # Initializing a PromptDepthAnything small style configuration + >>> configuration = PromptDepthAnythingConfig() + + >>> # Initializing a model from the PromptDepthAnything small style configuration + >>> model = PromptDepthAnythingForDepthEstimation(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "prompt_depth_anything" + + def __init__( + self, + backbone_config=None, + backbone=None, + use_pretrained_backbone=False, + use_timm_backbone=False, + backbone_kwargs=None, + patch_size=14, + initializer_range=0.02, + reassemble_hidden_size=384, + reassemble_factors=[4, 2, 1, 0.5], + neck_hidden_sizes=[48, 96, 192, 384], + fusion_hidden_size=64, + head_in_index=-1, + head_hidden_size=32, + depth_estimation_type="relative", + max_depth=None, + **kwargs, + ): + super().__init__(**kwargs) + if backbone_config is None and backbone is None: + logger.info("`backbone_config` is `None`. Initializing the config with the default `Dinov2` backbone.") + backbone_config = CONFIG_MAPPING["dinov2"]( + image_size=518, + hidden_size=384, + num_attention_heads=6, + out_indices=[9, 10, 11, 12], + apply_layernorm=True, + reshape_hidden_states=False, + ) + elif isinstance(backbone_config, dict): + backbone_model_type = backbone_config.get("model_type") + config_class = CONFIG_MAPPING[backbone_model_type] + backbone_config = config_class.from_dict(backbone_config) + + verify_backbone_config_arguments( + use_timm_backbone=use_timm_backbone, + use_pretrained_backbone=use_pretrained_backbone, + backbone=backbone, + backbone_config=backbone_config, + backbone_kwargs=backbone_kwargs, + ) + + self.backbone_config = backbone_config + self.backbone = backbone + self.use_pretrained_backbone = use_pretrained_backbone + self.use_timm_backbone = use_timm_backbone + self.backbone_kwargs = backbone_kwargs + self.reassemble_hidden_size = reassemble_hidden_size + self.patch_size = patch_size + self.initializer_range = initializer_range + self.reassemble_factors = reassemble_factors + self.neck_hidden_sizes = neck_hidden_sizes + self.fusion_hidden_size = fusion_hidden_size + self.head_in_index = head_in_index + self.head_hidden_size = head_hidden_size + if depth_estimation_type not in ["relative", "metric"]: + raise ValueError("depth_estimation_type must be one of ['relative', 'metric']") + self.depth_estimation_type = depth_estimation_type + self.max_depth = max_depth if max_depth else 1 + + def to_dict(self): + """ + Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. Returns: + `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, + """ + output = copy.deepcopy(self.__dict__) + + if output["backbone_config"] is not None: + output["backbone_config"] = self.backbone_config.to_dict() + + output["model_type"] = self.__class__.model_type + return output + + +__all__ = ["PromptDepthAnythingConfig"] diff --git a/docs/transformers/src/transformers/models/prompt_depth_anything/convert_prompt_depth_anything_to_hf.py b/docs/transformers/src/transformers/models/prompt_depth_anything/convert_prompt_depth_anything_to_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..237be38fff3ed051cbe3795012cb1245d520df37 --- /dev/null +++ b/docs/transformers/src/transformers/models/prompt_depth_anything/convert_prompt_depth_anything_to_hf.py @@ -0,0 +1,292 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert Prompt Depth Anything checkpoints from the original repository. URL: +https://github.com/DepthAnything/PromptDA""" + +import argparse +import re +from pathlib import Path + +import requests +import torch +from huggingface_hub import hf_hub_download +from PIL import Image + +from transformers import ( + Dinov2Config, + PromptDepthAnythingConfig, + PromptDepthAnythingForDepthEstimation, + PromptDepthAnythingImageProcessor, +) +from transformers.utils import logging + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + + +def get_dpt_config(model_name): + if "small" in model_name or "vits" in model_name: + out_indices = [3, 6, 9, 12] + backbone_config = Dinov2Config.from_pretrained( + "facebook/dinov2-small", out_indices=out_indices, apply_layernorm=True, reshape_hidden_states=False + ) + fusion_hidden_size = 64 + neck_hidden_sizes = [48, 96, 192, 384] + elif "base" in model_name or "vitb" in model_name: + out_indices = [3, 6, 9, 12] + backbone_config = Dinov2Config.from_pretrained( + "facebook/dinov2-base", out_indices=out_indices, apply_layernorm=True, reshape_hidden_states=False + ) + fusion_hidden_size = 128 + neck_hidden_sizes = [96, 192, 384, 768] + elif "large" in model_name or "vitl" in model_name: + out_indices = [5, 12, 18, 24] + backbone_config = Dinov2Config.from_pretrained( + "facebook/dinov2-large", out_indices=out_indices, apply_layernorm=True, reshape_hidden_states=False + ) + fusion_hidden_size = 256 + neck_hidden_sizes = [256, 512, 1024, 1024] + else: + raise NotImplementedError(f"Model not supported: {model_name}") + + depth_estimation_type = "metric" + max_depth = None + + config = PromptDepthAnythingConfig( + reassemble_hidden_size=backbone_config.hidden_size, + patch_size=backbone_config.patch_size, + backbone_config=backbone_config, + fusion_hidden_size=fusion_hidden_size, + neck_hidden_sizes=neck_hidden_sizes, + depth_estimation_type=depth_estimation_type, + max_depth=max_depth, + ) + + return config + + +def transform_qkv_weights(key, value, config): + if not key.startswith("qkv_transform"): + return value + + layer_idx = int(key.split("_")[-1]) + hidden_size = config.backbone_config.hidden_size + + suffix = "bias" if "bias" in key else "weight" + return { + f"backbone.encoder.layer.{layer_idx}.attention.attention.query.{suffix}": value[:hidden_size], + f"backbone.encoder.layer.{layer_idx}.attention.attention.key.{suffix}": value[hidden_size : hidden_size * 2], + f"backbone.encoder.layer.{layer_idx}.attention.attention.value.{suffix}": value[-hidden_size:], + } + + +ORIGINAL_TO_CONVERTED_KEY_MAPPING = { + # Stem + r"pretrained.cls_token": r"backbone.embeddings.cls_token", + r"pretrained.mask_token": r"backbone.embeddings.mask_token", + r"pretrained.pos_embed": r"backbone.embeddings.position_embeddings", + r"pretrained.patch_embed.proj.(weight|bias)": r"backbone.embeddings.patch_embeddings.projection.\1", + # Backbone + r"pretrained.norm.(weight|bias)": r"backbone.layernorm.\1", + # Transformer layers + r"pretrained.blocks.(\d+).ls1.gamma": r"backbone.encoder.layer.\1.layer_scale1.lambda1", + r"pretrained.blocks.(\d+).ls2.gamma": r"backbone.encoder.layer.\1.layer_scale2.lambda1", + r"pretrained.blocks.(\d+).norm1.(weight|bias)": r"backbone.encoder.layer.\1.norm1.\2", + r"pretrained.blocks.(\d+).norm2.(weight|bias)": r"backbone.encoder.layer.\1.norm2.\2", + r"pretrained.blocks.(\d+).mlp.fc1.(weight|bias)": r"backbone.encoder.layer.\1.mlp.fc1.\2", + r"pretrained.blocks.(\d+).mlp.fc2.(weight|bias)": r"backbone.encoder.layer.\1.mlp.fc2.\2", + r"pretrained.blocks.(\d+).attn.proj.(weight|bias)": r"backbone.encoder.layer.\1.attention.output.dense.\2", + r"pretrained.blocks.(\d+).attn.qkv.(weight|bias)": r"qkv_transform_\2_\1", + # Neck + r"depth_head.projects.(\d+).(weight|bias)": r"neck.reassemble_stage.layers.\1.projection.\2", + r"depth_head.scratch.layer(\d+)_rn.weight": lambda m: f"neck.convs.{int(m.group(1)) - 1}.weight", + r"depth_head.resize_layers.(\d+).(weight|bias)": r"neck.reassemble_stage.layers.\1.resize.\2", + # Refinenet (with reversed indices) + r"depth_head.scratch.refinenet(\d+).out_conv.(weight|bias)": lambda m: f"neck.fusion_stage.layers.{4 - int(m.group(1))}.projection.{m.group(2)}", + r"depth_head.scratch.refinenet(\d+).resConfUnit1.conv1.(weight|bias)": lambda m: f"neck.fusion_stage.layers.{4 - int(m.group(1))}.residual_layer1.convolution1.{m.group(2)}", + r"depth_head.scratch.refinenet(\d+).resConfUnit1.conv2.(weight|bias)": lambda m: f"neck.fusion_stage.layers.{4 - int(m.group(1))}.residual_layer1.convolution2.{m.group(2)}", + r"depth_head.scratch.refinenet(\d+).resConfUnit2.conv1.(weight|bias)": lambda m: f"neck.fusion_stage.layers.{4 - int(m.group(1))}.residual_layer2.convolution1.{m.group(2)}", + r"depth_head.scratch.refinenet(\d+).resConfUnit2.conv2.(weight|bias)": lambda m: f"neck.fusion_stage.layers.{4 - int(m.group(1))}.residual_layer2.convolution2.{m.group(2)}", + r"depth_head.scratch.refinenet(\d+).resConfUnit_depth.0.(weight|bias)": lambda m: f"neck.fusion_stage.layers.{4 - int(m.group(1))}.prompt_depth_layer.convolution1.{m.group(2)}", + r"depth_head.scratch.refinenet(\d+).resConfUnit_depth.2.(weight|bias)": lambda m: f"neck.fusion_stage.layers.{4 - int(m.group(1))}.prompt_depth_layer.convolution2.{m.group(2)}", + r"depth_head.scratch.refinenet(\d+).resConfUnit_depth.4.(weight|bias)": lambda m: f"neck.fusion_stage.layers.{4 - int(m.group(1))}.prompt_depth_layer.convolution3.{m.group(2)}", + # Head + r"depth_head.scratch.output_conv1.(weight|bias)": r"head.conv1.\1", + r"depth_head.scratch.output_conv2.0.(weight|bias)": r"head.conv2.\1", + r"depth_head.scratch.output_conv2.2.(weight|bias)": r"head.conv3.\1", +} + + +def convert_old_keys_to_new_keys(state_dict_keys: dict = None): + """ + Convert old state dict keys to new keys using regex patterns. + """ + output_dict = {} + if state_dict_keys is not None: + for old_key in state_dict_keys: + new_key = old_key + for pattern, replacement in ORIGINAL_TO_CONVERTED_KEY_MAPPING.items(): + match = re.match(pattern, old_key) + if match: + if callable(replacement): + new_key = replacement(match) + else: + new_key = re.sub(pattern, replacement, old_key) + break + output_dict[old_key] = new_key + return output_dict + + +@torch.no_grad() +def convert_dpt_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub, verify_logits): + """ + Copy/paste/tweak model's weights to our DPT structure. + """ + + # define DPT configuration + config = get_dpt_config(model_name) + + model_name_to_repo = { + "prompt-depth-anything-vits": "depth-anything/prompt-depth-anything-vits", + "prompt-depth-anything-vits-transparent": "depth-anything/prompt-depth-anything-vits-transparent", + "prompt-depth-anything-vitl": "depth-anything/prompt-depth-anything-vitl", + } + + # load original state_dict + repo_id = model_name_to_repo[model_name] + filename = name_to_checkpoint[model_name] + filepath = hf_hub_download( + repo_id=repo_id, + filename=f"{filename}", + ) + + state_dict = torch.load(filepath, map_location="cpu", weights_only=True)["state_dict"] + state_dict = {key[9:]: state_dict[key] for key in state_dict} + + # Convert state dict using mappings + key_mapping = convert_old_keys_to_new_keys(state_dict.keys()) + new_state_dict = {} + for key, value in state_dict.items(): + new_key = key_mapping[key] + transformed_value = transform_qkv_weights(new_key, value, config) + if isinstance(transformed_value, dict): + new_state_dict.update(transformed_value) + else: + new_state_dict[new_key] = transformed_value + + # load HuggingFace model + model = PromptDepthAnythingForDepthEstimation(config) + model.load_state_dict(new_state_dict, strict=False) + model.eval() + + processor = PromptDepthAnythingImageProcessor( + do_resize=True, + size=756, + ensure_multiple_of=14, + keep_aspect_ratio=True, + do_rescale=True, + do_normalize=True, + image_mean=[0.485, 0.456, 0.406], + image_std=[0.229, 0.224, 0.225], + ) + url = "https://github.com/DepthAnything/PromptDA/blob/main/assets/example_images/image.jpg?raw=true" + image = Image.open(requests.get(url, stream=True).raw) + + prompt_depth_url = ( + "https://github.com/DepthAnything/PromptDA/blob/main/assets/example_images/arkit_depth.png?raw=true" + ) + prompt_depth = Image.open(requests.get(prompt_depth_url, stream=True).raw) + + inputs = processor(image, return_tensors="pt", prompt_depth=prompt_depth) + + # Verify forward pass + with torch.no_grad(): + outputs = model(**inputs) + predicted_depth = outputs.predicted_depth + + print("Shape of predicted depth:", predicted_depth.shape) + print("First values:", predicted_depth[0, :3, :3]) + + # assert logits + if verify_logits: + expected_shape = torch.Size([1, 756, 1008]) + if model_name == "prompt-depth-anything-vits": + expected_slice = torch.tensor( + [[3.0100, 3.0016, 3.0219], [3.0046, 3.0137, 3.0275], [3.0083, 3.0191, 3.0292]] + ) + elif model_name == "prompt-depth-anything-vits-transparent": + expected_slice = torch.tensor( + [[3.0058, 3.0397, 3.0460], [3.0314, 3.0393, 3.0504], [3.0326, 3.0465, 3.0545]] + ) + elif model_name == "prompt-depth-anything-vitl": + expected_slice = torch.tensor( + [[3.1336, 3.1358, 3.1363], [3.1368, 3.1267, 3.1414], [3.1397, 3.1385, 3.1448]] + ) + else: + raise ValueError("Not supported") + assert predicted_depth.shape == torch.Size(expected_shape) + assert torch.allclose(predicted_depth[0, :3, :3], expected_slice, atol=5e-3) # 5mm tolerance + print("Looks ok!") + + if pytorch_dump_folder_path is not None: + Path(pytorch_dump_folder_path).mkdir(exist_ok=True) + print(f"Saving model and processor to {pytorch_dump_folder_path}") + model.save_pretrained(pytorch_dump_folder_path) + processor.save_pretrained(pytorch_dump_folder_path) + + if push_to_hub: + print("Pushing model and processor to hub...") + model.push_to_hub(repo_id=f"{model_name.title()}-hf") + processor.push_to_hub(repo_id=f"{model_name.title()}-hf") + + +name_to_checkpoint = { + "prompt-depth-anything-vits": "model.ckpt", + "prompt-depth-anything-vits-transparent": "model.ckpt", + "prompt-depth-anything-vitl": "model.ckpt", +} + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--model_name", + default="prompt_depth_anything_vits", + type=str, + choices=name_to_checkpoint.keys(), + help="Name of the model you'd like to convert.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", + default=None, + type=str, + help="Path to the output PyTorch model directory.", + ) + parser.add_argument( + "--push_to_hub", + action="store_true", + help="Whether to push the model to the hub after conversion.", + ) + parser.add_argument( + "--verify_logits", + action="store_false", + required=False, + help="Whether to verify the logits after conversion.", + ) + + args = parser.parse_args() + convert_dpt_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub, args.verify_logits) diff --git a/docs/transformers/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py b/docs/transformers/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py new file mode 100644 index 0000000000000000000000000000000000000000..898475835be833cd842ffc4c4bd7a5d42b1e0edd --- /dev/null +++ b/docs/transformers/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py @@ -0,0 +1,504 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for PromptDepthAnything.""" + +import math +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union + + +if TYPE_CHECKING: + from ...modeling_outputs import DepthEstimatorOutput + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import pad, resize, to_channel_dimension_format +from ...image_utils import ( + IMAGENET_STANDARD_MEAN, + IMAGENET_STANDARD_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + get_image_size, + infer_channel_dimension_format, + is_scaled_image, + is_torch_available, + make_list_of_images, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...utils import ( + TensorType, + filter_out_non_signature_kwargs, + logging, + requires_backends, +) + + +if is_torch_available(): + import torch + + +logger = logging.get_logger(__name__) + + +def _constrain_to_multiple_of(val, multiple, min_val=0, max_val=None): + x = round(val / multiple) * multiple + + if max_val is not None and x > max_val: + x = math.floor(val / multiple) * multiple + + if x < min_val: + x = math.ceil(val / multiple) * multiple + + return x + + +def _get_resize_output_image_size( + input_image: np.ndarray, + output_size: Union[int, Iterable[int]], + keep_aspect_ratio: bool, + multiple: int, + input_data_format: Optional[Union[str, ChannelDimension]] = None, +) -> Tuple[int, int]: + output_size = (output_size, output_size) if isinstance(output_size, int) else output_size + + input_height, input_width = get_image_size(input_image, input_data_format) + output_height, output_width = output_size + + # determine new height and width + scale_height = output_height / input_height + scale_width = output_width / input_width + + if keep_aspect_ratio: + # scale as little as possible + if abs(1 - scale_width) < abs(1 - scale_height): + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + + new_height = _constrain_to_multiple_of(scale_height * input_height, multiple=multiple) + new_width = _constrain_to_multiple_of(scale_width * input_width, multiple=multiple) + + return (new_height, new_width) + + +class PromptDepthAnythingImageProcessor(BaseImageProcessor): + r""" + Constructs a PromptDepthAnything image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions. Can be overidden by `do_resize` in `preprocess`. + size (`Dict[str, int]` *optional*, defaults to `{"height": 384, "width": 384}`): + Size of the image after resizing. Can be overidden by `size` in `preprocess`. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`): + Defines the resampling filter to use if resizing the image. Can be overidden by `resample` in `preprocess`. + keep_aspect_ratio (`bool`, *optional*, defaults to `False`): + If `True`, the image is resized to the largest possible size such that the aspect ratio is preserved. Can + be overidden by `keep_aspect_ratio` in `preprocess`. + ensure_multiple_of (`int`, *optional*, defaults to 1): + If `do_resize` is `True`, the image is resized to a size that is a multiple of this value. Can be overidden + by `ensure_multiple_of` in `preprocess`. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overidden by `do_rescale` in + `preprocess`. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overidden by `rescale_factor` in `preprocess`. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess` + method. + image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + do_pad (`bool`, *optional*, defaults to `False`): + Whether to apply center padding. This was introduced in the DINOv2 paper, which uses the model in + combination with DPT. + size_divisor (`int`, *optional*): + If `do_pad` is `True`, pads the image dimensions to be divisible by this value. This was introduced in the + DINOv2 paper, which uses the model in combination with DPT. + prompt_scale_to_meter (`float`, *optional*, defaults to 0.001): + Scale factor to convert the prompt depth to meters. + """ + + model_input_names = ["pixel_values", "prompt_depth"] + + def __init__( + self, + do_resize: bool = True, + size: Dict[str, int] = None, + resample: PILImageResampling = PILImageResampling.BICUBIC, + keep_aspect_ratio: bool = False, + ensure_multiple_of: int = 1, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_pad: bool = False, + size_divisor: Optional[int] = None, + prompt_scale_to_meter: float = 0.001, # default unit is mm + **kwargs, + ): + super().__init__(**kwargs) + size = size if size is not None else {"height": 384, "width": 384} + size = get_size_dict(size) + self.do_resize = do_resize + self.size = size + self.keep_aspect_ratio = keep_aspect_ratio + self.ensure_multiple_of = ensure_multiple_of + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN + self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD + self.do_pad = do_pad + self.size_divisor = size_divisor + self.prompt_scale_to_meter = prompt_scale_to_meter + + def resize( + self, + image: np.ndarray, + size: Dict[str, int], + keep_aspect_ratio: bool = False, + ensure_multiple_of: int = 1, + resample: PILImageResampling = PILImageResampling.BICUBIC, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize an image to target size `(size["height"], size["width"])`. If `keep_aspect_ratio` is `True`, the image + is resized to the largest possible size such that the aspect ratio is preserved. If `ensure_multiple_of` is + set, the image is resized to a size that is a multiple of this value. + + Args: + image (`np.ndarray`): + Image to resize. + size (`Dict[str, int]`): + Target size of the output image. + keep_aspect_ratio (`bool`, *optional*, defaults to `False`): + If `True`, the image is resized to the largest possible size such that the aspect ratio is preserved. + ensure_multiple_of (`int`, *optional*, defaults to 1): + The image is resized to a size that is a multiple of this value. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + Resampling filter to use when resiizing the image. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be the same as the input image. + input_data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + size = get_size_dict(size) + if "height" not in size or "width" not in size: + raise ValueError(f"The size dictionary must contain the keys 'height' and 'width'. Got {size.keys()}") + + output_size = _get_resize_output_image_size( + image, + output_size=(size["height"], size["width"]), + keep_aspect_ratio=keep_aspect_ratio, + multiple=ensure_multiple_of, + input_data_format=input_data_format, + ) + return resize( + image, + size=output_size, + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + def pad_image( + self, + image: np.ndarray, + size_divisor: int, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + """ + Center pad an image to be a multiple of `multiple`. + + Args: + image (`np.ndarray`): + Image to pad. + size_divisor (`int`): + The width and height of the image will be padded to a multiple of this number. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + + def _get_pad(size, size_divisor): + new_size = math.ceil(size / size_divisor) * size_divisor + pad_size = new_size - size + pad_size_left = pad_size // 2 + pad_size_right = pad_size - pad_size_left + return pad_size_left, pad_size_right + + if input_data_format is None: + input_data_format = infer_channel_dimension_format(image) + + height, width = get_image_size(image, input_data_format) + + pad_size_left, pad_size_right = _get_pad(height, size_divisor) + pad_size_top, pad_size_bottom = _get_pad(width, size_divisor) + + padded_image = pad( + image, ((pad_size_left, pad_size_right), (pad_size_top, pad_size_bottom)), data_format=data_format + ) + return padded_image + + @filter_out_non_signature_kwargs() + def preprocess( + self, + images: ImageInput, + prompt_depth: Optional[ImageInput] = None, + do_resize: Optional[bool] = None, + size: Optional[int] = None, + keep_aspect_ratio: Optional[bool] = None, + ensure_multiple_of: Optional[int] = None, + resample: Optional[PILImageResampling] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_pad: Optional[bool] = None, + size_divisor: Optional[int] = None, + prompt_scale_to_meter: Optional[float] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: ChannelDimension = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> BatchFeature: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + prompt_depth (`ImageInput`, *optional*): + Prompt depth to preprocess, which can be sparse depth obtained from multi-view geometry or + low-resolution depth from a depth sensor. Generally has shape (height, width), where height + and width can be smaller than those of the images. It's optional and can be None, which means no prompt depth + is used. If it is None, the output depth will be a monocular relative depth. + It is recommended to provide a prompt_scale_to_meter value, which is the scale factor to convert the prompt depth + to meters. This is useful when the prompt depth is not in meters. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Size of the image after resizing. If `keep_aspect_ratio` is `True`, the image is resized to the largest + possible size such that the aspect ratio is preserved. If `ensure_multiple_of` is set, the image is + resized to a size that is a multiple of this value. + keep_aspect_ratio (`bool`, *optional*, defaults to `self.keep_aspect_ratio`): + Whether to keep the aspect ratio of the image. If False, the image will be resized to (size, size). If + True, the image will be resized to keep the aspect ratio and the size will be the maximum possible. + ensure_multiple_of (`int`, *optional*, defaults to `self.ensure_multiple_of`): + Ensure that the image size is a multiple of this value. + resample (`int`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only + has an effect if `do_resize` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image values between [0 - 1]. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation. + prompt_scale_to_meter (`float`, *optional*, defaults to `self.prompt_scale_to_meter`): + Scale factor to convert the prompt depth to meters. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `ChannelDimension.LAST`: image in (height, width, num_channels) format. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + do_resize = do_resize if do_resize is not None else self.do_resize + size = size if size is not None else self.size + size = get_size_dict(size) + keep_aspect_ratio = keep_aspect_ratio if keep_aspect_ratio is not None else self.keep_aspect_ratio + ensure_multiple_of = ensure_multiple_of if ensure_multiple_of is not None else self.ensure_multiple_of + resample = resample if resample is not None else self.resample + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + do_pad = do_pad if do_pad is not None else self.do_pad + size_divisor = size_divisor if size_divisor is not None else self.size_divisor + + images = make_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_pad=do_pad, + size_divisibility=size_divisor, + do_resize=do_resize, + size=size, + resample=resample, + ) + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if is_scaled_image(images[0]) and do_rescale: + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + preprocessed_images = [] + for image in images: + if do_resize: + image = self.resize( + image=image, + size=size, + resample=resample, + keep_aspect_ratio=keep_aspect_ratio, + ensure_multiple_of=ensure_multiple_of, + input_data_format=input_data_format, + ) + + if do_rescale: + image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + + if do_normalize: + image = self.normalize( + image=image, mean=image_mean, std=image_std, input_data_format=input_data_format + ) + + if do_pad: + image = self.pad_image(image=image, size_divisor=size_divisor, input_data_format=input_data_format) + + image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + preprocessed_images.append(image) + + images = preprocessed_images + + data = {"pixel_values": images} + if prompt_depth is not None: + # prompt_depth is a list of images with shape (height, width) + # we need to convert it to a list of images with shape (1, height, width) + prompt_depths = make_list_of_images(prompt_depth, expected_ndims=2) + + # Validate prompt_depths has same length as images + if len(prompt_depths) != len(images): + raise ValueError( + f"Number of prompt depth images ({len(prompt_depths)}) does not match number of input images ({len(images)})" + ) + + if prompt_scale_to_meter is None: + prompt_scale_to_meter = self.prompt_scale_to_meter + + processed_prompt_depths = [] + for depth in prompt_depths: + depth = to_numpy_array(depth) + depth = depth * prompt_scale_to_meter + if depth.min() == depth.max(): + # Prompt depth is invalid, min and max are the same. + # We can simply select one pixel and set it to a small value. + depth[0, 0] = depth[0, 0] + 1e-6 + depth = depth[..., None].astype(np.float32) + depth = to_channel_dimension_format(depth, data_format, input_channel_dim=input_data_format) + + processed_prompt_depths.append(depth) + prompt_depths = processed_prompt_depths + data["prompt_depth"] = prompt_depths + return BatchFeature(data=data, tensor_type=return_tensors) + + # Copied from transformers.models.dpt.image_processing_dpt.DPTImageProcessor.post_process_depth_estimation with DPT->PromptDepthAnything + def post_process_depth_estimation( + self, + outputs: "DepthEstimatorOutput", + target_sizes: Optional[Union[TensorType, List[Tuple[int, int]], None]] = None, + ) -> List[Dict[str, TensorType]]: + """ + Converts the raw output of [`DepthEstimatorOutput`] into final depth predictions and depth PIL images. + Only supports PyTorch. + + Args: + outputs ([`DepthEstimatorOutput`]): + Raw outputs of the model. + target_sizes (`TensorType` or `List[Tuple[int, int]]`, *optional*): + Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size + (height, width) of each image in the batch. If left to None, predictions will not be resized. + + Returns: + `List[Dict[str, TensorType]]`: A list of dictionaries of tensors representing the processed depth + predictions. + """ + requires_backends(self, "torch") + + predicted_depth = outputs.predicted_depth + + if (target_sizes is not None) and (len(predicted_depth) != len(target_sizes)): + raise ValueError( + "Make sure that you pass in as many target sizes as the batch dimension of the predicted depth" + ) + + results = [] + target_sizes = [None] * len(predicted_depth) if target_sizes is None else target_sizes + for depth, target_size in zip(predicted_depth, target_sizes): + if target_size is not None: + depth = torch.nn.functional.interpolate( + depth.unsqueeze(0).unsqueeze(1), size=target_size, mode="bicubic", align_corners=False + ).squeeze() + + results.append({"predicted_depth": depth}) + + return results + + +__all__ = ["PromptDepthAnythingImageProcessor"] diff --git a/docs/transformers/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py b/docs/transformers/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py new file mode 100644 index 0000000000000000000000000000000000000000..7a1e2c7e47fb96295974673d6c2598a462522a5f --- /dev/null +++ b/docs/transformers/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py @@ -0,0 +1,541 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_prompt_depth_anything.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2025 The HuggingFace Team. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from transformers.utils.generic import torch_int + +from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings +from ...modeling_outputs import DepthEstimatorOutput +from ...modeling_utils import PreTrainedModel +from ...utils.backbone_utils import load_backbone +from .configuration_prompt_depth_anything import PromptDepthAnythingConfig + + +_CONFIG_FOR_DOC = "PromptDepthAnythingConfig" + + +class PromptDepthAnythingLayer(nn.Module): + def __init__(self, config: PromptDepthAnythingConfig): + super().__init__() + self.convolution1 = nn.Conv2d( + 1, + config.fusion_hidden_size, + kernel_size=3, + stride=1, + padding=1, + bias=True, + ) + self.activation1 = nn.ReLU() + + self.convolution2 = nn.Conv2d( + config.fusion_hidden_size, + config.fusion_hidden_size, + kernel_size=3, + stride=1, + padding=1, + bias=True, + ) + self.activation2 = nn.ReLU() + + self.convolution3 = nn.Conv2d( + config.fusion_hidden_size, + config.fusion_hidden_size, + kernel_size=3, + stride=1, + padding=1, + bias=True, + ) + + def forward(self, prompt_depth: torch.Tensor) -> torch.Tensor: + hidden_state = self.convolution1(prompt_depth) + hidden_state = self.activation1(hidden_state) + hidden_state = self.convolution2(hidden_state) + hidden_state = self.activation2(hidden_state) + hidden_state = self.convolution3(hidden_state) + return hidden_state + + +class PromptDepthAnythingPreActResidualLayer(nn.Module): + """ + ResidualConvUnit, pre-activate residual unit. + + Args: + config (`[PromptDepthAnythingConfig]`): + Model configuration class defining the model architecture. + """ + + def __init__(self, config): + super().__init__() + + self.activation1 = nn.ReLU() + self.convolution1 = nn.Conv2d( + config.fusion_hidden_size, + config.fusion_hidden_size, + kernel_size=3, + stride=1, + padding=1, + bias=True, + ) + + self.activation2 = nn.ReLU() + self.convolution2 = nn.Conv2d( + config.fusion_hidden_size, + config.fusion_hidden_size, + kernel_size=3, + stride=1, + padding=1, + bias=True, + ) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + residual = hidden_state + hidden_state = self.activation1(hidden_state) + hidden_state = self.convolution1(hidden_state) + hidden_state = self.activation2(hidden_state) + hidden_state = self.convolution2(hidden_state) + + return hidden_state + residual + + +class PromptDepthAnythingFeatureFusionLayer(nn.Module): + """Feature fusion layer, merges feature maps from different stages. + + Args: + config (`[PromptDepthAnythingConfig]`): + Model configuration class defining the model architecture. + """ + + def __init__(self, config: PromptDepthAnythingConfig): + super().__init__() + + self.projection = nn.Conv2d(config.fusion_hidden_size, config.fusion_hidden_size, kernel_size=1, bias=True) + + self.residual_layer1 = PromptDepthAnythingPreActResidualLayer(config) + self.residual_layer2 = PromptDepthAnythingPreActResidualLayer(config) + self.prompt_depth_layer = PromptDepthAnythingLayer(config) + + def forward(self, hidden_state, residual=None, size=None, prompt_depth=None): + if residual is not None: + if hidden_state.shape != residual.shape: + residual = nn.functional.interpolate( + residual, size=hidden_state.shape[2:], mode="bilinear", align_corners=False + ) + hidden_state = hidden_state + self.residual_layer1(residual) + + hidden_state = self.residual_layer2(hidden_state) + + if prompt_depth is not None: + prompt_depth = nn.functional.interpolate( + prompt_depth, size=hidden_state.shape[2:], mode="bilinear", align_corners=False + ) + res = self.prompt_depth_layer(prompt_depth) + hidden_state = hidden_state + res + + modifier = {"scale_factor": 2} if size is None else {"size": size} + + hidden_state = nn.functional.interpolate( + hidden_state, + **modifier, + mode="bilinear", + align_corners=True, + ) + hidden_state = self.projection(hidden_state) + + return hidden_state + + +class PromptDepthAnythingFeatureFusionStage(nn.Module): + def __init__(self, config): + super().__init__() + self.layers = nn.ModuleList() + for _ in range(len(config.neck_hidden_sizes)): + self.layers.append(PromptDepthAnythingFeatureFusionLayer(config)) + + def forward(self, hidden_states, size=None, prompt_depth=None): + # reversing the hidden_states, we start from the last + hidden_states = hidden_states[::-1] + + fused_hidden_states = [] + fused_hidden_state = None + + for idx, (hidden_state, layer) in enumerate(zip(hidden_states, self.layers)): + size = hidden_states[idx + 1].shape[2:] if idx != (len(hidden_states) - 1) else None + + if fused_hidden_state is None: + # first layer only uses the last hidden_state + fused_hidden_state = layer(hidden_state, size=size, prompt_depth=prompt_depth) + else: + fused_hidden_state = layer(fused_hidden_state, hidden_state, size=size, prompt_depth=prompt_depth) + + fused_hidden_states.append(fused_hidden_state) + + return fused_hidden_states + + +class PromptDepthAnythingDepthEstimationHead(nn.Module): + """ + Output head consisting of 3 convolutional layers. It progressively halves the feature dimension and upsamples + the predictions to the input resolution after the first convolutional layer (details can be found in the DPT paper's + supplementary material). The final activation function is either ReLU or Sigmoid, depending on the depth estimation + type (relative or metric). For metric depth estimation, the output is scaled by the maximum depth used during pretraining. + """ + + def __init__(self, config): + super().__init__() + + self.head_in_index = config.head_in_index + self.patch_size = config.patch_size + + features = config.fusion_hidden_size + self.conv1 = nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1) + self.conv2 = nn.Conv2d(features // 2, config.head_hidden_size, kernel_size=3, stride=1, padding=1) + self.activation1 = nn.ReLU() + self.conv3 = nn.Conv2d(config.head_hidden_size, 1, kernel_size=1, stride=1, padding=0) + if config.depth_estimation_type == "relative": + self.activation2 = nn.ReLU() + elif config.depth_estimation_type == "metric": + self.activation2 = nn.Sigmoid() + else: + raise ValueError(f"Unknown depth estimation type: {config.depth_estimation_type}") + self.max_depth = config.max_depth + + def forward(self, hidden_states: List[torch.Tensor], patch_height: int, patch_width: int) -> torch.Tensor: + hidden_states = hidden_states[-1] + + predicted_depth = self.conv1(hidden_states) + target_height = torch_int(patch_height * self.patch_size) + target_width = torch_int(patch_width * self.patch_size) + predicted_depth = nn.functional.interpolate( + predicted_depth, + (target_height, target_width), + mode="bilinear", + align_corners=True, + ) + predicted_depth = self.conv2(predicted_depth) + predicted_depth = self.activation1(predicted_depth) + predicted_depth = self.conv3(predicted_depth) + predicted_depth = self.activation2(predicted_depth) + # (batch_size, 1, height, width) -> (batch_size, height, width), which + # keeps the same behavior as Depth Anything v1 & v2 + predicted_depth = predicted_depth.squeeze(dim=1) + + return predicted_depth + + +class PromptDepthAnythingPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = PromptDepthAnythingConfig + base_model_prefix = "prompt_depth_anything" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + + +class PromptDepthAnythingReassembleLayer(nn.Module): + def __init__(self, config: PromptDepthAnythingConfig, channels: int, factor: int): + super().__init__() + self.projection = nn.Conv2d(in_channels=config.reassemble_hidden_size, out_channels=channels, kernel_size=1) + + # up/down sampling depending on factor + if factor > 1: + self.resize = nn.ConvTranspose2d(channels, channels, kernel_size=factor, stride=factor, padding=0) + elif factor == 1: + self.resize = nn.Identity() + elif factor < 1: + # so should downsample + stride = torch_int(1 / factor) + self.resize = nn.Conv2d(channels, channels, kernel_size=3, stride=stride, padding=1) + + def forward(self, hidden_state): + hidden_state = self.projection(hidden_state) + hidden_state = self.resize(hidden_state) + + return hidden_state + + +class PromptDepthAnythingReassembleStage(nn.Module): + """ + This class reassembles the hidden states of the backbone into image-like feature representations at various + resolutions. + + This happens in 3 stages: + 1. Take the patch embeddings and reshape them to image-like feature representations. + 2. Project the channel dimension of the hidden states according to `config.neck_hidden_sizes`. + 3. Resizing the spatial dimensions (height, width). + + Args: + config (`[PromptDepthAnythingConfig]`): + Model configuration class defining the model architecture. + """ + + def __init__(self, config): + super().__init__() + + self.config = config + self.layers = nn.ModuleList() + for channels, factor in zip(config.neck_hidden_sizes, config.reassemble_factors): + self.layers.append(PromptDepthAnythingReassembleLayer(config, channels=channels, factor=factor)) + + def forward(self, hidden_states: List[torch.Tensor], patch_height=None, patch_width=None) -> List[torch.Tensor]: + """ + Args: + hidden_states (`List[torch.FloatTensor]`, each of shape `(batch_size, sequence_length + 1, hidden_size)`): + List of hidden states from the backbone. + """ + out = [] + + for i, hidden_state in enumerate(hidden_states): + # reshape to (batch_size, num_channels, height, width) + hidden_state = hidden_state[:, 1:] + batch_size, _, num_channels = hidden_state.shape + hidden_state = hidden_state.reshape(batch_size, patch_height, patch_width, num_channels) + hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous() + hidden_state = self.layers[i](hidden_state) + out.append(hidden_state) + + return out + + +class PromptDepthAnythingNeck(nn.Module): + """ + PromptDepthAnythingNeck. A neck is a module that is normally used between the backbone and the head. It takes a list of tensors as + input and produces another list of tensors as output. For PromptDepthAnything, it includes 2 stages: + + * PromptDepthAnythingReassembleStage + * PromptDepthAnythingFeatureFusionStage. + + Args: + config (dict): config dict. + """ + + def __init__(self, config): + super().__init__() + self.config = config + + self.reassemble_stage = PromptDepthAnythingReassembleStage(config) + + self.convs = nn.ModuleList() + for channel in config.neck_hidden_sizes: + self.convs.append(nn.Conv2d(channel, config.fusion_hidden_size, kernel_size=3, padding=1, bias=False)) + + # fusion + self.fusion_stage = PromptDepthAnythingFeatureFusionStage(config) + + def forward( + self, + hidden_states: List[torch.Tensor], + patch_height: Optional[int] = None, + patch_width: Optional[int] = None, + prompt_depth: Optional[torch.Tensor] = None, + ) -> List[torch.Tensor]: + """ + Args: + hidden_states (`List[torch.FloatTensor]`, each of shape `(batch_size, sequence_length, hidden_size)` or `(batch_size, hidden_size, height, width)`): + List of hidden states from the backbone. + """ + if not isinstance(hidden_states, (tuple, list)): + raise TypeError("hidden_states should be a tuple or list of tensors") + + if len(hidden_states) != len(self.config.neck_hidden_sizes): + raise ValueError("The number of hidden states should be equal to the number of neck hidden sizes.") + + # postprocess hidden states + hidden_states = self.reassemble_stage(hidden_states, patch_height, patch_width) + + features = [self.convs[i](feature) for i, feature in enumerate(hidden_states)] + + # fusion blocks + output = self.fusion_stage(features, prompt_depth=prompt_depth) + + return output + + +PROMPT_DEPTH_ANYTHING_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it + as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`PromptDepthAnythingConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +PROMPT_DEPTH_ANYTHING_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`DPTImageProcessor.__call__`] + for details. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + prompt_depth (`torch.FloatTensor` of shape `(batch_size, 1, height, width)`, *optional*): + Prompt depth is the sparse or low-resolution depth obtained from multi-view geometry or a + low-resolution depth sensor. It generally has shape (height, width), where height + and width can be smaller than those of the images. It is optional and can be None, which means no prompt depth + will be used. If it is None, the output will be a monocular relative depth. + The values are recommended to be in meters, but this is not necessary. + return_dict (`bool`, *optional*): + Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + """ + Prompt Depth Anything Model with a depth estimation head on top (consisting of 3 convolutional layers) e.g. for KITTI, NYUv2. + """, + PROMPT_DEPTH_ANYTHING_START_DOCSTRING, +) +class PromptDepthAnythingForDepthEstimation(PromptDepthAnythingPreTrainedModel): + _no_split_modules = ["DPTViTEmbeddings"] + + def __init__(self, config): + super().__init__(config) + + self.backbone = load_backbone(config) + self.neck = PromptDepthAnythingNeck(config) + self.head = PromptDepthAnythingDepthEstimationHead(config) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(PROMPT_DEPTH_ANYTHING_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=DepthEstimatorOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: torch.FloatTensor, + prompt_depth: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], DepthEstimatorOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*): + Ground truth depth estimation maps for computing the loss. + + Returns: + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, AutoModelForDepthEstimation + >>> import torch + >>> import numpy as np + >>> from PIL import Image + >>> import requests + + >>> url = "https://github.com/DepthAnything/PromptDA/blob/main/assets/example_images/image.jpg?raw=true" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("depth-anything/prompt-depth-anything-vits-hf") + >>> model = AutoModelForDepthEstimation.from_pretrained("depth-anything/prompt-depth-anything-vits-hf") + + >>> prompt_depth_url = "https://github.com/DepthAnything/PromptDA/blob/main/assets/example_images/arkit_depth.png?raw=true" + >>> prompt_depth = Image.open(requests.get(prompt_depth_url, stream=True).raw) + + >>> # prepare image for the model + >>> inputs = image_processor(images=image, return_tensors="pt", prompt_depth=prompt_depth) + + >>> with torch.no_grad(): + ... outputs = model(**inputs) + + >>> # interpolate to original size + >>> post_processed_output = image_processor.post_process_depth_estimation( + ... outputs, + ... target_sizes=[(image.height, image.width)], + ... ) + + >>> # visualize the prediction + >>> predicted_depth = post_processed_output[0]["predicted_depth"] + >>> depth = predicted_depth * 1000. + >>> depth = depth.detach().cpu().numpy() + >>> depth = Image.fromarray(depth.astype("uint16")) # mm + ```""" + loss = None + if labels is not None: + raise NotImplementedError("Training is not implemented yet") + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + outputs = self.backbone.forward_with_filtered_kwargs( + pixel_values, output_hidden_states=output_hidden_states, output_attentions=output_attentions + ) + hidden_states = outputs.feature_maps + + _, _, height, width = pixel_values.shape + patch_size = self.config.patch_size + patch_height = height // patch_size + patch_width = width // patch_size + + if prompt_depth is not None: + # normalize prompt depth + batch_size = prompt_depth.shape[0] + depth_min = torch.min(prompt_depth.reshape(batch_size, -1), dim=1).values + depth_max = torch.max(prompt_depth.reshape(batch_size, -1), dim=1).values + depth_min, depth_max = depth_min.view(batch_size, 1, 1, 1), depth_max.view(batch_size, 1, 1, 1) + prompt_depth = (prompt_depth - depth_min) / (depth_max - depth_min) + # normalize done + + hidden_states = self.neck(hidden_states, patch_height, patch_width, prompt_depth=prompt_depth) + + predicted_depth = self.head(hidden_states, patch_height, patch_width) + if prompt_depth is not None: + # denormalize predicted depth + depth_min = depth_min.squeeze(1).to(predicted_depth.device) + depth_max = depth_max.squeeze(1).to(predicted_depth.device) + predicted_depth = predicted_depth * (depth_max - depth_min) + depth_min + # denormalize done + + if not return_dict: + if output_hidden_states: + output = (predicted_depth,) + outputs[1:] + else: + output = (predicted_depth,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return DepthEstimatorOutput( + loss=loss, + predicted_depth=predicted_depth, + hidden_states=outputs.hidden_states if output_hidden_states else None, + attentions=outputs.attentions, + ) + + +__all__ = ["PromptDepthAnythingForDepthEstimation", "PromptDepthAnythingPreTrainedModel"] diff --git a/docs/transformers/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py b/docs/transformers/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py new file mode 100644 index 0000000000000000000000000000000000000000..aa834339ea9cf2fc916eb69a401bd4ceee80a056 --- /dev/null +++ b/docs/transformers/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py @@ -0,0 +1,386 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from transformers.models.depth_anything.configuration_depth_anything import DepthAnythingConfig +from transformers.models.depth_anything.modeling_depth_anything import ( + DepthAnythingDepthEstimationHead, + DepthAnythingFeatureFusionLayer, + DepthAnythingFeatureFusionStage, + DepthAnythingForDepthEstimation, + DepthAnythingNeck, + DepthAnythingReassembleStage, +) +from transformers.utils.generic import torch_int + +from ...file_utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + replace_return_docstrings, +) +from ...modeling_outputs import DepthEstimatorOutput +from ...modeling_utils import PreTrainedModel + + +_CONFIG_FOR_DOC = "PromptDepthAnythingConfig" + + +class PromptDepthAnythingConfig(DepthAnythingConfig): + model_type = "prompt_depth_anything" + + +class PromptDepthAnythingLayer(nn.Module): + def __init__(self, config: PromptDepthAnythingConfig): + super().__init__() + self.convolution1 = nn.Conv2d( + 1, + config.fusion_hidden_size, + kernel_size=3, + stride=1, + padding=1, + bias=True, + ) + self.activation1 = nn.ReLU() + + self.convolution2 = nn.Conv2d( + config.fusion_hidden_size, + config.fusion_hidden_size, + kernel_size=3, + stride=1, + padding=1, + bias=True, + ) + self.activation2 = nn.ReLU() + + self.convolution3 = nn.Conv2d( + config.fusion_hidden_size, + config.fusion_hidden_size, + kernel_size=3, + stride=1, + padding=1, + bias=True, + ) + + def forward(self, prompt_depth: torch.Tensor) -> torch.Tensor: + hidden_state = self.convolution1(prompt_depth) + hidden_state = self.activation1(hidden_state) + hidden_state = self.convolution2(hidden_state) + hidden_state = self.activation2(hidden_state) + hidden_state = self.convolution3(hidden_state) + return hidden_state + + +class PromptDepthAnythingFeatureFusionLayer(DepthAnythingFeatureFusionLayer): + def __init__(self, config: PromptDepthAnythingConfig): + super().__init__(config) + self.prompt_depth_layer = PromptDepthAnythingLayer(config) + + def forward(self, hidden_state, residual=None, size=None, prompt_depth=None): + if residual is not None: + if hidden_state.shape != residual.shape: + residual = nn.functional.interpolate( + residual, size=hidden_state.shape[2:], mode="bilinear", align_corners=False + ) + hidden_state = hidden_state + self.residual_layer1(residual) + + hidden_state = self.residual_layer2(hidden_state) + + if prompt_depth is not None: + prompt_depth = nn.functional.interpolate( + prompt_depth, size=hidden_state.shape[2:], mode="bilinear", align_corners=False + ) + res = self.prompt_depth_layer(prompt_depth) + hidden_state = hidden_state + res + + modifier = {"scale_factor": 2} if size is None else {"size": size} + + hidden_state = nn.functional.interpolate( + hidden_state, + **modifier, + mode="bilinear", + align_corners=True, + ) + hidden_state = self.projection(hidden_state) + + return hidden_state + + +class PromptDepthAnythingFeatureFusionStage(DepthAnythingFeatureFusionStage): + def forward(self, hidden_states, size=None, prompt_depth=None): + # reversing the hidden_states, we start from the last + hidden_states = hidden_states[::-1] + + fused_hidden_states = [] + fused_hidden_state = None + + for idx, (hidden_state, layer) in enumerate(zip(hidden_states, self.layers)): + size = hidden_states[idx + 1].shape[2:] if idx != (len(hidden_states) - 1) else None + + if fused_hidden_state is None: + # first layer only uses the last hidden_state + fused_hidden_state = layer(hidden_state, size=size, prompt_depth=prompt_depth) + else: + fused_hidden_state = layer(fused_hidden_state, hidden_state, size=size, prompt_depth=prompt_depth) + + fused_hidden_states.append(fused_hidden_state) + + return fused_hidden_states + + +class PromptDepthAnythingDepthEstimationHead(DepthAnythingDepthEstimationHead): + def forward(self, hidden_states: List[torch.Tensor], patch_height: int, patch_width: int) -> torch.Tensor: + hidden_states = hidden_states[-1] + + predicted_depth = self.conv1(hidden_states) + target_height = torch_int(patch_height * self.patch_size) + target_width = torch_int(patch_width * self.patch_size) + predicted_depth = nn.functional.interpolate( + predicted_depth, + (target_height, target_width), + mode="bilinear", + align_corners=True, + ) + predicted_depth = self.conv2(predicted_depth) + predicted_depth = self.activation1(predicted_depth) + predicted_depth = self.conv3(predicted_depth) + predicted_depth = self.activation2(predicted_depth) + # (batch_size, 1, height, width) -> (batch_size, height, width), which + # keeps the same behavior as Depth Anything v1 & v2 + predicted_depth = predicted_depth.squeeze(dim=1) + + return predicted_depth + + +PROMPT_DEPTH_ANYTHING_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it + as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`PromptDepthAnythingConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +PROMPT_DEPTH_ANYTHING_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`DPTImageProcessor.__call__`] + for details. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + prompt_depth (`torch.FloatTensor` of shape `(batch_size, 1, height, width)`, *optional*): + Prompt depth is the sparse or low-resolution depth obtained from multi-view geometry or a + low-resolution depth sensor. It generally has shape (height, width), where height + and width can be smaller than those of the images. It is optional and can be None, which means no prompt depth + will be used. If it is None, the output will be a monocular relative depth. + The values are recommended to be in meters, but this is not necessary. + return_dict (`bool`, *optional*): + Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. +""" + + +class PromptDepthAnythingPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = PromptDepthAnythingConfig + base_model_prefix = "prompt_depth_anything" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + + +class PromptDepthAnythingReassembleLayer(nn.Module): + def __init__(self, config: PromptDepthAnythingConfig, channels: int, factor: int): + super().__init__() + self.projection = nn.Conv2d(in_channels=config.reassemble_hidden_size, out_channels=channels, kernel_size=1) + + # up/down sampling depending on factor + if factor > 1: + self.resize = nn.ConvTranspose2d(channels, channels, kernel_size=factor, stride=factor, padding=0) + elif factor == 1: + self.resize = nn.Identity() + elif factor < 1: + # so should downsample + stride = torch_int(1 / factor) + self.resize = nn.Conv2d(channels, channels, kernel_size=3, stride=stride, padding=1) + + def forward(self, hidden_state): + hidden_state = self.projection(hidden_state) + hidden_state = self.resize(hidden_state) + + return hidden_state + + +class PromptDepthAnythingReassembleStage(DepthAnythingReassembleStage): + pass + + +class PromptDepthAnythingNeck(DepthAnythingNeck): + def forward( + self, + hidden_states: List[torch.Tensor], + patch_height: Optional[int] = None, + patch_width: Optional[int] = None, + prompt_depth: Optional[torch.Tensor] = None, + ) -> List[torch.Tensor]: + """ + Args: + hidden_states (`List[torch.FloatTensor]`, each of shape `(batch_size, sequence_length, hidden_size)` or `(batch_size, hidden_size, height, width)`): + List of hidden states from the backbone. + """ + if not isinstance(hidden_states, (tuple, list)): + raise TypeError("hidden_states should be a tuple or list of tensors") + + if len(hidden_states) != len(self.config.neck_hidden_sizes): + raise ValueError("The number of hidden states should be equal to the number of neck hidden sizes.") + + # postprocess hidden states + hidden_states = self.reassemble_stage(hidden_states, patch_height, patch_width) + + features = [self.convs[i](feature) for i, feature in enumerate(hidden_states)] + + # fusion blocks + output = self.fusion_stage(features, prompt_depth=prompt_depth) + + return output + + +@add_start_docstrings( + """ + Prompt Depth Anything Model with a depth estimation head on top (consisting of 3 convolutional layers) e.g. for KITTI, NYUv2. + """, + PROMPT_DEPTH_ANYTHING_START_DOCSTRING, +) +class PromptDepthAnythingForDepthEstimation(DepthAnythingForDepthEstimation): + @add_start_docstrings_to_model_forward(PROMPT_DEPTH_ANYTHING_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=DepthEstimatorOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: torch.FloatTensor, + prompt_depth: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], DepthEstimatorOutput]: + r""" + ```python + >>> from transformers import AutoImageProcessor, AutoModelForDepthEstimation + >>> import torch + >>> import numpy as np + >>> from PIL import Image + >>> import requests + + >>> url = "https://github.com/DepthAnything/PromptDA/blob/main/assets/example_images/image.jpg?raw=true" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("depth-anything/prompt-depth-anything-vits-hf") + >>> model = AutoModelForDepthEstimation.from_pretrained("depth-anything/prompt-depth-anything-vits-hf") + + >>> prompt_depth_url = "https://github.com/DepthAnything/PromptDA/blob/main/assets/example_images/arkit_depth.png?raw=true" + >>> prompt_depth = Image.open(requests.get(prompt_depth_url, stream=True).raw) + + >>> # prepare image for the model + >>> inputs = image_processor(images=image, return_tensors="pt", prompt_depth=prompt_depth) + + >>> with torch.no_grad(): + ... outputs = model(**inputs) + + >>> # interpolate to original size + >>> post_processed_output = image_processor.post_process_depth_estimation( + ... outputs, + ... target_sizes=[(image.height, image.width)], + ... ) + + >>> # visualize the prediction + >>> predicted_depth = post_processed_output[0]["predicted_depth"] + >>> depth = predicted_depth * 1000. + >>> depth = depth.detach().cpu().numpy() + >>> depth = Image.fromarray(depth.astype("uint16")) # mm + ```""" + loss = None + if labels is not None: + raise NotImplementedError("Training is not implemented yet") + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + outputs = self.backbone.forward_with_filtered_kwargs( + pixel_values, output_hidden_states=output_hidden_states, output_attentions=output_attentions + ) + hidden_states = outputs.feature_maps + + _, _, height, width = pixel_values.shape + patch_size = self.config.patch_size + patch_height = height // patch_size + patch_width = width // patch_size + + if prompt_depth is not None: + # normalize prompt depth + batch_size = prompt_depth.shape[0] + depth_min = torch.min(prompt_depth.reshape(batch_size, -1), dim=1).values + depth_max = torch.max(prompt_depth.reshape(batch_size, -1), dim=1).values + depth_min, depth_max = depth_min.view(batch_size, 1, 1, 1), depth_max.view(batch_size, 1, 1, 1) + prompt_depth = (prompt_depth - depth_min) / (depth_max - depth_min) + # normalize done + + hidden_states = self.neck(hidden_states, patch_height, patch_width, prompt_depth=prompt_depth) + + predicted_depth = self.head(hidden_states, patch_height, patch_width) + if prompt_depth is not None: + # denormalize predicted depth + depth_min = depth_min.squeeze(1).to(predicted_depth.device) + depth_max = depth_max.squeeze(1).to(predicted_depth.device) + predicted_depth = predicted_depth * (depth_max - depth_min) + depth_min + # denormalize done + + if not return_dict: + if output_hidden_states: + output = (predicted_depth,) + outputs[1:] + else: + output = (predicted_depth,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return DepthEstimatorOutput( + loss=loss, + predicted_depth=predicted_depth, + hidden_states=outputs.hidden_states if output_hidden_states else None, + attentions=outputs.attentions, + ) + + +__all__ = [ + "PromptDepthAnythingConfig", + "PromptDepthAnythingForDepthEstimation", + "PromptDepthAnythingPreTrainedModel", +] diff --git a/docs/transformers/src/transformers/models/prophetnet/__init__.py b/docs/transformers/src/transformers/models/prophetnet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7485dfbb530c3e11683c2b7dceaf35bf9edc9086 --- /dev/null +++ b/docs/transformers/src/transformers/models/prophetnet/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_prophetnet import * + from .modeling_prophetnet import * + from .tokenization_prophetnet import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/prophetnet/configuration_prophetnet.py b/docs/transformers/src/transformers/models/prophetnet/configuration_prophetnet.py new file mode 100644 index 0000000000000000000000000000000000000000..1219e1faacd42b57402ffab352924c03c3b19623 --- /dev/null +++ b/docs/transformers/src/transformers/models/prophetnet/configuration_prophetnet.py @@ -0,0 +1,180 @@ +# coding=utf-8 +# Copyright 2020 The Microsoft Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""ProphetNet model configuration""" + +from typing import Callable, Optional, Union + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class ProphetNetConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`ProphetNetModel`]. It is used to instantiate a + ProphetNet model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the ProphetNet + [microsoft/prophetnet-large-uncased](https://huggingface.co/microsoft/prophetnet-large-uncased) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + activation_dropout (`float`, *optional*, defaults to 0.1): + The dropout ratio for activations inside the fully connected layer. + activation_function (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + vocab_size (`int`, *optional*, defaults to 30522): + Vocabulary size of the ProphetNET model. Defines the number of different tokens that can be represented by + the `inputs_ids` passed when calling [`ProphetNetModel`]. + hidden_size (`int`, *optional*, defaults to 1024): + Dimensionality of the layers and the pooler layer. + encoder_ffn_dim (`int`, *optional*, defaults to 4096): + Dimensionality of the "intermediate" (often named feed-forward) layer in decoder. + num_encoder_layers (`int`, *optional*, defaults to 12): + Number of encoder layers. + num_encoder_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + decoder_ffn_dim (`int`, *optional*, defaults to 4096): + Dimensionality of the `intermediate` (often named feed-forward) layer in decoder. + num_decoder_layers (`int`, *optional*, defaults to 12): + Number of decoder layers. + num_decoder_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer decoder. + attention_dropout (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + init_std (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + add_cross_attention (`bool`, *optional*, defaults to `True`): + Whether cross-attention layers should be added to the model. + is_encoder_decoder (`bool`, *optional*, defaults to `True`): + Whether this is an encoder/decoder model. + pad_token_id (`int`, *optional*, defaults to 1) + Padding token id. + bos_token_id (`int`, *optional*, defaults to 0) + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2) + End of stream token id. + ngram (`int`, *optional*, defaults to 2) + Number of future tokens to predict. Set to 1 to be same as traditional Language model to predict next first + token. + num_buckets (`int`, *optional*, defaults to 32) + The number of buckets to use for each attention layer. This is for relative position calculation. See the + [T5 paper](see https://arxiv.org/abs/1910.10683) for more details. + relative_max_distance (`int`, *optional*, defaults to 128) + Relative distances greater than this number will be put into the last same bucket. This is for relative + position calculation. See the [T5 paper](see https://arxiv.org/abs/1910.10683) for more details. + disable_ngram_loss (`bool`, *optional*, defaults to `False`): + Whether be trained predicting only the next first token. + eps (`float`, *optional*, defaults to 0.0): + Controls the `epsilon` parameter value for label smoothing in the loss calculation. If set to 0, no label + smoothing is performed. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + """ + + model_type = "prophetnet" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = { + "num_attention_heads": "num_encoder_attention_heads", + } + + def __init__( + self, + activation_dropout: Optional[float] = 0.1, + activation_function: Optional[Union[str, Callable]] = "gelu", + vocab_size: Optional[int] = 30522, + hidden_size: Optional[int] = 1024, + encoder_ffn_dim: Optional[int] = 4096, + num_encoder_layers: Optional[int] = 12, + num_encoder_attention_heads: Optional[int] = 16, + decoder_ffn_dim: Optional[int] = 4096, + num_decoder_layers: Optional[int] = 12, + num_decoder_attention_heads: Optional[int] = 16, + attention_dropout: Optional[float] = 0.1, + dropout: Optional[float] = 0.1, + max_position_embeddings: Optional[int] = 512, + init_std: Optional[float] = 0.02, + is_encoder_decoder: Optional[bool] = True, + add_cross_attention: Optional[bool] = True, + decoder_start_token_id: Optional[int] = 0, + ngram: Optional[int] = 2, + num_buckets: Optional[int] = 32, + relative_max_distance: Optional[int] = 128, + disable_ngram_loss: Optional[bool] = False, + eps: Optional[float] = 0.0, + use_cache: Optional[bool] = True, + pad_token_id: Optional[int] = 0, + bos_token_id: Optional[int] = 1, + eos_token_id: Optional[int] = 2, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.encoder_ffn_dim = encoder_ffn_dim + self.num_encoder_layers = num_encoder_layers + self.num_encoder_attention_heads = num_encoder_attention_heads + self.decoder_ffn_dim = decoder_ffn_dim + self.num_decoder_layers = num_decoder_layers + self.num_decoder_attention_heads = num_decoder_attention_heads + self.max_position_embeddings = max_position_embeddings + self.init_std = init_std # Normal(0, this parameter) + self.activation_function = activation_function + + # parameters for prophetnet + self.ngram = ngram + self.num_buckets = num_buckets + self.relative_max_distance = relative_max_distance + self.disable_ngram_loss = disable_ngram_loss + self.eps = eps + + # 3 Types of Dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.dropout = dropout + + self.use_cache = use_cache + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + is_encoder_decoder=is_encoder_decoder, + add_cross_attention=add_cross_attention, + decoder_start_token_id=decoder_start_token_id, + **kwargs, + ) + + @property + def num_hidden_layers(self) -> int: + return self.num_encoder_layers + self.num_decoder_layers + + @num_hidden_layers.setter + def num_hidden_layers(self, value): + raise NotImplementedError( + "This model does not support the setting of `num_hidden_layers`. Please set `num_encoder_layers` and" + " `num_decoder_layers`." + ) + + +__all__ = ["ProphetNetConfig"] diff --git a/docs/transformers/src/transformers/models/prophetnet/convert_prophetnet_original_pytorch_checkpoint_to_pytorch.py b/docs/transformers/src/transformers/models/prophetnet/convert_prophetnet_original_pytorch_checkpoint_to_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..805338511d8a21f5e46998c0f70941659d446025 --- /dev/null +++ b/docs/transformers/src/transformers/models/prophetnet/convert_prophetnet_original_pytorch_checkpoint_to_pytorch.py @@ -0,0 +1,159 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert ProphetNet checkpoint.""" + +import argparse + +from torch import nn + +# transformers_old should correspond to branch `save_old_prophetnet_model_structure` here +# original prophetnet_checkpoints are saved under `patrickvonplaten/..._old` respectively +from transformers_old.modeling_prophetnet import ( + ProphetNetForConditionalGeneration as ProphetNetForConditionalGenerationOld, +) +from transformers_old.modeling_xlm_prophetnet import ( + XLMProphetNetForConditionalGeneration as XLMProphetNetForConditionalGenerationOld, +) + +from transformers import ProphetNetForConditionalGeneration, XLMProphetNetForConditionalGeneration, logging + + +logger = logging.get_logger(__name__) +logging.set_verbosity_info() + + +def convert_prophetnet_checkpoint_to_pytorch(prophetnet_checkpoint_path: str, pytorch_dump_folder_path: str): + """ + Copy/paste/tweak prohpetnet's weights to our prophetnet structure. + """ + if "xprophetnet" in prophetnet_checkpoint_path: + prophet_old = XLMProphetNetForConditionalGenerationOld.from_pretrained(prophetnet_checkpoint_path) + prophet, loading_info = XLMProphetNetForConditionalGeneration.from_pretrained( + prophetnet_checkpoint_path, output_loading_info=True + ) + else: + prophet_old = ProphetNetForConditionalGenerationOld.from_pretrained(prophetnet_checkpoint_path) + prophet, loading_info = ProphetNetForConditionalGeneration.from_pretrained( + prophetnet_checkpoint_path, output_loading_info=True + ) + + special_keys = ["key_proj", "value_proj", "query_proj"] + + mapping = { + "self_attn": "ngram_self_attn", + "cross_attn": "encoder_attn", + "cross_attn_layer_norm": "encoder_attn_layer_norm", + "feed_forward_layer_norm": "final_layer_norm", + "feed_forward": "", + "intermediate": "fc1", + "output": "fc2", + "key_proj": "k_proj", + "query_proj": "q_proj", + "value_proj": "v_proj", + "word_embeddings": "embed_tokens", + "embeddings_layer_norm": "emb_layer_norm", + "relative_pos_embeddings": "relative_linear", + "ngram_embeddings": "ngram_input_embed", + "position_embeddings": "embed_positions", + } + + for key in loading_info["missing_keys"]: + attributes = key.split(".") + + if attributes[0] == "lm_head": + model = prophet + old_model = prophet_old + else: + model = prophet.prophetnet + old_model = prophet_old.model + + is_key_init = False + for attribute in attributes: + if attribute in mapping: + old_attribute = mapping[attribute] + if not hasattr(old_model, old_attribute) and len(old_attribute) > 0: + old_attribute = attribute + elif hasattr(old_model, attribute): + old_attribute = attribute + + if attribute == "weight": + assert old_model.weight.shape == model.weight.shape, "Shapes have to match!" + model.weight = old_model.weight + logger.info(f"{attribute} is initialized.") + is_key_init = True + break + elif attribute == "bias": + assert old_model.bias.shape == model.bias.shape, "Shapes have to match!" + model.bias = old_model.bias + logger.info(f"{attribute} is initialized") + is_key_init = True + break + elif attribute in special_keys and hasattr(old_model, "in_proj_weight"): + embed_dim = old_model.in_proj_weight.shape[0] // 3 + param = getattr(model, attribute) + param.weight.shape == old_model.in_proj_weight[:embed_dim, :].shape, "Shapes have to match" + param.bias.shape == old_model.in_proj_bias[:embed_dim].shape, "Shapes have to match" + if attribute == "query_proj": + model.query_proj.weight = nn.Parameter(old_model.in_proj_weight[:embed_dim, :]) + model.query_proj.bias = nn.Parameter(old_model.in_proj_bias[:embed_dim]) + + elif attribute == "key_proj": + model.key_proj.weight = nn.Parameter(old_model.in_proj_weight[embed_dim : 2 * embed_dim, :]) + model.key_proj.bias = nn.Parameter(old_model.in_proj_bias[embed_dim : 2 * embed_dim]) + elif attribute == "value_proj": + model.value_proj.weight = nn.Parameter(old_model.in_proj_weight[2 * embed_dim :, :]) + model.value_proj.bias = nn.Parameter(old_model.in_proj_bias[2 * embed_dim :]) + is_key_init = True + break + elif attribute == "position_embeddings": + assert model.position_embeddings.weight.shape[-1] == old_model.embed_positions.weight.shape[-1], ( + "Hidden size has to match" + ) + assert model.position_embeddings.weight.shape[0] == 512, "We want 512 position_embeddings." + model.position_embeddings.weight = nn.Parameter(old_model.embed_positions.weight[:512, :]) + is_key_init = True + break + + if attribute.isdigit(): + model = model[int(attribute)] + old_model = old_model[int(old_attribute)] + else: + model = getattr(model, attribute) + + if old_attribute == "": + old_model = old_model + else: + if not hasattr(old_model, old_attribute): + raise ValueError(f"{old_model} does not have {old_attribute}") + old_model = getattr(old_model, old_attribute) + + if not is_key_init: + raise ValueError(f"{key} was not correctly initialized!") + + print(f"Saving model to {pytorch_dump_folder_path}") + prophet.save_pretrained(pytorch_dump_folder_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--prophetnet_checkpoint_path", default=None, type=str, required=True, help="Path the official PyTorch dump." + ) + parser.add_argument( + "--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model." + ) + args = parser.parse_args() + convert_prophetnet_checkpoint_to_pytorch(args.prophetnet_checkpoint_path, args.pytorch_dump_folder_path) diff --git a/docs/transformers/src/transformers/models/prophetnet/modeling_prophetnet.py b/docs/transformers/src/transformers/models/prophetnet/modeling_prophetnet.py new file mode 100644 index 0000000000000000000000000000000000000000..ddb72211b0e64695820914c42a63a548ceb9a74a --- /dev/null +++ b/docs/transformers/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -0,0 +1,2321 @@ +# coding=utf-8 +# Copyright 2020 The Microsoft Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch ProphetNet model, ported from ProphetNet repo(fairsequery_states version).""" + +import copy +import math +import warnings +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import Tensor, nn +from torch.nn import LayerNorm + +from ...activations import ACT2FN +from ...generation import GenerationMixin +from ...modeling_outputs import BaseModelOutput +from ...modeling_utils import PreTrainedModel +from ...utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from .configuration_prophetnet import ProphetNetConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "ProphenetConfig" +_CHECKPOINT_FOR_DOC = "microsoft/prophetnet-large-uncased" + + +PROPHETNET_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + Original ProphetNet code can be found [here](https://github.com/microsoft/ProphetNet). Checkpoints were converted + from original Fairseq checkpoints. For more information on the checkpoint conversion, please take a look at the + file `convert_prophetnet_original_pytorch_checkpoint_to_pytorch.py`. + + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use + it as a regular PyTorch Module and refer to the PyTorch documentation for all matters related to general usage and + behavior. + + Parameters: + config ([`ProphetNetConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +PROPHETNET_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + ProphetNet uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If + `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + decoder_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): + Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of + hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +PROPHETNET_STANDALONE_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +def softmax(hidden_state, dim, onnx_trace=False): + if onnx_trace: + return nn.functional.softmax(hidden_state.float(), dim=dim) + else: + return nn.functional.softmax(hidden_state, dim=dim, dtype=torch.float32) + + +def ngram_attention_bias(sequence_length, ngram, device, dtype): + """ + This function computes the bias for the predict stream + """ + left_block = ( + torch.ones((ngram, sequence_length, sequence_length), device=device, dtype=dtype) * torch.finfo(dtype).min + ) + right_block = left_block.detach().clone() + # create bias + for stream_idx in range(ngram): + right_block[stream_idx].fill_diagonal_(0, wrap=False) + left_block[stream_idx].triu_(-stream_idx + 1) + + left_block[:, :, 0] = 0 + return torch.cat([left_block, right_block], dim=2) + + +def compute_relative_buckets(num_buckets, max_distance, relative_positions, is_bidirectional=False): + """ + This function computes individual parts of the relative position buckets. For more detail, see paper. + """ + inv_relative_positions = -relative_positions + rel_positions_bucket = 0 + + if is_bidirectional: + num_buckets = num_buckets // 2 + rel_positions_bucket = ( + rel_positions_bucket + + torch.lt(inv_relative_positions, torch.zeros_like(inv_relative_positions)).int() * num_buckets + ) + inv_relative_positions = torch.abs(inv_relative_positions) + else: + inv_relative_positions = torch.max(inv_relative_positions, torch.zeros_like(inv_relative_positions)) + + max_exact = num_buckets // 2 + is_small = torch.lt(inv_relative_positions, max_exact) + val_if_large = max_exact + torch.log(inv_relative_positions.float() / max_exact) / math.log( + max_distance / max_exact + ) * (num_buckets - max_exact) + val_if_large = torch.min(val_if_large, torch.ones_like(val_if_large) * (num_buckets - 1)).int() + rel_positions_bucket = rel_positions_bucket + torch.where(is_small, inv_relative_positions.int(), val_if_large) + return rel_positions_bucket + + +def compute_all_stream_relative_buckets(num_buckets, max_distance, position_ids): + """ + This function computes both main and predict relative position buckets. For more detail, see paper. + """ + # main stream + main_stream_relative_positions = position_ids.unsqueeze(1).repeat(1, position_ids.size(-1), 1) + main_stream_relative_positions = main_stream_relative_positions - position_ids.unsqueeze(-1) + + # predicting stream + predicting_stream_relative_positions = torch.cat((position_ids - 1, position_ids), dim=-1).unsqueeze(1) + predicting_stream_relative_positions = predicting_stream_relative_positions.repeat(1, position_ids.size(-1), 1) + predicting_stream_relative_positions = predicting_stream_relative_positions - position_ids.unsqueeze(-1) + + # get both position buckets + main_relative_position_buckets = compute_relative_buckets( + num_buckets, max_distance, main_stream_relative_positions, is_bidirectional=False + ) + predict_relative_position_buckets = compute_relative_buckets( + num_buckets, max_distance, predicting_stream_relative_positions, is_bidirectional=False + ) + return main_relative_position_buckets, predict_relative_position_buckets + + +@dataclass +class ProphetNetSeq2SeqLMOutput(ModelOutput): + """ + Base class for sequence-to-sequence language models outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss. + logits (`torch.FloatTensor` of shape `(batch_size, decoder_sequence_length, config.vocab_size)`): + Prediction scores of the main stream language modeling head (scores for each vocabulary token before + SoftMax). + logits_ngram (`torch.FloatTensor` of shape `(batch_size, ngram * decoder_sequence_length, config.vocab_size)`): + Prediction scores of the predict stream language modeling head (scores for each vocabulary token before + SoftMax). + past_key_values (`List[torch.FloatTensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + List of `torch.FloatTensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, + num_attn_heads, decoder_sequence_length, embed_size_per_head)`). + + Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be + used (see `past_key_values` input) to speed up sequential decoding. + decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, decoder_sequence_length, hidden_size)`. + + Hidden-states of main stream of the decoder at the output of each layer plus the initial embedding outputs. + decoder_ngram_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, ngram * decoder_sequence_length, hidden_size)`. + + Hidden-states of the predict stream of the decoder at the output of each layer plus the initial embedding + outputs. + decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads, + decoder_sequence_length, decoder_sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + decoder_ngram_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads, + decoder_sequence_length, decoder_sequence_length)`. + + Attentions weights of the predict stream of the decoder, after the attention softmax, used to compute the + weighted average in the self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads, + encoder_sequence_length, decoder_sequence_length)`. + + Attentions weights of the cross-attention layer of the decoder, after the attention softmax, used to + compute the weighted average in the + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, encoder_sequence_length, hidden_size)`. + + Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads, + encoder_sequence_length, encoder_sequence_length)`. Attentions weights of the encoder, after the attention + softmax, used to compute the weighted average in the self-attention heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + logits_ngram: Optional[torch.FloatTensor] = None + past_key_values: Optional[Tuple[torch.FloatTensor]] = None + decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + decoder_ngram_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + decoder_ngram_attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + + @property + def decoder_cross_attentions(self): + warnings.warn( + "`decoder_cross_attentions` is deprecated and will be removed soon. Please use `cross_attentions`" + " instead.", + FutureWarning, + ) + return self.cross_attentions + + +@dataclass +class ProphetNetSeq2SeqModelOutput(ModelOutput): + """ + Base class for model encoder's outputs that also contains : pre-computed hidden states that can speed up sequential + decoding. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, decoder_sequence_length, hidden_size)`): + Sequence of main stream hidden-states at the output of the last layer of the decoder of the model. + + If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, + hidden_size)` is output. + last_hidden_state_ngram (`torch.FloatTensor` of shape `(batch_size,ngram * decoder_sequence_length, config.vocab_size)`, *optional*): + Sequence of predict stream hidden-states at the output of the last layer of the decoder of the model. + past_key_values (`List[torch.FloatTensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + List of `torch.FloatTensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, + num_attn_heads, decoder_sequence_length, embed_size_per_head)`). + + Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be + used (see `past_key_values` input) to speed up sequential decoding. + decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, decoder_sequence_length, hidden_size)`. + + Hidden-states of main stream of the decoder at the output of each layer plus the initial embedding outputs. + decoder_ngram_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, ngram * decoder_sequence_length, hidden_size)`. + + Hidden-states of the predict stream of the decoder at the output of each layer plus the initial embedding + outputs. + decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads, + decoder_sequence_length, decoder_sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + decoder_ngram_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads, + decoder_sequence_length, decoder_sequence_length)`. + + Attentions weights of the predict stream of the decoder, after the attention softmax, used to compute the + weighted average in the + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads, + encoder_sequence_length, decoder_sequence_length)`. + + Attentions weights of the cross-attention layer of the decoder, after the attention softmax, used to + compute the weighted average in the + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, encoder_sequence_length, hidden_size)`. + + Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads, + encoder_sequence_length, encoder_sequence_length)`. + + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + """ + + last_hidden_state: torch.FloatTensor + last_hidden_state_ngram: Optional[torch.FloatTensor] = None + past_key_values: Optional[Tuple[torch.FloatTensor]] = None + decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + decoder_ngram_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + decoder_ngram_attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + + @property + def decoder_cross_attentions(self): + warnings.warn( + "`decoder_cross_attentions` is deprecated and will be removed soon. Please use `cross_attentions`" + " instead.", + FutureWarning, + ) + return self.cross_attentions + + +@dataclass +class ProphetNetDecoderModelOutput(ModelOutput): + """ + Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding). + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, decoder_sequence_length, hidden_size)`): + Sequence of main stream hidden-states at the output of the last layer of the decoder of the model. + + If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, + hidden_size)` is output. + last_hidden_state_ngram (`torch.FloatTensor` of shape `(batch_size, ngram * decoder_sequence_length, config.vocab_size)`): + Sequence of predict stream hidden-states at the output of the last layer of the decoder of the model. + past_key_values (`List[torch.FloatTensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + List of `torch.FloatTensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, + num_attn_heads, decoder_sequence_length, embed_size_per_head)`). + + Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be + used (see `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, decoder_sequence_length, hidden_size)`. + + Hidden-states of main stream of the decoder at the output of each layer plus the initial embedding outputs. + ngram_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, ngram * decoder_sequence_length, hidden_size)`. + + Hidden-states of the predict stream of the decoder at the output of each layer plus the initial embedding + outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads, + decoder_sequence_length, decoder_sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + ngram_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads, + decoder_sequence_length, decoder_sequence_length)`. + + Attentions weights of the predict stream of the decoder, after the attention softmax, used to compute the + weighted average in the + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads, + encoder_sequence_length, decoder_sequence_length)`. + + Attentions weights of the cross-attention layer of the decoder, after the attention softmax, used to + compute the weighted average in the + """ + + last_hidden_state: torch.FloatTensor + last_hidden_state_ngram: Optional[torch.FloatTensor] = None + past_key_values: Optional[Tuple[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + hidden_states_ngram: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + ngram_attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class ProphetNetDecoderLMOutput(ModelOutput): + """ + Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding). + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss. + logits (`torch.FloatTensor` of shape `(batch_size, decoder_sequence_length, config.vocab_size)`): + Prediction scores of the main stream language modeling head (scores for each vocabulary token before + SoftMax). + logits_ngram (`torch.FloatTensor` of shape `(batch_size, ngram * decoder_sequence_length, config.vocab_size)`): + Prediction scores of the predict stream language modeling head (scores for each vocabulary token before + SoftMax). + past_key_values (`List[torch.FloatTensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + List of `torch.FloatTensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, + num_attn_heads, decoder_sequence_length, embed_size_per_head)`). + + Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be + used (see `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, decoder_sequence_length, hidden_size)`. + + Hidden-states of main stream of the decoder at the output of each layer plus the initial embedding outputs. + ngram_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, ngram * decoder_sequence_length, hidden_size)`. + + Hidden-states of the predict stream of the decoder at the output of each layer plus the initial embedding + outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads, + decoder_sequence_length, decoder_sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + ngram_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads, + decoder_sequence_length, decoder_sequence_length)`. + + Attentions weights of the predict stream of the decoder, after the attention softmax, used to compute the + weighted average in the + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads, + encoder_sequence_length, decoder_sequence_length)`. + + Attentions weights of the cross-attention layer of the decoder, after the attention softmax, used to + compute the weighted average in the + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + logits_ngram: Optional[torch.FloatTensor] = None + past_key_values: Optional[Tuple[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + hidden_states_ngram: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + ngram_attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + + +class ProphetNetPreTrainedModel(PreTrainedModel): + config_class = ProphetNetConfig + base_model_prefix = "prophetnet" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=self.config.init_std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.init_std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def _shift_right(self, input_ids): + decoder_start_token_id = self.config.decoder_start_token_id + pad_token_id = self.config.pad_token_id + + assert decoder_start_token_id is not None, ( + "self.model.config.decoder_start_token_id has to be defined. In ProphetNet it is usually set to the" + " pad_token_id. See ProphetNet docs for more information" + ) + + # shift inputs to the right + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() + shifted_input_ids[..., 0] = decoder_start_token_id + + assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + assert torch.all(shifted_input_ids >= 0).item(), "Verify that `shifted_input_ids` has only positive values" + + return shifted_input_ids + + +class ProphetNetPositionalEmbeddings(nn.Embedding): + """ + This module learns positional embeddings up to a fixed maximum size. Padding ids are ignored by either offsetting + based on padding_idx or by setting padding_idx to None and ensuring that the appropriate position ids are passed to + the forward function. + """ + + def __init__(self, config: ProphetNetConfig) -> None: + self.max_length = config.max_position_embeddings + super().__init__(config.max_position_embeddings, config.hidden_size, config.pad_token_id) + + def forward(self, inputs_shape, device, attention_mask=None, past_key_values=None, position_ids=None): + assert (position_ids is None) or (self.padding_idx is None), ( + "If position_ids is pre-computed then padding_idx should not be set." + ) + + if position_ids is None: + if past_key_values is not None: + # position_ids is the same for every token when decoding a single step + # Without the int() cast, it doesn't work in some cases when exporting to ONNX + prev_num_input_ids = past_key_values[0][0].shape[2] + num_input_ids = inputs_shape[1] + prev_num_input_ids + position_ids = torch.ones((1, 1), dtype=torch.long, device=device) * ( + int(self.padding_idx + num_input_ids) + ) + else: + if attention_mask is None: + attention_mask = torch.ones(inputs_shape, dtype=torch.long, device=device) + + # retrieve position_ids from input_ids / attention_mask + position_ids = ( + torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask + ).long() + self.padding_idx + + # make sure position_ids are not bigger then max_length + position_ids = position_ids.clamp(0, self.max_length - 1) + + return super().forward(position_ids), position_ids + + def _forward(self, position_ids): + return super().forward(position_ids) + + +class ProphetNetAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + config: ProphetNetConfig, + num_attn_heads: int, + ): + super().__init__() + hidden_size = config.hidden_size + + self.attention_dropout = config.attention_dropout + self.dropout = config.dropout + self.num_attn_heads = num_attn_heads + self.head_dim = hidden_size // num_attn_heads + + assert self.head_dim * num_attn_heads == hidden_size, ( + "`config.hidden_size` must be divisible by `config.num_encoder_attention_heads` and" + " `config.num_decoder_attention_heads`" + ) + + self.key_proj = nn.Linear(hidden_size, hidden_size) + self.value_proj = nn.Linear(hidden_size, hidden_size) + self.query_proj = nn.Linear(hidden_size, hidden_size) + + self.out_proj = nn.Linear(hidden_size, hidden_size) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_attn_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states, + key_value_states: Optional[Tensor] = None, + attention_mask: Optional[Tensor] = None, + layer_head_mask: Optional[Tensor] = None, + past_key_value: Optional[Tuple[Tensor]] = None, + output_attentions: bool = False, + ) -> Tuple[Tensor, Optional[Tensor]]: + batch_size, tgt_len, hidden_size = hidden_states.size() + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + assert list(hidden_states.size()) == [ + batch_size, + tgt_len, + hidden_size, + ], f"Size of hidden states should be {batch_size, tgt_len, hidden_size}, but is {hidden_states.size()}" + + # previous time steps are cached - no need to recompute key and value if they are static + query_states = self.query_proj(hidden_states) / (self.head_dim**0.5) + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.key_proj(key_value_states), -1, batch_size) + value_states = self._shape(self.value_proj(key_value_states), -1, batch_size) + else: + # self_attention + key_states = self._shape(self.key_proj(hidden_states), -1, batch_size) + value_states = self._shape(self.value_proj(hidden_states), -1, batch_size) + + if is_cross_attention: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + # project states into the correct shape + proj_shape = (batch_size, self.num_attn_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, batch_size).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + src_len = key_states.size(2) + attn_weights = torch.einsum("bsij,bsjk->bsik", query_states, key_states.transpose(2, 3)) + expected_shape = (batch_size, self.num_attn_heads, tgt_len, src_len) + if attn_weights.size() != expected_shape: + raise ValueError(f"Attention weights should have size {expected_shape}, but is {attn_weights.size()}") + + # This is part of a workaround to get around fork/join parallelism not supporting Optional types. + if attention_mask is not None and attention_mask.dim() == 0: + attention_mask = None + + expected_shape = (batch_size, self.num_attn_heads, 1, src_len) + if attention_mask is not None and attention_mask.size() != expected_shape: + raise ValueError(f"Attention mask should have size {expected_shape}, but is {attention_mask.size()}") + if attention_mask is not None: # don't attend to padding symbols + attn_weights = attn_weights + attention_mask + if output_attentions: + attn_weights_reshaped = attn_weights + else: + attn_weights_reshaped = None + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + assert layer_head_mask.size() == (self.num_attn_heads,), ( + f"Head mask for a single layer should be of size {(self.num_attn_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view( + batch_size, self.num_attn_heads, tgt_len, src_len + ) + + # apply head_mask also on attn_weights_reshaped which is used for n-gram attention inside the model + attn_weights_reshaped = layer_head_mask.view(1, -1, 1, 1) * attn_weights_reshaped + + attn_probs = nn.functional.dropout( + attn_weights, + p=self.attention_dropout, + training=self.training, + ) + attn_output = torch.einsum("bsij,bsjk->bsik", attn_probs, value_states) + expected_shape = (batch_size, self.num_attn_heads, tgt_len, self.head_dim) + if attn_output.size() != expected_shape: + raise ValueError(f"`attn_output` should have shape {expected_shape}, but is of shape {attn_output.size()}") + + attn_output = attn_output.transpose(1, 2).reshape(batch_size, tgt_len, hidden_size) + attn_output = self.out_proj(attn_output) + + attn_output = nn.functional.dropout(attn_output, p=self.dropout, training=self.training) + return attn_output, attn_weights_reshaped, past_key_value + + +class ProphetNetFeedForward(nn.Module): + """ + This is the residual two feed-forward layer block based on the original Transformer implementation. + """ + + def __init__(self, config: ProphetNetConfig, ffn_dim: int): + super().__init__() + self.activation_fn = ACT2FN[config.activation_function] + self.intermediate = nn.Linear(config.hidden_size, ffn_dim) + self.output = nn.Linear(ffn_dim, config.hidden_size) + self.activation_dropout = config.activation_dropout + self.dropout = config.dropout + + def forward(self, hidden_states): + hidden_states = self.intermediate(hidden_states) + hidden_states = self.activation_fn(hidden_states) + + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.output(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + return hidden_states + + +class ProphetNetNgramSelfAttention(nn.Module): + def __init__(self, config: ProphetNetConfig): + super().__init__() + self.hidden_size = config.hidden_size + + self.num_buckets = config.num_buckets + self.relative_max_distance = config.relative_max_distance + self.num_attn_heads = config.num_decoder_attention_heads + self.dropout = config.dropout + self.attention_dropout = config.attention_dropout + self.head_dim = config.hidden_size // self.num_attn_heads + self.ngram = config.ngram + + assert self.head_dim * self.num_attn_heads == config.hidden_size, ( + "config.hidden_size must be divisible by num_attn_heads" + ) + # key, value, query projection + self.key_proj = nn.Linear(config.hidden_size, config.hidden_size) + self.value_proj = nn.Linear(config.hidden_size, config.hidden_size) + self.query_proj = nn.Linear(config.hidden_size, config.hidden_size) + + # out projection + self.out_proj = nn.Linear(config.hidden_size, config.hidden_size) + + # rel position embeddings + self.relative_pos_embeddings = nn.Linear(config.hidden_size, self.num_buckets * self.num_attn_heads) + + # for onnx runtime + self.onnx_trace = False + + def _shape(self, tensor, seq_len, batch_size): + return tensor.view(batch_size, seq_len, self.num_attn_heads, self.head_dim).transpose(1, 2).contiguous() + + def prepare_for_onnx_export_(self): + self.onnx_trace = True + + def forward( + self, + hidden_states, + past_key_value: Optional[Tuple[Tensor]] = None, + attention_mask=None, + layer_head_mask=None, + extended_predict_attention_mask=None, + main_relative_position_buckets=None, + predict_relative_position_buckets=None, + position_ids=None, + ): + batch_size, ngram_sequence_length, hidden_size = hidden_states.size() + assert list(hidden_states.size()) == [batch_size, ngram_sequence_length, hidden_size], ( + f"`hidden_states` should be of shape {batch_size, ngram_sequence_length, hidden_size}, but is of shape" + f" {hidden_states.shape}" + ) + + # project + query_states = self.query_proj(hidden_states) + key_states = self.key_proj(hidden_states) + value_states = self.value_proj(hidden_states) + + # normalize + query_states = query_states / (self.head_dim**0.5) + + # reshape + query_states = self._shape(query_states, ngram_sequence_length, batch_size) + key_states = self._shape(key_states, -1, batch_size) + value_states = self._shape(value_states, -1, batch_size) + proj_shape = (batch_size, self.num_attn_heads, -1, self.head_dim) + + query_states = query_states.view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + # chunk into main stream and predict stream + hidden_states_list = hidden_states.chunk(1 + self.ngram, dim=1) + query_states_list = query_states.chunk(1 + self.ngram, dim=2) + key_states_list = key_states.chunk(1 + self.ngram, dim=2) + value_states_list = value_states.chunk(1 + self.ngram, dim=2) + + main_hidden_states, hidden_states_predict_list = hidden_states_list[0], hidden_states_list[1:] + main_query_states, predict_query_states_list = query_states_list[0], query_states_list[1:] + main_key_states, predict_key_states_list = key_states_list[0], key_states_list[1:] + main_value_states, predict_value_states_list = value_states_list[0], value_states_list[1:] + + # saved states are stored with shape (batch_size, num_attn_heads, seq_len, head_dim) + if past_key_value is not None: + prev_main_key_states = past_key_value[0] + main_key_states = torch.cat((prev_main_key_states, main_key_states), dim=2) + prev_main_value_states = past_key_value[1] + main_value_states = torch.cat((prev_main_value_states, main_value_states), dim=2) + + # Update cache + past_key_value = (main_key_states, main_value_states) + + # get seq_length of main stream only + sequence_length = ngram_sequence_length // (1 + self.ngram) + + # MAIN-STREAM + # main attn weights + # [batch_size, number_heads, sequence_length, head_dimesion] + # x [batch_size, number_heads, head_dimesion, sequence_length] + # -> [batch_size, number_heads, sequence_length, sequence_length] + main_attn_weights = torch.einsum("bntc,bncs->bnts", main_query_states, main_key_states.transpose(2, 3)) + + # retrieve relative position embeddings for each layer -> see paper for more details + main_relative_pos_embeddings = self.get_main_relative_pos_embeddings( + main_hidden_states, main_attn_weights, position_ids, main_relative_position_buckets + ) + + main_attn_weights = main_attn_weights + main_relative_pos_embeddings + + if attention_mask is not None: + main_attn_weights = main_attn_weights + attention_mask + + main_attn_probs = softmax( + main_attn_weights, + dim=-1, + onnx_trace=self.onnx_trace, + ).type_as(main_attn_weights) + + if layer_head_mask is not None: + assert layer_head_mask.size() == (self.num_attn_heads,), ( + f"Head mask for a single layer should be of size {(self.num_attn_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + main_attn_probs = layer_head_mask.view(1, -1, 1, 1) * main_attn_probs.view( + batch_size, self.num_attn_heads, -1, sequence_length + ) + + main_attn_probs = nn.functional.dropout(main_attn_probs, p=self.attention_dropout, training=self.training) + # project to attn_output + # [batch_size, number_heads, sequence_length, sequence_length] + # x [batch_size, number_heads, sequence_length, head_dimesion] + # -> [batch_size, number_heads, sequence_length, head_dimesion] + main_attn_output = torch.einsum("bntc,bncs->bnts", main_attn_probs, main_value_states) + # reshape so that num_heads dim is merged into last `head_dim` axis + main_attn_output = main_attn_output.transpose(1, 2).reshape(batch_size, 1, sequence_length, hidden_size) + main_attn_output = self.out_proj(main_attn_output) + + # PREDICT-STREAM + # [batch_size, ngram, number_heads, sequence_length, head_dimesion] + predict_query_states = torch.stack(predict_query_states_list, 1).view( + batch_size, self.ngram, self.num_attn_heads, sequence_length, self.head_dim + ) + + # [batch_size, ngram, number_heads, 2*sequence_length, head_dimesion] + predict_key_states = torch.stack([torch.cat([main_key_states, key], 2) for key in predict_key_states_list], 1) + + # [batch_size, sequence_length, ngram, hidden_size] + predict_hidden_states = torch.stack(hidden_states_predict_list, dim=2) + + # [batch_size, number_heads, ngram, 2*sequence_length, head_dimesion] + predict_value_states = torch.cat( + [torch.cat([main_value_states, v_p], 2).unsqueeze(2) for v_p in predict_value_states_list], 2 + ) + + # [batch_size, ngram, number_heads, sequence_length, head_dimesion] + # x [batch_size, ngram, number_heads, 2*sequence_length, head_dimesion] + # -> [batch_size, ngram, number_heads, sequence_length, 2*sequence_length] + predict_attn_weights = torch.einsum("bnhtc,bnhsc->bnhts", (predict_query_states, predict_key_states)) + + # retrieve relative position embeddings for each layer -> see paper for more details + # [batch_size, ngram, number_heads, sequence_length, predict_relative_pos_embeddings] + predict_relative_pos_embeddings = self.get_predict_relative_pos_embeddings( + predict_hidden_states, predict_attn_weights, position_ids, predict_relative_position_buckets + ) + + # [batch_size, ngram, number_heads, sequence_length, 2*sequence_length] + predict_attn_weights = predict_attn_weights + predict_relative_pos_embeddings + + if extended_predict_attention_mask is not None: + # Permuting Predict attention mask to [batch_size, ngram, number_heads, sequence_length, 2*sequence_length] + extended_predict_attention_mask = extended_predict_attention_mask.permute(0, 2, 1, 3, 4) + extended_predict_attention_mask = extended_predict_attention_mask.to(predict_attn_weights.dtype) + predict_attn_weights = predict_attn_weights + extended_predict_attention_mask + + predict_attn_probs = softmax( + predict_attn_weights, + dim=-1, + onnx_trace=self.onnx_trace, + ).type_as(predict_attn_weights) + + if layer_head_mask is not None: + assert layer_head_mask.size() == (self.num_attn_heads,), ( + f"Head mask for a single layer should be of size {(self.num_attn_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + predict_attn_probs = layer_head_mask.view(1, 1, -1, 1, 1) * predict_attn_probs + + predict_attn_probs = nn.functional.dropout( + predict_attn_probs, p=self.attention_dropout, training=self.training + ) + # project to attention output + # [batch_size, ngram, number_heads, sequence_length, 2*sequence_length] + # x [batch_size, ngram, number_heads, 2*sequence_length, head_dimesion] + # -> [batch_size, ngram, number_heads, sequence_length, head_dimesion] + predict_attn_output = torch.einsum( + "bnhts,bnhsc->bnhtc", (predict_attn_probs, predict_value_states.transpose(1, 2)) + ) + + # reshape so that num_heads dim is merged into last `head_dim` axis + # [batch_size, ngram, number_heads, sequence_length, head_dimesion] -> [batch_size, ngram, sequence_length, hidden_size] + predict_attn_output = predict_attn_output.transpose(2, 3) + predict_attn_output = predict_attn_output.reshape(batch_size, self.ngram, sequence_length, hidden_size) + predict_attn_output = self.out_proj(predict_attn_output) + + # concat to single attn output + # [batch_size, (1+ngram)*sequence_length, hidden_size] + attn_output = torch.cat([main_attn_output, predict_attn_output], 1).view(batch_size, -1, hidden_size) + # reshape into better form for `config.output_attentions` + main_attn_probs = main_attn_probs.view(batch_size, self.num_attn_heads, sequence_length, -1) + + attn_output = nn.functional.dropout(attn_output, p=self.dropout, training=self.training) + + return attn_output, main_attn_probs, predict_attn_probs, past_key_value + + def get_main_relative_pos_embeddings( + self, hidden_states, attn_weights, position_ids, main_relative_position_buckets + ): + # input hidden_states [batch_size, sequence_length, hidden_size] + # input attn_weights [batch_size, num_heads, sequence_length, sequence_length] + # input position_ids [batch_size, sequence_length] or [1,1] + batch_size, num_attn_heads, tgt_len, src_len = attn_weights.shape + attn_weights = attn_weights.view(batch_size, num_attn_heads, tgt_len, src_len) + if main_relative_position_buckets is None: + batch_size, sequence_length = hidden_states.shape[:2] + relative_positions = ( + torch.arange(1, attn_weights.shape[-1] + 1) + .unsqueeze(0) + .unsqueeze(0) + .repeat(batch_size, sequence_length, 1) + .to(position_ids.device) + ) + # [batch_size, sequence_length, sequence_length+1] + relative_positions = relative_positions - position_ids.unsqueeze(0).repeat(batch_size, sequence_length, 1) + main_relative_position_buckets = compute_relative_buckets( + self.num_buckets, self.relative_max_distance, relative_positions, False + ) + + # [batch_size, sequence_length, num_buckets * num_heads] + rel_pos_embeddings = self.relative_pos_embeddings(hidden_states) + rel_pos_embeddings = rel_pos_embeddings.view( + rel_pos_embeddings.shape[:2] + (self.num_buckets, self.num_attn_heads) + ) + rel_pos_embeddings = rel_pos_embeddings.permute(0, 3, 1, 2) + # [batch_size, num_heads, sequence_length, num_buckets] + rel_pos_embeddings = rel_pos_embeddings.reshape(attn_weights.shape[:3] + (-1,)) + + main_relative_position_buckets = main_relative_position_buckets.repeat(1, self.num_attn_heads, 1) + # [batch_size * num_heads * sequence_length, sequence_length] + main_relative_position_buckets = main_relative_position_buckets.view( + -1, main_relative_position_buckets.shape[-1] + ) + main_relative_position_buckets = main_relative_position_buckets.long() + # [batch_size * num_heads * sequence_length, sequence_length] + rel_pos_embeddings = rel_pos_embeddings.reshape(-1, rel_pos_embeddings.size(-1)) + + main_relative_pos_embeddings = torch.gather(rel_pos_embeddings, dim=1, index=main_relative_position_buckets) + main_relative_pos_embeddings = main_relative_pos_embeddings.view(batch_size, num_attn_heads, tgt_len, -1) + return main_relative_pos_embeddings + + def get_predict_relative_pos_embeddings( + self, hidden_states, attn_weights, position_ids, predict_relative_position_buckets + ): + # input hidden_states [batch_size, sequence_length, ngram, hidden_size] + # input attn_weights [batch_size, ngram, num_heads, sequence_length, 2*sequence_length] + # input position_ids [batch_size, sequence_length] or [1,1] + # input predict_relative_position_buckets [batch_size, sequence_length, 2*sequence_length] or None + batch_size, sequence_length = hidden_states.shape[0:2] + + if predict_relative_position_buckets is None: + key_sequence_length = attn_weights.shape[-1] + assert position_ids[0][0] == key_sequence_length - 1, ( + "`position_ids` are incorrect. They should be of the format 1 2 3 4 5 ... (key_sequence_length - 1)" + ) + relative_positions = ( + torch.arange(0, key_sequence_length) + .unsqueeze(0) + .unsqueeze(0) + .repeat(batch_size, sequence_length, 1) + .to(position_ids.device) + ) + + relative_positions = relative_positions - position_ids.unsqueeze(0).repeat(batch_size, sequence_length, 1) + predict_relative_position_buckets = compute_relative_buckets( + self.num_buckets, self.relative_max_distance, relative_positions, False + ) + + # [batch_size, ngram, sequence_length, hidden_size] + hidden_states = hidden_states.transpose(1, 2) + rel_pos_embeddings = self.relative_pos_embeddings(hidden_states) + + # [batch_size, ngram, sequence_length, num_buckets, num_heads] + rel_pos_embeddings = rel_pos_embeddings.view( + hidden_states.shape[:-1] + (self.num_buckets, self.num_attn_heads) + ) + rel_pos_embeddings = rel_pos_embeddings.permute(0, 2, 1, 4, 3) + # [batch_size * ngram * sequence_length * num_heads, num_buckets] + rel_pos_embeddings = rel_pos_embeddings.reshape(-1, self.num_buckets) + # [ngram, batch_size, num_heads * sequence_length, -1] + predict_relative_position_buckets = predict_relative_position_buckets.unsqueeze(0) + predict_relative_position_buckets = predict_relative_position_buckets.repeat( + self.ngram, 1, self.num_attn_heads, 1 + ) + # [ngram * batch_size * num_heads * sequence_length, -1] + predict_relative_position_buckets = predict_relative_position_buckets.view( + -1, predict_relative_position_buckets.size(-1) + ).long() + + predict_relative_pos_embeddings = torch.gather( + rel_pos_embeddings, dim=1, index=predict_relative_position_buckets + ) + + # [batch_size, gram, num_heads, sequence_length, -1] + predict_relative_pos_embeddings = predict_relative_pos_embeddings.view( + batch_size, self.ngram, self.num_attn_heads, sequence_length, -1 + ) + + return predict_relative_pos_embeddings + + +class ProphetNetEncoderLayer(nn.Module): + """ + Encoder block for Prophetnet + """ + + def __init__(self, config: ProphetNetConfig): + super().__init__() + # 1st residual block + self.self_attn = ProphetNetAttention(config, config.num_encoder_attention_heads) + self.self_attn_layer_norm = LayerNorm(config.hidden_size) + + # 2nd residual block + self.feed_forward = ProphetNetFeedForward(config, config.encoder_ffn_dim) + self.feed_forward_layer_norm = LayerNorm(config.hidden_size) + + def forward( + self, + hidden_states, + attention_mask, + layer_head_mask, + output_attentions: bool = False, + ): + # 1st residual block + attention_output, attn_weights, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = self.self_attn_layer_norm(attention_output + hidden_states) + + # 2nd residual block + feed_forward_output = self.feed_forward(hidden_states) + hidden_states = self.feed_forward_layer_norm(feed_forward_output + hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class ProphetNetDecoderLayer(nn.Module): + """ + Decoder block for Prophetnet + """ + + def __init__(self, config: ProphetNetConfig): + super().__init__() + # 1st residual block + self.self_attn = ProphetNetNgramSelfAttention(config) + self.self_attn_layer_norm = LayerNorm(config.hidden_size) + + # 2nd residual block + if config.add_cross_attention: + self.cross_attn = ProphetNetAttention(config, config.num_decoder_attention_heads) + self.cross_attn_layer_norm = LayerNorm(config.hidden_size) + + # 3rd residual block + self.feed_forward = ProphetNetFeedForward(config, config.decoder_ffn_dim) + self.feed_forward_layer_norm = LayerNorm(config.hidden_size) + + def forward( + self, + hidden_states, + attention_mask=None, + encoder_hidden_states=None, + encoder_attn_mask=None, + layer_head_mask=None, + cross_attn_layer_head_mask=None, + extended_predict_attention_mask=None, + main_relative_position_buckets=None, + predict_relative_position_buckets=None, + position_ids=None, + past_key_value=None, + use_cache: bool = True, + output_attentions: bool = False, + ): + # 1st residual block + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + ngram_attention_output, self_attn_weights, self_attn_weights_ngram, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=self_attn_past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + extended_predict_attention_mask=extended_predict_attention_mask, + main_relative_position_buckets=main_relative_position_buckets, + predict_relative_position_buckets=predict_relative_position_buckets, + position_ids=position_ids, + ) + hidden_states = self.self_attn_layer_norm(hidden_states + ngram_attention_output) + + # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + cross_attn_weights = None + if encoder_hidden_states is not None: + # 2nd residual block + attention_output, cross_attn_weights, cross_attn_present_key_value = self.cross_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attn_mask, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=cross_attn_past_key_value, + output_attentions=output_attentions, + ) + hidden_states = self.cross_attn_layer_norm(attention_output + hidden_states) + + # add cross-attn to positions 3,4 of present_key_value tuple + present_key_value = present_key_value + cross_attn_present_key_value + + # 3rd residual block + feed_forward_output = self.feed_forward(hidden_states) + hidden_states = self.feed_forward_layer_norm(feed_forward_output + hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, self_attn_weights_ngram, cross_attn_weights) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +@add_start_docstrings( + "The standalone encoder part of the ProphetNetModel.", + PROPHETNET_START_DOCSTRING, +) +class ProphetNetEncoder(ProphetNetPreTrainedModel): + r""" + word_embeddings (`torch.nn.Embeddings` of shape `(config.vocab_size, config.hidden_size)`, *optional*): + The word embedding parameters. This can be used to initialize [`ProphetNetEncoder`] with pre-defined word + embeddings instead of randomly initialized word embeddings. + """ + + def __init__(self, config: ProphetNetConfig, word_embeddings: nn.Embedding = None): + super().__init__(config) + + self.word_embeddings = ( + word_embeddings + if word_embeddings is not None + else nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + ) + self.position_embeddings = ProphetNetPositionalEmbeddings(config) + self.embeddings_layer_norm = LayerNorm(config.hidden_size) + + self.layers = nn.ModuleList([ProphetNetEncoderLayer(config) for _ in range(config.num_encoder_layers)]) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.word_embeddings + + def set_input_embeddings(self, value): + self.word_embeddings = value + + @add_start_docstrings_to_model_forward(PROPHETNET_STANDALONE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, ProphetNetEncoder + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/prophetnet-large-uncased") + >>> model = ProphetNetEncoder.from_pretrained("patrickvonplaten/prophetnet-large-uncased-standalone") + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> last_hidden_states = outputs.last_hidden_state + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is None and inputs_embeds is None: + raise ValueError("Either input_ids or inputs_embeds has to be passed.") + elif input_ids is not None and inputs_embeds is not None: + raise ValueError("Make sure to only pass input_ids or inputs_embeds.") + elif input_ids is not None and inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + # prepare attention mask + if attention_mask is not None: + extended_attention_mask = ( + 1.0 - attention_mask[:, None, None, :].repeat(1, self.config.num_encoder_attention_heads, 1, 1) + ) * torch.finfo(self.dtype).min + extended_attention_mask = extended_attention_mask.to(inputs_embeds.dtype) + else: + extended_attention_mask = None + + position_embeddings, position_ids = self.position_embeddings(inputs_embeds.shape[:2], inputs_embeds.device) + + hidden_states = inputs_embeds + position_embeddings + hidden_states = self.embeddings_layer_norm(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.config.dropout, training=self.training) + + encoder_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + assert head_mask.size()[0] == (len(self.layers)), ( + f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." + ) + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_hidden_states = encoder_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + extended_attention_mask, + (head_mask[idx] if head_mask is not None else None), + output_attentions, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask=extended_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_hidden_states = encoder_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_hidden_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_hidden_states, attentions=all_attentions + ) + + +@add_start_docstrings( + "The standalone decoder part of the ProphetNetModel.", + PROPHETNET_START_DOCSTRING, +) +class ProphetNetDecoder(ProphetNetPreTrainedModel): + r""" + word_embeddings (`torch.nn.Embeddings` of shape `(config.vocab_size, config.hidden_size)`, *optional*): + The word embedding parameters. This can be used to initialize [`ProphetNetEncoder`] with pre-defined word + embeddings instead of randomly initialized word embeddings. + """ + + def __init__(self, config: ProphetNetConfig, word_embeddings: Optional[nn.Embedding] = None): + super().__init__(config) + + self.ngram = config.ngram + self.num_buckets = config.num_buckets + self.relative_max_distance = config.relative_max_distance + self.dropout = config.dropout + self.max_target_positions = config.max_position_embeddings + + self.word_embeddings = ( + word_embeddings + if word_embeddings is not None + else nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + ) + self.position_embeddings = ProphetNetPositionalEmbeddings(config) + + self.ngram_embeddings = nn.Embedding(self.ngram, config.hidden_size, None) + self.layers = nn.ModuleList([ProphetNetDecoderLayer(config) for _ in range(config.num_decoder_layers)]) + self.embeddings_layer_norm = LayerNorm(config.hidden_size) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.word_embeddings + + def set_input_embeddings(self, value): + self.word_embeddings = value + + @add_start_docstrings_to_model_forward(PROPHETNET_STANDALONE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=ProphetNetDecoderModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, ProphetNetDecoderModelOutput]: + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, ProphetNetDecoder + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/prophetnet-large-uncased") + >>> model = ProphetNetDecoder.from_pretrained("microsoft/prophetnet-large-uncased", add_cross_attention=False) + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> last_hidden_states = outputs.last_hidden_state + ```""" + use_cache = use_cache if use_cache is not None else self.config.use_cache + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is None and inputs_embeds is None: + raise ValueError("Either `decoder_input_ids` or `decoder_inputs_embeds` has to be passed.") + elif input_ids is not None and inputs_embeds is not None: + raise ValueError("Make sure to only pass `decoder_input_ids` or `decoder_inputs_embeds`.") + elif input_ids is not None and inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + batch_size, sequence_length = inputs_embeds.shape[:2] + + main_stream_pos_embed, position_ids = self.position_embeddings( + (batch_size, sequence_length), + device=inputs_embeds.device, + past_key_values=past_key_values, + ) + + if past_key_values is not None: + main_relative_position_buckets, predict_relative_position_buckets = None, None + else: + ( + main_relative_position_buckets, + predict_relative_position_buckets, + ) = self.compute_buffered_relative_buckets(position_ids) + predicting_stream_pos_embed = self.position_embeddings._forward(position_ids + 1) + + # add position embeddings + hidden_states = inputs_embeds + main_stream_pos_embed + + ngram_embeddings = self.ngram_embeddings.weight + + # prepare attention mask + if past_key_values is not None: + assert hidden_states.size(1) == 1, ( + "At the moment `use_cache` is only supported for `decoder_input_ids` of length 1" + ) + + ngram_hidden_states = [ + (ngram_embeddings[ngram - 1] + predicting_stream_pos_embed).repeat(batch_size, 1, 1) + for ngram in range(self.ngram) + ] + extended_attention_mask = None + extended_predict_attention_mask = None + else: + ngram_hidden_states = [ + (ngram_embeddings[ngram - 1] + predicting_stream_pos_embed) for ngram in range(self.ngram) + ] + extended_attention_mask = self.prepare_attention_mask(hidden_states, attention_mask) + extended_predict_attention_mask = self.prepare_predict_attention_mask(hidden_states, attention_mask) + + # prepare encoder attention mask + if encoder_attention_mask is not None: + extended_encoder_attention_mask = ( + 1.0 - encoder_attention_mask[:, None, None, :].repeat(1, self.config.num_decoder_attention_heads, 1, 1) + ) * torch.finfo(self.dtype).min + extended_encoder_attention_mask = extended_encoder_attention_mask.to(inputs_embeds.dtype) + else: + extended_encoder_attention_mask = None + + hidden_states = torch.cat([hidden_states] + ngram_hidden_states, 1) + + if self.embeddings_layer_norm: + hidden_states = self.embeddings_layer_norm(hidden_states) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + # init attentions, hidden_states and cache with empty tuples + all_main_stream_hidden_states = () if output_hidden_states else None + all_ngram_stream_hidden_states = () if output_hidden_states and self.config.ngram > 0 else None + + all_main_stream_attns = () if output_attentions else None + all_ngram_stream_attns = () if output_attentions else None + all_cross_attns = () if output_attentions and self.config.add_cross_attention else None + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + present_key_values = () if use_cache else None + + # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): + if attn_mask is not None: + assert attn_mask.size()[0] == (len(self.layers)), ( + f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + # grad cannot be kept because tensor is sliced + all_main_stream_hidden_states += (hidden_states[:, :sequence_length],) + if self.config.ngram > 0: + all_ngram_stream_hidden_states += (hidden_states[:, sequence_length:],) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + extended_attention_mask, + encoder_hidden_states, + extended_encoder_attention_mask, + (head_mask[idx] if head_mask is not None else None), + (cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), + extended_predict_attention_mask, + main_relative_position_buckets, + predict_relative_position_buckets, + position_ids, + None, + use_cache, + output_attentions, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=extended_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attn_mask=extended_encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=( + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None + ), + extended_predict_attention_mask=extended_predict_attention_mask, + main_relative_position_buckets=main_relative_position_buckets, + predict_relative_position_buckets=predict_relative_position_buckets, + position_ids=position_ids, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + present_key_values += (layer_outputs[4 if output_attentions else 1],) + + if output_attentions: + all_main_stream_attns += (layer_outputs[1],) + all_ngram_stream_attns += (layer_outputs[2],) + + if self.config.add_cross_attention: + all_cross_attns += (layer_outputs[3],) + + if output_hidden_states: + all_main_stream_hidden_states += (hidden_states[:, :sequence_length],) + if self.config.ngram > 0: + all_ngram_stream_hidden_states += (hidden_states[:, sequence_length:],) + + # split last_hidden_state for return + last_hidden_state = hidden_states[:, :sequence_length] + last_hidden_state_ngram = hidden_states[:, sequence_length:] if self.config.ngram > 0 else None + + if not return_dict: + return tuple( + v + for v in [ + last_hidden_state, + last_hidden_state_ngram, + present_key_values, + all_main_stream_hidden_states, + all_ngram_stream_hidden_states, + all_main_stream_attns, + all_ngram_stream_attns, + all_cross_attns, + ] + if v is not None + ) + return ProphetNetDecoderModelOutput( + last_hidden_state=last_hidden_state, + last_hidden_state_ngram=last_hidden_state_ngram, + past_key_values=present_key_values, + hidden_states=all_main_stream_hidden_states, + hidden_states_ngram=all_ngram_stream_hidden_states, + attentions=all_main_stream_attns, + ngram_attentions=all_ngram_stream_attns, + cross_attentions=all_cross_attns, + ) + + def compute_buffered_relative_buckets(self, position_ids): + batch_size, sequence_length = position_ids.shape + + position_ids = torch.arange(1, self.max_target_positions).to(position_ids.device).repeat(1, 1) + main_relative_buckets, predict_relative_buckets = compute_all_stream_relative_buckets( + self.num_buckets, self.relative_max_distance, position_ids + ) + + # buffer relative buckets + main_relative_buckets = main_relative_buckets[:, :sequence_length, :sequence_length].repeat(batch_size, 1, 1) + predict_relative_buckets = torch.cat( + [ + predict_relative_buckets[:, :sequence_length, :sequence_length], + predict_relative_buckets[ + :, :sequence_length, self.max_target_positions : self.max_target_positions + sequence_length + ], + ], + 2, + ).repeat(batch_size, 1, 1) + + return main_relative_buckets, predict_relative_buckets + + def prepare_attention_mask(self, hidden_states, attention_mask): + batch_size, seq_length = hidden_states.shape[:2] + + # get causal mask + causal_mask = torch.full( + (seq_length, seq_length), + torch.finfo(hidden_states.dtype).min, + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + causal_mask = torch.triu(causal_mask, 1) + + extended_causal_mask = causal_mask[:seq_length, :seq_length][None, None, :, :].expand( + (batch_size, self.config.num_decoder_attention_heads) + causal_mask.shape + ) + + # add usual attention mask + if attention_mask is not None: + extended_attention_mask = (1.0 - attention_mask[:, None, None, :]) * torch.finfo(self.dtype).min + extended_attention_mask = extended_causal_mask + extended_attention_mask + else: + extended_attention_mask = extended_causal_mask + return extended_attention_mask.to(hidden_states.dtype) + + def prepare_predict_attention_mask(self, hidden_states, attention_mask): + batch_size, seq_length = hidden_states.shape[:2] + + # get causal mask + predict_causal_mask = ngram_attention_bias( + self.max_target_positions, self.ngram, hidden_states.device, hidden_states.dtype + ) + predict_causal_mask = torch.cat( + [ + predict_causal_mask[:, :seq_length, :seq_length], + predict_causal_mask[ + :, :seq_length, self.max_target_positions : self.max_target_positions + seq_length + ], + ], + dim=-1, + ) + extended_predict_causal_mask = predict_causal_mask[None, None, :, :, :].expand( + (batch_size, self.config.num_decoder_attention_heads) + predict_causal_mask.shape + ) + + # add usual attention mask + if attention_mask is not None: + extended_attention_mask = (1.0 - attention_mask[:, None, None, None, :]) * torch.finfo(self.dtype).min + extended_attention_mask = extended_attention_mask.expand( + (batch_size, self.config.num_decoder_attention_heads, self.ngram, seq_length, seq_length) + ) + # predicted stream attention_mask should always be 0 + extended_attention_mask = torch.cat( + [extended_attention_mask, torch.zeros_like(extended_attention_mask)], dim=-1 + ) + extended_predict_attention_mask = extended_predict_causal_mask + extended_attention_mask + else: + extended_predict_attention_mask = extended_predict_causal_mask + return extended_predict_attention_mask.to(hidden_states.dtype) + + +@add_start_docstrings( + "The bare ProphetNet Model outputting raw hidden-states without any specific head on top.", + PROPHETNET_START_DOCSTRING, +) +class ProphetNetModel(ProphetNetPreTrainedModel): + _tied_weights_keys = ["encoder.word_embeddings.weight", "decoder.word_embeddings.weight"] + + def __init__(self, config: ProphetNetConfig): + super().__init__(config) + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + + encoder_config = copy.deepcopy(config) + encoder_config.is_encoder_decoder = False + encoder_config.use_cache = False + self.encoder = ProphetNetEncoder(encoder_config, self.word_embeddings) + + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.is_encoder_decoder = False + self.decoder = ProphetNetDecoder(decoder_config, self.word_embeddings) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.word_embeddings + + def set_input_embeddings(self, value): + self.word_embeddings = value + self.encoder.word_embeddings = self.word_embeddings + self.decoder.word_embeddings = self.word_embeddings + + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.word_embeddings, self.word_embeddings) + self._tie_or_clone_weights(self.decoder.word_embeddings, self.word_embeddings) + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + @add_start_docstrings_to_model_forward(PROPHETNET_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=ProphetNetSeq2SeqModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.Tensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + decoder_inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, ProphetNetSeq2SeqModelOutput]: + r""" + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, ProphetNetModel + + >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/prophetnet-large-uncased") + >>> model = ProphetNetModel.from_pretrained("microsoft/prophetnet-large-uncased") + + >>> input_ids = tokenizer( + ... "Studies have been shown that owning a dog is good for you", return_tensors="pt" + ... ).input_ids # Batch size 1 + >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1 + >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) + + >>> last_hidden_states = outputs.last_hidden_state # main stream hidden states + >>> last_hidden_states_ngram = outputs.last_hidden_state_ngram # predict hidden states + ```""" + use_cache = use_cache if use_cache is not None else self.config.use_cache + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + use_cache=use_cache, + return_dict=return_dict, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + return ProphetNetSeq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + last_hidden_state_ngram=decoder_outputs.last_hidden_state_ngram, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_ngram_hidden_states=decoder_outputs.hidden_states_ngram, + decoder_attentions=decoder_outputs.attentions, + decoder_ngram_attentions=decoder_outputs.ngram_attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + "The ProphetNet Model with a language modeling head. Can be used for sequence generation tasks.", + PROPHETNET_START_DOCSTRING, +) +class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["encoder.word_embeddings.weight", "decoder.word_embeddings.weight", "lm_head.weight"] + + def __init__(self, config: ProphetNetConfig): + super().__init__(config) + self.prophetnet = ProphetNetModel(config) + self.padding_idx = config.pad_token_id + self.disable_ngram_loss = config.disable_ngram_loss + + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.prophetnet.word_embeddings, self.lm_head) + + def get_input_embeddings(self): + return self.prophetnet.word_embeddings + + @add_start_docstrings_to_model_forward(PROPHETNET_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=ProphetNetSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.Tensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + decoder_inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, ProphetNetSeq2SeqLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ..., + config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for + labels in `[0, ..., config.vocab_size]` + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, ProphetNetForConditionalGeneration + + >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/prophetnet-large-uncased") + >>> model = ProphetNetForConditionalGeneration.from_pretrained("microsoft/prophetnet-large-uncased") + + >>> input_ids = tokenizer( + ... "Studies have been shown that owning a dog is good for you", return_tensors="pt" + ... ).input_ids # Batch size 1 + >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1 + >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) + + >>> logits_next_token = outputs.logits # logits to predict next token as usual + >>> logits_ngram_next_tokens = outputs.logits_ngram # logits to predict 2nd, 3rd, ... next tokens + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: + # get decoder inputs from shifting lm labels to the right + decoder_input_ids = self._shift_right(labels) + + outputs = self.prophetnet( + input_ids=input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + encoder_outputs=encoder_outputs, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + batch_size, sequence_length = ( + decoder_input_ids.shape if decoder_input_ids is not None else decoder_inputs_embeds.shape[:2] + ) + + predicting_streams = outputs[1].view(batch_size, self.config.ngram, sequence_length, -1) + predict_logits = self.lm_head(predicting_streams) + + logits = predict_logits[:, 0] + logits_ngram = predict_logits[:, 1:] if self.config.ngram > 1 else None + + # To use .view in loss computation, make sure that logits is contiguous. + if not logits.is_contiguous(): + logits = logits.contiguous() + + loss = None + if labels is not None: + loss = self._compute_loss(predict_logits, labels) + + if not return_dict: + all_logits = tuple(v for v in [logits, logits_ngram] if v is not None) + return (loss,) + all_logits + outputs[2:] if loss is not None else all_logits + outputs[2:] + else: + return ProphetNetSeq2SeqLMOutput( + loss=loss, + logits=logits, + logits_ngram=logits_ngram, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_ngram_hidden_states=outputs.decoder_ngram_hidden_states, + decoder_attentions=outputs.decoder_attentions, + decoder_ngram_attentions=outputs.decoder_ngram_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + def _compute_loss(self, logits, labels, ignore_index=-100): + expend_targets = labels.new_zeros(self.config.ngram, labels.size(0), labels.size(1)).fill_(ignore_index) + + for i in range(self.config.ngram): + if i > 0 and self.disable_ngram_loss: + break + expend_targets[i, :, :] = labels + + logits = logits.transpose(0, 1).contiguous() + lprobs = nn.functional.log_softmax( + logits.view(-1, logits.size(-1)), + dim=-1, + dtype=torch.float32, + ) + + loss = nn.functional.nll_loss(lprobs, expend_targets.view(-1), reduction="mean") + + if self.config.eps > 0.0: + smooth_loss = -lprobs.sum(dim=-1, keepdim=True) + non_masked_tokens = expend_targets.ne(ignore_index).view(-1) + smooth_loss = smooth_loss[non_masked_tokens] + smooth_loss = smooth_loss.mean() + + eps_i = self.config.eps / lprobs.size(-1) + loss = (1.0 - self.config.eps) * loss + eps_i * smooth_loss + + return loss + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return self._shift_right(labels) + + @staticmethod + # Copied from transformers.models.bart.modeling_bart.BartForConditionalGeneration._reorder_cache + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + # cached cross_attention states don't have to be reordered -> they are always the same + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) + + layer_past[2:], + ) + return reordered_past + + def get_encoder(self): + return self.prophetnet.encoder + + def get_decoder(self): + return self.prophetnet.decoder + + +@add_start_docstrings( + "The standalone decoder part of the ProphetNetModel with a lm head on top. The model can be used for causal" + " language modeling.", + PROPHETNET_START_DOCSTRING, +) +class ProphetNetForCausalLM(ProphetNetPreTrainedModel, GenerationMixin): + _tied_weights_keys = [ + "prophetnet.word_embeddings.weight", + "prophetnet.decoder.word_embeddings.weight", + "lm_head.weight", + ] + + def __init__(self, config: ProphetNetConfig): + # set config for CLM + config = copy.deepcopy(config) + config.is_decoder = True + config.is_encoder_decoder = False + super().__init__(config) + self.prophetnet = ProphetNetDecoderWrapper(config) + + self.padding_idx = config.pad_token_id + self.disable_ngram_loss = config.disable_ngram_loss + + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.prophetnet.decoder.word_embeddings + + def set_input_embeddings(self, value): + self.prophetnet.decoder.word_embeddings = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.prophetnet.decoder.word_embeddings, self.lm_head) + + def set_decoder(self, decoder): + self.prophetnet.decoder = decoder + + def get_decoder(self): + return self.prophetnet.decoder + + @add_start_docstrings_to_model_forward(PROPHETNET_STANDALONE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=ProphetNetDecoderLMOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, ProphetNetDecoderLMOutput]: + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are + ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]` + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, ProphetNetForCausalLM + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/prophetnet-large-uncased") + >>> model = ProphetNetForCausalLM.from_pretrained("microsoft/prophetnet-large-uncased") + >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder." + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> logits = outputs.logits + + >>> # Model can also be used with EncoderDecoder framework + >>> from transformers import BertTokenizer, EncoderDecoderModel, AutoTokenizer + >>> import torch + + >>> tokenizer_enc = BertTokenizer.from_pretrained("google-bert/bert-large-uncased") + >>> tokenizer_dec = AutoTokenizer.from_pretrained("microsoft/prophetnet-large-uncased") + >>> model = EncoderDecoderModel.from_encoder_decoder_pretrained( + ... "google-bert/bert-large-uncased", "microsoft/prophetnet-large-uncased" + ... ) + + >>> ARTICLE = ( + ... "the us state department said wednesday it had received no " + ... "formal word from bolivia that it was expelling the us ambassador there " + ... "but said the charges made against him are `` baseless ." + ... ) + >>> input_ids = tokenizer_enc(ARTICLE, return_tensors="pt").input_ids + >>> labels = tokenizer_dec( + ... "us rejects charges against its ambassador in bolivia", return_tensors="pt" + ... ).input_ids + >>> outputs = model(input_ids=input_ids, decoder_input_ids=labels[:, :-1], labels=labels[:, 1:]) + + >>> loss = outputs.loss + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn) + outputs = self.prophetnet.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + batch_size, sequence_length = input_ids.shape if input_ids is not None else inputs_embeds.shape[:2] + + predicting_streams = outputs[1].view(batch_size, self.config.ngram, sequence_length, -1) + predict_logits = self.lm_head(predicting_streams) + + logits = predict_logits[:, 0] + logits_ngram = predict_logits[:, 1:] if self.config.ngram > 1 else None + + loss = None + if labels is not None: + loss = self._compute_loss(predict_logits, labels) + + if not return_dict: + all_logits = tuple(v for v in [logits, logits_ngram] if v is not None) + return (loss,) + all_logits + outputs[2:] if loss is not None else all_logits + outputs[2:] + else: + return ProphetNetDecoderLMOutput( + loss=loss, + logits=logits, + logits_ngram=logits_ngram, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + hidden_states_ngram=outputs.hidden_states_ngram, + attentions=outputs.attentions, + ngram_attentions=outputs.ngram_attentions, + cross_attentions=outputs.cross_attentions, + ) + + def _compute_loss(self, logits, labels, ignore_index=-100): + expend_targets = labels.new_zeros(self.config.ngram, labels.size(0), labels.size(1)).fill_(ignore_index) + + for i in range(self.config.ngram): + if i > 0 and self.disable_ngram_loss: + break + expend_targets[i, :, :] = labels + + logits = logits.transpose(0, 1).contiguous() + lprobs = nn.functional.log_softmax( + logits.view(-1, logits.size(-1)), + dim=-1, + dtype=torch.float32, + ) + + loss = nn.functional.nll_loss(lprobs, expend_targets.view(-1), reduction="mean") + + if self.config.eps > 0.0: + smooth_loss = -lprobs.sum(dim=-1, keepdim=True) + non_masked_tokens = expend_targets.ne(ignore_index).view(-1) + smooth_loss = smooth_loss[non_masked_tokens] + smooth_loss = smooth_loss.mean() + + eps_i = self.config.eps / lprobs.size(-1) + loss = (1.0 - self.config.eps) * loss + eps_i * smooth_loss + + return loss + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + head_mask=None, + use_cache=None, + **kwargs, + ): + # Overwritten -- our tests complain if we use GenerationMixin.prepare_inputs_for_generation + + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_ids.shape) + + if past_key_values: + input_ids = input_ids[:, -1:] + # first step, decoder_cached_states are empty + return { + "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed + "attention_mask": attention_mask, + "head_mask": head_mask, + "past_key_values": past_key_values, + "use_cache": use_cache, + } + + @staticmethod + # Copied from transformers.models.bart.modeling_bart.BartForCausalLM._reorder_cache + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + + +class ProphetNetDecoderWrapper(ProphetNetPreTrainedModel): + """ + This is a wrapper class, so that [`ProphetNetForCausalLM`] can correctly be loaded from pretrained prophetnet + classes. + """ + + def __init__(self, config: ProphetNetConfig): + super().__init__(config) + + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.decoder = ProphetNetDecoder(config, word_embeddings=self.word_embeddings) + + # Initialize weights and apply final processing + self.post_init() + + def _tie_weights(self): + self._tie_or_clone_weights(self.word_embeddings, self.decoder.get_input_embeddings()) + + def forward(self, *args, **kwargs): + return self.decoder(*args, **kwargs) + + +__all__ = [ + "ProphetNetDecoder", + "ProphetNetEncoder", + "ProphetNetForCausalLM", + "ProphetNetForConditionalGeneration", + "ProphetNetModel", + "ProphetNetPreTrainedModel", +] diff --git a/docs/transformers/src/transformers/models/prophetnet/tokenization_prophetnet.py b/docs/transformers/src/transformers/models/prophetnet/tokenization_prophetnet.py new file mode 100644 index 0000000000000000000000000000000000000000..276dbb8438f7fa1ee9680f6f0e357039deac5734 --- /dev/null +++ b/docs/transformers/src/transformers/models/prophetnet/tokenization_prophetnet.py @@ -0,0 +1,507 @@ +# coding=utf-8 +# Copyright 2020 The Microsoft Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +import os +import unicodedata +from typing import Iterable, List, Optional, Tuple + +from ...tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace +from ...utils import logging + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "prophetnet.tokenizer"} + + +# Copied from transformers.models.bert.tokenization_bert.whitespace_tokenize +def whitespace_tokenize(text): + """Runs basic whitespace cleaning and splitting on a piece of text.""" + text = text.strip() + if not text: + return [] + tokens = text.split() + return tokens + + +# Copied from transformers.models.bert.tokenization_bert.BasicTokenizer +class BasicTokenizer: + """ + Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.). + + Args: + do_lower_case (`bool`, *optional*, defaults to `True`): + Whether or not to lowercase the input when tokenizing. + never_split (`Iterable`, *optional*): + Collection of tokens which will never be split during tokenization. Only has an effect when + `do_basic_tokenize=True` + tokenize_chinese_chars (`bool`, *optional*, defaults to `True`): + Whether or not to tokenize Chinese characters. + + This should likely be deactivated for Japanese (see this + [issue](https://github.com/huggingface/transformers/issues/328)). + strip_accents (`bool`, *optional*): + Whether or not to strip all accents. If this option is not specified, then it will be determined by the + value for `lowercase` (as in the original BERT). + do_split_on_punc (`bool`, *optional*, defaults to `True`): + In some instances we want to skip the basic punctuation splitting so that later tokenization can capture + the full context of the words, such as contractions. + """ + + def __init__( + self, + do_lower_case=True, + never_split=None, + tokenize_chinese_chars=True, + strip_accents=None, + do_split_on_punc=True, + ): + if never_split is None: + never_split = [] + self.do_lower_case = do_lower_case + self.never_split = set(never_split) + self.tokenize_chinese_chars = tokenize_chinese_chars + self.strip_accents = strip_accents + self.do_split_on_punc = do_split_on_punc + + def tokenize(self, text, never_split=None): + """ + Basic Tokenization of a piece of text. For sub-word tokenization, see WordPieceTokenizer. + + Args: + never_split (`List[str]`, *optional*) + Kept for backward compatibility purposes. Now implemented directly at the base class level (see + [`PreTrainedTokenizer.tokenize`]) List of token not to split. + """ + # union() returns a new set by concatenating the two sets. + never_split = self.never_split.union(set(never_split)) if never_split else self.never_split + text = self._clean_text(text) + + # This was added on November 1st, 2018 for the multilingual and Chinese + # models. This is also applied to the English models now, but it doesn't + # matter since the English models were not trained on any Chinese data + # and generally don't have any Chinese data in them (there are Chinese + # characters in the vocabulary because Wikipedia does have some Chinese + # words in the English Wikipedia.). + if self.tokenize_chinese_chars: + text = self._tokenize_chinese_chars(text) + # prevents treating the same character with different unicode codepoints as different characters + unicode_normalized_text = unicodedata.normalize("NFC", text) + orig_tokens = whitespace_tokenize(unicode_normalized_text) + split_tokens = [] + for token in orig_tokens: + if token not in never_split: + if self.do_lower_case: + token = token.lower() + if self.strip_accents is not False: + token = self._run_strip_accents(token) + elif self.strip_accents: + token = self._run_strip_accents(token) + split_tokens.extend(self._run_split_on_punc(token, never_split)) + + output_tokens = whitespace_tokenize(" ".join(split_tokens)) + return output_tokens + + def _run_strip_accents(self, text): + """Strips accents from a piece of text.""" + text = unicodedata.normalize("NFD", text) + output = [] + for char in text: + cat = unicodedata.category(char) + if cat == "Mn": + continue + output.append(char) + return "".join(output) + + def _run_split_on_punc(self, text, never_split=None): + """Splits punctuation on a piece of text.""" + if not self.do_split_on_punc or (never_split is not None and text in never_split): + return [text] + chars = list(text) + i = 0 + start_new_word = True + output = [] + while i < len(chars): + char = chars[i] + if _is_punctuation(char): + output.append([char]) + start_new_word = True + else: + if start_new_word: + output.append([]) + start_new_word = False + output[-1].append(char) + i += 1 + + return ["".join(x) for x in output] + + def _tokenize_chinese_chars(self, text): + """Adds whitespace around any CJK character.""" + output = [] + for char in text: + cp = ord(char) + if self._is_chinese_char(cp): + output.append(" ") + output.append(char) + output.append(" ") + else: + output.append(char) + return "".join(output) + + def _is_chinese_char(self, cp): + """Checks whether CP is the codepoint of a CJK character.""" + # This defines a "chinese character" as anything in the CJK Unicode block: + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + # + # Note that the CJK Unicode block is NOT all Japanese and Korean characters, + # despite its name. The modern Korean Hangul alphabet is a different block, + # as is Japanese Hiragana and Katakana. Those alphabets are used to write + # space-separated words, so they are not treated specially and handled + # like the all of the other languages. + if ( + (cp >= 0x4E00 and cp <= 0x9FFF) + or (cp >= 0x3400 and cp <= 0x4DBF) # + or (cp >= 0x20000 and cp <= 0x2A6DF) # + or (cp >= 0x2A700 and cp <= 0x2B73F) # + or (cp >= 0x2B740 and cp <= 0x2B81F) # + or (cp >= 0x2B820 and cp <= 0x2CEAF) # + or (cp >= 0xF900 and cp <= 0xFAFF) + or (cp >= 0x2F800 and cp <= 0x2FA1F) # + ): # + return True + + return False + + def _clean_text(self, text): + """Performs invalid character removal and whitespace cleanup on text.""" + output = [] + for char in text: + cp = ord(char) + if cp == 0 or cp == 0xFFFD or _is_control(char): + continue + if _is_whitespace(char): + output.append(" ") + else: + output.append(char) + return "".join(output) + + +# Copied from transformers.models.bert.tokenization_bert.WordpieceTokenizer +class WordpieceTokenizer: + """Runs WordPiece tokenization.""" + + def __init__(self, vocab, unk_token, max_input_chars_per_word=100): + self.vocab = vocab + self.unk_token = unk_token + self.max_input_chars_per_word = max_input_chars_per_word + + def tokenize(self, text): + """ + Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform + tokenization using the given vocabulary. + + For example, `input = "unaffable"` wil return as output `["un", "##aff", "##able"]`. + + Args: + text: A single token or whitespace separated tokens. This should have + already been passed through *BasicTokenizer*. + + Returns: + A list of wordpiece tokens. + """ + + output_tokens = [] + for token in whitespace_tokenize(text): + chars = list(token) + if len(chars) > self.max_input_chars_per_word: + output_tokens.append(self.unk_token) + continue + + is_bad = False + start = 0 + sub_tokens = [] + while start < len(chars): + end = len(chars) + cur_substr = None + while start < end: + substr = "".join(chars[start:end]) + if start > 0: + substr = "##" + substr + if substr in self.vocab: + cur_substr = substr + break + end -= 1 + if cur_substr is None: + is_bad = True + break + sub_tokens.append(cur_substr) + start = end + + if is_bad: + output_tokens.append(self.unk_token) + else: + output_tokens.extend(sub_tokens) + return output_tokens + + +def load_vocab(vocab_file): + """Loads a vocabulary file into a dictionary.""" + vocab = collections.OrderedDict() + with open(vocab_file, "r", encoding="utf-8") as reader: + tokens = reader.readlines() + for index, token in enumerate(tokens): + token = token.rstrip("\n") + vocab[token] = index + return vocab + + +class ProphetNetTokenizer(PreTrainedTokenizer): + r""" + Construct a ProphetNetTokenizer. Based on WordPiece. + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + File containing the vocabulary. + do_lower_case (`bool`, *optional*, defaults to `True`): + Whether or not to lowercase the input when tokenizing. + do_basic_tokenize (`bool`, *optional*, defaults to `True`): + Whether or not to do basic tokenization before WordPiece. + never_split (`Iterable`, *optional*): + Collection of tokens which will never be split during tokenization. Only has an effect when + `do_basic_tokenize=True` + unk_token (`str`, *optional*, defaults to `"[UNK]"`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + sep_token (`str`, *optional*, defaults to `"[SEP]"`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + x_sep_token (`str`, *optional*, defaults to `"[X_SEP]"`): + Special second separator token, which can be generated by [`ProphetNetForConditionalGeneration`]. It is + used to separate bullet-point like sentences in summarization, *e.g.*. + pad_token (`str`, *optional*, defaults to `"[PAD]"`): + The token used for padding, for example when batching sequences of different lengths. + mask_token (`str`, *optional*, defaults to `"[MASK]"`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + tokenize_chinese_chars (`bool`, *optional*, defaults to `True`): + Whether or not to tokenize Chinese characters. + + This should likely be deactivated for Japanese (see this + [issue](https://github.com/huggingface/transformers/issues/328)). + strip_accents (`bool`, *optional*): + Whether or not to strip all accents. If this option is not specified, then it will be determined by the + value for `lowercase` (as in the original BERT). + clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`): + Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like + extra spaces. + """ + + vocab_files_names = VOCAB_FILES_NAMES + + # first name has to correspond to main model input name + # to make sure `tokenizer.pad(...)` works correctly + # `ProphetNet` doesn't have `token_type_ids` as argument. + model_input_names: List[str] = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file: str, + do_lower_case: Optional[bool] = True, + do_basic_tokenize: Optional[bool] = True, + never_split: Optional[Iterable] = None, + unk_token: Optional[str] = "[UNK]", + sep_token: Optional[str] = "[SEP]", + x_sep_token: Optional[str] = "[X_SEP]", + pad_token: Optional[str] = "[PAD]", + mask_token: Optional[str] = "[MASK]", + tokenize_chinese_chars: Optional[bool] = True, + strip_accents: Optional[bool] = None, + clean_up_tokenization_spaces: bool = True, + **kwargs, + ): + if not os.path.isfile(vocab_file): + raise ValueError( + f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained" + " model use `tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`" + ) + self.vocab = load_vocab(vocab_file) + self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()]) + self.do_basic_tokenize = do_basic_tokenize + if do_basic_tokenize: + self.basic_tokenizer = BasicTokenizer( + do_lower_case=do_lower_case, + never_split=never_split, + tokenize_chinese_chars=tokenize_chinese_chars, + strip_accents=strip_accents, + ) + self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=str(unk_token)) + + super().__init__( + do_lower_case=do_lower_case, + do_basic_tokenize=do_basic_tokenize, + never_split=never_split, + unk_token=unk_token, + sep_token=sep_token, + x_sep_token=x_sep_token, + pad_token=pad_token, + mask_token=mask_token, + tokenize_chinese_chars=tokenize_chinese_chars, + strip_accents=strip_accents, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + **kwargs, + ) + + @property + def vocab_size(self): + return len(self.vocab) + + def get_vocab(self): + return dict(self.vocab, **self.added_tokens_encoder) + + def _tokenize(self, text): + split_tokens = [] + if self.do_basic_tokenize: + for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens): + # If the token is part of the never_split set + if token in self.basic_tokenizer.never_split: + split_tokens.append(token) + else: + split_tokens += self.wordpiece_tokenizer.tokenize(token) + else: + split_tokens = self.wordpiece_tokenizer.tokenize(text) + return split_tokens + + def _convert_token_to_id(self, token: str): + """Converts a token (str) in an id using the vocab.""" + return self.vocab.get(token, self.vocab.get(self.unk_token)) + + def _convert_id_to_token(self, index: int): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.ids_to_tokens.get(index, self.unk_token) + + def convert_tokens_to_string(self, tokens: str): + """Converts a sequence of tokens (string) in a single string.""" + out_string = " ".join(tokens).replace(" ##", "").strip() + return out_string + + def get_special_tokens_mask( + self, + token_ids_0: List[int], + token_ids_1: Optional[List[int]] = None, + already_has_special_tokens: Optional[bool] = False, + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + if token_ids_1 is None: + return ([0] * len(token_ids_0)) + [1] + return ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. A ProphetNet + sequence pair mask has the following format: + + ``` + 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 + | first sequence | second sequence | + ``` + + If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s). + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s). + """ + sep = [self.sep_token_id] + if token_ids_1 is None: + return len(token_ids_0 + sep) * [0] + return len(token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + index = 0 + if os.path.isdir(save_directory): + vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + else: + vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory + with open(vocab_file, "w", encoding="utf-8") as writer: + for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning( + f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive." + " Please check that the vocabulary is not corrupted!" + ) + index = token_index + writer.write(token + "\n") + index += 1 + return (vocab_file,) + + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A BERT sequence has the following format: + + - single sequence: `[CLS] X [SEP]` + - pair of sequences: `[CLS] A [SEP] B [SEP]` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + if token_ids_1 is None: + return token_ids_0 + [self.sep_token_id] + sep = [self.sep_token_id] + return token_ids_0 + sep + token_ids_1 + sep + + +__all__ = ["ProphetNetTokenizer"] diff --git a/docs/transformers/src/transformers/models/pvt/__init__.py b/docs/transformers/src/transformers/models/pvt/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..371478776da01819236b470d7a96717003a26d74 --- /dev/null +++ b/docs/transformers/src/transformers/models/pvt/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_pvt import * + from .image_processing_pvt import * + from .image_processing_pvt_fast import * + from .modeling_pvt import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/pvt/configuration_pvt.py b/docs/transformers/src/transformers/models/pvt/configuration_pvt.py new file mode 100644 index 0000000000000000000000000000000000000000..c97c2703ef2358a38e082aebc01a41e25cc8f5bf --- /dev/null +++ b/docs/transformers/src/transformers/models/pvt/configuration_pvt.py @@ -0,0 +1,162 @@ +# coding=utf-8 +# Copyright 2023 Authors: Wenhai Wang, Enze Xie, Xiang Li, Deng-Ping Fan, +# Kaitao Song, Ding Liang, Tong Lu, Ping Luo, Ling Shao and The HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Pvt model configuration""" + +from collections import OrderedDict +from typing import Callable, List, Mapping + +from packaging import version + +from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class PvtConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`PvtModel`]. It is used to instantiate an Pvt + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the Pvt + [Xrenya/pvt-tiny-224](https://huggingface.co/Xrenya/pvt-tiny-224) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + image_size (`int`, *optional*, defaults to 224): + The input image size + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + num_encoder_blocks (`int`, *optional*, defaults to 4): + The number of encoder blocks (i.e. stages in the Mix Transformer encoder). + depths (`List[int]`, *optional*, defaults to `[2, 2, 2, 2]`): + The number of layers in each encoder block. + sequence_reduction_ratios (`List[int]`, *optional*, defaults to `[8, 4, 2, 1]`): + Sequence reduction ratios in each encoder block. + hidden_sizes (`List[int]`, *optional*, defaults to `[64, 128, 320, 512]`): + Dimension of each of the encoder blocks. + patch_sizes (`List[int]`, *optional*, defaults to `[4, 2, 2, 2]`): + Patch size before each encoder block. + strides (`List[int]`, *optional*, defaults to `[4, 2, 2, 2]`): + Stride before each encoder block. + num_attention_heads (`List[int]`, *optional*, defaults to `[1, 2, 5, 8]`): + Number of attention heads for each attention layer in each block of the Transformer encoder. + mlp_ratios (`List[int]`, *optional*, defaults to `[8, 8, 4, 4]`): + Ratio of the size of the hidden layer compared to the size of the input layer of the Mix FFNs in the + encoder blocks. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + drop_path_rate (`float`, *optional*, defaults to 0.0): + The dropout probability for stochastic depth, used in the blocks of the Transformer encoder. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + qkv_bias (`bool`, *optional*, defaults to `True`): + Whether or not a learnable bias should be added to the queries, keys and values. + num_labels ('int', *optional*, defaults to 1000): + The number of classes. + Example: + + ```python + >>> from transformers import PvtModel, PvtConfig + + >>> # Initializing a PVT Xrenya/pvt-tiny-224 style configuration + >>> configuration = PvtConfig() + + >>> # Initializing a model from the Xrenya/pvt-tiny-224 style configuration + >>> model = PvtModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "pvt" + + def __init__( + self, + image_size: int = 224, + num_channels: int = 3, + num_encoder_blocks: int = 4, + depths: List[int] = [2, 2, 2, 2], + sequence_reduction_ratios: List[int] = [8, 4, 2, 1], + hidden_sizes: List[int] = [64, 128, 320, 512], + patch_sizes: List[int] = [4, 2, 2, 2], + strides: List[int] = [4, 2, 2, 2], + num_attention_heads: List[int] = [1, 2, 5, 8], + mlp_ratios: List[int] = [8, 8, 4, 4], + hidden_act: Mapping[str, Callable] = "gelu", + hidden_dropout_prob: float = 0.0, + attention_probs_dropout_prob: float = 0.0, + initializer_range: float = 0.02, + drop_path_rate: float = 0.0, + layer_norm_eps: float = 1e-6, + qkv_bias: bool = True, + num_labels: int = 1000, + **kwargs, + ): + super().__init__(**kwargs) + + self.image_size = image_size + self.num_channels = num_channels + self.num_encoder_blocks = num_encoder_blocks + self.depths = depths + self.sequence_reduction_ratios = sequence_reduction_ratios + self.hidden_sizes = hidden_sizes + self.patch_sizes = patch_sizes + self.strides = strides + self.mlp_ratios = mlp_ratios + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.initializer_range = initializer_range + self.drop_path_rate = drop_path_rate + self.layer_norm_eps = layer_norm_eps + self.num_labels = num_labels + self.qkv_bias = qkv_bias + + +class PvtOnnxConfig(OnnxConfig): + torch_onnx_minimum_version = version.parse("1.11") + + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict( + [ + ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}), + ] + ) + + @property + def atol_for_validation(self) -> float: + return 1e-4 + + @property + def default_onnx_opset(self) -> int: + return 12 + + +__all__ = ["PvtConfig", "PvtOnnxConfig"] diff --git a/docs/transformers/src/transformers/models/pvt/convert_pvt_to_pytorch.py b/docs/transformers/src/transformers/models/pvt/convert_pvt_to_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..633d759123f4cfed554e128b87aaa4b4e9eafcf4 --- /dev/null +++ b/docs/transformers/src/transformers/models/pvt/convert_pvt_to_pytorch.py @@ -0,0 +1,226 @@ +# coding=utf-8 +# Copyright 2023 Authors: Wenhai Wang, Enze Xie, Xiang Li, Deng-Ping Fan, +# Kaitao Song, Ding Liang, Tong Lu, Ping Luo, Ling Shao and The HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert Pvt checkpoints from the original library.""" + +import argparse +from pathlib import Path + +import requests +import torch +from PIL import Image + +from transformers import PvtConfig, PvtForImageClassification, PvtImageProcessor +from transformers.utils import logging + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + + +# here we list all keys to be renamed (original name on the left, our name on the right) +def create_rename_keys(config): + rename_keys = [] + for i in range(config.num_encoder_blocks): + # Remane embedings' paramters + rename_keys.append((f"pos_embed{i + 1}", f"pvt.encoder.patch_embeddings.{i}.position_embeddings")) + + rename_keys.append((f"patch_embed{i + 1}.proj.weight", f"pvt.encoder.patch_embeddings.{i}.projection.weight")) + rename_keys.append((f"patch_embed{i + 1}.proj.bias", f"pvt.encoder.patch_embeddings.{i}.projection.bias")) + rename_keys.append((f"patch_embed{i + 1}.norm.weight", f"pvt.encoder.patch_embeddings.{i}.layer_norm.weight")) + rename_keys.append((f"patch_embed{i + 1}.norm.bias", f"pvt.encoder.patch_embeddings.{i}.layer_norm.bias")) + + for j in range(config.depths[i]): + # Rename blocks' parameters + rename_keys.append( + (f"block{i + 1}.{j}.attn.q.weight", f"pvt.encoder.block.{i}.{j}.attention.self.query.weight") + ) + rename_keys.append( + (f"block{i + 1}.{j}.attn.q.bias", f"pvt.encoder.block.{i}.{j}.attention.self.query.bias") + ) + rename_keys.append( + (f"block{i + 1}.{j}.attn.kv.weight", f"pvt.encoder.block.{i}.{j}.attention.self.kv.weight") + ) + rename_keys.append((f"block{i + 1}.{j}.attn.kv.bias", f"pvt.encoder.block.{i}.{j}.attention.self.kv.bias")) + + if config.sequence_reduction_ratios[i] > 1: + rename_keys.append( + ( + f"block{i + 1}.{j}.attn.norm.weight", + f"pvt.encoder.block.{i}.{j}.attention.self.layer_norm.weight", + ) + ) + rename_keys.append( + (f"block{i + 1}.{j}.attn.norm.bias", f"pvt.encoder.block.{i}.{j}.attention.self.layer_norm.bias") + ) + rename_keys.append( + ( + f"block{i + 1}.{j}.attn.sr.weight", + f"pvt.encoder.block.{i}.{j}.attention.self.sequence_reduction.weight", + ) + ) + rename_keys.append( + ( + f"block{i + 1}.{j}.attn.sr.bias", + f"pvt.encoder.block.{i}.{j}.attention.self.sequence_reduction.bias", + ) + ) + + rename_keys.append( + (f"block{i + 1}.{j}.attn.proj.weight", f"pvt.encoder.block.{i}.{j}.attention.output.dense.weight") + ) + rename_keys.append( + (f"block{i + 1}.{j}.attn.proj.bias", f"pvt.encoder.block.{i}.{j}.attention.output.dense.bias") + ) + + rename_keys.append((f"block{i + 1}.{j}.norm1.weight", f"pvt.encoder.block.{i}.{j}.layer_norm_1.weight")) + rename_keys.append((f"block{i + 1}.{j}.norm1.bias", f"pvt.encoder.block.{i}.{j}.layer_norm_1.bias")) + + rename_keys.append((f"block{i + 1}.{j}.norm2.weight", f"pvt.encoder.block.{i}.{j}.layer_norm_2.weight")) + rename_keys.append((f"block{i + 1}.{j}.norm2.bias", f"pvt.encoder.block.{i}.{j}.layer_norm_2.bias")) + + rename_keys.append((f"block{i + 1}.{j}.mlp.fc1.weight", f"pvt.encoder.block.{i}.{j}.mlp.dense1.weight")) + rename_keys.append((f"block{i + 1}.{j}.mlp.fc1.bias", f"pvt.encoder.block.{i}.{j}.mlp.dense1.bias")) + rename_keys.append((f"block{i + 1}.{j}.mlp.fc2.weight", f"pvt.encoder.block.{i}.{j}.mlp.dense2.weight")) + rename_keys.append((f"block{i + 1}.{j}.mlp.fc2.bias", f"pvt.encoder.block.{i}.{j}.mlp.dense2.bias")) + + # Rename cls token + rename_keys.extend( + [ + ("cls_token", "pvt.encoder.patch_embeddings.3.cls_token"), + ] + ) + # Rename norm layer and classifier layer + rename_keys.extend( + [ + ("norm.weight", "pvt.encoder.layer_norm.weight"), + ("norm.bias", "pvt.encoder.layer_norm.bias"), + ("head.weight", "classifier.weight"), + ("head.bias", "classifier.bias"), + ] + ) + + return rename_keys + + +# we split up the matrix of each encoder layer into queries, keys and values +def read_in_k_v(state_dict, config): + # for each of the encoder blocks: + for i in range(config.num_encoder_blocks): + for j in range(config.depths[i]): + # read in weights + bias of keys and values (which is a single matrix in the original implementation) + kv_weight = state_dict.pop(f"pvt.encoder.block.{i}.{j}.attention.self.kv.weight") + kv_bias = state_dict.pop(f"pvt.encoder.block.{i}.{j}.attention.self.kv.bias") + # next, add keys and values (in that order) to the state dict + state_dict[f"pvt.encoder.block.{i}.{j}.attention.self.key.weight"] = kv_weight[: config.hidden_sizes[i], :] + state_dict[f"pvt.encoder.block.{i}.{j}.attention.self.key.bias"] = kv_bias[: config.hidden_sizes[i]] + + state_dict[f"pvt.encoder.block.{i}.{j}.attention.self.value.weight"] = kv_weight[ + config.hidden_sizes[i] :, : + ] + state_dict[f"pvt.encoder.block.{i}.{j}.attention.self.value.bias"] = kv_bias[config.hidden_sizes[i] :] + + +def rename_key(dct, old, new): + val = dct.pop(old) + dct[new] = val + + +# We will verify our results on an image of cute cats +def prepare_img(): + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + im = Image.open(requests.get(url, stream=True).raw) + return im + + +@torch.no_grad() +def convert_pvt_checkpoint(pvt_size, pvt_checkpoint, pytorch_dump_folder_path): + """ + Copy/paste/tweak model's weights to our PVT structure. + """ + + # define default Pvt configuration + if pvt_size == "tiny": + config_path = "Zetatech/pvt-tiny-224" + elif pvt_size == "small": + config_path = "Zetatech/pvt-small-224" + elif pvt_size == "medium": + config_path = "Zetatech/pvt-medium-224" + elif pvt_size == "large": + config_path = "Zetatech/pvt-large-224" + else: + raise ValueError(f"Available model's size: 'tiny', 'small', 'medium', 'large', but '{pvt_size}' was given") + config = PvtConfig(name_or_path=config_path) + # load original model from https://github.com/whai362/PVT + state_dict = torch.load(pvt_checkpoint, map_location="cpu", weights_only=True) + + rename_keys = create_rename_keys(config) + for src, dest in rename_keys: + rename_key(state_dict, src, dest) + read_in_k_v(state_dict, config) + + # load HuggingFace model + model = PvtForImageClassification(config).eval() + model.load_state_dict(state_dict) + + # Check outputs on an image, prepared by PVTFeatureExtractor + image_processor = PvtImageProcessor(size=config.image_size) + encoding = image_processor(images=prepare_img(), return_tensors="pt") + pixel_values = encoding["pixel_values"] + outputs = model(pixel_values) + logits = outputs.logits.detach().cpu() + + if pvt_size == "tiny": + expected_slice_logits = torch.tensor([-1.4192, -1.9158, -0.9702]) + elif pvt_size == "small": + expected_slice_logits = torch.tensor([0.4353, -0.1960, -0.2373]) + elif pvt_size == "medium": + expected_slice_logits = torch.tensor([-0.2914, -0.2231, 0.0321]) + elif pvt_size == "large": + expected_slice_logits = torch.tensor([0.3740, -0.7739, -0.4214]) + else: + raise ValueError(f"Available model's size: 'tiny', 'small', 'medium', 'large', but '{pvt_size}' was given") + + assert torch.allclose(logits[0, :3], expected_slice_logits, atol=1e-4) + + Path(pytorch_dump_folder_path).mkdir(exist_ok=True) + print(f"Saving model pytorch_model.bin to {pytorch_dump_folder_path}") + model.save_pretrained(pytorch_dump_folder_path) + print(f"Saving image processor to {pytorch_dump_folder_path}") + image_processor.save_pretrained(pytorch_dump_folder_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--pvt_size", + default="tiny", + type=str, + help="Size of the PVT pretrained model you'd like to convert.", + ) + parser.add_argument( + "--pvt_checkpoint", + default="pvt_tiny.pth", + type=str, + help="Checkpoint of the PVT pretrained model you'd like to convert.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory." + ) + + args = parser.parse_args() + convert_pvt_checkpoint(args.pvt_size, args.pvt_checkpoint, args.pytorch_dump_folder_path) diff --git a/docs/transformers/src/transformers/models/pvt/image_processing_pvt.py b/docs/transformers/src/transformers/models/pvt/image_processing_pvt.py new file mode 100644 index 0000000000000000000000000000000000000000..75fcba6ea14f0e46146be17deed3b401c14d8440 --- /dev/null +++ b/docs/transformers/src/transformers/models/pvt/image_processing_pvt.py @@ -0,0 +1,276 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for Pvt.""" + +from typing import Dict, List, Optional, Union + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import resize, to_channel_dimension_format +from ...image_utils import ( + IMAGENET_DEFAULT_MEAN, + IMAGENET_DEFAULT_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + infer_channel_dimension_format, + is_scaled_image, + make_list_of_images, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...utils import TensorType, filter_out_non_signature_kwargs, logging + + +logger = logging.get_logger(__name__) + + +class PvtImageProcessor(BaseImageProcessor): + r""" + Constructs a PVT image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to the specified `(size["height"], + size["width"])`. Can be overridden by the `do_resize` parameter in the `preprocess` method. + size (`dict`, *optional*, defaults to `{"height": 224, "width": 224}`): + Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess` + method. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`): + Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the + `preprocess` method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale` + parameter in the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the + `preprocess` method. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess` + method. + image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_DEFAULT_MEAN`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_DEFAULT_STD`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize: bool = True, + size: Optional[Dict[str, int]] = None, + resample: PILImageResampling = PILImageResampling.BILINEAR, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"height": 224, "width": 224} + size = get_size_dict(size) + self.do_resize = do_resize + self.do_rescale = do_rescale + self.do_normalize = do_normalize + self.size = size + self.resample = resample + self.rescale_factor = rescale_factor + self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN + self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD + + # Copied from transformers.models.vit.image_processing_vit.ViTImageProcessor.resize + def resize( + self, + image: np.ndarray, + size: Dict[str, int], + resample: PILImageResampling = PILImageResampling.BILINEAR, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize an image to `(size["height"], size["width"])`. + + Args: + image (`np.ndarray`): + Image to resize. + size (`Dict[str, int]`): + Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`): + `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`. + data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the output image. If unset, the channel dimension format of the input + image is used. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + + Returns: + `np.ndarray`: The resized image. + """ + size = get_size_dict(size) + if "height" not in size or "width" not in size: + raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}") + output_size = (size["height"], size["width"]) + return resize( + image, + size=output_size, + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + @filter_out_non_signature_kwargs() + def preprocess( + self, + images: ImageInput, + do_resize: Optional[bool] = None, + size: Dict[str, int] = None, + resample: PILImageResampling = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Dictionary in the format `{"height": h, "width": w}` specifying the size of the output image after + resizing. + resample (`PILImageResampling` filter, *optional*, defaults to `self.resample`): + `PILImageResampling` filter to use if resizing the image e.g. `PILImageResampling.BILINEAR`. Only has + an effect if `do_resize` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image values between [0 - 1]. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use if `do_normalize` is set to `True`. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use if `do_normalize` is set to `True`. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + do_resize = do_resize if do_resize is not None else self.do_resize + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + resample = resample if resample is not None else self.resample + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + + size = size if size is not None else self.size + size_dict = get_size_dict(size) + + images = make_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_resize=do_resize, + size=size, + resample=resample, + ) + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if do_rescale and is_scaled_image(images[0]): + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + if do_resize: + images = [ + self.resize(image=image, size=size_dict, resample=resample, input_data_format=input_data_format) + for image in images + ] + + if do_rescale: + images = [ + self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + for image in images + ] + + if do_normalize: + images = [ + self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) + for image in images + ] + + images = [ + to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images + ] + + data = {"pixel_values": images} + return BatchFeature(data=data, tensor_type=return_tensors) + + +__all__ = ["PvtImageProcessor"] diff --git a/docs/transformers/src/transformers/models/pvt/image_processing_pvt_fast.py b/docs/transformers/src/transformers/models/pvt/image_processing_pvt_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..a371c8b32bfad97482a46afac4fcc1a37b280175 --- /dev/null +++ b/docs/transformers/src/transformers/models/pvt/image_processing_pvt_fast.py @@ -0,0 +1,44 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for Pvt.""" + +from ...image_processing_utils_fast import ( + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, + BaseImageProcessorFast, +) +from ...image_utils import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, PILImageResampling +from ...utils import add_start_docstrings + + +@add_start_docstrings( + "Constructs a fast Pvt image processor.", + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, +) +class PvtImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.BILINEAR + image_mean = IMAGENET_DEFAULT_MEAN + image_std = IMAGENET_DEFAULT_STD + size = {"height": 224, "width": 224} + default_to_square = True + crop_size = None + do_resize = True + do_center_crop = None + do_rescale = True + do_normalize = True + do_convert_rgb = None + model_input_names = ["pixel_values"] + + +__all__ = ["PvtImageProcessorFast"] diff --git a/docs/transformers/src/transformers/models/pvt/modeling_pvt.py b/docs/transformers/src/transformers/models/pvt/modeling_pvt.py new file mode 100644 index 0000000000000000000000000000000000000000..f9eb1a3c586cee81a0dd3342f5ae8a1fa2ac8ba6 --- /dev/null +++ b/docs/transformers/src/transformers/models/pvt/modeling_pvt.py @@ -0,0 +1,671 @@ +# coding=utf-8 +# Copyright 2023 Authors: Wenhai Wang, Enze Xie, Xiang Li, Deng-Ping Fan, +# Kaitao Song, Ding Liang, Tong Lu, Ping Luo, Ling Shao and The HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch PVT model.""" + +import collections +import math +from typing import Iterable, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, +) +from .configuration_pvt import PvtConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "PvtConfig" + +_CHECKPOINT_FOR_DOC = "Zetatech/pvt-tiny-224" +_EXPECTED_OUTPUT_SHAPE = [1, 50, 512] + +_IMAGE_CLASS_CHECKPOINT = "Zetatech/pvt-tiny-224" +_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat" + + +# Copied from transformers.models.beit.modeling_beit.drop_path +def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, + however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the + layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the + argument. + """ + if drop_prob == 0.0 or not training: + return input + keep_prob = 1 - drop_prob + shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) + random_tensor.floor_() # binarize + output = input.div(keep_prob) * random_tensor + return output + + +# Copied from transformers.models.convnext.modeling_convnext.ConvNextDropPath with ConvNext->Pvt +class PvtDropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: Optional[float] = None) -> None: + super().__init__() + self.drop_prob = drop_prob + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return drop_path(hidden_states, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return "p={}".format(self.drop_prob) + + +class PvtPatchEmbeddings(nn.Module): + """ + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__( + self, + config: PvtConfig, + image_size: Union[int, Iterable[int]], + patch_size: Union[int, Iterable[int]], + stride: int, + num_channels: int, + hidden_size: int, + cls_token: bool = False, + ): + super().__init__() + self.config = config + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + + self.position_embeddings = nn.Parameter( + torch.randn(1, num_patches + 1 if cls_token else num_patches, hidden_size) + ) + self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size)) if cls_token else None + self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=stride, stride=patch_size) + self.layer_norm = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(p=config.hidden_dropout_prob) + + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + num_patches = height * width + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == self.config.image_size * self.config.image_size: + return self.position_embeddings + embeddings = embeddings.reshape(1, height, width, -1).permute(0, 3, 1, 2) + interpolated_embeddings = F.interpolate(embeddings, size=(height, width), mode="bilinear") + interpolated_embeddings = interpolated_embeddings.reshape(1, -1, height * width).permute(0, 2, 1) + return interpolated_embeddings + + def forward(self, pixel_values: torch.Tensor) -> Tuple[torch.Tensor, int, int]: + batch_size, num_channels, height, width = pixel_values.shape + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + patch_embed = self.projection(pixel_values) + *_, height, width = patch_embed.shape + patch_embed = patch_embed.flatten(2).transpose(1, 2) + embeddings = self.layer_norm(patch_embed) + if self.cls_token is not None: + cls_token = self.cls_token.expand(batch_size, -1, -1) + embeddings = torch.cat((cls_token, embeddings), dim=1) + position_embeddings = self.interpolate_pos_encoding(self.position_embeddings[:, 1:], height, width) + position_embeddings = torch.cat((self.position_embeddings[:, :1], position_embeddings), dim=1) + else: + position_embeddings = self.interpolate_pos_encoding(self.position_embeddings, height, width) + embeddings = self.dropout(embeddings + position_embeddings) + + return embeddings, height, width + + +class PvtSelfOutput(nn.Module): + def __init__(self, config: PvtConfig, hidden_size: int): + super().__init__() + self.dense = nn.Linear(hidden_size, hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +class PvtEfficientSelfAttention(nn.Module): + """Efficient self-attention mechanism with reduction of the sequence [PvT paper](https://arxiv.org/abs/2102.12122).""" + + def __init__( + self, config: PvtConfig, hidden_size: int, num_attention_heads: int, sequences_reduction_ratio: float + ): + super().__init__() + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + + if self.hidden_size % self.num_attention_heads != 0: + raise ValueError( + f"The hidden size ({self.hidden_size}) is not a multiple of the number of attention " + f"heads ({self.num_attention_heads})" + ) + + self.attention_head_size = int(self.hidden_size / self.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(self.hidden_size, self.all_head_size, bias=config.qkv_bias) + self.key = nn.Linear(self.hidden_size, self.all_head_size, bias=config.qkv_bias) + self.value = nn.Linear(self.hidden_size, self.all_head_size, bias=config.qkv_bias) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + self.sequences_reduction_ratio = sequences_reduction_ratio + if sequences_reduction_ratio > 1: + self.sequence_reduction = nn.Conv2d( + hidden_size, hidden_size, kernel_size=sequences_reduction_ratio, stride=sequences_reduction_ratio + ) + self.layer_norm = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps) + + def transpose_for_scores(self, hidden_states: int) -> torch.Tensor: + new_shape = hidden_states.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + hidden_states = hidden_states.view(new_shape) + return hidden_states.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states: torch.Tensor, + height: int, + width: int, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor]: + query_layer = self.transpose_for_scores(self.query(hidden_states)) + + if self.sequences_reduction_ratio > 1: + batch_size, seq_len, num_channels = hidden_states.shape + # Reshape to (batch_size, num_channels, height, width) + hidden_states = hidden_states.permute(0, 2, 1).reshape(batch_size, num_channels, height, width) + # Apply sequence reduction + hidden_states = self.sequence_reduction(hidden_states) + # Reshape back to (batch_size, seq_len, num_channels) + hidden_states = hidden_states.reshape(batch_size, num_channels, -1).permute(0, 2, 1) + hidden_states = self.layer_norm(hidden_states) + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +class PvtAttention(nn.Module): + def __init__( + self, config: PvtConfig, hidden_size: int, num_attention_heads: int, sequences_reduction_ratio: float + ): + super().__init__() + self.self = PvtEfficientSelfAttention( + config, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + sequences_reduction_ratio=sequences_reduction_ratio, + ) + self.output = PvtSelfOutput(config, hidden_size=hidden_size) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, hidden_states: torch.Tensor, height: int, width: int, output_attentions: bool = False + ) -> Tuple[torch.Tensor]: + self_outputs = self.self(hidden_states, height, width, output_attentions) + + attention_output = self.output(self_outputs[0]) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class PvtFFN(nn.Module): + def __init__( + self, + config: PvtConfig, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + ): + super().__init__() + out_features = out_features if out_features is not None else in_features + self.dense1 = nn.Linear(in_features, hidden_features) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + self.dense2 = nn.Linear(hidden_features, out_features) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense1(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.dense2(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +class PvtLayer(nn.Module): + def __init__( + self, + config: PvtConfig, + hidden_size: int, + num_attention_heads: int, + drop_path: float, + sequences_reduction_ratio: float, + mlp_ratio: float, + ): + super().__init__() + self.layer_norm_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps) + self.attention = PvtAttention( + config=config, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + sequences_reduction_ratio=sequences_reduction_ratio, + ) + self.drop_path = PvtDropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.layer_norm_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps) + mlp_hidden_size = int(hidden_size * mlp_ratio) + self.mlp = PvtFFN(config=config, in_features=hidden_size, hidden_features=mlp_hidden_size) + + def forward(self, hidden_states: torch.Tensor, height: int, width: int, output_attentions: bool = False): + self_attention_outputs = self.attention( + hidden_states=self.layer_norm_1(hidden_states), + height=height, + width=width, + output_attentions=output_attentions, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] + + attention_output = self.drop_path(attention_output) + hidden_states = attention_output + hidden_states + + mlp_output = self.mlp(self.layer_norm_2(hidden_states)) + + mlp_output = self.drop_path(mlp_output) + layer_output = hidden_states + mlp_output + + outputs = (layer_output,) + outputs + + return outputs + + +class PvtEncoder(nn.Module): + def __init__(self, config: PvtConfig): + super().__init__() + self.config = config + + # stochastic depth decay rule + drop_path_decays = torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu").tolist() + + # patch embeddings + embeddings = [] + + for i in range(config.num_encoder_blocks): + embeddings.append( + PvtPatchEmbeddings( + config=config, + image_size=config.image_size if i == 0 else self.config.image_size // (2 ** (i + 1)), + patch_size=config.patch_sizes[i], + stride=config.strides[i], + num_channels=config.num_channels if i == 0 else config.hidden_sizes[i - 1], + hidden_size=config.hidden_sizes[i], + cls_token=i == config.num_encoder_blocks - 1, + ) + ) + self.patch_embeddings = nn.ModuleList(embeddings) + + # Transformer blocks + blocks = [] + cur = 0 + for i in range(config.num_encoder_blocks): + # each block consists of layers + layers = [] + if i != 0: + cur += config.depths[i - 1] + for j in range(config.depths[i]): + layers.append( + PvtLayer( + config=config, + hidden_size=config.hidden_sizes[i], + num_attention_heads=config.num_attention_heads[i], + drop_path=drop_path_decays[cur + j], + sequences_reduction_ratio=config.sequence_reduction_ratios[i], + mlp_ratio=config.mlp_ratios[i], + ) + ) + blocks.append(nn.ModuleList(layers)) + + self.block = nn.ModuleList(blocks) + + # Layer norms + self.layer_norm = nn.LayerNorm(config.hidden_sizes[-1], eps=config.layer_norm_eps) + + def forward( + self, + pixel_values: torch.FloatTensor, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + ) -> Union[Tuple, BaseModelOutput]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + batch_size = pixel_values.shape[0] + num_blocks = len(self.block) + hidden_states = pixel_values + for idx, (embedding_layer, block_layer) in enumerate(zip(self.patch_embeddings, self.block)): + # first, obtain patch embeddings + hidden_states, height, width = embedding_layer(hidden_states) + # second, send embeddings through blocks + for block in block_layer: + layer_outputs = block(hidden_states, height, width, output_attentions) + hidden_states = layer_outputs[0] + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + if idx != num_blocks - 1: + hidden_states = hidden_states.reshape(batch_size, height, width, -1).permute(0, 3, 1, 2).contiguous() + hidden_states = self.layer_norm(hidden_states) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class PvtPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = PvtConfig + base_model_prefix = "pvt" + main_input_name = "pixel_values" + _no_split_modules = [] + + def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: + """Initialize the weights""" + if isinstance(module, nn.Linear): + # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid + # `trunc_normal_cpu` not implemented in `half` issues + module.weight.data = nn.init.trunc_normal_(module.weight.data, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, PvtPatchEmbeddings): + module.position_embeddings.data = nn.init.trunc_normal_( + module.position_embeddings.data, + mean=0.0, + std=self.config.initializer_range, + ) + if module.cls_token is not None: + module.cls_token.data = nn.init.trunc_normal_( + module.cls_token.data, + mean=0.0, + std=self.config.initializer_range, + ) + + +PVT_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use + it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`~PvtConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +PVT_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`PvtImageProcessor.__call__`] + for details. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare Pvt encoder outputting raw hidden-states without any specific head on top.", + PVT_START_DOCSTRING, +) +class PvtModel(PvtPreTrainedModel): + def __init__(self, config: PvtConfig): + super().__init__(config) + self.config = config + + # hierarchical Transformer encoder + self.encoder = PvtEncoder(config) + + # Initialize weights and apply final processing + self.post_init() + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(PVT_INPUTS_DOCSTRING.format("(batch_size, channels, height, width)")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutput, + config_class=_CONFIG_FOR_DOC, + modality="vision", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward( + self, + pixel_values: torch.FloatTensor, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_outputs = self.encoder( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + + if not return_dict: + return (sequence_output,) + encoder_outputs[1:] + + return BaseModelOutput( + last_hidden_state=sequence_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + """ + Pvt Model transformer with an image classification head on top (a linear layer on top of the final hidden state of + the [CLS] token) e.g. for ImageNet. + """, + PVT_START_DOCSTRING, +) +class PvtForImageClassification(PvtPreTrainedModel): + def __init__(self, config: PvtConfig) -> None: + super().__init__(config) + + self.num_labels = config.num_labels + self.pvt = PvtModel(config) + + # Classifier head + self.classifier = ( + nn.Linear(config.hidden_sizes[-1], config.num_labels) if config.num_labels > 0 else nn.Identity() + ) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(PVT_INPUTS_DOCSTRING.format("(batch_size, channels, height, width)")) + @add_code_sample_docstrings( + checkpoint=_IMAGE_CLASS_CHECKPOINT, + output_type=ImageClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, + ) + def forward( + self, + pixel_values: Optional[torch.Tensor], + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, ImageClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.pvt( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.classifier(sequence_output[:, 0, :]) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = ["PvtForImageClassification", "PvtModel", "PvtPreTrainedModel"] diff --git a/docs/transformers/src/transformers/models/pvt_v2/__init__.py b/docs/transformers/src/transformers/models/pvt_v2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e3cb83f130dd42eb08943d5d4e08ed9321e6f076 --- /dev/null +++ b/docs/transformers/src/transformers/models/pvt_v2/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_pvt_v2 import * + from .modeling_pvt_v2 import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/pvt_v2/configuration_pvt_v2.py b/docs/transformers/src/transformers/models/pvt_v2/configuration_pvt_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..9aef0e760f3c5d47bbc8b9b67274b33ce5b19b1e --- /dev/null +++ b/docs/transformers/src/transformers/models/pvt_v2/configuration_pvt_v2.py @@ -0,0 +1,156 @@ +# coding=utf-8 +# Copyright 2024 Authors: Wenhai Wang, Enze Xie, Xiang Li, Deng-Ping Fan, +# Kaitao Song, Ding Liang, Tong Lu, Ping Luo, Ling Shao and The HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Pvt V2 model configuration""" + +from typing import Callable, List, Tuple, Union + +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices + + +logger = logging.get_logger(__name__) + + +class PvtV2Config(BackboneConfigMixin, PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`PvtV2Model`]. It is used to instantiate a Pvt V2 + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the Pvt V2 B0 + [OpenGVLab/pvt_v2_b0](https://huggingface.co/OpenGVLab/pvt_v2_b0) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + image_size (`Union[int, Tuple[int, int]]`, *optional*, defaults to 224): + The input image size. Pass int value for square image, or tuple of (height, width). + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + num_encoder_blocks (`[int]`, *optional*, defaults to 4): + The number of encoder blocks (i.e. stages in the Mix Transformer encoder). + depths (`List[int]`, *optional*, defaults to `[2, 2, 2, 2]`): + The number of layers in each encoder block. + sr_ratios (`List[int]`, *optional*, defaults to `[8, 4, 2, 1]`): + Spatial reduction ratios in each encoder block. + hidden_sizes (`List[int]`, *optional*, defaults to `[32, 64, 160, 256]`): + Dimension of each of the encoder blocks. + patch_sizes (`List[int]`, *optional*, defaults to `[7, 3, 3, 3]`): + Patch size for overlapping patch embedding before each encoder block. + strides (`List[int]`, *optional*, defaults to `[4, 2, 2, 2]`): + Stride for overlapping patch embedding before each encoder block. + num_attention_heads (`List[int]`, *optional*, defaults to `[1, 2, 5, 8]`): + Number of attention heads for each attention layer in each block of the Transformer encoder. + mlp_ratios (`List[int]`, *optional*, defaults to `[8, 8, 4, 4]`): + Ratio of the size of the hidden layer compared to the size of the input layer of the Mix FFNs in the + encoder blocks. + hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + drop_path_rate (`float`, *optional*, defaults to 0.0): + The dropout probability for stochastic depth, used in the blocks of the Transformer encoder. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + qkv_bias (`bool`, *optional*, defaults to `True`): + Whether or not a learnable bias should be added to the queries, keys and values. + linear_attention (`bool`, *optional*, defaults to `False`): + Use linear attention complexity. If set to True, `sr_ratio` is ignored and average pooling is used for + dimensionality reduction in the attention layers rather than strided convolution. + out_features (`List[str]`, *optional*): + If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc. + (depending on how many stages the model has). If unset and `out_indices` is set, will default to the + corresponding stages. If unset and `out_indices` is unset, will default to the last stage. + out_indices (`List[int]`, *optional*): + If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how + many stages the model has). If unset and `out_features` is set, will default to the corresponding stages. + If unset and `out_features` is unset, will default to the last stage. + Example: + + ```python + >>> from transformers import PvtV2Model, PvtV2Config + + >>> # Initializing a pvt_v2_b0 style configuration + >>> configuration = PvtV2Config() + + >>> # Initializing a model from the OpenGVLab/pvt_v2_b0 style configuration + >>> model = PvtV2Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "pvt_v2" + + def __init__( + self, + image_size: Union[int, Tuple[int, int]] = 224, + num_channels: int = 3, + num_encoder_blocks: int = 4, + depths: List[int] = [2, 2, 2, 2], + sr_ratios: List[int] = [8, 4, 2, 1], + hidden_sizes: List[int] = [32, 64, 160, 256], + patch_sizes: List[int] = [7, 3, 3, 3], + strides: List[int] = [4, 2, 2, 2], + num_attention_heads: List[int] = [1, 2, 5, 8], + mlp_ratios: List[int] = [8, 8, 4, 4], + hidden_act: Union[str, Callable] = "gelu", + hidden_dropout_prob: float = 0.0, + attention_probs_dropout_prob: float = 0.0, + initializer_range: float = 0.02, + drop_path_rate: float = 0.0, + layer_norm_eps: float = 1e-6, + qkv_bias: bool = True, + linear_attention: bool = False, + out_features=None, + out_indices=None, + **kwargs, + ): + super().__init__(**kwargs) + + image_size = (image_size, image_size) if isinstance(image_size, int) else image_size + + self.image_size = image_size + self.num_channels = num_channels + self.num_encoder_blocks = num_encoder_blocks + self.depths = depths + self.sr_ratios = sr_ratios + self.hidden_sizes = hidden_sizes + self.patch_sizes = patch_sizes + self.strides = strides + self.mlp_ratios = mlp_ratios + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.initializer_range = initializer_range + self.drop_path_rate = drop_path_rate + self.layer_norm_eps = layer_norm_eps + self.qkv_bias = qkv_bias + self.linear_attention = linear_attention + self.stage_names = [f"stage{idx}" for idx in range(1, len(depths) + 1)] + self._out_features, self._out_indices = get_aligned_output_features_output_indices( + out_features=out_features, out_indices=out_indices, stage_names=self.stage_names + ) + + +__all__ = ["PvtV2Config"] diff --git a/docs/transformers/src/transformers/models/pvt_v2/convert_pvt_v2_to_pytorch.py b/docs/transformers/src/transformers/models/pvt_v2/convert_pvt_v2_to_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..b315d540dab3af2a332ef5ccc080405ce5812e86 --- /dev/null +++ b/docs/transformers/src/transformers/models/pvt_v2/convert_pvt_v2_to_pytorch.py @@ -0,0 +1,294 @@ +# coding=utf-8 +# Copyright 2023 Authors: Wenhai Wang, Enze Xie, Xiang Li, Deng-Ping Fan, +# Kaitao Song, Ding Liang, Tong Lu, Ping Luo, Ling Shao and The HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert PvtV2 checkpoints from the original library.""" + +import argparse +from pathlib import Path + +import requests +import torch +from PIL import Image + +from transformers import PvtImageProcessor, PvtV2Config, PvtV2ForImageClassification +from transformers.utils import logging + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + + +# here we list all keys to be renamed (original name on the left, our name on the right) +def create_rename_keys(config): + rename_keys = [] + for i in range(config.num_encoder_blocks): + # Remane embedings' paramters + rename_keys.append( + (f"patch_embed{i + 1}.proj.weight", f"pvt_v2.encoder.layers.{i}.patch_embedding.proj.weight") + ) + rename_keys.append((f"patch_embed{i + 1}.proj.bias", f"pvt_v2.encoder.layers.{i}.patch_embedding.proj.bias")) + rename_keys.append( + (f"patch_embed{i + 1}.norm.weight", f"pvt_v2.encoder.layers.{i}.patch_embedding.layer_norm.weight") + ) + rename_keys.append( + (f"patch_embed{i + 1}.norm.bias", f"pvt_v2.encoder.layers.{i}.patch_embedding.layer_norm.bias") + ) + rename_keys.append((f"norm{i + 1}.weight", f"pvt_v2.encoder.layers.{i}.layer_norm.weight")) + rename_keys.append((f"norm{i + 1}.bias", f"pvt_v2.encoder.layers.{i}.layer_norm.bias")) + + for j in range(config.depths[i]): + # Rename blocks' parameters + rename_keys.append( + (f"block{i + 1}.{j}.attn.q.weight", f"pvt_v2.encoder.layers.{i}.blocks.{j}.attention.query.weight") + ) + rename_keys.append( + (f"block{i + 1}.{j}.attn.q.bias", f"pvt_v2.encoder.layers.{i}.blocks.{j}.attention.query.bias") + ) + rename_keys.append( + (f"block{i + 1}.{j}.attn.kv.weight", f"pvt_v2.encoder.layers.{i}.blocks.{j}.attention.kv.weight") + ) + rename_keys.append( + (f"block{i + 1}.{j}.attn.kv.bias", f"pvt_v2.encoder.layers.{i}.blocks.{j}.attention.kv.bias") + ) + + if config.linear_attention or config.sr_ratios[i] > 1: + rename_keys.append( + ( + f"block{i + 1}.{j}.attn.norm.weight", + f"pvt_v2.encoder.layers.{i}.blocks.{j}.attention.layer_norm.weight", + ) + ) + rename_keys.append( + ( + f"block{i + 1}.{j}.attn.norm.bias", + f"pvt_v2.encoder.layers.{i}.blocks.{j}.attention.layer_norm.bias", + ) + ) + rename_keys.append( + ( + f"block{i + 1}.{j}.attn.sr.weight", + f"pvt_v2.encoder.layers.{i}.blocks.{j}.attention.spatial_reduction.weight", + ) + ) + rename_keys.append( + ( + f"block{i + 1}.{j}.attn.sr.bias", + f"pvt_v2.encoder.layers.{i}.blocks.{j}.attention.spatial_reduction.bias", + ) + ) + + rename_keys.append( + (f"block{i + 1}.{j}.attn.proj.weight", f"pvt_v2.encoder.layers.{i}.blocks.{j}.attention.proj.weight") + ) + rename_keys.append( + (f"block{i + 1}.{j}.attn.proj.bias", f"pvt_v2.encoder.layers.{i}.blocks.{j}.attention.proj.bias") + ) + + rename_keys.append( + (f"block{i + 1}.{j}.norm1.weight", f"pvt_v2.encoder.layers.{i}.blocks.{j}.layer_norm_1.weight") + ) + rename_keys.append( + (f"block{i + 1}.{j}.norm1.bias", f"pvt_v2.encoder.layers.{i}.blocks.{j}.layer_norm_1.bias") + ) + + rename_keys.append( + (f"block{i + 1}.{j}.norm2.weight", f"pvt_v2.encoder.layers.{i}.blocks.{j}.layer_norm_2.weight") + ) + rename_keys.append( + (f"block{i + 1}.{j}.norm2.bias", f"pvt_v2.encoder.layers.{i}.blocks.{j}.layer_norm_2.bias") + ) + + rename_keys.append( + (f"block{i + 1}.{j}.mlp.fc1.weight", f"pvt_v2.encoder.layers.{i}.blocks.{j}.mlp.dense1.weight") + ) + rename_keys.append( + (f"block{i + 1}.{j}.mlp.fc1.bias", f"pvt_v2.encoder.layers.{i}.blocks.{j}.mlp.dense1.bias") + ) + rename_keys.append( + ( + f"block{i + 1}.{j}.mlp.dwconv.dwconv.weight", + f"pvt_v2.encoder.layers.{i}.blocks.{j}.mlp.dwconv.dwconv.weight", + ) + ) + rename_keys.append( + ( + f"block{i + 1}.{j}.mlp.dwconv.dwconv.bias", + f"pvt_v2.encoder.layers.{i}.blocks.{j}.mlp.dwconv.dwconv.bias", + ) + ) + rename_keys.append( + (f"block{i + 1}.{j}.mlp.fc2.weight", f"pvt_v2.encoder.layers.{i}.blocks.{j}.mlp.dense2.weight") + ) + rename_keys.append( + (f"block{i + 1}.{j}.mlp.fc2.bias", f"pvt_v2.encoder.layers.{i}.blocks.{j}.mlp.dense2.bias") + ) + + rename_keys.extend( + [ + ("head.weight", "classifier.weight"), + ("head.bias", "classifier.bias"), + ] + ) + + return rename_keys + + +# we split up the matrix of each encoder layer into queries, keys and values +def read_in_k_v(state_dict, config): + # for each of the encoder blocks: + for i in range(config.num_encoder_blocks): + for j in range(config.depths[i]): + # read in weights + bias of keys and values (which is a single matrix in the original implementation) + kv_weight = state_dict.pop(f"pvt_v2.encoder.layers.{i}.blocks.{j}.attention.kv.weight") + kv_bias = state_dict.pop(f"pvt_v2.encoder.layers.{i}.blocks.{j}.attention.kv.bias") + # next, add keys and values (in that order) to the state dict + state_dict[f"pvt_v2.encoder.layers.{i}.blocks.{j}.attention.key.weight"] = kv_weight[ + : config.hidden_sizes[i], : + ] + state_dict[f"pvt_v2.encoder.layers.{i}.blocks.{j}.attention.key.bias"] = kv_bias[: config.hidden_sizes[i]] + + state_dict[f"pvt_v2.encoder.layers.{i}.blocks.{j}.attention.value.weight"] = kv_weight[ + config.hidden_sizes[i] :, : + ] + state_dict[f"pvt_v2.encoder.layers.{i}.blocks.{j}.attention.value.bias"] = kv_bias[ + config.hidden_sizes[i] : + ] + + +def rename_key(dct, old, new): + val = dct.pop(old) + dct[new] = val + + +# We will verify our results on an image of cute cats +def prepare_img(): + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + im = Image.open(requests.get(url, stream=True).raw) + return im + + +@torch.no_grad() +def convert_pvt_v2_checkpoint(pvt_v2_size, pvt_v2_checkpoint, pytorch_dump_folder_path, verify_imagenet_weights=False): + """ + Copy/paste/tweak model's weights to our PVT structure. + """ + + # define default PvtV2 configuration + if pvt_v2_size == "b0": + config_path = "OpenGVLab/pvt_v2_b0" + elif pvt_v2_size == "b1": + config_path = "OpenGVLab/pvt_v2_b1" + elif pvt_v2_size == "b2": + config_path = "OpenGVLab/pvt_v2_b2" + elif pvt_v2_size == "b2-linear": + config_path = "OpenGVLab/pvt_v2_b2_linear" + elif pvt_v2_size == "b3": + config_path = "OpenGVLab/pvt_v2_b3" + elif pvt_v2_size == "b4": + config_path = "OpenGVLab/pvt_v2_b4" + elif pvt_v2_size == "b5": + config_path = "OpenGVLab/pvt_v2_b5" + else: + raise ValueError( + f"Available model sizes: 'b0', 'b1', 'b2', 'b2-linear', 'b3', 'b4', 'b5', but '{pvt_v2_size}' was given" + ) + config = PvtV2Config.from_pretrained(config_path) + # load original model from https://github.com/whai362/PVT + state_dict = torch.load(pvt_v2_checkpoint, map_location="cpu", weights_only=True) + + rename_keys = create_rename_keys(config) + for src, dest in rename_keys: + rename_key(state_dict, src, dest) + read_in_k_v(state_dict, config) + + # load HuggingFace model + model = PvtV2ForImageClassification(config).eval() + model.load_state_dict(state_dict) + image_processor = PvtImageProcessor(size=config.image_size) + + if verify_imagenet_weights: + # Check outputs on an image, prepared by PvtImageProcessor + print("Verifying conversion of pretrained ImageNet weights...") + encoding = image_processor(images=prepare_img(), return_tensors="pt") + pixel_values = encoding["pixel_values"] + outputs = model(pixel_values) + logits = outputs.logits.detach().cpu() + + if pvt_v2_size == "b0": + expected_slice_logits = torch.tensor([-1.1939, -1.4547, -0.1076]) + elif pvt_v2_size == "b1": + expected_slice_logits = torch.tensor([-0.4716, -0.7335, -0.4600]) + elif pvt_v2_size == "b2": + expected_slice_logits = torch.tensor([0.0795, -0.3170, 0.2247]) + elif pvt_v2_size == "b2-linear": + expected_slice_logits = torch.tensor([0.0968, 0.3937, -0.4252]) + elif pvt_v2_size == "b3": + expected_slice_logits = torch.tensor([-0.4595, -0.2870, 0.0940]) + elif pvt_v2_size == "b4": + expected_slice_logits = torch.tensor([-0.1769, -0.1747, -0.0143]) + elif pvt_v2_size == "b5": + expected_slice_logits = torch.tensor([-0.2943, -0.1008, 0.6812]) + else: + raise ValueError( + f"Available model sizes: 'b0', 'b1', 'b2', 'b2-linear', 'b3', 'b4', 'b5', but " + f"'{pvt_v2_size}' was given" + ) + + assert torch.allclose(logits[0, :3], expected_slice_logits, atol=1e-4), ( + "ImageNet weights not converted successfully." + ) + + print("ImageNet weights verified, conversion successful.") + + Path(pytorch_dump_folder_path).mkdir(exist_ok=True) + print(f"Saving model pytorch_model.bin to {pytorch_dump_folder_path}") + model.save_pretrained(pytorch_dump_folder_path) + print(f"Saving image processor to {pytorch_dump_folder_path}") + image_processor.save_pretrained(pytorch_dump_folder_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--pvt_v2_size", + default="b0", + type=str, + help="Size of the PVTv2 pretrained model you'd like to convert.", + ) + parser.add_argument( + "--pvt_v2_checkpoint", + default="pvt_v2_b0.pth", + type=str, + help="Checkpoint of the PVTv2 pretrained model you'd like to convert.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory." + ) + parser.add_argument( + "--verify-imagenet-weights", + action="store_true", + default=False, + help="Verifies the correct conversion of author-published pretrained ImageNet weights.", + ) + + args = parser.parse_args() + convert_pvt_v2_checkpoint( + pvt_v2_size=args.pvt_v2_size, + pvt_v2_checkpoint=args.pvt_v2_checkpoint, + pytorch_dump_folder_path=args.pytorch_dump_folder_path, + verify_imagenet_weights=args.verify_imagenet_weights, + ) diff --git a/docs/transformers/src/transformers/models/pvt_v2/modeling_pvt_v2.py b/docs/transformers/src/transformers/models/pvt_v2/modeling_pvt_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..517d2edbe32381f7d811d15dab052ed16d9514c1 --- /dev/null +++ b/docs/transformers/src/transformers/models/pvt_v2/modeling_pvt_v2.py @@ -0,0 +1,703 @@ +# coding=utf-8 +# Copyright 2024 Authors: Wenhai Wang, Enze Xie, Xiang Li, Deng-Ping Fan, +# Kaitao Song, Ding Liang, Tong Lu, Ping Luo, Ling Shao and The HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch PVTv2 model.""" + +import math +from typing import Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...modeling_outputs import BackboneOutput, BaseModelOutput, ImageClassifierOutput +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from ...utils.backbone_utils import BackboneMixin +from .configuration_pvt_v2 import PvtV2Config + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "PvtV2Config" + +_CHECKPOINT_FOR_DOC = "OpenGVLab/pvt_v2_b0" +_EXPECTED_OUTPUT_SHAPE = [1, 256, 7, 7] + +_IMAGE_CLASS_CHECKPOINT = "OpenGVLab/pvt_v2_b0" +_IMAGE_CLASS_EXPECTED_OUTPUT = "LABEL_281" # ImageNet ID for "tabby, tabby cat" + + +# Copied from transformers.models.beit.modeling_beit.drop_path +def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, + however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the + layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the + argument. + """ + if drop_prob == 0.0 or not training: + return input + keep_prob = 1 - drop_prob + shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) + random_tensor.floor_() # binarize + output = input.div(keep_prob) * random_tensor + return output + + +# Copied from transformers.models.convnext.modeling_convnext.ConvNextDropPath with ConvNext->Pvt +class PvtV2DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: Optional[float] = None) -> None: + super().__init__() + self.drop_prob = drop_prob + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return drop_path(hidden_states, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return "p={}".format(self.drop_prob) + + +class PvtV2OverlapPatchEmbeddings(nn.Module): + """Image to Patch Embedding""" + + def __init__(self, config: PvtV2Config, layer_idx: int): + super().__init__() + patch_size = config.patch_sizes[layer_idx] + patch_size = (patch_size, patch_size) if isinstance(patch_size, int) else patch_size + stride = config.strides[layer_idx] + num_channels = config.num_channels if layer_idx == 0 else config.hidden_sizes[layer_idx - 1] + hidden_size = config.hidden_sizes[layer_idx] + self.patch_size = patch_size + self.proj = nn.Conv2d( + num_channels, + hidden_size, + kernel_size=patch_size, + stride=stride, + padding=(patch_size[0] // 2, patch_size[1] // 2), + ) + self.layer_norm = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps) + + def forward(self, pixel_values): + embeddings = self.proj(pixel_values) + _, _, height, width = embeddings.shape + embeddings = embeddings.flatten(2).transpose(1, 2) + embeddings = self.layer_norm(embeddings) + return embeddings, height, width + + +class PvtV2DepthWiseConv(nn.Module): + """ + Depth-wise (DW) convolution to infuse positional information using zero-padding. Depth-wise convolutions + have an equal number of groups to the number of input channels, meaning one filter per input channel. This + reduces the overall parameters and compute costs since the key purpose of this layer is position encoding. + """ + + def __init__(self, config: PvtV2Config, dim: int = 768): + super().__init__() + self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) + + def forward(self, hidden_states, height, width): + batch_size, seq_len, num_channels = hidden_states.shape + hidden_states = hidden_states.transpose(1, 2).view(batch_size, num_channels, height, width) + hidden_states = self.dwconv(hidden_states) + hidden_states = hidden_states.flatten(2).transpose(1, 2) + + return hidden_states + + +class PvtV2SelfAttention(nn.Module): + """Efficient self-attention mechanism.""" + + def __init__(self, config: PvtV2Config, hidden_size: int, num_attention_heads: int, spatial_reduction_ratio: int): + super().__init__() + self.linear_attention = config.linear_attention + self.pruned_heads = set() + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + + if self.hidden_size % self.num_attention_heads != 0: + raise ValueError( + f"The hidden size ({self.hidden_size}) is not a multiple of the number of attention " + f"heads ({self.num_attention_heads})" + ) + + self.attention_head_size = int(self.hidden_size / self.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(self.hidden_size, self.all_head_size, bias=config.qkv_bias) + self.key = nn.Linear(self.hidden_size, self.all_head_size, bias=config.qkv_bias) + self.value = nn.Linear(self.hidden_size, self.all_head_size, bias=config.qkv_bias) + self.attn_drop = nn.Dropout(config.attention_probs_dropout_prob) + self.proj = nn.Linear(self.hidden_size, self.hidden_size) + self.proj_drop = nn.Dropout(config.hidden_dropout_prob) + + self.spatial_reduction_ratio = spatial_reduction_ratio + if self.linear_attention: + self.pool = nn.AdaptiveAvgPool2d(7) + self.spatial_reduction = nn.Conv2d(self.hidden_size, self.hidden_size, kernel_size=1, stride=1) + self.layer_norm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps) + self.act = nn.GELU() + elif spatial_reduction_ratio > 1: + self.spatial_reduction = nn.Conv2d( + self.hidden_size, self.hidden_size, kernel_size=spatial_reduction_ratio, stride=spatial_reduction_ratio + ) + self.layer_norm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps) + + def transpose_for_scores(self, hidden_states) -> torch.Tensor: + new_shape = hidden_states.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + hidden_states = hidden_states.view(new_shape) + return hidden_states.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states: torch.Tensor, + height: int, + width: int, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor]: + batch_size, seq_len, num_channels = hidden_states.shape + query_layer = self.transpose_for_scores(self.query(hidden_states)) + + if self.linear_attention: + hidden_states = hidden_states.permute(0, 2, 1).reshape(batch_size, num_channels, height, width) + hidden_states = ( + self.spatial_reduction(self.pool(hidden_states)).reshape(batch_size, num_channels, -1).permute(0, 2, 1) + ) + hidden_states = self.act(self.layer_norm(hidden_states)) + elif self.spatial_reduction_ratio > 1: + hidden_states = hidden_states.permute(0, 2, 1).reshape(batch_size, num_channels, height, width) + hidden_states = ( + self.spatial_reduction(hidden_states).reshape(batch_size, num_channels, -1).permute(0, 2, 1) + ) + hidden_states = self.layer_norm(hidden_states) + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.attn_drop(attention_probs) + context_layer = (attention_probs @ value_layer).transpose(1, 2).reshape(batch_size, seq_len, num_channels) + context_layer = self.proj(context_layer) + context_layer = self.proj_drop(context_layer) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.num_attention_heads, self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.query = prune_linear_layer(self.query, index) + self.key = prune_linear_layer(self.key, index) + self.value = prune_linear_layer(self.value, index) + self.proj = prune_linear_layer(self.proj, index, dim=1) + + # Update hyper params and store pruned heads + self.num_attention_heads = self.num_attention_heads - len(heads) + self.all_head_size = self.attention_head_size * self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + +class PvtV2ConvFeedForwardNetwork(nn.Module): + def __init__( + self, + config: PvtV2Config, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + ): + super().__init__() + out_features = out_features if out_features is not None else in_features + self.dense1 = nn.Linear(in_features, hidden_features) + self.dwconv = PvtV2DepthWiseConv(config, hidden_features) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + self.dense2 = nn.Linear(hidden_features, out_features) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.relu = nn.ReLU() if config.linear_attention else nn.Identity() + + def forward(self, hidden_states: torch.Tensor, height, width) -> torch.Tensor: + hidden_states = self.dense1(hidden_states) + hidden_states = self.relu(hidden_states) + hidden_states = self.dwconv(hidden_states, height, width) + hidden_states = self.intermediate_act_fn(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.dense2(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +class PvtV2BlockLayer(nn.Module): + def __init__(self, config: PvtV2Config, layer_idx: int, drop_path: float = 0.0): + super().__init__() + hidden_size: int = config.hidden_sizes[layer_idx] + num_attention_heads: int = config.num_attention_heads[layer_idx] + spatial_reduction_ratio: int = config.sr_ratios[layer_idx] + mlp_ratio: float = config.mlp_ratios[layer_idx] + self.layer_norm_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps) + self.attention = PvtV2SelfAttention( + config=config, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + spatial_reduction_ratio=spatial_reduction_ratio, + ) + self.drop_path = PvtV2DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.layer_norm_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps) + mlp_hidden_size = int(hidden_size * mlp_ratio) + self.mlp = PvtV2ConvFeedForwardNetwork(config=config, in_features=hidden_size, hidden_features=mlp_hidden_size) + + def forward(self, hidden_states: torch.Tensor, height: int, width: int, output_attentions: bool = False): + self_attention_outputs = self.attention( + hidden_states=self.layer_norm_1(hidden_states), + height=height, + width=width, + output_attentions=output_attentions, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] + + attention_output = self.drop_path(attention_output) + hidden_states = attention_output + hidden_states + + mlp_output = self.mlp(self.layer_norm_2(hidden_states), height, width) + + mlp_output = self.drop_path(mlp_output) + layer_output = hidden_states + mlp_output + + outputs = (layer_output,) + outputs + + return outputs + + +class PvtV2EncoderLayer(nn.Module): + def __init__(self, config: PvtV2Config, layer_idx: int): + super().__init__() + self.patch_embedding = PvtV2OverlapPatchEmbeddings( + config=config, + layer_idx=layer_idx, + ) + # Transformer block + # stochastic depth decay rule + drop_path_decays = torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu").tolist() + block_layers = [] + for block_idx in range(config.depths[layer_idx]): + block_layers.append( + PvtV2BlockLayer( + config=config, + layer_idx=layer_idx, + drop_path=drop_path_decays[sum(config.depths[:layer_idx]) + block_idx], + ) + ) + self.blocks = nn.ModuleList(block_layers) + + # Layer norm + self.layer_norm = nn.LayerNorm(config.hidden_sizes[layer_idx], eps=config.layer_norm_eps) + + def forward(self, hidden_states, output_attentions): + all_self_attentions = () if output_attentions else None + # first, obtain patch embeddings + hidden_states, height, width = self.patch_embedding(hidden_states) + # second, send embeddings through blocks + for block in self.blocks: + layer_outputs = block(hidden_states, height, width, output_attentions) + hidden_states = layer_outputs[0] + if output_attentions: + all_self_attentions += (layer_outputs[1],) + # third, apply layer norm + hidden_states = self.layer_norm(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (all_self_attentions,) + + return outputs, height, width + + +class PvtV2Encoder(nn.Module): + def __init__(self, config: PvtV2Config): + super().__init__() + self.config = config + self.gradient_checkpointing = False + + # encoder layers + self.layers = nn.ModuleList([PvtV2EncoderLayer(config, i) for i in range(config.num_encoder_blocks)]) + + def forward( + self, + pixel_values: torch.FloatTensor, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + ) -> Union[Tuple, BaseModelOutput]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + batch_size = pixel_values.shape[0] + hidden_states = pixel_values + for idx, layer in enumerate(self.layers): + if self.gradient_checkpointing and self.training: + layer_output = self._gradient_checkpointing_func(layer.__call__, hidden_states, output_attentions) + else: + layer_output = layer(hidden_states, output_attentions) + outputs, height, width = layer_output + hidden_states = outputs[0] + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[1],) + # reshape back to (batch_size, num_channels, height, width) + hidden_states = hidden_states.reshape(batch_size, height, width, -1).permute(0, 3, 1, 2).contiguous() + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class PvtV2PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = PvtV2Config + base_model_prefix = "pvt_v2" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + + def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: + """Initialize the weights""" + if isinstance(module, nn.Linear): + # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid + # `trunc_normal_cpu` not implemented in `half` issues + module.weight.data = nn.init.trunc_normal_(module.weight.data, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, nn.Conv2d): + fan_out = module.kernel_size[0] * module.kernel_size[1] * module.out_channels + fan_out //= module.groups + module.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if module.bias is not None: + module.bias.data.zero_() + + +PVT_V2_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use + it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`~PvtV2Config`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +PVT_V2_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See + [`PvtImageProcessor.__call__`] for details. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare Pvt-v2 encoder outputting raw hidden-states without any specific head on top.", + PVT_V2_START_DOCSTRING, +) +class PvtV2Model(PvtV2PreTrainedModel): + def __init__(self, config: PvtV2Config): + super().__init__(config) + self.config = config + + # hierarchical Transformer encoder + self.encoder = PvtV2Encoder(config) + + # Initialize weights and apply final processing + self.post_init() + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(PVT_V2_INPUTS_DOCSTRING.format("(batch_size, channels, height, width)")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutput, + config_class=_CONFIG_FOR_DOC, + modality="vision", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward( + self, + pixel_values: torch.FloatTensor, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_outputs = self.encoder( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + + if not return_dict: + return (sequence_output,) + encoder_outputs[1:] + + return BaseModelOutput( + last_hidden_state=sequence_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + """ + Pvt-v2 Model transformer with an image classification head on top (a linear layer on top of the final hidden state + of the [CLS] token) e.g. for ImageNet. + """, + PVT_V2_START_DOCSTRING, +) +class PvtV2ForImageClassification(PvtV2PreTrainedModel): + def __init__(self, config: PvtV2Config) -> None: + super().__init__(config) + + self.num_labels = config.num_labels + self.pvt_v2 = PvtV2Model(config) + + # Classifier head + self.classifier = ( + nn.Linear(config.hidden_sizes[-1], config.num_labels) if config.num_labels > 0 else nn.Identity() + ) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(PVT_V2_INPUTS_DOCSTRING.format("(batch_size, channels, height, width)")) + @add_code_sample_docstrings( + checkpoint=_IMAGE_CLASS_CHECKPOINT, + output_type=ImageClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, + ) + def forward( + self, + pixel_values: Optional[torch.Tensor], + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, ImageClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.pvt_v2( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + # convert last hidden states to (batch_size, height*width, hidden_size) + batch_size = sequence_output.shape[0] + # (batch_size, num_channels, height, width) -> (batch_size, height, width, num_channels) + sequence_output = sequence_output.permute(0, 2, 3, 1) + sequence_output = sequence_output.reshape(batch_size, -1, self.config.hidden_sizes[-1]) + + # global average pooling + sequence_output = sequence_output.mean(dim=1) + + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + PVTv2 backbone, to be used with frameworks like DETR and MaskFormer. + """, + PVT_V2_START_DOCSTRING, +) +class PvtV2Backbone(PvtV2Model, BackboneMixin): + def __init__(self, config: PvtV2Config): + super().__init__(config) + super()._init_backbone(config) + self.num_features = config.hidden_sizes + + @add_start_docstrings_to_model_forward(PVT_V2_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: torch.FloatTensor, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> BackboneOutput: + """ + Returns: + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, AutoBackbone + >>> import torch + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> processor = AutoImageProcessor.from_pretrained("OpenGVLab/pvt_v2_b0") + >>> model = AutoBackbone.from_pretrained( + ... "OpenGVLab/pvt_v2_b0", out_features=["stage1", "stage2", "stage3", "stage4"] + ... ) + + >>> inputs = processor(image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> feature_maps = outputs.feature_maps + >>> list(feature_maps[-1].shape) + [1, 256, 7, 7] + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + outputs = self.encoder( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=True, + return_dict=return_dict, + ) + + hidden_states = outputs.hidden_states + + feature_maps = () + for idx, stage in enumerate(self.stage_names): + if stage in self.out_features: + feature_maps += (hidden_states[idx],) + + if not return_dict: + output = (feature_maps,) + if output_hidden_states: + output += (outputs.hidden_states,) + return output + + return BackboneOutput( + feature_maps=feature_maps, + hidden_states=outputs.hidden_states if output_hidden_states else None, + attentions=None, + ) + + +__all__ = ["PvtV2ForImageClassification", "PvtV2Model", "PvtV2PreTrainedModel", "PvtV2Backbone"] diff --git a/docs/transformers/src/transformers/models/qwen2/__init__.py b/docs/transformers/src/transformers/models/qwen2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e447a94c54c86d07eee6c534a6860a2a5a6d0c3e --- /dev/null +++ b/docs/transformers/src/transformers/models/qwen2/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2024 The Qwen Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_qwen2 import * + from .modeling_qwen2 import * + from .tokenization_qwen2 import * + from .tokenization_qwen2_fast import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/qwen2/configuration_qwen2.py b/docs/transformers/src/transformers/models/qwen2/configuration_qwen2.py new file mode 100644 index 0000000000000000000000000000000000000000..fde5bd502b549c0624d7b675d35909a7cf06dc94 --- /dev/null +++ b/docs/transformers/src/transformers/models/qwen2/configuration_qwen2.py @@ -0,0 +1,204 @@ +# coding=utf-8 +# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Qwen2 model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...modeling_rope_utils import rope_config_validation +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class Qwen2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen2Model`]. It is used to instantiate a + Qwen2 model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of + Qwen2-7B-beta [Qwen/Qwen2-7B-beta](https://huggingface.co/Qwen/Qwen2-7B-beta). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 151936): + Vocabulary size of the Qwen2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Qwen2Model`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 22016): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 32): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 32768): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + use_sliding_window (`bool`, *optional*, defaults to `False`): + Whether to use sliding window attention. + sliding_window (`int`, *optional*, defaults to 4096): + Sliding window attention (SWA) window size. If not specified, will default to `4096`. + max_window_layers (`int`, *optional*, defaults to 28): + The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + + ```python + >>> from transformers import Qwen2Model, Qwen2Config + + >>> # Initializing a Qwen2 style configuration + >>> configuration = Qwen2Config() + + >>> # Initializing a model from the Qwen2-7B style configuration + >>> model = Qwen2Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen2" + keys_to_ignore_at_inference = ["past_key_values"] + + # Default tensor parallel plan for base model `Qwen2` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=151936, + hidden_size=4096, + intermediate_size=22016, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=32, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + use_sliding_window=False, + sliding_window=4096, + max_window_layers=28, + attention_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.use_sliding_window = use_sliding_window + self.sliding_window = sliding_window # we check `use_sliding_window` in the modeling code + self.max_window_layers = max_window_layers + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_dropout = attention_dropout + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, move it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + + super().__init__( + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +__all__ = ["Qwen2Config"] diff --git a/docs/transformers/src/transformers/models/qwen2/modeling_qwen2.py b/docs/transformers/src/transformers/models/qwen2/modeling_qwen2.py new file mode 100644 index 0000000000000000000000000000000000000000..46fb6326ada90bd0dac50e0bc57d505cb8f92520 --- /dev/null +++ b/docs/transformers/src/transformers/models/qwen2/modeling_qwen2.py @@ -0,0 +1,1126 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/qwen2/modular_qwen2.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_qwen2.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +from typing import Callable, Optional, Tuple, Union + +import torch +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ( + LossKwargs, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + can_return_tuple, + is_torch_flex_attn_available, + logging, + replace_return_docstrings, +) +from .configuration_qwen2 import Qwen2Config + + +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "meta-qwen2/Qwen2-2-7b-hf" +_CONFIG_FOR_DOC = "Qwen2Config" + + +class Qwen2MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class Qwen2Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Qwen2Config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + sliding_window = None + if ( + self.config.use_sliding_window + and getattr(self.config, "sliding_window", None) is not None + and self.layer_idx >= self.config.max_window_layers + ): + sliding_window = self.config.sliding_window + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=sliding_window, # main diff with Llama + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +@use_kernel_forward_from_hub("RMSNorm") +class Qwen2RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Qwen2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class Qwen2DecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Qwen2Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = Qwen2Attention(config=config, layer_idx=layer_idx) + self.mlp = Qwen2MLP(config) + self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + if config.sliding_window and config._attn_implementation != "flash_attention_2": + logger.warning_once( + f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " + "unexpected results may be encountered." + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +QWEN2_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Qwen2Config`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.", + QWEN2_START_DOCSTRING, +) +class Qwen2PreTrainedModel(PreTrainedModel): + config_class = Qwen2Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Qwen2DecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + _supports_attention_backend = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, Qwen2RMSNorm): + module.weight.data.fill_(1.0) + + +class Qwen2RotaryEmbedding(nn.Module): + def __init__(self, config: Qwen2Config, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +QWEN2_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length) or `BlockMask`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + If the model is configured to use flex_attention, it will attempt to convert the mask Tensor into a BlockMask, + but you can also pass a `BlockMask` object directly here. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.", + QWEN2_START_DOCSTRING, +) +class Qwen2Model(Qwen2PreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`] + + Args: + config: Qwen2Config + """ + + def __init__(self, config: Qwen2Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen2RotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @can_return_tuple + @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> BaseModelOutputWithPast: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache + if not isinstance(past_key_values, (type(None), Cache)): + raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, "BlockMask"], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and past_key_values is not None: + is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Qwen2. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if ( + self.config._attn_implementation == "sdpa" + and not (using_static_cache or using_sliding_window_cache) + and not output_attentions + ): + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + # SlidingWindowCache or StaticCache + if using_sliding_window_cache or using_static_cache: + target_length = past_key_values.get_max_cache_shape() + # DynamicCache or no cache + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + config=self.config, + past_key_values=past_key_values, + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + config: Qwen2Config, + past_key_values: Cache, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to place the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + config (`Qwen2Config`): + The model's configuration class + past_key_values (`Cache`): + The cache class that is being used currently to generate + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + if config.get_text_config().sliding_window is not None: + # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also + # the check is needed to verify is current checkpoint was trained with sliding window or not + if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: + sliding_attend_mask = torch.arange(target_length, device=device) <= ( + cache_position.reshape(-1, 1) - config.get_text_config().sliding_window + ) + diagonal_attend_mask.bitwise_or_(sliding_attend_mask) + causal_mask *= diagonal_attend_mask + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.shape[-1] > target_length: + attention_mask = attention_mask[:, :target_length] + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + return causal_mask + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + +class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = Qwen2Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @can_return_tuple + @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[KwargsForCausalLM], + ) -> CausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, Qwen2ForCausalLM + + >>> model = Qwen2ForCausalLM.from_pretrained("meta-qwen2/Qwen2-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-qwen2/Qwen2-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + The Qwen2 Model transformer with a sequence classification head on top (linear layer). + + [`Qwen2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + QWEN2_START_DOCSTRING, +) +class Qwen2ForSequenceClassification(Qwen2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = Qwen2Model(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @can_return_tuple + @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> SequenceClassifierOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + transformer_outputs: BaseModelOutputWithPast = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + hidden_states = transformer_outputs.last_hidden_state + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) + last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) + else: + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ + The Qwen2 Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + QWEN2_START_DOCSTRING, +) +class Qwen2ForTokenClassification(Qwen2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = Qwen2Model(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @can_return_tuple + @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> TokenClassifierOutput: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + outputs: BaseModelOutputWithPast = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + sequence_output = outputs.last_hidden_state + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.config) + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ +The Qwen2 Model transformer with a span classification head on top for extractive question-answering tasks like +SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + QWEN2_START_DOCSTRING, +) +class Qwen2ForQuestionAnswering(Qwen2PreTrainedModel): + base_model_prefix = "transformer" + + def __init__(self, config): + super().__init__(config) + self.transformer = Qwen2Model(config) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.transformer.embed_tokens + + def set_input_embeddings(self, value): + self.transformer.embed_tokens = value + + @can_return_tuple + @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + **kwargs, + ) -> QuestionAnsweringModelOutput: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + + outputs: BaseModelOutputWithPast = self.transformer( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + sequence_output = outputs.last_hidden_state + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + loss = None + if start_positions is not None and end_positions is not None: + loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs) + + return QuestionAnsweringModelOutput( + loss=loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "Qwen2PreTrainedModel", + "Qwen2Model", + "Qwen2ForCausalLM", + "Qwen2ForSequenceClassification", + "Qwen2ForTokenClassification", + "Qwen2ForQuestionAnswering", +] diff --git a/docs/transformers/src/transformers/models/qwen2/modular_qwen2.py b/docs/transformers/src/transformers/models/qwen2/modular_qwen2.py new file mode 100644 index 0000000000000000000000000000000000000000..afb3d1cf35e5e4cb9377bd1c9633f905bf63274c --- /dev/null +++ b/docs/transformers/src/transformers/models/qwen2/modular_qwen2.py @@ -0,0 +1,149 @@ +from typing import Callable, Optional, Tuple + +import torch +import torch.utils.checkpoint +from torch import nn + +from ...cache_utils import Cache +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...processing_utils import Unpack +from ...utils import logging +from ..llama.modeling_llama import ( + LlamaAttention, + LlamaDecoderLayer, + LlamaForCausalLM, + LlamaForQuestionAnswering, + LlamaForSequenceClassification, + LlamaForTokenClassification, + LlamaMLP, + LlamaPreTrainedModel, + apply_rotary_pos_emb, + eager_attention_forward, +) +from ..mistral.modeling_mistral import MistralModel +from .configuration_qwen2 import Qwen2Config + + +logger = logging.get_logger(__name__) + + +class Qwen2MLP(LlamaMLP): + def __init__(self, config): + super().__init__(config) + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + + +class Qwen2Attention(LlamaAttention): + def __init__(self, config: Qwen2Config, layer_idx: int): + super().__init__(config, layer_idx) + self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + sliding_window = None + if ( + self.config.use_sliding_window + and getattr(self.config, "sliding_window", None) is not None + and self.layer_idx >= self.config.max_window_layers + ): + sliding_window = self.config.sliding_window + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=sliding_window, # main diff with Llama + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Qwen2DecoderLayer(LlamaDecoderLayer): + def __init__(self, config: Qwen2Config, layer_idx: int): + super().__init__() + self.self_attn = Qwen2Attention(config=config, layer_idx=layer_idx) + self.mlp = Qwen2MLP(config) + if config.sliding_window and config._attn_implementation != "flash_attention_2": + logger.warning_once( + f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " + "unexpected results may be encountered." + ) + + +class Qwen2PreTrainedModel(LlamaPreTrainedModel): + pass + + +class Qwen2Model(MistralModel): + pass + + +class Qwen2ForCausalLM(LlamaForCausalLM): + pass + + +class Qwen2ForSequenceClassification(LlamaForSequenceClassification): + pass + + +class Qwen2ForTokenClassification(LlamaForTokenClassification): + pass + + +class Qwen2ForQuestionAnswering(LlamaForQuestionAnswering): + pass + + +__all__ = [ + "Qwen2PreTrainedModel", + "Qwen2Model", + "Qwen2ForCausalLM", + "Qwen2ForSequenceClassification", + "Qwen2ForTokenClassification", + "Qwen2ForQuestionAnswering", +] diff --git a/docs/transformers/src/transformers/models/qwen2/tokenization_qwen2.py b/docs/transformers/src/transformers/models/qwen2/tokenization_qwen2.py new file mode 100644 index 0000000000000000000000000000000000000000..c388789b728ff06172b5297f8238f2519d2eabca --- /dev/null +++ b/docs/transformers/src/transformers/models/qwen2/tokenization_qwen2.py @@ -0,0 +1,342 @@ +# coding=utf-8 +# Copyright 2024 The Qwen team, Alibaba Group and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for Qwen2.""" + +import json +import os +import unicodedata +from functools import lru_cache +from typing import Optional, Tuple + +import regex as re + +from ...tokenization_utils import AddedToken, PreTrainedTokenizer +from ...utils import logging + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = { + "vocab_file": "vocab.json", + "merges_file": "merges.txt", +} + + +MAX_MODEL_INPUT_SIZES = {"qwen/qwen-tokenizer": 32768} + +PRETOKENIZE_REGEX = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""" + + +@lru_cache() +# Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control + characters the bpe code barfs on. + + The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab + if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for + decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup + tables between utf-8 bytes and unicode strings. + """ + bs = ( + list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +# Copied from transformers.models.gpt2.tokenization_gpt2.get_pairs +def get_pairs(word): + """ + Return set of symbol pairs in a word. + + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +class Qwen2Tokenizer(PreTrainedTokenizer): + """ + Construct a Qwen2 tokenizer. Based on byte-level Byte-Pair-Encoding. + + Same with GPT2Tokenizer, this tokenizer has been trained to treat spaces like parts of the tokens so a word will + be encoded differently whether it is at the beginning of the sentence (without space) or not: + + ```python + >>> from transformers import Qwen2Tokenizer + + >>> tokenizer = Qwen2Tokenizer.from_pretrained("Qwen/Qwen-tokenizer") + >>> tokenizer("Hello world")["input_ids"] + [9707, 1879] + + >>> tokenizer(" Hello world")["input_ids"] + [21927, 1879] + ``` + This is expected. + + You should not use GPT2Tokenizer instead, because of the different pretokenization rules. + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + merges_file (`str`): + Path to the merges file. + errors (`str`, *optional*, defaults to `"replace"`): + Paradigm to follow when decoding bytes to UTF-8. See + [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information. + unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + bos_token (`str`, *optional*): + The beginning of sequence token. Not applicable for this tokenizer. + eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`): + The end of sequence token. + pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`): + The token used for padding, for example when batching sequences of different lengths. + clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`): + Whether or not the model should cleanup the spaces that were added when splitting the input text during the + tokenization process. Not applicable to this tokenizer, since tokenization does not add spaces. + split_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the special tokens should be split during the tokenization process. The default behavior is + to not split special tokens. This means that if `<|endoftext|>` is the `eos_token`, then `tokenizer.tokenize("<|endoftext|>") = + ['<|endoftext|>`]. Otherwise, if `split_special_tokens=True`, then `tokenizer.tokenize("<|endoftext|>")` will be give `['<', + '|', 'endo', 'ft', 'ext', '|', '>']`. This argument is only supported for `slow` tokenizers for the moment. + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + merges_file, + errors="replace", + unk_token="<|endoftext|>", + bos_token=None, + eos_token="<|endoftext|>", + pad_token="<|endoftext|>", + clean_up_tokenization_spaces=False, + split_special_tokens=False, + **kwargs, + ): + # Qwen vocab does not contain control tokens; added tokens need to be special + bos_token = ( + AddedToken(bos_token, lstrip=False, rstrip=False, special=True, normalized=False) + if isinstance(bos_token, str) + else bos_token + ) + eos_token = ( + AddedToken(eos_token, lstrip=False, rstrip=False, special=True, normalized=False) + if isinstance(eos_token, str) + else eos_token + ) + unk_token = ( + AddedToken(unk_token, lstrip=False, rstrip=False, special=True, normalized=False) + if isinstance(unk_token, str) + else unk_token + ) + pad_token = ( + AddedToken(pad_token, lstrip=False, rstrip=False, special=True, normalized=False) + if isinstance(pad_token, str) + else pad_token + ) + + with open(vocab_file, encoding="utf-8") as vocab_handle: + self.encoder = json.load(vocab_handle) + self.decoder = {v: k for k, v in self.encoder.items()} + self.errors = errors # how to handle errors in decoding + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + bpe_merges = [] + with open(merges_file, encoding="utf-8") as merges_handle: + for i, line in enumerate(merges_handle): + line = line.strip() + if (i == 0 and line.startswith("#version:")) or not line: + continue + bpe_merges.append(tuple(line.split())) + self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) + # NOTE: the cache can grow without bound and will get really large for long running processes + # (esp. for texts of language that do not use space between word, e.g. Chinese); technically + # not a memory leak but appears as one. + # GPT2Tokenizer has the same problem, so let's be consistent. + self.cache = {} + + self.pat = re.compile(PRETOKENIZE_REGEX) + + if kwargs.get("add_prefix_space", False): + logger.warning_once( + f"{self.__class__.__name} does not support `add_prefix_space`, setting it to True has no effect." + ) + + super().__init__( + errors=errors, + bos_token=bos_token, + eos_token=eos_token, + pad_token=pad_token, + unk_token=unk_token, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + split_special_tokens=split_special_tokens, + **kwargs, + ) + + @property + def vocab_size(self) -> int: + return len(self.encoder) + + # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.get_vocab + def get_vocab(self): + return dict(self.encoder, **self.added_tokens_encoder) + + # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.bpe + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token) + pairs = get_pairs(word) + + if not pairs: + return token + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + except ValueError: + new_word.extend(word[i:]) + break + else: + new_word.extend(word[i:j]) + i = j + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = " ".join(word) + self.cache[token] = word + return word + + # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._tokenize + def _tokenize(self, text): + """Tokenize a string.""" + bpe_tokens = [] + for token in re.findall(self.pat, text): + token = "".join( + self.byte_encoder[b] for b in token.encode("utf-8") + ) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case) + bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" ")) + return bpe_tokens + + # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_token_to_id + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.encoder.get(token, self.encoder.get(self.unk_token)) + + # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_id_to_token + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.decoder.get(index) + + # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.convert_tokens_to_string + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + text = "".join(tokens) + text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors) + return text + + def decode( + self, + token_ids, + skip_special_tokens: bool = False, + clean_up_tokenization_spaces: Optional[bool] = False, + spaces_between_special_tokens: bool = False, + **kwargs, + ) -> str: + # `spaces_between_special_tokens` defaults to True for _decode in slow tokenizers + # and cannot be configured elsewhere, but it should default to False for Qwen2Tokenizer + return super().decode( + token_ids, + skip_special_tokens=skip_special_tokens, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + spaces_between_special_tokens=spaces_between_special_tokens, + **kwargs, + ) + + # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.save_vocabulary + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + merge_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"] + ) + + with open(vocab_file, "w", encoding="utf-8") as f: + f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n") + + index = 0 + with open(merge_file, "w", encoding="utf-8") as writer: + writer.write("#version: 0.2\n") + for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning( + f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive." + " Please check that the tokenizer is not corrupted!" + ) + index = token_index + writer.write(" ".join(bpe_tokens) + "\n") + index += 1 + + return vocab_file, merge_file + + def prepare_for_tokenization(self, text, **kwargs): + text = unicodedata.normalize("NFC", text) + return (text, kwargs) + + +__all__ = ["Qwen2Tokenizer"] diff --git a/docs/transformers/src/transformers/models/qwen2/tokenization_qwen2_fast.py b/docs/transformers/src/transformers/models/qwen2/tokenization_qwen2_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..b7312755ef580f157c024f9cf2b79c0ce6603077 --- /dev/null +++ b/docs/transformers/src/transformers/models/qwen2/tokenization_qwen2_fast.py @@ -0,0 +1,137 @@ +# coding=utf-8 +# Copyright 2024 The Qwen team, Alibaba Group and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for Qwen2.""" + +from typing import Optional, Tuple + +from ...tokenization_utils import AddedToken +from ...tokenization_utils_fast import PreTrainedTokenizerFast +from ...utils import logging +from .tokenization_qwen2 import Qwen2Tokenizer + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = { + "vocab_file": "vocab.json", + "merges_file": "merges.txt", + "tokenizer_file": "tokenizer.json", +} + + +MAX_MODEL_INPUT_SIZES = {"qwen/qwen-tokenizer": 32768} + + +class Qwen2TokenizerFast(PreTrainedTokenizerFast): + """ + Construct a "fast" Qwen2 tokenizer (backed by HuggingFace's *tokenizers* library). Based on byte-level + Byte-Pair-Encoding. + + Same with GPT2Tokenizer, this tokenizer has been trained to treat spaces like parts of the tokens so a word will + be encoded differently whether it is at the beginning of the sentence (without space) or not: + + ```python + >>> from transformers import Qwen2TokenizerFast + + >>> tokenizer = Qwen2TokenizerFast.from_pretrained("Qwen/Qwen-tokenizer") + >>> tokenizer("Hello world")["input_ids"] + [9707, 1879] + + >>> tokenizer(" Hello world")["input_ids"] + [21927, 1879] + ``` + This is expected. + + This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should + refer to this superclass for more information regarding those methods. + + Args: + vocab_file (`str`, *optional*): + Path to the vocabulary file. + merges_file (`str`, *optional*): + Path to the merges file. + tokenizer_file (`str`, *optional*): + Path to [tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that + contains everything needed to load the tokenizer. + unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. Not applicable to this tokenizer. + bos_token (`str`, *optional*): + The beginning of sequence token. Not applicable for this tokenizer. + eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`): + The end of sequence token. + pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`): + The token used for padding, for example when batching sequences of different lengths. + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + slow_tokenizer_class = Qwen2Tokenizer + + def __init__( + self, + vocab_file=None, + merges_file=None, + tokenizer_file=None, + unk_token="<|endoftext|>", + bos_token=None, + eos_token="<|endoftext|>", + pad_token="<|endoftext|>", + **kwargs, + ): + # We need to at least pass vocab_file and merges_file to base class + # in case a slow tokenizer needs to be initialized; other can be + # configured through files. + # following GPT2TokenizerFast, also adding unk_token, bos_token, and eos_token + + bos_token = ( + AddedToken(bos_token, lstrip=False, rstrip=False, special=True, normalized=False) + if isinstance(bos_token, str) + else bos_token + ) + eos_token = ( + AddedToken(eos_token, lstrip=False, rstrip=False, special=True, normalized=False) + if isinstance(eos_token, str) + else eos_token + ) + unk_token = ( + AddedToken(unk_token, lstrip=False, rstrip=False, special=True, normalized=False) + if isinstance(unk_token, str) + else unk_token + ) + pad_token = ( + AddedToken(pad_token, lstrip=False, rstrip=False, special=True, normalized=False) + if isinstance(pad_token, str) + else pad_token + ) + + super().__init__( + vocab_file=vocab_file, + merges_file=merges_file, + tokenizer_file=tokenizer_file, + unk_token=unk_token, + bos_token=bos_token, + eos_token=eos_token, + pad_token=pad_token, + **kwargs, + ) + + # Copied from transformers.models.gpt2.tokenization_gpt2_fast.GPT2TokenizerFast.save_vocabulary + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + files = self._tokenizer.model.save(save_directory, name=filename_prefix) + return tuple(files) + + +__all__ = ["Qwen2TokenizerFast"] diff --git a/docs/transformers/src/transformers/models/qwen2_5_omni/__init__.py b/docs/transformers/src/transformers/models/qwen2_5_omni/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0d7ddae0da7e1b2c16b95e94e6678f9511cb716f --- /dev/null +++ b/docs/transformers/src/transformers/models/qwen2_5_omni/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_qwen2_5_omni import * + from .modeling_qwen2_5_omni import * + from .processing_qwen2_5_omni import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py b/docs/transformers/src/transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py new file mode 100644 index 0000000000000000000000000000000000000000..bb1d27ea91f210827bda858d41185f3d44200cd4 --- /dev/null +++ b/docs/transformers/src/transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py @@ -0,0 +1,1063 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_qwen2_5_omni.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...configuration_utils import PretrainedConfig +from ...modeling_rope_utils import rope_config_validation +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class Qwen2_5OmniVisionEncoderConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen2_5OmniThinkerVision`]. It is used to instantiate a + Qwen2.5-VL vision encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the audio encoder of the Qwen2.5-VL + architecture. + + e.g. [Qwen/Qwen2.5-Omni-7B](https://huggingface.co/Qwen/Qwen2.5-Omni-7B) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + depth (`int`, *optional*, defaults to 32): + Number of layers (depth) in the model. + hidden_size (`int`, *optional*, defaults to 3584): + The size of the hidden layers. + hidden_act (`str`, *optional*, defaults to `"quick_gelu"`): + The non-linear activation function used in the model. Supported options include `"quick_gelu"` and others as applicable. + mlp_ratio (`float`, *optional*, defaults to 4): + The ratio used to determine the size of the MLP (Multi-Layer Perceptron) hidden layer. + num_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer. + in_channels (`int`, *optional*, defaults to 3): + Number of input channels. + patch_size (`int`, *optional*, defaults to 14): + The size of the patches extracted from the input. + spatial_merge_size (`int`, *optional*, defaults to 2): + The size used for merging spatial dimensions. + temporal_patch_size (`int`, *optional*, defaults to 2): + The size used for patches along the temporal dimension. + + Example: + + ```python + >>> from transformers import Qwen2_5OmniVisionEncoderConfig, Qwen2_5OmniVisionEncoder + + >>> # Initializing a Qwen2_5OmniVisionEncoderConfig + >>> configuration = Qwen2_5OmniVisionEncoderConfig() + + >>> # Initializing a Qwen2_5OmniVisionEncoder (with random weights) + >>> model = Qwen2_5OmniVisionEncoder(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen2_5_omni_vision_encoder" + base_config_key = "vision_config" + + def __init__( + self, + depth=32, + hidden_size=3584, + hidden_act="silu", + intermediate_size=3420, + num_heads=16, + in_channels=3, + patch_size=14, + spatial_merge_size=2, + temporal_patch_size=2, + window_size=112, + out_hidden_size=3584, + fullatt_block_indexes=[7, 15, 23, 31], + initializer_range=0.02, + **kwargs, + ): + super().__init__(**kwargs) + + self.depth = depth + self.hidden_size = hidden_size + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.num_heads = num_heads + self.in_channels = in_channels + self.patch_size = patch_size + self.spatial_merge_size = spatial_merge_size + self.temporal_patch_size = temporal_patch_size + self.window_size = window_size + self.fullatt_block_indexes = fullatt_block_indexes + self.out_hidden_size = out_hidden_size + self.initializer_range = initializer_range + + +class Qwen2_5OmniAudioEncoderConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen2_5OmniAudioEncoder`]. It is used to instantiate a + Qwen2.5-Omni-Thinker audio encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the audio encoder of the Qwen2-Audio + architecture. + + e.g. [Qwen/Qwen2.5-Omni-7B](https://huggingface.co/Qwen/Qwen2.5-Omni-7B) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + num_mel_bins (`int`, *optional*, defaults to 128): + Number of mel features used per input features. Should correspond to the value used in the + `Qwen2_5OmniProcessor` class. + encoder_layers (`int`, *optional*, defaults to 32): + Number of encoder layers. + encoder_attention_heads (`int`, *optional*, defaults to 20): + Number of attention heads for each attention layer in the Transformer encoder. + encoder_ffn_dim (`int`, *optional*, defaults to 5120): + Dimensionality of the "intermediate" (often named feed-forward) layer in encoder. + d_model (`int`, *optional*, defaults to 1280): + Dimensionality of the layers. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + activation_function (`str`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + activation_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for activations inside the fully connected layer. + scale_embedding (`bool`, *optional*, defaults to `False`): + Scale embeddings by diving by sqrt(d_model). + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + max_source_positions (`int`, *optional*, defaults to 1500): + The maximum sequence length of log-mel filter-bank features that this model might ever be used with. + n_window (`int`, *optional*, defaults to 100): + The chunk for conv and flash attn in AudioEncoder. + output_dim (`int`, *optional*, defaults to 3584): + The output dimention of AudioEncoder. + + Example: + + ```python + >>> from transformers import Qwen2_5OmniAudioEncoderConfig, Qwen2_5OmniAudioEncoder + + >>> # Initializing a Qwen2_5OmniAudioEncoderConfig + >>> configuration = Qwen2_5OmniAudioEncoderConfig() + + >>> # Initializing a Qwen2_5OmniAudioEncoder (with random weights) + >>> model = Qwen2_5OmniAudioEncoder(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen2_5_omni_audio_encoder" + + def __init__( + self, + num_mel_bins=128, + encoder_layers=32, + encoder_attention_heads=20, + encoder_ffn_dim=5120, + d_model=1280, + dropout=0, + attention_dropout=0, + activation_function="gelu", + activation_dropout=0, + scale_embedding=False, + initializer_range=0.02, + max_source_positions=1500, + n_window=100, + output_dim=3584, + **kwargs, + ): + super().__init__(**kwargs) + + self.num_mel_bins = num_mel_bins + self.d_model = d_model + self.encoder_layers = encoder_layers + self.encoder_attention_heads = encoder_attention_heads + self.encoder_ffn_dim = encoder_ffn_dim + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_function = activation_function + self.activation_dropout = activation_dropout + self.num_hidden_layers = encoder_layers + self.initializer_range = initializer_range + self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True + self.max_source_positions = max_source_positions + self.n_window = n_window + self.output_dim = output_dim + + +class Qwen2_5OmniTextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen2_5OmniThinkerForConditionalGeneration`]. It is used to instantiate an + Qwen2.5-Omni-Thinker model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the Qwen2.5-Omni-Thinker. + + e.g. [Qwen/Qwen2.5-Omni-7B](https://huggingface.co/Qwen/Qwen2.5-Omni-7B) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 152064): + Vocabulary size of the QwenOmni model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Qwen2VLModel`] + hidden_size (`int`, *optional*, defaults to 3584): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 18944): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 28): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 28): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 4): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 32768): + The maximum sequence length that this model might ever be used with. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + rope_theta (`float`, *optional*, defaults to 1000000.0): + The base period of the RoPE embeddings. + use_sliding_window (`bool`, *optional*, defaults to `False`): + Whether to use sliding window attention. + sliding_window (`int`, *optional*, defaults to 32768): + Sliding window attention (SWA) window size. If not specified, will default to `4096`. + max_window_layers (`int`, *optional*, defaults to 28): + The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + + Example: + + ```python + >>> from transformers import Qwen2_5OmniThinkerForConditionalGeneration, Qwen2_5OmniThinkerConfig, Qwen2_5OmniAudioEncoderConfig, Qwen2_5OmniVisionEncoderConfig + + >>> # Initializing a Qwen2_5OmniAudioEncoder config + >>> audio_config = Qwen2_5OmniAudioEncoderConfig() + + >>> # Initializing a Qwen2_5OmniVisionEncoder config + >>> vision_config = Qwen2_5OmniVisionEncoderConfig() + + >>> # Initializing a Qwen2.5OmniThinker configuration + >>> configuration = Qwen2_5OmniThinkerConfig(audio_config, vision_config) + + >>> # Initializing a model from the Qwen-Omni style configuration + >>> model = Qwen2_5OmniThinkerForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen2_5_omni_text" + keys_to_ignore_at_inference = ["past_key_values"] + + # Default tensor parallel plan for base model `Qwen25OmniText` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=152064, + hidden_size=3584, + intermediate_size=18944, + num_hidden_layers=28, + num_attention_heads=28, + num_key_value_heads=4, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + rope_theta=1000000.0, + rope_scaling=None, + use_sliding_window=False, + sliding_window=32768, + max_window_layers=28, + attention_dropout=0.0, + **kwargs, + ): + super().__init__( + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.use_sliding_window = use_sliding_window + self.sliding_window = sliding_window # we check `use_sliding_window` in the modeling code + self.max_window_layers = max_window_layers + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_dropout = attention_dropout + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, move it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + if self.rope_scaling is None: + self.rope_scaling = {"mrope_section": [16, 24, 24], "rope_type": "default", "type": "default"} + + +class Qwen2_5OmniThinkerConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen2_5OmniThinkerForConditionalGeneration`]. It is used to instantiate an + Qwen2.5-Omni-Thinker model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the Qwen2.5-Omni-Thinker. + + e.g. [Qwen/Qwen2.5-Omni-7B](https://huggingface.co/Qwen/Qwen2.5-Omni-7B) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + audio_config (`dict`, *optional*): + The config dictionary of the audio backbone. + vision_config (`dict`, *optional*): + The config dictionary of the vision backbone. + text_config (`dict`, *optional*): + The config dictionary of the text backbone. + audio_token_index (`int`, *optional*, defaults to 151646): + The audio token index to encode the audio prompt. + image_token_index (`int`, *optional*, defaults to 151655): + The image token index to encode the image prompt. + video_token_index (`int`, *optional*, defaults to 151656): + The video token index to encode the video prompt. + position_id_per_seconds (`int`, *optional*, defaults to 25): + The increment of position id per second. + seconds_per_chunk (`int`, *optional*, defaults to 2): + The duration in seconds of the chunk of audio and video data. + audio_start_token_id (`int`, *optional*, defaults to 151647): + The audio start token index to encode the audio prompt. + audio_end_token_id (`int`, *optional*, defaults to 151648): + The audio end token index to encode the audio prompt. + user_token_id (`int, *optional*, defaults to 872): + The user token index to encode the user token. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + + Example: + + ```python + >>> from transformers import Qwen2_5OmniThinkerForConditionalGeneration, Qwen2_5OmniThinkerConfig, Qwen2_5OmniAudioEncoderConfig, Qwen2_5OmniVisionEncoderConfig + + >>> # Initializing a Qwen2_5OmniAudioEncoder config + >>> audio_config = Qwen2_5OmniAudioEncoderConfig() + + >>> # Initializing a Qwen2_5OmniVisionEncoder config + >>> vision_config = Qwen2_5OmniVisionEncoderConfig() + + >>> # Initializing a Qwen2_5OmniTextConfig config + >>> text_config = Qwen2_5OmniTextConfig() + + >>> # Initializing a Qwen2.5OmniThinker configuration + >>> configuration = Qwen2_5OmniThinkerConfig(audio_config, vision_config, text_config) + + >>> # Initializing a model from the Qwen-Omni style configuration + >>> model = Qwen2_5OmniThinkerForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen2_5_omni_thinker" + attribute_map = { + "image_token_id": "image_token_index", + "video_token_id": "video_token_index", + "audio_token_id": "audio_token_index", + } + sub_configs = { + "audio_config": Qwen2_5OmniAudioEncoderConfig, + "vision_config": Qwen2_5OmniVisionEncoderConfig, + "text_config": Qwen2_5OmniTextConfig, + } + + def __init__( + self, + audio_config=None, + vision_config=None, + text_config=None, + audio_token_index=151646, + image_token_index=151655, + video_token_index=151656, + position_id_per_seconds=25, + seconds_per_chunk=2, + audio_start_token_id=151647, + audio_end_token_id=151648, + user_token_id=872, + initializer_range=0.02, + **kwargs, + ): + self.audio_token_index = audio_token_index + self.image_token_index = image_token_index + self.video_token_index = video_token_index + self.user_token_id = user_token_id + self.position_id_per_seconds = position_id_per_seconds + self.seconds_per_chunk = seconds_per_chunk + self.audio_start_token_id = audio_start_token_id + self.audio_end_token_id = audio_end_token_id + self.initializer_range = initializer_range + + if isinstance(vision_config, dict): + vision_config = Qwen2_5OmniVisionEncoderConfig(**vision_config) + elif vision_config is None: + vision_config = Qwen2_5OmniVisionEncoderConfig() + self.vision_config = vision_config + + if isinstance(audio_config, dict): + audio_config = Qwen2_5OmniAudioEncoderConfig(**audio_config) + elif audio_config is None: + audio_config = Qwen2_5OmniAudioEncoderConfig() + self.audio_config = audio_config + + if isinstance(text_config, dict): + text_config = Qwen2_5OmniTextConfig(**text_config) + elif text_config is None: + text_config = Qwen2_5OmniTextConfig() + self.text_config = text_config + + super().__init__(**kwargs) + + +class Qwen2_5OmniTalkerConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen2_5OmniTalkerForConditionalGeneration`]. It is used to instantiate an + Qwen2.5-Omni-Talker model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the Qwen2.5-Omni-Thinker. + + e.g. [Qwen/Qwen2.5-Omni-7B](https://huggingface.co/Qwen/Qwen2.5-Omni-7B) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + audio_token_index (`int`, *optional*, defaults to 151646): + The audio token index to encode the audio prompt. + image_token_index (`int`, *optional*, defaults to 151655): + The image token index to encode the image prompt. + video_token_index (`int`, *optional*, defaults to 151656): + The video token index to encode the video prompt. + vocab_size (`int`, *optional*, defaults to 8448): + Vocabulary size of the QwenOmni model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Qwen2VLModel`] + tts_text_start_token_id (`int`, *optional*, defaults to 151860): + The tts text start token index to encode the start of tts text. + tts_text_end_token_id (`int`, *optional*, defaults to 151861): + The tts text end token index to encode the end of tts text. + tts_text_pad_token_id (`int`, *optional*, defaults to 151859): + The tts text pad token index to encode the pad of tts text. + tts_codec_start_token_id (`int`, *optional*, defaults to 8293): + The tts codec start token index to encode the start of tts codec. + tts_codec_end_token_id (`int`, *optional*, defaults to 8294): + The tts codec end token index to encode the end of tts codec. + tts_codec_pad_token_id (`int`, *optional*, defaults to 8292): + The tts codec pad token index to encode the pad of tts codec. + tts_codec_mask_token_id (`int`, *optional*, defaults to 8296): + The tts codec mask token index to encode the mask of tts codec. + vision_start_token_id (`int`, *optional*, defaults to 151652): + The tts vision start token index to encode the start of vision. + vision_end_token_id (`int`, *optional*, defaults to 151653): + The tts vision end token index to encode the end of vision. + embedding_size (`int`, *optional*, defaults to 3584): + Dimension of the embedding representations. + hidden_size (`int`, *optional*, defaults to 3584): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 18944): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 28): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 28): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 4): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 32768): + The maximum sequence length that this model might ever be used with. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + head_dim (`int`, *optional*, defaults to 128): + The dimension of each attention head. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 1000000.0): + The base period of the RoPE embeddings. + use_sliding_window (`bool`, *optional*, defaults to `False`): + Whether to use sliding window attention. + sliding_window (`int`, *optional*, defaults to 32768): + Sliding window attention (SWA) window size. If not specified, will default to `4096`. + max_window_layers (`int`, *optional*, defaults to 28): + The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + position_id_per_seconds (`int`, *optional*, defaults to 25): + The increment of position id per second. + seconds_per_chunk (`int`, *optional*, defaults to 2): + The duration in seconds of the chunk of audio and video data. + audio_start_token_id (`int`, *optional*, defaults to 151647): + The audio start token index to encode the audio prompt. + audio_end_token_id (`int`, *optional*, defaults to 151648): + The audio end token index to encode the audio prompt. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + spatial_merge_size (`int`, *optional*, defaults to 2): + The size used for merging spatial dimensions. + + Example: + + ```python + >>> from transformers import Qwen2_5OmniTalkerForConditionalGeneration, Qwen2_5OmniThinkerConfig, Qwen2_5OmniAudioEncoderConfig, Qwen2_5OmniVisionEncoderConfig + + >>> # Initializing a Qwen2_5OmniAudioEncoder config + >>> audio_config = Qwen2_5OmniAudioEncoderConfig() + + >>> # Initializing a Qwen2 config + >>> text_config = Qwen2Config() + + >>> # Initializing a Qwen2_5Omni configuration + >>> configuration = Qwen2_5OmniThinkerConfig(audio_config, text_config) + + >>> # Initializing a model from the qwen2-audio style configuration + >>> model = Qwen2_5OmniTalkerForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen2_5_omni_talker" + attribute_map = { + "image_token_id": "image_token_index", + "video_token_id": "video_token_index", + "audio_token_id": "audio_token_index", + } + + def __init__( + self, + audio_token_index=151646, + image_token_index=151655, + video_token_index=151656, + vocab_size=8448, + tts_text_start_token_id=151860, + tts_text_end_token_id=151861, + tts_text_pad_token_id=151859, + tts_codec_start_token_id=8293, + tts_codec_end_token_id=8294, + tts_codec_pad_token_id=8292, + tts_codec_mask_token_id=8296, + vision_start_token_id=151652, + vision_end_token_id=151653, + embedding_size=3584, + hidden_size=3584, + intermediate_size=18944, + num_hidden_layers=28, + num_attention_heads=28, + num_key_value_heads=4, + hidden_act="silu", + max_position_embeddings=32768, + rms_norm_eps=1e-06, + head_dim=128, + use_cache=True, + tie_word_embeddings=False, + rope_theta=1000000.0, + use_sliding_window=False, + sliding_window=32768, + max_window_layers=28, + attention_dropout=0.0, + rope_scaling=None, + position_id_per_seconds=25, + seconds_per_chunk=2, + audio_start_token_id=151647, + audio_end_token_id=151648, + initializer_range=0.02, + spatial_merge_size=2, + **kwargs, + ): + self.audio_token_index = audio_token_index + self.image_token_index = image_token_index + self.video_token_index = video_token_index + + self.tts_text_start_token_id = tts_text_start_token_id + self.tts_text_end_token_id = tts_text_end_token_id + self.tts_text_pad_token_id = tts_text_pad_token_id + self.tts_codec_start_token_id = tts_codec_start_token_id + self.tts_codec_end_token_id = tts_codec_end_token_id + self.tts_codec_pad_token_id = tts_codec_pad_token_id + + self.tts_codec_mask_token_id = tts_codec_mask_token_id + + self.vision_start_token_id = vision_start_token_id + self.vision_end_token_id = vision_end_token_id + + self.vocab_size = vocab_size + self.head_dim = head_dim + self.embedding_size = embedding_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.use_sliding_window = use_sliding_window + self.sliding_window = sliding_window + self.max_window_layers = max_window_layers + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + self.rope_scaling = rope_scaling + self.position_id_per_seconds = position_id_per_seconds # zf + self.seconds_per_chunk = seconds_per_chunk # zf + self.audio_start_token_id = audio_start_token_id # zf + self.audio_end_token_id = audio_end_token_id # zf + + self.initializer_range = initializer_range + self.spatial_merge_size = spatial_merge_size + + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + + +class Qwen2_5OmniDiTConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of the Qwen2_5OmniToken2WavDiT used in the Qwen2.5-Omni-Token2Wav model. + It defines the architecture of the DiT model, which is used for generating mel-spectrograms from tokens. + + Args: + hidden_size (`int`, *optional*, defaults to 1024): + The dimension of the model. + num_hidden_layers (`int`, *optional*, defaults to 22): + The number of transformer blocks in the DiT model. + num_attention_heads (`int`, *optional*, defaults to 16): + The number of attention heads in each transformer block. + ff_mult (`int`, *optional*, defaults to 2): + The multiplier for the feedforward layer in each transformer block. + emb_dim (`int`, *optional*, defaults to 512): + The dimension of the embedding layer. + head_dim (`int`, *optional*, defaults to 64): + The dimension of each attention head. + repeats (`int`, *optional*, defaults to 2): + The number of times the codec embeddings are repeated. + num_embeds (`int`, *optional*, defaults to 8193): + The number of unique embeddings in the codec. + mel_dim (`int`, *optional*, defaults to 80): + The dimension of the mel-spectrogram. + dropout (`float`, *optional*, defaults to 0.1): + The dropout rate for the transformer blocks. + + enc_emb_dim (`int`, *optional*, defaults to 192): + The dimension of the pre-trained speaker embedding. + enc_dim (`int`, *optional*, defaults to 128): + The dimension of the encoder output. + enc_channels (`List[int]`, *optional*, defaults to `[256, 256, 256, 256, 768]`): + A list of output channels for each TDNN/SERes2Net layer in the encoder. + enc_kernel_sizes (`List[int]`, *optional*, defaults to `[5, 3, 3, 3, 1]`): + A list of kernel sizes for each layer in the encoder. + enc_dilations (`List[int]`, *optional*, defaults to `[1, 2, 3, 4, 1]`): + A list of dilations for each layer in the encoder. + enc_attention_channels (`int`, *optional*, defaults to 64): + The number of attention channels in the SqueezeExcitationBlock. + enc_res2net_scale (`int`, *optional*, defaults to 2): + The scale of the Res2Net block in the encoder. + enc_se_channels (`int`, *optional*, defaults to 64): + The number of output channels after squeeze in the SqueezeExcitationBlock. + """ + + model_type = "qwen2_5_omni_dit" + + def __init__( + self, + hidden_size=1024, + num_hidden_layers=22, + num_attention_heads=16, + ff_mult=2, + emb_dim=512, + head_dim=64, + rope_theta=10000.0, + max_position_embeddings=32768, + block_size=24, + look_ahead_layers=[10], + look_backward_layers=[0, 20], + repeats=2, + num_embeds=8193, + mel_dim=80, + dropout=0.1, + enc_emb_dim=192, + enc_dim=128, + enc_channels=[256, 256, 256, 256, 768], + enc_kernel_sizes=[5, 3, 3, 3, 1], + enc_dilations=[1, 2, 3, 4, 1], + enc_attention_channels=64, + enc_res2net_scale=2, + enc_se_channels=64, + **kwargs, + ): + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.ff_mult = ff_mult + self.emb_dim = emb_dim + self.head_dim = head_dim + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + self.block_size = block_size + self.look_ahead_layers = look_ahead_layers + self.look_backward_layers = look_backward_layers + self.repeats = repeats + self.num_embeds = num_embeds + self.mel_dim = mel_dim + self.dropout = dropout + self.enc_emb_dim = enc_emb_dim + self.enc_dim = enc_dim + self.enc_channels = enc_channels + self.enc_kernel_sizes = enc_kernel_sizes + self.enc_dilations = enc_dilations + self.enc_attention_channels = enc_attention_channels + self.enc_res2net_scale = enc_res2net_scale + self.enc_se_channels = enc_se_channels + super().__init__(**kwargs) + + +class Qwen2_5OmniBigVGANConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of the Qwen2_5OmniToken2WavBigVGAN module used in the Qwen2.5-Omni-Token2Wav model. + It defines the architecture of the BigVGAN model, which is used for converting mel-spectrograms to waveforms. + + Args: + mel_dim (`int`, *optional*, defaults to 80): + The dimension of the mel-spectrogram. + upsample_initial_channel (`int`, *optional*, defaults to 1536): + The number of channels in the initial upsampling layer. + resblock_kernel_sizes (`List[int]`, *optional*, defaults to `[3, 7, 11]`): + A list of kernel sizes for each residual block. + resblock_dilation_sizes (`List[List[int]]`, *optional*, defaults to `[[1, 3, 5], [1, 3, 5], [1, 3, 5]]`): + A list of dilation sizes for each residual block. + upsample_rates (`List[int]`, *optional*, defaults to `[5, 3, 2, 2, 2, 2]`): + A list of upsampling rates for each upsampling layer. + upsample_kernel_sizes (`List[int]`, *optional*, defaults to `[11, 7, 4, 4, 4, 4]`): + A list of kernel sizes for each upsampling layer. + """ + + model_type = "qwen2_5_omni_bigvgan" + + def __init__( + self, + mel_dim=80, + upsample_initial_channel=1536, + resblock_kernel_sizes=[3, 7, 11], + resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], + upsample_rates=[5, 3, 2, 2, 2, 2], + upsample_kernel_sizes=[11, 7, 4, 4, 4, 4], + **kwargs, + ): + self.mel_dim = mel_dim + self.upsample_initial_channel = upsample_initial_channel + self.resblock_kernel_sizes = resblock_kernel_sizes + self.resblock_dilation_sizes = resblock_dilation_sizes + self.upsample_rates = upsample_rates + self.upsample_kernel_sizes = upsample_kernel_sizes + super().__init__(**kwargs) + + +class Qwen2_5OmniToken2WavConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen2_5OmniToken2WavModel`]. + It is used to instantiate the Qwen2.5-Omni-Token2Wav model which combines a Diffusion Transformer (DiT) for mel-spectrogram generation with a BigVGAN model for waveform synthesis. The configuration contains sub-configurations for both components. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + dit_config ([`DiT_Args`], *optional*): + Configuration class for the Diffusion Transformer (DiT) module responsible for generating mel-spectrograms. + bigvgan_config ([`BigVGAN_Args`], *optional*): + Configuration class for the BigVGAN module responsible for converting mel-spectrograms to waveforms. + Example: + + ```python + >>> from transformers import Qwen2_5OmniToken2WavModel, DiT_Args, BigVGAN_Args + + >>> # Initialize DiT configuration + >>> dit_config = DiT_Args( + ... dim=1024, + ... depth=22, + ... heads=16, + ... ff_mult=2 + ... ) + + >>> # Initialize BigVGAN configuration + >>> bigvgan_config = BigVGAN_Args( + ... mel_dim=80, + ... upsample_rates=[5,3,2,2,2,2] + ... ) + + >>> # Initialize main configuration + >>> config = Qwen2_5OmniToken2WavConfig(dit_config, bigvgan_config) + + >>> # Initialize model with config + >>> model = Qwen2_5OmniToken2Wav(config) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "qwen2_5_omni_token2wav" + sub_configs = { + "dit_config": Qwen2_5OmniDiTConfig, + "bigvgan_config": Qwen2_5OmniBigVGANConfig, + } + + def __init__(self, dit_config=None, bigvgan_config=None, **kwargs): + if dit_config is None: + dit_config = {} + if bigvgan_config is None: + bigvgan_config = {} + self.dit_config = Qwen2_5OmniDiTConfig(**dit_config) + self.bigvgan_config = Qwen2_5OmniBigVGANConfig(**bigvgan_config) + super().__init__(**kwargs) + + +class Qwen2_5OmniConfig(PretrainedConfig): + """ + This is the configuration class to store the configuration of a [`Qwen2_5OmniForConditionalGeneration`]. It is used to instantiate a Qwen2.5Omni + model according to the specified sub-models configurations, defining the model architecture. + + Instantiating a configuration with the defaults will yield a similar configuration to that of the + [Qwen/Qwen2.5-Omni-7B](https://huggingface.co/Qwen/Qwen2.5-Omni-7B) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + thinker_config (`dict`, *optional*): Configuration of the underlying thinker sub-model. + talker_config (`dict`, *optional*): Configuration of the underlying talker sub-model. + token2wav_config (`dict`, *optional*): Configuration of the underlying codec sub-model. + enable_audio_output (`bool`, *optional*, defaults to `True`): Whether enabel audio output and load talker and token2wav module. + + Example: + + ```python + >>> from transformers import ( + ... Qwen2_5OmniThinkerConfig, + ... Qwen2_5OmniTalkerConfig, + ... Qwen2_5OmniToken2WavConfig, + ... Qwen2_5OmniForConditionalGeneration, + ... Qwen2_5OmniConfig, + ... ) + + >>> # Initializing sub-modules configurations. + >>> thinker_config = Qwen2_5OmniThinkerConfig() + >>> talker_config = Qwen2_5OmniTalkerConfig() + >>> token2wav_config = Qwen2_5OmniToken2WavConfig() + + + >>> # Initializing a module style configuration + >>> configuration = Qwen2_5OmniConfig.from_sub_model_configs( + ... thinker_config, talker_config, token2wav_config + ... ) + + >>> # Initializing a model (with random weights) + >>> model = Qwen2_5OmniForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "qwen2_5_omni" + sub_configs = { + "thinker_config": Qwen2_5OmniThinkerConfig, + "talker_config": Qwen2_5OmniTalkerConfig, + "token2wav_config": Qwen2_5OmniToken2WavConfig, + } + + def __init__( + self, + thinker_config=None, + talker_config=None, + token2wav_config=None, + enable_audio_output: bool = True, + **kwargs, + ): + if thinker_config is None: + thinker_config = {} + logger.info("thinker_config is None. Initializing thinker model with default values") + + if talker_config is None: + talker_config = {} + logger.info("talker_config is None. Initializing talker model with default values") + + if token2wav_config is None: + token2wav_config = {} + logger.info("token2wav_config is None. Initializing token2wav model with default values") + + self.thinker_config = Qwen2_5OmniThinkerConfig(**thinker_config) + self.talker_config = Qwen2_5OmniTalkerConfig(**talker_config) + self.token2wav_config = Qwen2_5OmniToken2WavConfig(**token2wav_config) + self.enable_audio_output = enable_audio_output + + super().__init__(**kwargs) + + def get_text_config(self, decoder=False) -> "PretrainedConfig": + """ + Returns the config that is meant to be used with text IO. On most models, it is the original config instance + itself. On specific composite models, it is under a set of valid names. + + Args: + decoder (`Optional[bool]`, *optional*, defaults to `False`): + If set to `True`, then only search for decoder config names. + """ + # Overriden for deeply nested config like Qwen2-Omni. We don't have any omni model + # except for Qwen yet. This has to be generalized if more deeply nested configs are + # added. NOTE: currently method used only by vLLM + return self.thinker_config.get_text_config() + + +__all__ = ["Qwen2_5OmniConfig", "Qwen2_5OmniThinkerConfig", "Qwen2_5OmniTalkerConfig", "Qwen2_5OmniToken2WavConfig"] diff --git a/docs/transformers/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/docs/transformers/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py new file mode 100644 index 0000000000000000000000000000000000000000..18b23b3949f2495aaca45efb25564ba51d122c36 --- /dev/null +++ b/docs/transformers/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -0,0 +1,4695 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_qwen2_5_omni.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn import Parameter + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, ModelOutput +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + is_torch_flex_attn_available, + logging, + replace_return_docstrings, +) +from ...utils.hub import cached_file +from .configuration_qwen2_5_omni import ( + Qwen2_5OmniAudioEncoderConfig, + Qwen2_5OmniBigVGANConfig, + Qwen2_5OmniConfig, + Qwen2_5OmniDiTConfig, + Qwen2_5OmniTalkerConfig, + Qwen2_5OmniTextConfig, + Qwen2_5OmniThinkerConfig, + Qwen2_5OmniToken2WavConfig, + Qwen2_5OmniVisionEncoderConfig, +) + + +if is_flash_attn_2_available(): + from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func + from flash_attn.layers.rotary import apply_rotary_emb +else: + flash_attn_varlen_func = None + apply_rotary_emb = None + + +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + +if is_flash_attn_available(): + from ...modeling_flash_attention_utils import _flash_attention_forward + + +logger = logging.get_logger(__name__) + + +class Qwen2RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Qwen2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +Qwen2_5Omni_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Qwen2_5OmniConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Qwen2_5Omni Model outputting raw hidden-states without any specific head on top.", + Qwen2_5Omni_START_DOCSTRING, +) +class Qwen2_5OmniPreTrainedModel(PreTrainedModel): + config_class = Qwen2_5OmniConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Qwen2_5OmniDecoderLayer", "Qwen2_5OmniVisionBlock"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_static_cache = True + + def _init_weights(self, module): + # important: this ported version of Qwen2.5OmniThinker isn't meant for training from scratch - only + # inference and fine-tuning - so the proper init weights code has been removed + std = self.config.initializer_range if hasattr(self.config, "initializer_range") else 0.02 + + if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv3d, nn.ConvTranspose1d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + elif isinstance(module, Qwen2RMSNorm): + module.weight.data.fill_(1.0) + + +class Qwen2_5OmniPreTrainedModelForConditionalGeneration(Qwen2_5OmniPreTrainedModel): + def _prepare_4d_causal_attention_mask_with_cache_position( + self, + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + min_dtype: float, + cache_position: torch.Tensor, + batch_size: int, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + min_dtype (`float`): + The minimum value representable with the dtype `dtype`. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + def get_llm_pos_ids_for_vision( + self, + start_idx: int, + vision_idx: int, + spatial_merge_size: int, + t_index: List[int], + grid_hs: List[int], + grid_ws: List[int], + ): + llm_pos_ids_list = [] + llm_grid_h = grid_hs[vision_idx] // spatial_merge_size + llm_grid_w = grid_ws[vision_idx] // spatial_merge_size + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(len(t_index), -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(len(t_index), llm_grid_h, -1).flatten() + t_index = torch.Tensor(t_index).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten().long() + _llm_pos_ids = torch.stack([t_index, h_index, w_index]) + llm_pos_ids_list.append(_llm_pos_ids + start_idx) # + 1 ) # 12.09 by malinhan + llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1) + return llm_pos_ids + + def get_chunked_index( + self, token_indices: torch.Tensor, tokens_per_chunk: int, remove_index: int + ) -> list[tuple[int, int]]: + """ + Splits token index list into chunks based on token value ranges. + + Given a list of token indices, returns a list of (start, end) index tuples representing + slices of the list where the token values fall within successive ranges of `t_ntoken_per_chunk`. + + For example, if `t_ntoken_per_chunk` is 1000, the function will create chunks such that: + - the first chunk contains token values < 1000, + - the second chunk contains values >= 1000 and < 2000, and so on. + + Parameters: + token_indices (`torch.Tensor` of shape `(seq_len, )`): A monotonically increasing list of + token index values. + t_ntoken_per_chunk (`int`): Number of tokens per chunk (used as the chunk size threshold). + remove_index (`int`) An index id to subtract from `token_indices` before chunking + + Returns: + `List[Tuple[int, int]]`: A list of tuples, each representing the start (inclusive) + and end (exclusive) indices of a chunk in `token_indices`. + """ + + def _iter(): + i, start_idx = 0, 0 # skip bos token + current_chunk = 1 + while i < len(token_indices): # skip eos token + if token_indices[i] - remove_index >= current_chunk * tokens_per_chunk: + yield (start_idx, i) + start_idx = i + current_chunk += 1 + i += 1 + yield (start_idx, len(token_indices)) + + return list(_iter()) + + def get_rope_index( + self, + input_ids: Optional[torch.LongTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + use_audio_in_video: bool = False, + audio_seqlens: Optional[torch.LongTensor] = None, + second_per_grids: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Calculate the 3D rope index based on image and video's temporal, height and width in LLM. + + Explanation: + Each embedding sequence contains vision embedding and text embedding or just contains text embedding. + + For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs. + Examples: + input_ids: [T T T T T], here T is for text. + temporal position_ids: [0, 1, 2, 3, 4] + height position_ids: [0, 1, 2, 3, 4] + width position_ids: [0, 1, 2, 3, 4] + + For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part + and 1D rotary position embedding for text part. + Examples: + Temporal (Time): 3 patches, representing different segments of the video in time. + Height: 2 patches, dividing each frame vertically. + Width: 2 patches, dividing each frame horizontally. + We also have some important parameters: + fps (Frames Per Second): The video's frame rate, set to 1. This means one frame is processed each second. + tokens_per_second: This is a crucial parameter. It dictates how many "time-steps" or "temporal tokens" are conceptually packed into a one-second interval of the video. In this case, we have 25 tokens per second. So each second of the video will be represented with 25 separate time points. It essentially defines the temporal granularity. + temporal_patch_size: The number of frames that compose one temporal patch. Here, it's 2 frames. + interval: The step size for the temporal position IDs, calculated as tokens_per_second * temporal_patch_size / fps. In this case, 25 * 2 / 1 = 50. This means that each temporal patch will be have a difference of 50 in the temporal position IDs. + input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision. + vision temporal position_ids: [0, 0, 0, 0, 50, 50, 50, 50, 100, 100, 100, 100] + vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1] + vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] + text temporal position_ids: [101, 102, 103, 104, 105] + text height position_ids: [101, 102, 103, 104, 105] + text width position_ids: [101, 102, 103, 104, 105] + Here we calculate the text start position_ids as the max vision position_ids plus 1. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + use_audio_in_video (`bool`, *optional*): + If set to `True`, use the audio in video. + audio_seqlens (`torch.LongTensor` of shape `(num_audios)`, *optional*): + The length of feature shape of each audio in LLM. + second_per_grids (`torch.LongTensor` of shape `(num_videos)`, *optional*): + The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs. + + Returns: + position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`) + mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`) + """ + spatial_merge_size = self.spatial_merge_size + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + audio_token_id = self.config.audio_token_id + vision_start_token_id = self.config.vision_start_token_id + audio_start_token_id = self.config.audio_start_token_id + position_id_per_seconds = self.config.position_id_per_seconds + seconds_per_chunk = self.config.seconds_per_chunk + + mrope_position_deltas = [] + if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): + total_input_ids = input_ids + if attention_mask is None: + attention_mask = torch.ones_like(total_input_ids) + position_ids = torch.ones( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ) + image_idx, video_idx, audio_idx = 0, 0, 0 + attention_mask = attention_mask.to(total_input_ids.device) + for i, input_ids in enumerate(total_input_ids): + input_ids = input_ids[attention_mask[i] == 1] + image_nums, video_nums, audio_nums = 0, 0, 0 + vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + audio_nums = torch.sum(input_ids == audio_start_token_id) + image_nums = (vision_tokens == image_token_id).sum() + video_nums = ( + (vision_tokens == audio_start_token_id).sum() + if use_audio_in_video + else (vision_tokens == video_token_id).sum() + ) + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos, remain_audios = image_nums, video_nums, audio_nums + multimodal_nums = ( + image_nums + audio_nums if use_audio_in_video else image_nums + video_nums + audio_nums + ) + for _ in range(multimodal_nums): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if audio_token_id in input_tokens and remain_audios > 0: + ed_audio = input_tokens.index(audio_token_id, st) + else: + ed_audio = len(input_tokens) + 1 + min_ed = min(ed_image, ed_video, ed_audio) + if min_ed == ed_audio: + text_len = min_ed - st - 1 + if text_len != 0: + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + bos_len = 1 + llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + audio_len = ((audio_seqlens[audio_idx] - 1) // 2 + 1 - 2) // 2 + 1 + llm_pos_ids = torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx + llm_pos_ids_list.append(llm_pos_ids) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + eos_len = 1 + llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx) + + st += text_len + bos_len + audio_len + eos_len + audio_idx += 1 + remain_audios -= 1 + + elif min_ed == ed_image: + text_len = min_ed - st - 1 + if text_len != 0: + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + bos_len = 1 + llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + grid_t = image_grid_thw[image_idx][0] + grid_hs = image_grid_thw[:, 1] + grid_ws = image_grid_thw[:, 2] + t_index = (torch.arange(grid_t) * 1 * position_id_per_seconds).long() + llm_pos_ids = self.get_llm_pos_ids_for_vision( + st_idx, image_idx, spatial_merge_size, t_index, grid_hs, grid_ws + ) + image_len = image_grid_thw[image_idx].prod() // (spatial_merge_size**2) + llm_pos_ids_list.append(llm_pos_ids) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + eos_len = 1 + llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx) + + st += text_len + bos_len + image_len + eos_len + image_idx += 1 + remain_images -= 1 + + elif min_ed == ed_video and not use_audio_in_video: + text_len = min_ed - st - 1 + if text_len != 0: + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + bos_len = 1 + llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + grid_t = video_grid_thw[video_idx][0] + grid_hs = video_grid_thw[:, 1] + grid_ws = video_grid_thw[:, 2] + t_index = ( + torch.arange(grid_t) * second_per_grids[video_idx].cpu().float() * position_id_per_seconds + ).long() + llm_pos_ids = self.get_llm_pos_ids_for_vision( + st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws + ) + video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2) + llm_pos_ids_list.append(llm_pos_ids) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + eos_len = 1 + llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx) + + st += text_len + bos_len + video_len + eos_len + video_idx += 1 + remain_videos -= 1 + + elif min_ed == ed_video and use_audio_in_video: + text_len = min_ed - st - 2 + if text_len != 0: + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + bos_len = 1 + llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx) + llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + audio_len = ((audio_seqlens[audio_idx] - 1) // 2 + 1 - 2) // 2 + 1 + audio_llm_pos_ids = torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx + grid_t = video_grid_thw[video_idx][0] + grid_hs = video_grid_thw[:, 1] + grid_ws = video_grid_thw[:, 2] + + t_index = ( + torch.arange(grid_t) * second_per_grids[video_idx].cpu().float() * position_id_per_seconds + ).long() + video_llm_pos_ids = self.get_llm_pos_ids_for_vision( + st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws + ) + + t_ntoken_per_chunk = int(position_id_per_seconds * seconds_per_chunk) + video_chunk_indexes = self.get_chunked_index(video_llm_pos_ids[0], t_ntoken_per_chunk, st_idx) + audio_chunk_indexes = self.get_chunked_index(audio_llm_pos_ids[0], t_ntoken_per_chunk, st_idx) + sub_len = 0 + for j in range(max(len(video_chunk_indexes), len(audio_chunk_indexes))): + video_chunk_index = video_chunk_indexes[j] if j < len(video_chunk_indexes) else None + audio_chunk_index = audio_chunk_indexes[j] if j < len(audio_chunk_indexes) else None + if video_chunk_index is not None: + sub_len += video_chunk_index[1] - video_chunk_index[0] + + llm_pos_ids_list.append( + video_llm_pos_ids[:, video_chunk_index[0] : video_chunk_index[1]] + ) + if audio_chunk_index is not None: + sub_len += audio_chunk_index[1] - audio_chunk_index[0] + + llm_pos_ids_list.append( + audio_llm_pos_ids[:, audio_chunk_index[0] : audio_chunk_index[1]] + ) + video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + eos_len = 1 + llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx) + llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx) + + st += text_len + bos_len * 2 + audio_len + video_len + eos_len * 2 + + audio_idx += 1 + video_idx += 1 + remain_videos -= 1 + remain_audios -= 1 + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) + mrope_position_deltas.append(llm_positions.max() + 1 - len(input_ids)) + mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) + + return position_ids, mrope_position_deltas + else: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) + max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] + mrope_position_deltas = max_position_ids + 1 - torch.sum(attention_mask, dim=-1, keepdim=True) + + return position_ids, mrope_position_deltas + + +############################ +# Start Thinker # +############################ + + +@dataclass +class Qwen2_5OmniThinkerCausalLMOutputWithPast(ModelOutput): + """ + Base class for Qwen2.5OmniThinker causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`, *optional*): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + past_key_values: Optional[List[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + rope_deltas: Optional[torch.LongTensor] = None + + +class Qwen2_5OmniAudioAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + config: Qwen2_5OmniAudioEncoderConfig, + ): + super().__init__() + self.embed_dim = config.d_model + self.num_heads = config.encoder_attention_heads + self.dropout = config.attention_dropout + self.head_dim = self.embed_dim // self.num_heads + self.config = config + + if (self.head_dim * self.num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {self.num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = False + self.is_causal = False + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + seq_length, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states).reshape(seq_length, self.num_heads, -1) + key_states = self.k_proj(hidden_states).reshape(seq_length, self.num_heads, -1) + value_states = self.v_proj(hidden_states).reshape(seq_length, self.num_heads, -1) + + query_states = query_states.transpose(0, 1) + key_states = key_states.transpose(0, 1) + value_states = value_states.transpose(0, 1) + attn_weights = torch.matmul(query_states, key_states.transpose(1, 2)) / math.sqrt(self.head_dim) + + attention_mask = torch.full( + [1, seq_length, key_states.shape[1]], + torch.finfo(query_states.dtype).min, + device=query_states.device, + dtype=query_states.dtype, + ) + for i in range(1, len(cu_seqlens)): + attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 + + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1).to(query_states.dtype) + + attn_output = torch.matmul(attn_weights, value_states).transpose(0, 1).reshape(seq_length, self.embed_dim) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + + attn_output = self.out_proj(attn_output) + + return attn_output + + +class Qwen2_5OmniAudioFlashAttention2(Qwen2_5OmniAudioAttention): + """ + Qwen2.5OmniThinker flash attention module. This module inherits from `Qwen2_5OmniAudioAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + seq_length, all_dim = hidden_states.size() + query_states = self.q_proj(hidden_states) + query_states = query_states.reshape(seq_length, self.num_heads, -1) + + key_states = self.k_proj(hidden_states) + key_states = key_states.reshape(seq_length, self.num_heads, -1) + value_states = self.v_proj(hidden_states) + value_states = value_states.reshape(seq_length, self.num_heads, -1) + + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + attn_output = flash_attn_varlen_func( + query_states, key_states, value_states, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, dropout_p=0.0 + ) + attn_output = attn_output.reshape(seq_length, all_dim) + attn_output = self.out_proj(attn_output) + + return attn_output + + +class Qwen2_5OmniAudioSdpaAttention(Qwen2_5OmniAudioAttention): + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + seq_length, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states).reshape(seq_length, self.num_heads, -1) + key_states = self.k_proj(hidden_states).reshape(seq_length, self.num_heads, -1) + value_states = self.v_proj(hidden_states).reshape(seq_length, self.num_heads, -1) + + attention_mask = torch.zeros( + [1, seq_length, key_states.shape[0]], device=query_states.device, dtype=torch.bool + ) + for i in range(1, len(cu_seqlens)): + attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. + query_states = query_states.transpose(0, 1) + key_states = key_states.transpose(0, 1) + value_states = value_states.transpose(0, 1) + + # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, + # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=self.dropout if self.training else 0.0, + ) + attn_output = attn_output.transpose(0, 1) + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(seq_length, self.embed_dim) + attn_output = self.out_proj(attn_output) + return attn_output + + +QWEN2_5_OMNI_AUDIO_ATTENTION_CLASSES = { + "eager": Qwen2_5OmniAudioAttention, + "flash_attention_2": Qwen2_5OmniAudioFlashAttention2, + "sdpa": Qwen2_5OmniAudioSdpaAttention, +} + + +class Qwen2_5OmniAudioEncoderLayer(nn.Module): + def __init__(self, config: Qwen2_5OmniAudioEncoderConfig): + super().__init__() + self.embed_dim = config.d_model + self.self_attn = QWEN2_5_OMNI_AUDIO_ATTENTION_CLASSES[config._attn_implementation](config) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states = self.self_attn( + hidden_states=hidden_states, + cu_seqlens=cu_seqlens, + ) + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + hidden_states = residual + hidden_states + + if hidden_states.dtype == torch.float16: + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + return outputs + + +class SinusoidsPositionEmbedding(nn.Module): + def __init__(self, length, channels, max_timescale=10000): + super().__init__() + if channels % 2 != 0: + raise ValueError("SinusoidsPositionEmbedding needs even channels input") + log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) + inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2)).float() + scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] + self.register_buffer( + "positional_embedding", + torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1), + persistent=False, + ) + + def forward(self, seqlen: int): + return self.positional_embedding[:seqlen, :] + + +QWEN_OMNI_AUDIO_INPUTS = r""" + Args: + input_features (`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`): + Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be + obtained by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a + `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into + `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding + and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] + + feature_lens: [B], torch.LongTensor , mel length + + aftercnn_lens : [B], torch.LongTensor , mel length after cnn + """ + + +class Qwen2_5OmniAudioEncoder(Qwen2_5OmniPreTrainedModel): + """ + Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a + [`Qwen2_5OmniAudioEncoderLayer`]. + + Args: + config: Qwen2_5OmniAudioEncoderConfig + """ + + config_class = Qwen2_5OmniAudioEncoderConfig + main_input_name = "input_features" + _no_split_modules = ["Qwen2_5OmniAudioEncoderLayer"] + _supports_sdpa = True + + def __init__(self, config: Qwen2_5OmniAudioEncoderConfig): + super().__init__(config) + self.dropout = config.dropout + + embed_dim = config.d_model + self.num_mel_bins = config.num_mel_bins + self.max_source_positions = config.max_source_positions + self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + self.n_window = config.n_window + self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1) + self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1) + self.positional_embedding = SinusoidsPositionEmbedding(self.max_source_positions, embed_dim) + self.audio_bos_eos_token = nn.Embedding(2, config.output_dim) + self.layers = nn.ModuleList([Qwen2_5OmniAudioEncoderLayer(config) for _ in range(config.encoder_layers)]) + self.ln_post = nn.LayerNorm(config.d_model) + self.avg_pooler = nn.AvgPool1d(2, stride=2) + self.proj = nn.Linear(config.d_model, config.output_dim) + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def _freeze_parameters(self): + for param in self.parameters(): + param.requires_grad = False + self._requires_grad = False + + def get_input_embeddings(self) -> nn.Module: + return self.conv1 + + def set_input_embeddings(self, value: nn.Module): + self.conv1 = value + + @add_start_docstrings_to_model_forward(QWEN_OMNI_AUDIO_INPUTS) + def forward( + self, + input_features, + feature_lens=None, + aftercnn_lens=None, + ): + chunk_num = torch.ceil(feature_lens / (self.n_window * 2)).long() + + chunk_lengths = torch.tensor( + [self.n_window * 2] * chunk_num.sum(), + dtype=torch.long, + device=feature_lens.device, + ) + tail_chunk_index = F.pad(chunk_num, (1, 0), value=-1).cumsum(0)[1:] + chunk_lengths[tail_chunk_index] = feature_lens % (self.n_window * 2) + chunk_lengths = torch.where(chunk_lengths == 0, self.n_window * 2, chunk_lengths) + + chunk_list = input_features.split(chunk_lengths.tolist(), dim=1) + padded_feature, padded_mask, padded_mask_after_cnn = self.padded_and_mask_function( + chunk_list, chunk_lengths, padding_value=0, padding_side="right" + ) + padded_embed = nn.functional.gelu(self.conv1(padded_feature)) * padded_mask + padded_embed = nn.functional.gelu(self.conv2(padded_embed)).transpose(1, 2) + + padded_embed = padded_embed + self.positional_embedding.positional_embedding[ + : padded_embed.shape[1], : + ].unsqueeze(0).to(padded_embed.dtype) + hidden_states = padded_embed[padded_mask_after_cnn] + cu_seqlens = torch.cat( + ( + torch.zeros(1, device=padded_mask_after_cnn.device, dtype=torch.int32), + padded_mask_after_cnn.sum(1).cumsum(0), + ) + ).to(torch.int32) + + for idx, encoder_layer in enumerate(self.layers): + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + cu_seqlens, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + cu_seqlens, + ) + + hidden_states = layer_outputs[0] + + hidden_states_list = hidden_states.split(aftercnn_lens.tolist(), dim=0) + token_audio_list = [] + for each_audio_states in hidden_states_list: + each_audio_states = self.avg_pooler(each_audio_states.transpose(0, 1)).transpose_(0, 1) + each_audio_states = self.ln_post(each_audio_states) + each_audio_states = self.proj(each_audio_states) + token_audio_list.append(each_audio_states) + token_audio = torch.cat(token_audio_list, dim=0) + return BaseModelOutput(last_hidden_state=token_audio) + + def padded_and_mask_function(self, tensor_list, tensor_len, padding_value=0, padding_side="right"): + """ + Pads a sequence of tensors to their maximum length on indicated `padding_side`. + Then prepares a mask so that pad tokens are not attended to. + """ + max_len = tensor_len.max() + dim = tensor_list[0].shape[0] + padded_tensor = torch.full( + size=(len(tensor_list), dim, max_len), + fill_value=padding_value, + dtype=self.dtype, + device=tensor_list[0].device, + ) + + batch_mask = torch.zeros( + (len(tensor_len), max_len), + dtype=torch.long, + device=padded_tensor.device, + ) + for i, length in enumerate(tensor_len): + batch_mask[i, :length] = 1 + padded_tensor[i, :, :length] = tensor_list[i] + + feature_lens_after_cnn = (tensor_len - 1) // 2 + 1 + max_len_after_cnn = feature_lens_after_cnn.max() + batch_mask_after_cnn = torch.zeros( + (len(tensor_len), max_len_after_cnn), + dtype=torch.long, + device=padded_tensor.device, + ) + for i, length in enumerate(feature_lens_after_cnn): + batch_mask_after_cnn[i, :length] = 1 + return ( + padded_tensor, + batch_mask.unsqueeze(1), + batch_mask_after_cnn.bool(), + ) + + # Ignore copy + def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): + """ + Computes the output length of the convolutional layers and the output length of the audio encoder + """ + input_lengths = (input_lengths - 1) // 2 + 1 + output_lengths = (input_lengths - 2) // 2 + 1 + return input_lengths, output_lengths + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + orig_dtype = tensor.dtype + tensor = tensor.float() + cos = freqs.cos() + sin = freqs.sin() + cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() + sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() + output = (tensor * cos) + (rotate_half(tensor) * sin) + output = output.to(orig_dtype) + return output + + +class Qwen2_5OmniVisionAttention(nn.Module): + def __init__(self, dim: int, num_heads: int = 16) -> None: + super().__init__() + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.q = nn.Linear(dim, dim, bias=True) + self.k = nn.Linear(dim, dim, bias=True) + self.v = nn.Linear(dim, dim, bias=True) + self.proj = nn.Linear(dim, dim) + + def forward( + self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + q = self.q(hidden_states).reshape(seq_length, self.num_heads, -1) + k = self.k(hidden_states).reshape(seq_length, self.num_heads, -1) + v = self.v(hidden_states).reshape(seq_length, self.num_heads, -1) + q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) + k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) + + attention_mask = torch.full( + [1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype + ) + for i in range(1, len(cu_seqlens)): + attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 + + q = q.transpose(0, 1) + k = k.transpose(0, 1) + v = v.transpose(0, 1) + attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim) + attn_weights = attn_weights + attention_mask + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) + attn_output = torch.matmul(attn_weights, v) + attn_output = attn_output.transpose(0, 1) + attn_output = attn_output.reshape(seq_length, -1) + attn_output = self.proj(attn_output) + return attn_output + + +class Qwen2_5OmniVisionFlashAttention2(nn.Module): + def __init__(self, dim: int, num_heads: int = 16) -> None: + super().__init__() + self.num_heads = num_heads + self.q = nn.Linear(dim, dim, bias=True) + self.k = nn.Linear(dim, dim, bias=True) + self.v = nn.Linear(dim, dim, bias=True) + self.proj = nn.Linear(dim, dim) + + def _apply_rotary_pos_emb_flashatt(self, tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + tensor_ = tensor.float() + cos = freqs.cos() # .type_as(tensor_) + sin = freqs.sin() # .type_as(tensor_) + output = apply_rotary_emb(tensor_, cos, sin).type_as(tensor) + return output + + def forward( + self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + q = self.q(hidden_states).reshape(seq_length, self.num_heads, -1) + k = self.k(hidden_states).reshape(seq_length, self.num_heads, -1) + v = self.v(hidden_states).reshape(seq_length, self.num_heads, -1) + q = self._apply_rotary_pos_emb_flashatt(q.unsqueeze(0), rotary_pos_emb).squeeze(0) + k = self._apply_rotary_pos_emb_flashatt(k.unsqueeze(0), rotary_pos_emb).squeeze(0) + + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape( + seq_length, -1 + ) + attn_output = self.proj(attn_output) + return attn_output + + +class Qwen2_5OmniVisionSdpaAttention(nn.Module): + def __init__(self, dim: int, num_heads: int = 16) -> None: + super().__init__() + self.num_heads = num_heads + self.q = nn.Linear(dim, dim, bias=True) + self.k = nn.Linear(dim, dim, bias=True) + self.v = nn.Linear(dim, dim, bias=True) + self.proj = nn.Linear(dim, dim) + + def forward( + self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + q = self.q(hidden_states).reshape(seq_length, self.num_heads, -1) + k = self.k(hidden_states).reshape(seq_length, self.num_heads, -1) + v = self.v(hidden_states).reshape(seq_length, self.num_heads, -1) + q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) + k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) + + attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool) + for i in range(1, len(cu_seqlens)): + attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True + q = q.transpose(0, 1) + k = k.transpose(0, 1) + v = v.transpose(0, 1) + attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0) + attn_output = attn_output.transpose(0, 1) + attn_output = attn_output.reshape(seq_length, -1) + attn_output = self.proj(attn_output) + return attn_output + + +class Qwen2_5OmniMLP(nn.Module): + def __init__(self, config, bias: bool = False): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_state): + return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + + +QWEN2_5_OMNI_VISION_ATTENTION_CLASSES = { + "eager": Qwen2_5OmniVisionAttention, + "flash_attention_2": Qwen2_5OmniVisionFlashAttention2, + "sdpa": Qwen2_5OmniVisionSdpaAttention, +} + + +class Qwen2_5OmniVisionBlock(nn.Module): + def __init__(self, config: Qwen2_5OmniVisionEncoderConfig) -> None: + super().__init__() + self.norm1 = Qwen2RMSNorm(config.hidden_size, eps=1e-6) + self.norm2 = Qwen2RMSNorm(config.hidden_size, eps=1e-6) + self.attn = QWEN2_5_OMNI_VISION_ATTENTION_CLASSES[config._attn_implementation]( + config.hidden_size, num_heads=config.num_heads + ) + self.mlp = Qwen2_5OmniMLP(config, bias=True) + + def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor: + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb + ) + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + return hidden_states + + +class Qwen2_5_VisionPatchEmbed(nn.Module): + def __init__( + self, + patch_size: int = 14, + temporal_patch_size: int = 2, + in_channels: int = 3, + embed_dim: int = 1152, + ) -> None: + super().__init__() + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.in_channels = in_channels + self.embed_dim = embed_dim + + kernel_size = [temporal_patch_size, patch_size, patch_size] + self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + target_dtype = self.proj.weight.dtype + hidden_states = hidden_states.view( + -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size + ) + hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim) + return hidden_states + + +class Qwen2_5_VisionRotaryEmbedding(nn.Module): + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, seqlen: int) -> torch.Tensor: + seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = torch.outer(seq, self.inv_freq) + return freqs + + +class Qwen2_5OmniPatchMerger(nn.Module): + def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None: + super().__init__() + self.hidden_size = context_dim * (spatial_merge_size**2) + self.ln_q = Qwen2RMSNorm(context_dim, eps=1e-6) + self.mlp = nn.Sequential( + nn.Linear(self.hidden_size, self.hidden_size), + nn.GELU(), + nn.Linear(self.hidden_size, dim), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.mlp(self.ln_q(x).view(-1, self.hidden_size)) + return x + + +class Qwen2_5OmniVisionEncoder(Qwen2_5OmniPreTrainedModel): + config_class = Qwen2_5OmniVisionEncoderConfig + _no_split_modules = ["Qwen2_5OmniVisionBlock"] + + def __init__(self, config: Qwen2_5OmniVisionEncoderConfig, *inputs, **kwargs) -> None: + super().__init__(config, *inputs, **kwargs) + self.spatial_merge_size = config.spatial_merge_size + self.patch_size = config.patch_size + self.fullatt_block_indexes = config.fullatt_block_indexes + self.window_size = config.window_size + self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size + + self.patch_embed = Qwen2_5_VisionPatchEmbed( + patch_size=config.patch_size, + temporal_patch_size=config.temporal_patch_size, + in_channels=config.in_channels, + embed_dim=config.hidden_size, + ) + + head_dim = config.hidden_size // config.num_heads + self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2) + self.blocks = nn.ModuleList([Qwen2_5OmniVisionBlock(config) for _ in range(config.depth)]) + self.merger = Qwen2_5OmniPatchMerger( + dim=config.out_hidden_size, + context_dim=config.hidden_size, + spatial_merge_size=config.spatial_merge_size, + ) + self.gradient_checkpointing = False + + def rot_pos_emb(self, grid_thw): + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + + def get_window_index(self, grid_thw): + window_index: list = [] + cu_window_seqlens: list = [0] + window_index_id = 0 + vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size + + for grid_t, grid_h, grid_w in grid_thw: + llm_grid_h, llm_grid_w = ( + grid_h // self.spatial_merge_size, + grid_w // self.spatial_merge_size, + ) + index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w) + pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size + pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size + num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size + num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size + index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100) + index_padded = index_padded.reshape( + grid_t, + num_windows_h, + vit_merger_window_size, + num_windows_w, + vit_merger_window_size, + ) + index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( + grid_t, + num_windows_h * num_windows_w, + vit_merger_window_size, + vit_merger_window_size, + ) + seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) + index_padded = index_padded.reshape(-1) + index_new = index_padded[index_padded != -100] + window_index.append(index_new + window_index_id) + cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1] + cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) + window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() + window_index = torch.cat(window_index, dim=0) + + return window_index, cu_window_seqlens + + def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor: + """ + Args: + hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`): + The final hidden states of the model. + grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): + The temporal, height and width of feature shape of each image in LLM. + + Returns: + `torch.Tensor`: hidden_states. + """ + hidden_states = self.patch_embed(hidden_states) + rotary_pos_emb = self.rot_pos_emb(grid_thw) + + window_index, cu_window_seqlens = self.get_window_index(grid_thw) + cu_window_seqlens = torch.tensor( + cu_window_seqlens, + device=hidden_states.device, + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) + + seq_len, _ = hidden_states.size() + hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + hidden_states = hidden_states[window_index, :, :] + hidden_states = hidden_states.reshape(seq_len, -1) + rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + rotary_pos_emb = rotary_pos_emb[window_index, :, :] + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + dim=0, + # Select dtype based on the following factors: + # - FA2 requires that cu_seqlens_q must have dtype int32 + # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw + # See https://github.com/huggingface/transformers/pull/34852 for more information + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + # Modification here + for layer_num, blk in enumerate(self.blocks): + if layer_num in self.fullatt_block_indexes: + cu_seqlens_now = cu_seqlens + else: + cu_seqlens_now = cu_window_seqlens + if self.gradient_checkpointing and self.training: + hidden_states = self._gradient_checkpointing_func( + blk.__call__, hidden_states, cu_seqlens_now, rotary_pos_emb + ) + else: + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens_now, + rotary_pos_emb=rotary_pos_emb, + ) + hidden_states = self.merger(hidden_states) + reverse_indices = torch.argsort(window_index) + hidden_states = hidden_states[reverse_indices, :] + + return hidden_states + + +class Qwen2_5OmniRotaryEmbedding(nn.Module): + def __init__(self, config: Qwen2_5OmniThinkerConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + # In contrast to other models, Qwen2_5Omni has different position ids for the grids + # So we expand the inv_freq to shape (3, ...) + inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) + position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1): + """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). + + Explanation: + Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding + sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For + vision embedding part, we apply rotary position embedding on temporal, height and width dimension separately. + Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding. + For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal, + height and width) of text embedding is always the same, so the text embedding rotary position embedding has no + difference with modern LLMs. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + mrope_section(`List(int)`): + Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + mrope_section = mrope_section * 2 + cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze( + unsqueeze_dim + ) + sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze( + unsqueeze_dim + ) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class Qwen2_5OmniAttention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + """ + + def __init__(self, config: Qwen2_5OmniConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.is_causal = True + self.attention_dropout = config.attention_dropout + self.rope_scaling = config.rope_scaling + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + self.rotary_emb = Qwen2_5OmniRotaryEmbedding(config=config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] + ) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # Fix precision issues in Qwen2-VL float16 inference + # Replace inf values with zeros in attention weights to prevent NaN propagation + if query_states.dtype == torch.float16: + attn_weights = torch.where(torch.isinf(attn_weights), torch.zeros_like(attn_weights), attn_weights) + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class Qwen2MLP(nn.Module): + def __init__(self, config, bias: bool = False): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_state): + return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + + +class Qwen2_5OmniFlashAttention2(Qwen2_5OmniAttention): + """ + Qwen2_5Omni flash attention module, following Qwen2_5Omni attention module. This module inherits from `Qwen2_5OmniAttention` + as the weights of the module stays untouched. The only required change would be on the forward pass + where it needs to correctly call the public API of flash attention and deal with padding tokens + in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom + config.max_window_layers layers. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + ): + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + # Because the input can be padded, the absolute sequence length depends on the max position id. + cos, sin = position_embeddings + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] + ) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + dropout_rate = 0.0 if not self.training else self.attention_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + if ( + self.config.use_sliding_window + and getattr(self.config, "sliding_window", None) is not None + and self.layer_idx >= self.config.max_window_layers + ): + sliding_window = self.config.sliding_window + else: + sliding_window = None + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + sliding_window=sliding_window, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class Qwen2_5OmniSdpaAttention(Qwen2_5OmniAttention): + """ + Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from Qwen2Attention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "Qwen2_5OmniModel is using Qwen2_5OmniSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] + ) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +QWEN2_5_OMNI_ATTENTION_CLASSES = { + "eager": Qwen2_5OmniAttention, + "flash_attention_2": Qwen2_5OmniFlashAttention2, + "sdpa": Qwen2_5OmniSdpaAttention, +} + + +class Qwen2_5OmniDecoderLayer(nn.Module): + def __init__(self, config: Qwen2_5OmniTextConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + if config.use_sliding_window and config._attn_implementation != "flash_attention_2": + logger.warning_once( + f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " + "unexpected results may be encountered." + ) + self.self_attn = QWEN2_5_OMNI_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) + + self.mlp = Qwen2MLP(config) + self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +QWEN2_5OMNI_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`{config_class}`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Qwen2.5OmniThinker Model outputting raw hidden-states without any specific head on top.", + QWEN2_5OMNI_START_DOCSTRING.format(config_class="Qwen2_5OmniTextConfig"), +) +class Qwen2_5OmniThinkerTextModel(Qwen2_5OmniPreTrainedModel): + config_class = Qwen2_5OmniTextConfig + _no_split_modules = ["Qwen2_5OmniDecoderLayer"] + + def __init__(self, config: Qwen2_5OmniTextConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Qwen2_5OmniDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self._attn_implementation = config._attn_implementation + self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen2_5OmniRotaryEmbedding(config=config) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # torch.jit.trace() doesn't support cache objects in the output + if use_cache and past_key_values is None and not torch.jit.is_tracing(): + past_key_values = DynamicCache() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + # the hard coded `3` is for temporal, height and width. + if position_ids is None: + position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) + elif position_ids.dim() == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, "BlockMask"], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and past_key_values is not None: + is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Qwen25OmniThinkerText. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if ( + self.config._attn_implementation == "sdpa" + and not (using_static_cache or using_sliding_window_cache) + and not output_attentions + ): + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + # SlidingWindowCache or StaticCache + if using_sliding_window_cache or using_static_cache: + target_length = past_key_values.get_max_cache_shape() + # DynamicCache or no cache + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + config=self.config, + past_key_values=past_key_values, + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + config: Qwen2_5OmniConfig, + past_key_values: Cache, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to place the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + config (`Qwen25OmniThinkerTextConfig`): + The model's configuration class + past_key_values (`Cache`): + The cache class that is being used currently to generate + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + if config.get_text_config().sliding_window is not None: + # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also + # the check is needed to verify is current checkpoint was trained with sliding window or not + if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: + sliding_attend_mask = torch.arange(target_length, device=device) <= ( + cache_position.reshape(-1, 1) - config.get_text_config().sliding_window + ) + diagonal_attend_mask.bitwise_or_(sliding_attend_mask) + causal_mask *= diagonal_attend_mask + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.shape[-1] > target_length: + attention_mask = attention_mask[:, :target_length] + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + return causal_mask + + +QWEN2_5OMNITHINKER_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + input_features (`torch.FloatTensor` of shape `(batch_size, feature_size, feature_sequence_length)`): + Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by + loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via + the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the + [`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a + tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size), *optional*): + The tensors corresponding to the input images. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`SiglipImageProcessor.__call__`] for details ([]`NewTaskModelProcessor`] uses + [`SiglipImageProcessor`] for processing images). + pixel_values_videos(`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size), *optional*): + The tensors corresponding to the input videos. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`SiglipImageProcessor.__call__`] for details ([]`NewTaskModelProcessor`] uses + [`SiglipImageProcessor`] for processing videos). + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + feature_attention_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`, *optional*): + Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + audio_feature_lengths (`torch.LongTensor` of shape `(num_audios)`, *optional*): + The length of feature shape of each audio in LLM. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + past_key_values (`list(torch.FloatTensor)`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + """The Qwen2.5OmniThinker model which consists of a audio backbone and a language model.""", + QWEN2_5OMNI_START_DOCSTRING.format(config_class="Qwen2_5OmniThinkerConfig"), +) +class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForConditionalGeneration, GenerationMixin): + config_class = Qwen2_5OmniThinkerConfig + base_model_prefix = "thinker" + _no_split_modules = ["Qwen2_5OmniAudioEncoder", "Qwen2_5OmniVisionEncoder"] + + def __init__(self, config: Qwen2_5OmniThinkerConfig): + super().__init__(config) + self.audio_tower = Qwen2_5OmniAudioEncoder._from_config( + config.audio_config, attn_implementation=config._attn_implementation + ) + + self.visual = Qwen2_5OmniVisionEncoder._from_config( + config.vision_config, attn_implementation=config._attn_implementation + ) + + self.vocab_size = config.text_config.vocab_size + self.model = Qwen2_5OmniThinkerTextModel._from_config( + config.text_config, attn_implementation=config._attn_implementation + ) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 + self.spatial_merge_size = config.vision_config.spatial_merge_size + self.rope_deltas = None + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + @add_start_docstrings_to_model_forward(QWEN2_5OMNITHINKER_INPUTS_DOCSTRING) + @replace_return_docstrings( + output_type=Qwen2_5OmniThinkerCausalLMOutputWithPast, config_class="Qwen2_5OmniThinkerConfig" + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + input_features: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + feature_attention_mask: Optional[torch.Tensor] = None, + audio_feature_lengths: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + use_audio_in_video: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + video_second_per_grid: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, Qwen2_5OmniThinkerCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from io import BytesIO + >>> from urllib.request import urlopen + >>> import librosa + >>> from qwen_vl_utils import process_vision_info + >>> from transformers import Qwen2_5OmniProcessor, Qwen2_5OmniThinkerForConditionalGeneration + + >>> thinker = Qwen2_5OmniThinkerForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-Omni-7B") + >>> processor = Qwen2_5OmniProcessor.from_pretrained("Qwen/Qwen2.5-Omni-7B") + + >>> conversations = [ + >>> {'role': 'system', 'content': 'You are a helpful voice chat bot, and please respond to me in a casual conversation manner using random voice.'}, + >>> {"role": "user", "content": [ + >>> {"type": "image", "image_url": "https://www.ilankelman.org/stopsigns/australia.jpg"}, + >>> {"type": "audio", "audio_url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/glass-breaking-151256.mp3"}, + >>> ]}, + >>> ] + + >>> text = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False) + >>> audios = [ librosa.load(BytesIO(urlopen( conversations[1]['content'][1]['audio_url'] ).read()), sr=self.processor.feature_extractor.sampling_rate) ] + >>> images, videos = process_vision_info(conversations) + >>> inputs = processor(text=text, audios=audios, images=images, videos=videos, return_tensors="pt", padding=True) + + >>> # Generate + >>> inputs['use_audio_in_video'] = `True` or `False` + >>> generation = thinker.generate(**inputs, max_new_tokens=2048) + >>> generate_ids = generation[:, inputs.input_ids.size(1):] + + >>> response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if feature_attention_mask is not None: + audio_feature_lengths = torch.sum(feature_attention_mask, dim=1) + input_features = input_features.permute(0, 2, 1)[feature_attention_mask.bool()].permute(1, 0) + else: + audio_feature_lengths = None + if attention_mask is not None and position_ids is None: + if ( + cache_position is None + or (cache_position is not None and cache_position[0] == 0) + or self.rope_deltas is None + ): + delta0 = (1 - attention_mask).sum(dim=-1).unsqueeze(1) + position_ids, rope_deltas = self.get_rope_index( + input_ids, + image_grid_thw, + video_grid_thw, + attention_mask, + use_audio_in_video, + audio_feature_lengths, + video_second_per_grid, + ) + rope_deltas = rope_deltas - delta0 + self.rope_deltas = rope_deltas + else: + batch_size, seq_length = input_ids.shape + delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0 + position_ids = torch.arange(seq_length, device=input_ids.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + + if inputs_embeds is None: + # 1. Extract the input embeddings + inputs_embeds = self.get_input_embeddings()(input_ids) + + # 2. Merge text , audios , image and video + if input_ids is not None and input_ids.shape[1] != 1: # Prefill stage + if input_features is not None: + audio_feat_lengths, audio_output_lengths = self.audio_tower._get_feat_extract_output_lengths( + audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1) + ) + feature_lens = ( + audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1) + ) + audio_outputs = self.audio_tower( + input_features, + feature_lens=feature_lens, + aftercnn_lens=audio_feat_lengths, + ) + audio_features = audio_outputs.last_hidden_state + if audio_features.shape[0] != sum(audio_output_lengths.tolist()): + raise ValueError("length of audio_features should match audio_output_lengths") + audio_mask = ( + (input_ids == self.config.audio_token_id) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(audio_mask, audio_features) + + if pixel_values is not None: + pixel_values = pixel_values.type(self.visual.dtype) + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + image_mask = ( + (input_ids == self.config.image_token_id) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if pixel_values_videos is not None: + pixel_values_videos = pixel_values_videos.type(self.visual.dtype) + video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) + video_mask = ( + (input_ids == self.config.video_token_id) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + if attention_mask is not None: + attention_mask = attention_mask.to(inputs_embeds.device) + + outputs = self.model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.get_text_config().vocab_size + ) + + if not return_dict: + output = (logits,) + outputs + return (loss,) + output if loss is not None else output + + return Qwen2_5OmniThinkerCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=self.rope_deltas, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + pixel_values_videos=None, + image_grid_thw=None, + video_grid_thw=None, + input_features=None, + feature_attention_mask=None, + use_audio_in_video=False, + video_second_per_grid=None, + **kwargs, + ): + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + position_ids=position_ids, + use_cache=use_cache, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + input_features=input_features, + feature_attention_mask=feature_attention_mask, + use_audio_in_video=use_audio_in_video, + video_second_per_grid=video_second_per_grid, + **kwargs, + ) + + model_inputs["position_ids"] = None + + if cache_position[0] != 0: + model_inputs["pixel_values"] = None + model_inputs["pixel_values_videos"] = None + + return model_inputs + + +############################ +# Start Talker # +############################ + + +@dataclass +class Qwen2_5OmniTalkerCausalLMOutputWithPast(ModelOutput): + """ + Base class for Qwen2.5OmniTalker causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[List[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + rope_deltas: Optional[torch.LongTensor] = None + thinker_reply_part: torch.FloatTensor = None + + +@add_start_docstrings( + "The bare Qwen2.5OmniTalker Model outputting raw hidden-states without any specific head on top.", + QWEN2_5OMNI_START_DOCSTRING.format(config_class="Qwen2_5OmniTalkerConfig"), +) +class Qwen2_5OmniTalkerModel(Qwen2_5OmniPreTrainedModel): + config_class = Qwen2_5OmniTalkerConfig + _no_split_modules = ["Qwen2_5OmniTalkerDecoderLayer"] + + def __init__(self, config: Qwen2_5OmniTalkerConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.embed_tokens = nn.Embedding(config.vocab_size, config.embedding_size, self.padding_idx) + self.layers = nn.ModuleList( + [Qwen2_5OmniDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self._attn_implementation = config._attn_implementation + self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen2_5OmniRotaryEmbedding(config=config) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # torch.jit.trace() doesn't support cache objects in the output + if use_cache and past_key_values is None and not torch.jit.is_tracing(): + past_key_values = DynamicCache() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + # the hard coded `3` is for temporal, height and width. + if position_ids is None: + position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) + elif position_ids.dim() == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, "BlockMask"], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and past_key_values is not None: + is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Qwen25OmniTalker. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if ( + self.config._attn_implementation == "sdpa" + and not (using_static_cache or using_sliding_window_cache) + and not output_attentions + ): + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + # SlidingWindowCache or StaticCache + if using_sliding_window_cache or using_static_cache: + target_length = past_key_values.get_max_cache_shape() + # DynamicCache or no cache + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + config=self.config, + past_key_values=past_key_values, + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + config: Qwen2_5OmniConfig, + past_key_values: Cache, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to place the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + config (`Qwen25OmniTalkerConfig`): + The model's configuration class + past_key_values (`Cache`): + The cache class that is being used currently to generate + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + if config.get_text_config().sliding_window is not None: + # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also + # the check is needed to verify is current checkpoint was trained with sliding window or not + if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: + sliding_attend_mask = torch.arange(target_length, device=device) <= ( + cache_position.reshape(-1, 1) - config.get_text_config().sliding_window + ) + diagonal_attend_mask.bitwise_or_(sliding_attend_mask) + causal_mask *= diagonal_attend_mask + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.shape[-1] > target_length: + attention_mask = attention_mask[:, :target_length] + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + return causal_mask + + +class Qwen2_5OmniTalkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForConditionalGeneration, GenerationMixin): + config_class = Qwen2_5OmniTalkerConfig + base_model_prefix = "talker" + + def __init__(self, config: Qwen2_5OmniTalkerConfig): + super().__init__(config) + + self.thinker_to_talker_proj = nn.Linear(config.embedding_size, config.hidden_size) + + self.model = Qwen2_5OmniTalkerModel(config) + self.codebook_size = config.vocab_size + self.codec_head = nn.Linear(config.hidden_size, self.codebook_size, bias=False) + + self.codec_bos_token = config.tts_codec_start_token_id + self.codec_eos_token = config.tts_codec_end_token_id + self.codec_pad_token = config.tts_codec_pad_token_id + self.codec_mask_token = config.tts_codec_mask_token_id + + self.text_bos_token = config.tts_text_start_token_id + self.text_eos_token = config.tts_text_end_token_id + self.text_pad_token = config.tts_text_pad_token_id + + self.spatial_merge_size = self.config.spatial_merge_size + self.rope_deltas = None + + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + @add_start_docstrings_to_model_forward(QWEN2_5OMNITHINKER_INPUTS_DOCSTRING) + @replace_return_docstrings( + output_type=Qwen2_5OmniTalkerCausalLMOutputWithPast, config_class="Qwen2_5OmniTalkerConfig" + ) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + thinker_reply_part: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + input_text_ids: Optional[torch.LongTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + use_audio_in_video: Optional[bool] = None, + audio_feature_lengths: Optional[torch.LongTensor] = None, + video_second_per_grid: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, Qwen2_5OmniTalkerCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from io import BytesIO + >>> from urllib.request import urlopen + >>> import librosa + >>> from transformers import AutoProcessor, Qwen2_5OmniTalkerForConditionalGeneration + + >>> model = Qwen2_5OmniTalkerForConditionalGeneration.from_pretrained("Qwen/Qwen2-Audio-7B") + >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2-Audio-7B") + + >>> prompt = "<|audio_bos|><|AUDIO|><|audio_eos|>Generate the caption in English:" + >>> url = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/glass-breaking-151256.mp3" + >>> audio, _ = librosa.load(BytesIO(urlopen(url).read()), sr=self.processor.feature_extractor.sampling_rate) + + >>> inputs = processor(text=prompt, audios=audio, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(**inputs, max_length=30) + >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Generate the caption in English: Glass is breaking." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if attention_mask is not None and position_ids is None: + if ( + cache_position is None + or (cache_position is not None and cache_position[0] == 0) + or self.rope_deltas is None + ): + position_ids, rope_deltas = self.get_rope_index( + input_text_ids, + image_grid_thw, + video_grid_thw, + attention_mask, + use_audio_in_video, + audio_feature_lengths, + video_second_per_grid, + ) + + inputs_embeds[:, -1, :] += self.get_input_embeddings()( + torch.tensor([self.codec_bos_token], dtype=torch.long, device=inputs_embeds.device) + ) + inputs_embeds[:, -2, :] += self.get_input_embeddings()( + torch.tensor([self.codec_pad_token], dtype=torch.long, device=inputs_embeds.device) + ) + self.rope_deltas = rope_deltas + + else: + batch_size, seq_length = input_ids.shape + delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0 + position_ids = torch.arange(seq_length, device=input_ids.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + + if inputs_embeds is None: + # 1. Inference tokens after second token + codec_embeds = self.get_input_embeddings()(input_ids) + inputs_embeds = codec_embeds + thinker_reply_part[:, :1, :] + if thinker_reply_part.shape[1] > 1: + thinker_reply_part = thinker_reply_part[:, 1:, :] + + talker_lm_input = self.thinker_to_talker_proj(inputs_embeds) + + if attention_mask is not None: + attention_mask = attention_mask.to(inputs_embeds.device) + + outputs = self.model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=talker_lm_input, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.codec_head(hidden_states) + logits = logits.float() + + loss = None + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return Qwen2_5OmniTalkerCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=hidden_states, + attentions=outputs.attentions, + rope_deltas=self.rope_deltas, + thinker_reply_part=thinker_reply_part, + ) + + def _get_initial_cache_position(self, input_ids, model_kwargs): + # Talker needs to calculate cache_position with input_ids, so pop inputs_embeds temporarily + inputs_embeds = model_kwargs.pop("inputs_embeds") + model_kwargs = super()._get_initial_cache_position(input_ids, model_kwargs) + model_kwargs["inputs_embeds"] = inputs_embeds + return model_kwargs + + # prepare inputs for talker lm generation + def prepare_inputs_for_generation( + self, + input_ids, + input_text_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + thinker_reply_part=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + pixel_values_videos=None, + image_grid_thw=None, + video_grid_thw=None, + input_audio_features=None, + audio_feature_attention_mask=None, + audio_feature_lengths=None, + use_audio_in_video=False, + video_second_per_grid=None, + **kwargs, + ): + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values, + attention_mask, + inputs_embeds, + cache_position, + use_cache=use_cache, + thinker_reply_part=thinker_reply_part, + input_text_ids=input_text_ids, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + use_audio_in_video=use_audio_in_video, + audio_feature_lengths=audio_feature_lengths, + video_second_per_grid=video_second_per_grid, + **kwargs, + ) + + model_inputs["position_ids"] = None + + return model_inputs + + def _update_model_kwargs_for_generation( + self, + outputs: ModelOutput, + model_kwargs: Dict[str, Any], + is_encoder_decoder: bool = False, + num_new_tokens: int = 1, + ) -> Dict[str, Any]: + model_kwargs = super()._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder, num_new_tokens + ) + + if getattr(outputs, "thinker_reply_part", None) is not None: + model_kwargs["thinker_reply_part"] = outputs.thinker_reply_part + + return model_kwargs + + +############################ +# Start Token2Wav # +############################ + + +# Using custom RoPE, will use LlamaRotaryEmbedding next version +class Qwen2_5OmniDiTRotaryEmbedding(nn.Module): + def __init__(self, dim, base=10000): + super().__init__() + + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) + + def forward(self, x): + batch_size, seq_len = x.shape[0], x.shape[1] + t = torch.arange(seq_len, device=x.device) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = t.unsqueeze(1).float() @ self.inv_freq.unsqueeze(0).float() + freqs = torch.stack((freqs, freqs), dim=-1) + freqs = freqs.reshape(*freqs.shape[:-2], -1) + freqs = freqs.repeat(batch_size, *([1] * freqs.dim())) + cos = freqs.cos() + sin = freqs.sin() + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class TimeDelayNetBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + dilation, + ): + super().__init__() + self.conv = nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + dilation=dilation, + padding="same", + padding_mode="reflect", + ) + self.activation = nn.ReLU() + + def forward(self, hidden_states: torch.Tensor): + return self.activation(self.conv(hidden_states)) + + +class Res2NetBlock(torch.nn.Module): + def __init__(self, in_channels, out_channels, scale=8, kernel_size=3, dilation=1): + super().__init__() + + in_channel = in_channels // scale + hidden_channel = out_channels // scale + + self.blocks = nn.ModuleList( + [ + TimeDelayNetBlock( + in_channel, + hidden_channel, + kernel_size=kernel_size, + dilation=dilation, + ) + for i in range(scale - 1) + ] + ) + self.scale = scale + + def forward(self, hidden_states): + outputs = [] + for i, hidden_part in enumerate(torch.chunk(hidden_states, self.scale, dim=1)): + if i == 0: + output_part = hidden_part + elif i == 1: + output_part = self.blocks[i - 1](hidden_part) + else: + output_part = self.blocks[i - 1](hidden_part + output_part) + outputs.append(output_part) + output = torch.cat(outputs, dim=1) + return output + + +class SqueezeExcitationBlock(nn.Module): + def __init__(self, in_channels, se_channels, out_channels): + super().__init__() + + self.conv1 = nn.Conv1d( + in_channels=in_channels, + out_channels=se_channels, + kernel_size=1, + padding="same", + padding_mode="reflect", + ) + self.relu = nn.ReLU(inplace=True) + self.conv2 = nn.Conv1d( + in_channels=se_channels, + out_channels=out_channels, + kernel_size=1, + padding="same", + padding_mode="reflect", + ) + self.sigmoid = nn.Sigmoid() + + def forward(self, hidden_states): + hidden_states_mean = hidden_states.mean(dim=2, keepdim=True) + + hidden_states_mean = self.relu(self.conv1(hidden_states_mean)) + hidden_states_mean = self.sigmoid(self.conv2(hidden_states_mean)) + + return hidden_states * hidden_states_mean + + +class AttentiveStatisticsPooling(nn.Module): + """This class implements an attentive statistic pooling layer for each channel. + It returns the concatenated mean and std of the input tensor. + """ + + def __init__(self, channels, attention_channels=128): + super().__init__() + + self.eps = 1e-12 + self.tdnn = TimeDelayNetBlock(channels * 3, attention_channels, 1, 1) + self.tanh = nn.Tanh() + self.conv = nn.Conv1d( + in_channels=attention_channels, + out_channels=channels, + kernel_size=1, + padding="same", + padding_mode="reflect", + ) + + def _length_to_mask(self, length, max_len=None, dtype=None, device=None): + """Creates a binary mask for each sequence. + + Reference: https://discuss.pytorch.org/t/how-to-generate-variable-length-mask/23397/3 + + Arguments + --------- + length : torch.LongTensor + Containing the length of each sequence in the batch. Must be 1D. + max_len : int + Max length for the mask, also the size of the second dimension. + dtype : torch.dtype, default: None + The dtype of the generated mask. + device: torch.device, default: None + The device to put the mask variable. + + Returns + ------- + mask : tensor + The binary mask. + """ + + if max_len is None: + max_len = length.max().long().item() # using arange to generate mask + mask = torch.arange(max_len, device=length.device, dtype=length.dtype).expand( + len(length), max_len + ) < length.unsqueeze(1) + + mask = torch.as_tensor(mask, dtype=dtype, device=device) + return mask + + def _compute_statistics(self, x, m, dim=2): + mean = (m * x).sum(dim) + std = torch.sqrt((m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(self.eps)) + return mean, std + + def forward(self, hidden_states): + seq_length = hidden_states.shape[-1] + lengths = torch.ones(hidden_states.shape[0], device=hidden_states.device) + + # Make binary mask of shape [N, 1, L] + mask = self._length_to_mask( + lengths * seq_length, max_len=seq_length, dtype=hidden_states.dtype, device=hidden_states.device + ) + mask = mask.unsqueeze(1) + + # Expand the temporal context of the pooling layer by allowing the + # self-attention to look at global properties of the utterance. + total = mask.sum(dim=2, keepdim=True) + + mean, std = self._compute_statistics(hidden_states, mask / total) + mean = mean.unsqueeze(2).repeat(1, 1, seq_length) + std = std.unsqueeze(2).repeat(1, 1, seq_length) + attention = torch.cat([hidden_states, mean, std], dim=1) + + # Apply layers + attention = self.conv(self.tanh(self.tdnn(attention))) + + # Filter out zero-paddings + attention = attention.masked_fill(mask == 0, float("-inf")) + + attention = F.softmax(attention, dim=2) + mean, std = self._compute_statistics(hidden_states, attention) + # Append mean and std of the batch + pooled_stats = torch.cat((mean, std), dim=1) + pooled_stats = pooled_stats.unsqueeze(2) + + return pooled_stats + + +class SqueezeExcitationRes2NetBlock(nn.Module): + """An implementation of building block in ECAPA-TDNN, i.e., + TDNN-Res2Net-TDNN-SqueezeExcitationBlock. + """ + + def __init__( + self, + in_channels, + out_channels, + res2net_scale=8, + se_channels=128, + kernel_size=1, + dilation=1, + ): + super().__init__() + self.out_channels = out_channels + self.tdnn1 = TimeDelayNetBlock( + in_channels, + out_channels, + kernel_size=1, + dilation=1, + ) + self.res2net_block = Res2NetBlock(out_channels, out_channels, res2net_scale, kernel_size, dilation) + self.tdnn2 = TimeDelayNetBlock( + out_channels, + out_channels, + kernel_size=1, + dilation=1, + ) + self.se_block = SqueezeExcitationBlock(out_channels, se_channels, out_channels) + + def forward(self, hidden_state): + residual = hidden_state + + hidden_state = self.tdnn1(hidden_state) + hidden_state = self.res2net_block(hidden_state) + hidden_state = self.tdnn2(hidden_state) + hidden_state = self.se_block(hidden_state) + + return hidden_state + residual + + +class ECAPA_TimeDelayNet(torch.nn.Module): + """An implementation of the speaker embedding model in a paper. + "ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in + TDNN Based Speaker Verification" (https://arxiv.org/abs/2005.07143). + """ + + def __init__(self, config: Qwen2_5OmniDiTConfig): + super().__init__() + if len(config.enc_channels) != len(config.enc_kernel_sizes) or len(config.enc_channels) != len( + config.enc_dilations + ): + raise ValueError("enc_channels, enc_kernel_sizes and enc_dilations should have same length") + self.channels = config.enc_channels + self.blocks = nn.ModuleList() + + # The initial TDNN layer + self.blocks.append( + TimeDelayNetBlock( + config.mel_dim, + config.enc_channels[0], + config.enc_kernel_sizes[0], + config.enc_dilations[0], + ) + ) + + # SE-Res2Net layers + for i in range(1, len(config.enc_channels) - 1): + self.blocks.append( + SqueezeExcitationRes2NetBlock( + config.enc_channels[i - 1], + config.enc_channels[i], + res2net_scale=config.enc_res2net_scale, + se_channels=config.enc_se_channels, + kernel_size=config.enc_kernel_sizes[i], + dilation=config.enc_dilations[i], + ) + ) + + # Multi-layer feature aggregation + self.mfa = TimeDelayNetBlock( + config.enc_channels[-1], + config.enc_channels[-1], + config.enc_kernel_sizes[-1], + config.enc_dilations[-1], + ) + + # Attentive Statistical Pooling + self.asp = AttentiveStatisticsPooling( + config.enc_channels[-1], + attention_channels=config.enc_attention_channels, + ) + + # Final linear transformation + self.fc = nn.Conv1d( + in_channels=config.enc_channels[-1] * 2, + out_channels=config.enc_dim, + kernel_size=1, + padding="same", + padding_mode="reflect", + ) + + def forward(self, hidden_states): + # Minimize transpose for efficiency + hidden_states = hidden_states.transpose(1, 2) + + hidden_states_list = [] + for layer in self.blocks: + hidden_states = layer(hidden_states) + hidden_states_list.append(hidden_states) + + # Multi-layer feature aggregation + hidden_states = torch.cat(hidden_states_list[1:], dim=1) + hidden_states = self.mfa(hidden_states) + + # Attentive Statistical Pooling + hidden_states = self.asp(hidden_states) + + # Final linear transformation + hidden_states = self.fc(hidden_states) + + hidden_states = hidden_states.squeeze(-1) + return hidden_states + + +class DiTInputEmbedding(nn.Module): + def __init__(self, config: Qwen2_5OmniDiTConfig): + super().__init__() + self.proj = nn.Linear( + config.mel_dim + config.enc_dim + config.enc_emb_dim + config.emb_dim, + config.hidden_size, + ) + self.spk_encoder = ECAPA_TimeDelayNet(config) + + def forward( + self, + hidden_states: torch.Tensor, + speaker_embedding: torch.Tensor, + condition_vector: torch.Tensor, + code_embed: torch.Tensor, + drop_audio_cond: Optional[bool] = False, + code_embed_uncond: Optional[bool] = None, + apply_cfg: Optional[bool] = True, + ): + if apply_cfg: + hidden_states = torch.cat([hidden_states, hidden_states], dim=0) + speaker_embedding = torch.cat([speaker_embedding, torch.zeros_like(speaker_embedding)], dim=0) + condition_vector = torch.cat([condition_vector, torch.zeros_like(condition_vector)], dim=0) + code_embed = torch.cat([code_embed, code_embed_uncond], dim=0) + elif drop_audio_cond: # cfg for cond audio + condition_vector = torch.zeros_like(condition_vector) + speaker_embedding = torch.zeros_like(speaker_embedding) + condition_vector = self.spk_encoder(condition_vector).unsqueeze(1).repeat(1, hidden_states.size(1), 1) + hidden_states = self.proj(torch.cat((hidden_states, condition_vector, code_embed, speaker_embedding), dim=-1)) + + return hidden_states + + +# Transformer backbone using DiT blocks +class DiTCodecEmbedding(nn.Module): + def __init__(self, codec_num_embeds, codec_dim, repeats): + super().__init__() + self.repeats = repeats + self.codec_embed = nn.Embedding(codec_num_embeds + 1, codec_dim) + + def forward(self, code, drop_code=False): + if drop_code: + code = torch.zeros_like(code) + code_embed = self.codec_embed(code) + + code_embed = torch.repeat_interleave(code_embed, repeats=self.repeats, dim=1) + return code_embed + + +# AdaLayerNormZero +# return with modulated x for attn input, and params for later mlp modulation +class Qwen2_5_OmniAdaLayerNormZero(nn.Module): + def __init__(self, dim): + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(dim, dim * 6) + + self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + + def forward(self, hidden_states, emb=None): + emb = self.linear(self.silu(emb)) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1) + + hidden_states = self.norm(hidden_states) * (1 + scale_msa[:, None]) + shift_msa[:, None] + return hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp + + +# AdaLayerNormZero for final layer +# return only with modulated x for attn input, cuz no more mlp modulation +class Qwen2_5_OmniAdaLayerNormZero_Final(nn.Module): + def __init__(self, dim): + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(dim, dim * 2) + + self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + + def forward(self, hidden_states, emb): + emb = self.linear(self.silu(emb)) + scale, shift = torch.chunk(emb, 2, dim=1) + + hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :] + return hidden_states + + +# FeedForward +class DiTMLP(nn.Module): + def __init__(self, dim, mult=4, dropout=0.0): + super().__init__() + inner_dim = int(dim * mult) + + self.ff = nn.ModuleList( + [ + nn.Linear(dim, inner_dim), + nn.GELU(approximate="tanh"), + nn.Dropout(dropout), + nn.Linear(inner_dim, dim), + ] + ) + + def forward(self, hidden_states): + for layer in self.ff: + hidden_states = layer(hidden_states) + return hidden_states + + +# Modified from Llama with a different rotate function, will fixed in next release +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + + def rotate_half_codec(x): + # x = rearrange(x, "... (d r) -> ... d r", r=2) + x = x.reshape(*x.shape[:-1], -1, 2) + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + return x.reshape(*x.shape[:-2], -1) + + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half_codec(q) * sin) + k_embed = (k * cos) + (rotate_half_codec(k) * sin) + return q_embed, k_embed + + +class DiTAttention(nn.Module): + def __init__(self, config: Qwen2_5OmniDiTConfig): + super().__init__() + + self.config = config + self.dim = config.hidden_size + self.heads = config.num_attention_heads + self.inner_dim = config.head_dim * config.num_attention_heads + self.dropout = config.dropout + self._attn_implementation = config._attn_implementation + self.is_causal = False + + self.to_q = nn.Linear(config.hidden_size, self.inner_dim) + self.to_k = nn.Linear(config.hidden_size, self.inner_dim) + self.to_v = nn.Linear(config.hidden_size, self.inner_dim) + + self.to_out = nn.ModuleList([nn.Linear(self.inner_dim, config.hidden_size), nn.Dropout(config.dropout)]) + + def forward( + self, + hidden_states, # noised input x + position_embeddings=None, # rotary position embedding for x + attention_mask=None, + ) -> torch.Tensor: + batch_size = hidden_states.shape[0] + + # `sample` projections. + query = self.to_q(hidden_states) + key = self.to_k(hidden_states) + value = self.to_v(hidden_states) + + # attention + inner_dim = key.shape[-1] + head_dim = inner_dim // self.heads + query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + + # apply rotary position embedding + # Due to training process, only first head is applied with RoPE, will be fixed at next release + cos, sin = position_embeddings + query[:, :1], key[:, :1] = apply_rotary_pos_emb(query[:, :1], key[:, :1], cos, sin) + + attention_interface = ALL_ATTENTION_FUNCTIONS[self._attn_implementation] + attention_weights, _ = attention_interface( + self, + query, + key, + value, + attention_mask=attention_mask, + is_causal=False, + ) + + # mask. e.g. inference got a batch with different target durations, mask out the padding + attention_weights = attention_weights.reshape(batch_size, -1, self.heads * head_dim) + attention_weights = attention_weights.to(query.dtype) + + # linear proj + attention_output = self.to_out[0](attention_weights) + attention_output = self.to_out[1](attention_output) + + return attention_output + + +# time step conditioning embedding +class SinusPositionEmbedding(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, hidden_states, scale=1000): + device = hidden_states.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb) + emb = scale * hidden_states.unsqueeze(1) * emb.unsqueeze(0) + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb.type_as(hidden_states) + + +class DiTTimestepEmbedding(nn.Module): + def __init__(self, dim, freq_embed_dim=256): + super().__init__() + self.time_embed = SinusPositionEmbedding(freq_embed_dim) + self.time_mlp = nn.ModuleList([nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim)]) + + def forward(self, timestep): # noqa: F821 + time_hidden = self.time_embed(timestep) + time_hidden = time_hidden.to(timestep.dtype) + for layer in self.time_mlp: + time_hidden = layer(time_hidden) # b d + return time_hidden + + +class DiTDecoderLayer(nn.Module): + def __init__(self, config: Qwen2_5OmniDiTConfig, look_ahead_block=0, look_backward_block=0): + super().__init__() + self.attn_norm = Qwen2_5_OmniAdaLayerNormZero(config.hidden_size) + + self.attn = DiTAttention(config) + self.look_ahead_block = look_ahead_block + self.look_backward_block = look_backward_block + self.ff_norm = nn.LayerNorm(config.hidden_size, elementwise_affine=False, eps=1e-6) + self.ff = DiTMLP(dim=config.hidden_size, mult=config.ff_mult, dropout=config.dropout) + + def forward( + self, hidden_states, timestep, position_embeddings=None, block_diff=None + ): # x: noised input, t: time embedding + # pre-norm & modulation for attention input + norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(hidden_states, emb=timestep) + + # attention + attn_output = self.attn( + hidden_states=norm, + position_embeddings=position_embeddings, + attention_mask=(block_diff >= -float(self.look_backward_block)) + & (block_diff <= float(self.look_ahead_block)), + ) + + # process attention output for input x + hidden_states = hidden_states + gate_msa.unsqueeze(1) * attn_output + + norm = self.ff_norm(hidden_states) * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + ff_output = self.ff(norm) + hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output + + return hidden_states + + +class SnakeBeta(nn.Module): + """ + A modified Snake function which uses separate parameters for the magnitude of the periodic components + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + References: + - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + """ + + def __init__(self, in_features, alpha=1.0): + super().__init__() + self.in_features = in_features + + # initialize alpha + self.alpha = Parameter(torch.zeros(in_features) * alpha) + self.beta = Parameter(torch.zeros(in_features) * alpha) + + self.no_div_by_zero = 0.000000001 + + def forward(self, hidden_states): + """ + Forward pass of the function. + Applies the function to the input elementwise. + SnakeBeta ∶= x + 1/b * sin^2 (xa) + """ + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] + beta = self.beta.unsqueeze(0).unsqueeze(-1) + alpha = torch.exp(alpha) + beta = torch.exp(beta) + hidden_states = hidden_states + (1.0 / (beta + self.no_div_by_zero)) * torch.pow( + torch.sin(hidden_states * alpha), 2 + ) + + return hidden_states + + +def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): + """Generates a 1D Kaiser-windowed sinc filter. + + Args: + cutoff (float): Normalized cutoff frequency (0 to 0.5). + half_width (float): Transition bandwidth. + kernel_size (int): Number of filter taps. + + Returns: + torch.Tensor: A tensor of shape (1, 1, kernel_size) representing the filter. + """ + is_even = kernel_size % 2 == 0 + half_size = kernel_size // 2 + + # Compute Kaiser window parameters + delta_f = 4 * half_width + attenuation = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 + + if attenuation > 50.0: + beta = 0.1102 * (attenuation - 8.7) + elif attenuation >= 21.0: + beta = 0.5842 * (attenuation - 21) ** 0.4 + 0.07886 * (attenuation - 21.0) + else: + beta = 0.0 + + kaiser_window = torch.kaiser_window(kernel_size, beta=beta, periodic=False, dtype=torch.float32) + + # Compute time indices + if is_even: + time_indices = torch.arange(-half_size, half_size) + 0.5 + else: + time_indices = torch.arange(kernel_size) - half_size + + # Compute sinc filter + if cutoff == 0: + return torch.zeros((1, 1, kernel_size), dtype=torch.float32) # Ensures correct shape + + sinc_filter = torch.sinc(2 * cutoff * time_indices) + normalized_filter = 2 * cutoff * kaiser_window * sinc_filter + + # Normalize to ensure sum = 1 (avoid leakage of constant component) + normalized_filter /= normalized_filter.sum() + + return normalized_filter.view(1, 1, kernel_size) + + +class UpSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + self.ratio = ratio + self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + self.stride = ratio + self.pad = self.kernel_size // ratio - 1 + self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2 + self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2 + + filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size) + self.register_buffer("filter", filter, persistent=False) + + def forward(self, hidden_states): + channels = hidden_states.shape[1] + + hidden_states = F.pad(hidden_states, (self.pad, self.pad), mode="replicate") + hidden_states = self.ratio * F.conv_transpose1d( + hidden_states, self.filter.expand(channels, -1, -1), stride=self.stride, groups=channels + ) + hidden_states = hidden_states[..., self.pad_left : -self.pad_right] + + return hidden_states + + +class DownSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + cutoff = 0.5 / ratio + half_width = 0.6 / ratio + + if cutoff < 0.0: + raise ValueError("Minimum cutoff must be larger than zero.") + if cutoff > 0.5: + raise ValueError("A cutoff above 0.5 does not make sense.") + + self.even = kernel_size % 2 == 0 + self.pad_left = kernel_size // 2 - int(self.even) + self.pad_right = kernel_size // 2 + self.stride = ratio + filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size) + self.register_buffer("filter", filter, persistent=False) + + def forward(self, hidden_states): + channels = hidden_states.shape[1] + hidden_states = F.pad(hidden_states, (self.pad_left, self.pad_right), mode="replicate") + out = F.conv1d(hidden_states, self.filter.expand(channels, -1, -1), stride=self.stride, groups=channels) + return out + + +class TorchActivation1d(nn.Module): + def __init__( + self, + activation, + up_ratio: int = 2, + down_ratio: int = 2, + up_kernel_size: int = 12, + down_kernel_size: int = 12, + ): + super().__init__() + if not callable(activation): + raise ValueError("Activation function must be callable") + self.act = activation + self.upsample = UpSample1d(up_ratio, up_kernel_size) + self.downsample = DownSample1d(down_ratio, down_kernel_size) + + def forward(self, hidden_states): + hidden_states = self.upsample(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.downsample(hidden_states) + + return hidden_states + + +class AMPBlock(torch.nn.Module): + def __init__( + self, + channels, + kernel_size=3, + dilation=(1, 3, 5), + ): + super().__init__() + + self.convs1 = nn.ModuleList( + [ + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=self._get_padding(kernel_size, dilation[0]), + ), + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=self._get_padding(kernel_size, dilation[1]), + ), + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=self._get_padding(kernel_size, dilation[2]), + ), + ] + ) + + self.convs2 = nn.ModuleList( + [ + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=self._get_padding(kernel_size, 1), + ), + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=self._get_padding(kernel_size, 1), + ), + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=self._get_padding(kernel_size, 1), + ), + ] + ) + + self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers + + self.activations = nn.ModuleList( + [TorchActivation1d(activation=SnakeBeta(channels)) for _ in range(self.num_layers)] + ) + + def _get_padding(self, kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + def forward(self, hidden_states): + acts1, acts2 = self.activations[::2], self.activations[1::2] + for conv1, conv2, act1, act2 in zip(self.convs1, self.convs2, acts1, acts2): + residual = hidden_states + hidden_states = act1(hidden_states) + hidden_states = conv1(hidden_states) + hidden_states = act2(hidden_states) + hidden_states = conv2(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +@add_start_docstrings( + "The full Qwen2.5Omni Token2WavBigVGAN model. Which take mel spectrogram as input and predict waveform.", + QWEN2_5OMNI_START_DOCSTRING.format(config_class="Qwen2_5OmniBigVGANConfig"), +) +class Qwen2_5OmniToken2WavBigVGANModel(Qwen2_5OmniPreTrainedModel): + config_class = Qwen2_5OmniBigVGANConfig + + def __init__(self, config: Qwen2_5OmniBigVGANConfig): + super().__init__(config) + self.num_residual_blocks = len(config.resblock_kernel_sizes) + self.num_upsample_layers = len(config.upsample_rates) + + self.conv_pre = nn.Conv1d(config.mel_dim, config.upsample_initial_channel, 7, 1, padding=3) + + # Removing extra ModuleList breaks official state dict + ups = [ + nn.ModuleList( + [ + nn.ConvTranspose1d( + config.upsample_initial_channel // (2**layer_idx), + config.upsample_initial_channel // (2 ** (layer_idx + 1)), + kernel_size, + stride, + padding=(kernel_size - stride) // 2, + ) + ] + ) + for layer_idx, (stride, kernel_size) in enumerate(zip(config.upsample_rates, config.upsample_kernel_sizes)) + ] + self.ups = nn.ModuleList(ups) + + self.resblocks = nn.ModuleList( + [ + AMPBlock(config.upsample_initial_channel // (2 ** (layer_idx + 1)), kernel_size, dilation) + for layer_idx in range(self.num_upsample_layers) + for kernel_size, dilation in zip(config.resblock_kernel_sizes, config.resblock_dilation_sizes) + ] + ) + + self.activation_post = TorchActivation1d( + activation=SnakeBeta(config.upsample_initial_channel // (2**self.num_upsample_layers)) + ) + self.conv_post = nn.Conv1d( + config.upsample_initial_channel // (2**self.num_upsample_layers), 1, 7, 1, padding=3, bias=False + ) + + def normalize_spectrogram(self, spectrogram, max_value, min_db): + return torch.clamp((2 * max_value) * ((spectrogram - min_db) / (-min_db)) - max_value, -max_value, max_value) + + def amplitude_to_db(self, amplitude, min_db_level): + min_level = torch.exp( + torch.tensor(min_db_level / 20.0 * np.log(10), device=amplitude.device, dtype=amplitude.dtype) + ) + return 20 * torch.log10(torch.clamp(amplitude, min=min_level)) + + def process_mel_spectrogram(self, mel_spectrogram): + amplitude_spectrum = torch.exp(mel_spectrogram) + decibel_spectrum = self.amplitude_to_db(amplitude_spectrum, -115) - 20 + return self.normalize_spectrogram(decibel_spectrum, 1, -115) + + def forward(self, mel_spectrogram): + processed_spectrogram = self.process_mel_spectrogram(mel_spectrogram) + hidden_representation = self.conv_pre(processed_spectrogram) + + for layer_index in range(self.num_upsample_layers): + hidden_representation = self.ups[layer_index][0](hidden_representation) + residual_output = sum( + self.resblocks[layer_index * self.num_residual_blocks + block_index](hidden_representation) + for block_index in range(self.num_residual_blocks) + ) + residual_output = residual_output / self.num_residual_blocks + hidden_representation = residual_output + + hidden_representation = self.activation_post(hidden_representation) + output_waveform = self.conv_post(hidden_representation) + return torch.clamp(output_waveform, min=-1.0, max=1.0).squeeze().cpu() + + +class RungeKutta4ODESolver: + def __init__(self, function, initial_value): + self.function = function + self.initial_value = initial_value + + self._one_third = 1 / 3 + self._two_thirds = 2 / 3 + + def _rk4_step(self, function, time_start, time_step, time_end, value_start, function_value_start=None): + k1 = function_value_start if function_value_start is not None else function(time_start, value_start) + k2 = function(time_start + time_step * self._one_third, value_start + time_step * k1 * self._one_third) + k3 = function(time_start + time_step * self._two_thirds, value_start + time_step * (k2 - k1 * self._one_third)) + k4 = function(time_end, value_start + time_step * (k1 - k2 + k3)) + return (k1 + 3 * (k2 + k3) + k4) * time_step / 8 + + def _compute_step(self, function, time_start, time_step, time_end, value_start): + function_value_start = function(time_start, value_start) + return self._rk4_step( + function, time_start, time_step, time_end, value_start, function_value_start=function_value_start + ), function_value_start + + def _linear_interpolation(self, time_start, time_end, value_start, value_end, time_point): + if time_point == time_start: + return value_start + if time_point == time_end: + return value_end + weight = (time_point - time_start) / (time_end - time_start) + return value_start + weight * (value_end - value_start) + + def integrate(self, time_points): + solution = torch.empty( + len(time_points), + *self.initial_value.shape, + dtype=self.initial_value.dtype, + device=self.initial_value.device, + ) + solution[0] = self.initial_value + + current_index = 1 + current_value = self.initial_value + for time_start, time_end in zip(time_points[:-1], time_points[1:]): + time_step = time_end - time_start + delta_value, _ = self._compute_step(self.function, time_start, time_step, time_end, current_value) + next_value = current_value + delta_value + + while current_index < len(time_points) and time_end >= time_points[current_index]: + solution[current_index] = self._linear_interpolation( + time_start, time_end, current_value, next_value, time_points[current_index] + ) + current_index += 1 + + current_value = next_value + + return solution + + +@add_start_docstrings( + "The full Qwen2.5Omni Token2WavDiT model. Which take speech tokens as input and predict mel spectrogram.", + QWEN2_5OMNI_START_DOCSTRING.format(config_class="Qwen2_5OmniDiTConfig"), +) +class Qwen2_5OmniToken2WavDiTModel(Qwen2_5OmniPreTrainedModel): + config_class = Qwen2_5OmniDiTConfig + _no_split_modules = ["DiTDecoderLayer"] + + def __init__(self, config: Qwen2_5OmniDiTConfig): + super().__init__(config) + self.mel_dim = config.mel_dim + self.repeats = config.repeats + self.time_embed = DiTTimestepEmbedding(config.hidden_size) + + self.text_embed = DiTCodecEmbedding(config.num_embeds, config.emb_dim, config.repeats) + self.input_embed = DiTInputEmbedding(config) + + self.rotary_embed = Qwen2_5OmniDiTRotaryEmbedding(config.head_dim) + + self.hidden_size = config.hidden_size + self.layers = config.num_hidden_layers + self.block_size = config.block_size + self.num_attention_heads = config.num_attention_heads + + self.transformer_blocks = nn.ModuleList() + for i in range(config.num_hidden_layers): + self.transformer_blocks.append( + DiTDecoderLayer( + config, + look_ahead_block=1 if i in config.look_ahead_layers else 0, + look_backward_block=1 if i in config.look_backward_layers else 0, + ) + ) + + self.norm_out = Qwen2_5_OmniAdaLayerNormZero_Final(config.hidden_size) # final modulation + self.proj_out = nn.Linear(config.hidden_size, config.mel_dim) + + def _create_block_diff(self, hidden_states): + batch, seq_len = hidden_states.shape[0], hidden_states.shape[1] + block_indices = torch.arange(seq_len, device=hidden_states.device) // self.block_size # [seq_length] + + block_i = block_indices.unsqueeze(1) # [seq_length, 1] + block_j = block_indices.unsqueeze(0) # [1, seq_length] + block_diff = block_j - block_i # (n, n) + + return block_diff.expand(batch, self.num_attention_heads, seq_len, seq_len) + + def forward( + self, + hidden_states, + condition_vector, + speaker_embedding, + quantized_code, + time_step, + drop_audio_conditioning=False, + drop_code=False, + apply_cfg=True, + ): + batch_size = hidden_states.shape[0] + if time_step.ndim == 0: + time_step = time_step.repeat(batch_size) + + # Compute embeddings + time_embedding = self.time_embed(time_step) + text_embedding = self.text_embed(quantized_code, drop_code=False if apply_cfg else drop_code) + text_embedding_unconditioned = self.text_embed(quantized_code, drop_code=True) if apply_cfg else None + + hidden_states = self.input_embed( + hidden_states, + speaker_embedding, + condition_vector, + text_embedding, + drop_audio_cond=drop_audio_conditioning, + code_embed_uncond=text_embedding_unconditioned, + apply_cfg=apply_cfg, + ) + + # Compute positional encodings + position_embeddings = self.rotary_embed(hidden_states) + blockwise_difference = self._create_block_diff(hidden_states) + + # Transformer blocks + for transformer_block in self.transformer_blocks: + hidden_states = transformer_block( + hidden_states, + time_embedding, + position_embeddings=position_embeddings, + block_diff=blockwise_difference, + ) + + hidden_states = self.norm_out(hidden_states, time_embedding) + output = self.proj_out(hidden_states) + + return output + + @torch.no_grad() + def sample( + self, + conditioning_vector, + reference_mel_spectrogram, + quantized_code, + num_steps=10, + guidance_scale=0.5, + sway_coefficient=-1.0, + ): + noise_initialization = torch.randn([1, 30000, self.mel_dim], dtype=reference_mel_spectrogram.dtype) + maximum_duration = quantized_code.shape[1] * self.repeats + initial_state = noise_initialization[:, :maximum_duration].to(quantized_code.device) + batch_size = reference_mel_spectrogram.shape[0] + conditioning_vector = conditioning_vector.unsqueeze(1).repeat(1, maximum_duration, 1) + + if batch_size != 1: + raise ValueError("Only batch size = 1 is currently supported") + + def ode_function(time_step, hidden_states): + if guidance_scale < 1e-5: + prediction = self( + hidden_states=hidden_states, + speaker_embedding=conditioning_vector, + condition_vector=reference_mel_spectrogram, + quantized_code=quantized_code, + time_step=time_step, + drop_audio_conditioning=False, + drop_code=False, + ) + return prediction + + model_output = self( + hidden_states=hidden_states, + quantized_code=quantized_code, + speaker_embedding=conditioning_vector, + condition_vector=reference_mel_spectrogram, + time_step=time_step, + apply_cfg=True, + ) + guided_prediction, null_prediction = torch.chunk(model_output, 2, dim=0) + return guided_prediction + (guided_prediction - null_prediction) * guidance_scale + + initial_time = 0 + time_embedding = torch.linspace( + initial_time, 1, num_steps, device=quantized_code.device, dtype=conditioning_vector.dtype + ) + + if sway_coefficient is not None: + time_embedding += sway_coefficient * (torch.cos(torch.pi / 2 * time_embedding) - 1 + time_embedding) + + ode_solver = RungeKutta4ODESolver(function=ode_function, initial_value=initial_state) + solution_trajectory = ode_solver.integrate(time_embedding) + + generated_waveform = solution_trajectory[-1] + generated_mel_spectrogram = generated_waveform.permute(0, 2, 1) + return generated_mel_spectrogram + + +@add_start_docstrings( + "The full Qwen2.5Omni Token2Wav model. Consists a DiT model take speech tokens as input and predict mel spectrogram and a BigVGAN vocoder take mel spectrogram as input and predict waveform.", + QWEN2_5OMNI_START_DOCSTRING.format(config_class="Qwen2_5OmniToken2WavConfig"), +) +class Qwen2_5OmniToken2WavModel(Qwen2_5OmniPreTrainedModel): + config_class = Qwen2_5OmniToken2WavConfig + base_model_prefix = "model" + _no_split_modules = ["Qwen2_5OmniToken2WavDiTModel", "Qwen2_5OmniToken2WavBigVGANModel"] + + def __init__(self, config: Qwen2_5OmniToken2WavConfig): + super().__init__(config) + attn_impl = config._attn_implementation + if config._attn_implementation == "flash_attention_2": + logger.warning_once( + "Qwen2_5OmniToken2WavModel must inference with fp32, but flash_attention_2 only supports fp16 and bf16, " + "attention implementation of Qwen2_5OmniToken2WavModel will fallback to sdpa." + ) + attn_impl = "sdpa" + elif config._attn_implementation == "eager": + logger.warning_once( + "Qwen2_5OmniToken2WavModel does not support eager attention implementation, fall back to sdpa" + ) + attn_impl = "sdpa" + self.code2wav_dit_model = Qwen2_5OmniToken2WavDiTModel._from_config( + config.dit_config, attn_implementation=attn_impl + ) + self.code2wav_bigvgan_model = Qwen2_5OmniToken2WavBigVGANModel._from_config( + config.bigvgan_config, attn_implementation=attn_impl + ) + + def forward( + self, + code, + conditioning, + reference_mel, + num_steps=10, + guidance_scale=0.5, + sway_coefficient=-1.0, + **kwargs, + ): + """Generates a waveform from input code and conditioning parameters.""" + + mel_spectrogram = self.code2wav_dit_model.sample( + conditioning, + reference_mel, + code, + num_steps=num_steps, + guidance_scale=guidance_scale, + sway_coefficient=sway_coefficient, + ) + + waveform = self.code2wav_bigvgan_model(mel_spectrogram) + + return waveform + + +############################ +# Start Qwen2.5Omni # +############################ + + +@add_start_docstrings( + """ + The full Qwen2.5Omni model, a multimodal model composed of 3 sub-models: + - [`Qwen2_5OmniThinkerForConditionalGeneration`]: + a causal auto-regressive transformer takes text, audio, image, video as input and predict text tokens. + - [`Qwen2_5OmniTalkerForConditionalGeneration`]: + a causal auto-regressive transformer takes thinker hidden states and response as input and predict speech tokens. + - [`Qwen2_5OmniToken2WavModel`]: + a DiT model take speech tokens as input and predict mel spectrogram and a BigVGAN vocoder take mel spectrogram as input and predict waveform. + """, + QWEN2_5OMNI_START_DOCSTRING.format(config_class=Qwen2_5OmniConfig), +) +class Qwen2_5OmniForConditionalGeneration(Qwen2_5OmniPreTrainedModel, GenerationMixin): + config_class = Qwen2_5OmniConfig + _no_split_modules = [ + "Qwen2_5OmniTalkerForConditionalGeneration", + "Qwen2_5OmniToken2WavModel", + ] + + def __init__(self, config): + super().__init__(config) + + self.thinker = Qwen2_5OmniThinkerForConditionalGeneration(config.thinker_config) + + self.has_talker = config.enable_audio_output + self.speaker_map = {} + if config.enable_audio_output: + self.enable_talker() + self.post_init() + + def enable_talker(self): + self.talker = Qwen2_5OmniTalkerForConditionalGeneration(self.config.talker_config) + self.token2wav = Qwen2_5OmniToken2WavModel(self.config.token2wav_config) + self.token2wav.float() + self.has_talker = True + + def load_speakers(self, path): + for key, value in torch.load(path).items(): + self.speaker_map[key] = value + logger.info("Speaker {} loaded".format(list(self.speaker_map.keys()))) + + def disable_talker(self): + if hasattr(self, "talker"): + del self.talker + if hasattr(self, "token2wav"): + del self.token2wav + self.has_talker = False + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path, + *model_args, + config=None, + cache_dir=None, + ignore_mismatched_sizes=False, + force_download=False, + local_files_only=False, + token=None, + revision="main", + use_safetensors=None, + weights_only=True, + **kwargs, + ): + model = super().from_pretrained( + pretrained_model_name_or_path, + *model_args, + config=config, + cache_dir=cache_dir, + ignore_mismatched_sizes=ignore_mismatched_sizes, + force_download=force_download, + local_files_only=local_files_only, + token=token, + revision=revision, + use_safetensors=use_safetensors, + weights_only=weights_only, + **kwargs, + ) + spk_path = cached_file( + pretrained_model_name_or_path, + "spk_dict.pt", + subfolder=kwargs.pop("subfolder", None), + cache_dir=kwargs.pop("cache_dir", None), + force_download=kwargs.pop("force_download", False), + proxies=kwargs.pop("proxies", None), + resume_download=kwargs.pop("resume_download", None), + local_files_only=kwargs.pop("local_files_only", False), + token=kwargs.pop("use_auth_token", None), + revision=kwargs.pop("revision", None), + ) + if spk_path is None: + raise ValueError(f"""{pretrained_model_name_or_path}/{spk_path} not exists""") + model.load_speakers(spk_path) + + return model + + @torch.no_grad() + # TODO: raushan, defaults should be saved in generation config + def generate( + self, + input_ids: Optional[torch.Tensor] = None, + speaker: str = "Chelsie", + use_audio_in_video: bool = False, + return_audio: Optional[bool] = None, + thinker_max_new_tokens: int = 1024, + talker_max_new_tokens: int = 4096, + talker_do_sample: bool = True, + talker_top_k: int = 40, + talker_top_p: float = 0.8, + talker_temperature: float = 0.9, + talker_eos_token_id: list[int] = [8292, 8294], + talker_repetition_penalty: float = 1.05, + **kwargs, + ): + r""" + Generate text response and audio from input. + + Args: + input_ids (`Optional[torch.Tensor]`, *optional*): + Input ids, should obtain from processor. + speaker (`str` , defaults to "Chelsie"): + Which speaker should be used in audio response. + use_audio_in_video (`bool`, defaults to False): + Whether or not use audio track in video, should same as the parameter in `process_audio_info`. + return_audio (`Optional[bool]`, *optional*): + Whether or not return response in audio format. When `return_audio=None`, this parameter is same as `config.enable_audio_output`. + kwargs (*optional*): + - Without a prefix, they will be entered as `**kwargs` for the `generate` method of each sub-model. + - With a *thinker_*, *talker_*, *token2wav_* prefix, they will be input for the `generate` method of the + thinker, talker and token2wav respectively. It has the priority over the keywords without a prefix. + Returns: + When `return_audio=False`: + - **Text** (`torch.Tensor`): Generated text token sequence. + When `return_audio=True`: + - **Text** (`torch.Tensor`): Generated text token sequence. + - **Audio waveform** (`torch.Tensor`): Generated audio waveform. + """ + if speaker not in self.speaker_map: + raise ValueError(f"{speaker} is not availible, availible speakers: {self.speaker_map.keys()}") + if return_audio and not self.has_talker: + raise ValueError( + "Cannot use talker when talker module not initalized. Use `enable_talker` method or set enable_talker in config to enable talker." + ) + if return_audio is None: + return_audio = self.has_talker + if input_ids.shape[0] != 1 and return_audio: + raise NotImplementedError("Qwen2.5-Omni currently does not support batched inference with audio output") + + shared_kwargs = {"use_audio_in_video": use_audio_in_video} + thinker_kwargs = { + "max_new_tokens": thinker_max_new_tokens, + } + talker_kwargs = { + "max_new_tokens": talker_max_new_tokens, + "do_sample": talker_do_sample, + "top_k": talker_top_k, + "top_p": talker_top_p, + "temperature": talker_temperature, + "eos_token_id": talker_eos_token_id, + "repetition_penalty": talker_repetition_penalty, + } + token2wav_kwargs = {} + + for key, value in kwargs.items(): + if key.startswith("thinker_"): + thinker_kwargs[key[len("thinker_") :]] = value + elif key.startswith("talker_"): + talker_kwargs[key[len("talker_") :]] = value + elif key.startswith("token2wav_"): + token2wav_kwargs[key[len("token2wav_") :]] = value + # Process special input values + elif key == "feature_attention_mask": + thinker_kwargs[key] = value + talker_kwargs["audio_feature_lengths"] = torch.sum(value, dim=1) + elif key == "input_features" or key == "attention_mask": + thinker_kwargs[key] = value + # Put other key to shared kwargs + else: + shared_kwargs[key] = value + + # Merge kwargs + for key, value in shared_kwargs.items(): + if key not in thinker_kwargs: + thinker_kwargs[key] = value + if key not in talker_kwargs: + talker_kwargs[key] = value + if key not in token2wav_kwargs: + token2wav_kwargs[key] = value + speaker_params = self.speaker_map[speaker] + + # 1. Generate from thinker module + generate_audio = return_audio and self.has_talker + if generate_audio: + thinker_kwargs["output_hidden_states"] = True + thinker_kwargs["return_dict_in_generate"] = True + + thinker_result = self.thinker.generate(input_ids=input_ids, **thinker_kwargs) + + if not generate_audio: + return thinker_result + + # 2. Generate speech tokens from talker module + embeds_to_talker = thinker_result.hidden_states[0][0].clone().to(self.talker.device) + if thinker_kwargs.get("input_features", None) is not None: + audio_ids_mask = input_ids == self.config.thinker_config.audio_token_index + audio_mask = audio_ids_mask.unsqueeze(-1).expand_as(embeds_to_talker).to(embeds_to_talker.device) + audio_mask_tensor = torch.zeros( + [audio_ids_mask.sum(), embeds_to_talker.shape[-1]], + dtype=embeds_to_talker.dtype, + device=self.talker.device, + ) + embeds_to_talker.masked_scatter_(audio_mask, audio_mask_tensor) + if thinker_kwargs.get("pixel_values", None) is not None: + image_ids_mask = input_ids == self.config.thinker_config.image_token_index + image_mask = image_ids_mask.unsqueeze(-1).expand_as(embeds_to_talker).to(embeds_to_talker.device) + image_mask_tensor = torch.zeros( + [image_ids_mask.sum(), embeds_to_talker.shape[-1]], + dtype=embeds_to_talker.dtype, + device=self.talker.device, + ) + embeds_to_talker.masked_scatter_(image_mask, image_mask_tensor) + if thinker_kwargs.get("pixel_values_videos", None) is not None: + video_ids_mask = input_ids == self.config.thinker_config.video_token_index + video_mask = video_ids_mask.unsqueeze(-1).expand_as(embeds_to_talker).to(embeds_to_talker.device) + video_mask_tensor = torch.zeros( + [video_ids_mask.sum(), embeds_to_talker.shape[-1]], + dtype=embeds_to_talker.dtype, + device=self.talker.device, + ) + embeds_to_talker.masked_scatter_(video_mask, video_mask_tensor) + + processed_thinker_hidden = ( + (embeds_to_talker,) + thinker_result.hidden_states[0][1:], + ) + thinker_result.hidden_states[1:] + thinker_generate_ids = thinker_result.sequences[:, input_ids.size(1) :].to(self.talker.device) + thinker_token_embeds = [ + token_hidden_states[0].to(self.talker.device) for token_hidden_states in processed_thinker_hidden + ] + thinker_hidden_states = [ + token_hidden_states[-1].to(self.talker.device) for token_hidden_states in processed_thinker_hidden + ] + + talker_text_bos_token = speaker_params["bos_token"] + talker_input_text_ids = torch.cat( + [ + input_ids.to(self.talker.device), + torch.tensor([[talker_text_bos_token]], dtype=torch.long, device=self.talker.device), + thinker_generate_ids[:, :1], + ], + dim=-1, + ) + + talker_input_ids = torch.cat( + [ + torch.full_like(input_ids, fill_value=self.talker.codec_mask_token, device=self.talker.device), + torch.tensor([[self.talker.codec_pad_token]], dtype=torch.long, device=self.talker.device), + torch.tensor([[self.talker.codec_bos_token]], dtype=torch.long, device=self.talker.device), + ], + dim=1, + ) + + thinker_embed_tokens = self.thinker.get_input_embeddings() + thinker_reply_part = torch.cat(thinker_hidden_states[1:], dim=1) + torch.cat(thinker_token_embeds[1:], dim=1) + talker_inputs_embeds = thinker_hidden_states[0] + thinker_token_embeds[0] + talker_text_bos_token = torch.tensor([[talker_text_bos_token]], dtype=torch.long, device=self.thinker.device) + talker_text_bos_embed = thinker_embed_tokens(talker_text_bos_token).to(self.talker.device) + talker_inputs_embeds = torch.cat( + [ + talker_inputs_embeds, + talker_text_bos_embed, + thinker_reply_part[:, :1, :], + ], + dim=1, + ) + + eos_embedding = thinker_embed_tokens( + torch.tensor([[self.talker.text_eos_token]], dtype=torch.long, device=self.thinker.device) + ).to(self.talker.device) + + pad_embedding = thinker_embed_tokens( + torch.tensor([[self.talker.text_pad_token]], dtype=torch.long, device=self.thinker.device) + ).to(self.talker.device) + + thinker_reply_part = torch.cat( + [ + thinker_reply_part[:, 1:, :], + eos_embedding, + pad_embedding, + ], + dim=1, + ) + + talker_attention_mask = None + if "attention_mask" in kwargs: + talker_attention_mask = torch.cat( + [kwargs["attention_mask"], kwargs["attention_mask"].new_ones((1, 2))], dim=1 + ).to(self.talker.device) + + talker_result = self.talker.generate( + input_ids=talker_input_ids, + input_text_ids=talker_input_text_ids, + thinker_reply_part=thinker_reply_part, + inputs_embeds=talker_inputs_embeds, + attention_mask=talker_attention_mask, + suppress_tokens=[self.talker.codec_bos_token], + **{k: (v.to(self.talker.device) if torch.is_tensor(v) else v) for k, v in talker_kwargs.items()}, + ) + talker_generate_codes = talker_result[:, talker_input_ids.shape[1] : -1] + + # 3. Generate wavs from code + if self.token2wav.dtype != torch.float: + self.token2wav.float() + + wav = self.token2wav( + talker_generate_codes.to(self.token2wav.device), + conditioning=speaker_params["cond"].to(self.token2wav.device).float(), + reference_mel=speaker_params["ref_mel"].to(self.token2wav.device).float(), + **token2wav_kwargs, + ) + + return thinker_result.sequences, wav.float() + + +__all__ = [ + "Qwen2_5OmniForConditionalGeneration", + "Qwen2_5OmniThinkerTextModel", + "Qwen2_5OmniThinkerForConditionalGeneration", + "Qwen2_5OmniTalkerModel", + "Qwen2_5OmniTalkerForConditionalGeneration", + "Qwen2_5OmniToken2WavDiTModel", + "Qwen2_5OmniToken2WavBigVGANModel", + "Qwen2_5OmniToken2WavModel", + "Qwen2_5OmniPreTrainedModel", + "Qwen2_5OmniPreTrainedModelForConditionalGeneration", +] diff --git a/docs/transformers/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/docs/transformers/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py new file mode 100644 index 0000000000000000000000000000000000000000..294a72c6b6cc06b3d594cb04c67944677e6cba54 --- /dev/null +++ b/docs/transformers/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py @@ -0,0 +1,4382 @@ +# coding=utf-8 +# Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Qwen2.5Omni model (Audio, Image, Video).""" + +import math +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import Parameter + +from transformers.models.llama.modeling_llama import rotate_half +from transformers.models.qwen2.configuration_qwen2 import Qwen2Config +from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import Qwen2_5_VLVisionConfig +from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( + Qwen2_5_VisionTransformerPretrainedModel, + Qwen2_5_VLAttention, + Qwen2_5_VLMLP, + Qwen2_5_VLModel, + Qwen2_5_VLPreTrainedModel, + Qwen2_5_VLVisionBlock, + Qwen2RMSNorm, +) +from transformers.models.qwen2_audio.configuration_qwen2_audio import Qwen2AudioEncoderConfig +from transformers.models.qwen2_audio.modeling_qwen2_audio import Qwen2AudioEncoderLayer +from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLRotaryEmbedding + +from ...configuration_utils import PretrainedConfig +from ...generation import GenerationMixin +from ...modeling_outputs import BaseModelOutput, ModelOutput +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from ...utils.hub import cached_file + + +if is_flash_attn_2_available(): + from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func + from flash_attn.layers.rotary import apply_rotary_emb +else: + flash_attn_varlen_func = None + apply_rotary_emb = None + + +logger = logging.get_logger(__name__) + + +class Qwen2_5OmniVisionEncoderConfig(Qwen2_5_VLVisionConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen2_5OmniThinkerVision`]. It is used to instantiate a + Qwen2.5-VL vision encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the audio encoder of the Qwen2.5-VL + architecture. + + e.g. [Qwen/Qwen2.5-Omni-7B](https://huggingface.co/Qwen/Qwen2.5-Omni-7B) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + depth (`int`, *optional*, defaults to 32): + Number of layers (depth) in the model. + hidden_size (`int`, *optional*, defaults to 3584): + The size of the hidden layers. + hidden_act (`str`, *optional*, defaults to `"quick_gelu"`): + The non-linear activation function used in the model. Supported options include `"quick_gelu"` and others as applicable. + mlp_ratio (`float`, *optional*, defaults to 4): + The ratio used to determine the size of the MLP (Multi-Layer Perceptron) hidden layer. + num_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer. + in_channels (`int`, *optional*, defaults to 3): + Number of input channels. + patch_size (`int`, *optional*, defaults to 14): + The size of the patches extracted from the input. + spatial_merge_size (`int`, *optional*, defaults to 2): + The size used for merging spatial dimensions. + temporal_patch_size (`int`, *optional*, defaults to 2): + The size used for patches along the temporal dimension. + + Example: + + ```python + >>> from transformers import Qwen2_5OmniVisionEncoderConfig, Qwen2_5OmniVisionEncoder + + >>> # Initializing a Qwen2_5OmniVisionEncoderConfig + >>> configuration = Qwen2_5OmniVisionEncoderConfig() + + >>> # Initializing a Qwen2_5OmniVisionEncoder (with random weights) + >>> model = Qwen2_5OmniVisionEncoder(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen2_5_omni_vision_encoder" + + def __init__( + self, + depth=32, + hidden_size=3584, + hidden_act="silu", + intermediate_size=3420, + num_heads=16, + in_channels=3, + patch_size=14, + spatial_merge_size=2, + temporal_patch_size=2, + window_size=112, + out_hidden_size=3584, + fullatt_block_indexes=[7, 15, 23, 31], + initializer_range=0.02, + **kwargs, + ): + super().__init__( + depth, + hidden_size, + hidden_act, + intermediate_size, + num_heads, + in_channels, + patch_size, + spatial_merge_size, + temporal_patch_size, + window_size, + out_hidden_size, + fullatt_block_indexes, + initializer_range=initializer_range, + **kwargs, + ) + del self.tokens_per_second + + +class Qwen2_5OmniAudioEncoderConfig(Qwen2AudioEncoderConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen2_5OmniAudioEncoder`]. It is used to instantiate a + Qwen2.5-Omni-Thinker audio encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the audio encoder of the Qwen2-Audio + architecture. + + e.g. [Qwen/Qwen2.5-Omni-7B](https://huggingface.co/Qwen/Qwen2.5-Omni-7B) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + num_mel_bins (`int`, *optional*, defaults to 128): + Number of mel features used per input features. Should correspond to the value used in the + `Qwen2_5OmniProcessor` class. + encoder_layers (`int`, *optional*, defaults to 32): + Number of encoder layers. + encoder_attention_heads (`int`, *optional*, defaults to 20): + Number of attention heads for each attention layer in the Transformer encoder. + encoder_ffn_dim (`int`, *optional*, defaults to 5120): + Dimensionality of the "intermediate" (often named feed-forward) layer in encoder. + d_model (`int`, *optional*, defaults to 1280): + Dimensionality of the layers. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + activation_function (`str`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + activation_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for activations inside the fully connected layer. + scale_embedding (`bool`, *optional*, defaults to `False`): + Scale embeddings by diving by sqrt(d_model). + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + max_source_positions (`int`, *optional*, defaults to 1500): + The maximum sequence length of log-mel filter-bank features that this model might ever be used with. + n_window (`int`, *optional*, defaults to 100): + The chunk for conv and flash attn in AudioEncoder. + output_dim (`int`, *optional*, defaults to 3584): + The output dimention of AudioEncoder. + + Example: + + ```python + >>> from transformers import Qwen2_5OmniAudioEncoderConfig, Qwen2_5OmniAudioEncoder + + >>> # Initializing a Qwen2_5OmniAudioEncoderConfig + >>> configuration = Qwen2_5OmniAudioEncoderConfig() + + >>> # Initializing a Qwen2_5OmniAudioEncoder (with random weights) + >>> model = Qwen2_5OmniAudioEncoder(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen2_5_omni_audio_encoder" + + def __init__( + self, + num_mel_bins=128, + encoder_layers=32, + encoder_attention_heads=20, + encoder_ffn_dim=5120, + d_model=1280, + dropout=0, + attention_dropout=0, + activation_function="gelu", + activation_dropout=0, + scale_embedding=False, + initializer_range=0.02, + max_source_positions=1500, + n_window=100, + output_dim=3584, + **kwargs, + ): + super().__init__( + num_mel_bins, + encoder_layers, + encoder_attention_heads, + encoder_ffn_dim, + d_model, + dropout, + attention_dropout, + activation_function, + activation_dropout, + scale_embedding, + initializer_range, + max_source_positions, + **kwargs, + ) + self.n_window = n_window + self.output_dim = output_dim + del self.encoder_layerdrop + + +class Qwen2_5OmniTextConfig(Qwen2Config): + r""" + This is the configuration class to store the configuration of a [`Qwen2_5OmniThinkerForConditionalGeneration`]. It is used to instantiate an + Qwen2.5-Omni-Thinker model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the Qwen2.5-Omni-Thinker. + + e.g. [Qwen/Qwen2.5-Omni-7B](https://huggingface.co/Qwen/Qwen2.5-Omni-7B) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 152064): + Vocabulary size of the QwenOmni model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Qwen2VLModel`] + hidden_size (`int`, *optional*, defaults to 3584): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 18944): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 28): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 28): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 4): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 32768): + The maximum sequence length that this model might ever be used with. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + rope_theta (`float`, *optional*, defaults to 1000000.0): + The base period of the RoPE embeddings. + use_sliding_window (`bool`, *optional*, defaults to `False`): + Whether to use sliding window attention. + sliding_window (`int`, *optional*, defaults to 32768): + Sliding window attention (SWA) window size. If not specified, will default to `4096`. + max_window_layers (`int`, *optional*, defaults to 28): + The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + + Example: + + ```python + >>> from transformers import Qwen2_5OmniThinkerForConditionalGeneration, Qwen2_5OmniThinkerConfig, Qwen2_5OmniAudioEncoderConfig, Qwen2_5OmniVisionEncoderConfig + + >>> # Initializing a Qwen2_5OmniAudioEncoder config + >>> audio_config = Qwen2_5OmniAudioEncoderConfig() + + >>> # Initializing a Qwen2_5OmniVisionEncoder config + >>> vision_config = Qwen2_5OmniVisionEncoderConfig() + + >>> # Initializing a Qwen2.5OmniThinker configuration + >>> configuration = Qwen2_5OmniThinkerConfig(audio_config, vision_config) + + >>> # Initializing a model from the Qwen-Omni style configuration + >>> model = Qwen2_5OmniThinkerForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen2_5_omni_text" + + def __init__( + self, + vocab_size=152064, + hidden_size=3584, + intermediate_size=18944, + num_hidden_layers=28, + num_attention_heads=28, + num_key_value_heads=4, + rope_theta=1000000.0, + sliding_window=32768, + **super_kwargs, + ): + super().__init__(**super_kwargs) + if self.rope_scaling is None: + self.rope_scaling = {"mrope_section": [16, 24, 24], "rope_type": "default", "type": "default"} + + +class Qwen2_5OmniThinkerConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen2_5OmniThinkerForConditionalGeneration`]. It is used to instantiate an + Qwen2.5-Omni-Thinker model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the Qwen2.5-Omni-Thinker. + + e.g. [Qwen/Qwen2.5-Omni-7B](https://huggingface.co/Qwen/Qwen2.5-Omni-7B) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + audio_config (`dict`, *optional*): + The config dictionary of the audio backbone. + vision_config (`dict`, *optional*): + The config dictionary of the vision backbone. + text_config (`dict`, *optional*): + The config dictionary of the text backbone. + audio_token_index (`int`, *optional*, defaults to 151646): + The audio token index to encode the audio prompt. + image_token_index (`int`, *optional*, defaults to 151655): + The image token index to encode the image prompt. + video_token_index (`int`, *optional*, defaults to 151656): + The video token index to encode the video prompt. + position_id_per_seconds (`int`, *optional*, defaults to 25): + The increment of position id per second. + seconds_per_chunk (`int`, *optional*, defaults to 2): + The duration in seconds of the chunk of audio and video data. + audio_start_token_id (`int`, *optional*, defaults to 151647): + The audio start token index to encode the audio prompt. + audio_end_token_id (`int`, *optional*, defaults to 151648): + The audio end token index to encode the audio prompt. + user_token_id (`int, *optional*, defaults to 872): + The user token index to encode the user token. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + + Example: + + ```python + >>> from transformers import Qwen2_5OmniThinkerForConditionalGeneration, Qwen2_5OmniThinkerConfig, Qwen2_5OmniAudioEncoderConfig, Qwen2_5OmniVisionEncoderConfig + + >>> # Initializing a Qwen2_5OmniAudioEncoder config + >>> audio_config = Qwen2_5OmniAudioEncoderConfig() + + >>> # Initializing a Qwen2_5OmniVisionEncoder config + >>> vision_config = Qwen2_5OmniVisionEncoderConfig() + + >>> # Initializing a Qwen2_5OmniTextConfig config + >>> text_config = Qwen2_5OmniTextConfig() + + >>> # Initializing a Qwen2.5OmniThinker configuration + >>> configuration = Qwen2_5OmniThinkerConfig(audio_config, vision_config, text_config) + + >>> # Initializing a model from the Qwen-Omni style configuration + >>> model = Qwen2_5OmniThinkerForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen2_5_omni_thinker" + attribute_map = { + "image_token_id": "image_token_index", + "video_token_id": "video_token_index", + "audio_token_id": "audio_token_index", + } + sub_configs = { + "audio_config": Qwen2_5OmniAudioEncoderConfig, + "vision_config": Qwen2_5OmniVisionEncoderConfig, + "text_config": Qwen2_5OmniTextConfig, + } + + def __init__( + self, + audio_config=None, + vision_config=None, + text_config=None, + audio_token_index=151646, + image_token_index=151655, + video_token_index=151656, + position_id_per_seconds=25, + seconds_per_chunk=2, + audio_start_token_id=151647, + audio_end_token_id=151648, + user_token_id=872, + initializer_range=0.02, + **kwargs, + ): + self.audio_token_index = audio_token_index + self.image_token_index = image_token_index + self.video_token_index = video_token_index + self.user_token_id = user_token_id + self.position_id_per_seconds = position_id_per_seconds + self.seconds_per_chunk = seconds_per_chunk + self.audio_start_token_id = audio_start_token_id + self.audio_end_token_id = audio_end_token_id + self.initializer_range = initializer_range + + if isinstance(vision_config, dict): + vision_config = Qwen2_5OmniVisionEncoderConfig(**vision_config) + elif vision_config is None: + vision_config = Qwen2_5OmniVisionEncoderConfig() + self.vision_config = vision_config + + if isinstance(audio_config, dict): + audio_config = Qwen2_5OmniAudioEncoderConfig(**audio_config) + elif audio_config is None: + audio_config = Qwen2_5OmniAudioEncoderConfig() + self.audio_config = audio_config + + if isinstance(text_config, dict): + text_config = Qwen2_5OmniTextConfig(**text_config) + elif text_config is None: + text_config = Qwen2_5OmniTextConfig() + self.text_config = text_config + + super().__init__(**kwargs) + + +class Qwen2_5OmniTalkerConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen2_5OmniTalkerForConditionalGeneration`]. It is used to instantiate an + Qwen2.5-Omni-Talker model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the Qwen2.5-Omni-Thinker. + + e.g. [Qwen/Qwen2.5-Omni-7B](https://huggingface.co/Qwen/Qwen2.5-Omni-7B) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + audio_token_index (`int`, *optional*, defaults to 151646): + The audio token index to encode the audio prompt. + image_token_index (`int`, *optional*, defaults to 151655): + The image token index to encode the image prompt. + video_token_index (`int`, *optional*, defaults to 151656): + The video token index to encode the video prompt. + vocab_size (`int`, *optional*, defaults to 8448): + Vocabulary size of the QwenOmni model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Qwen2VLModel`] + tts_text_start_token_id (`int`, *optional*, defaults to 151860): + The tts text start token index to encode the start of tts text. + tts_text_end_token_id (`int`, *optional*, defaults to 151861): + The tts text end token index to encode the end of tts text. + tts_text_pad_token_id (`int`, *optional*, defaults to 151859): + The tts text pad token index to encode the pad of tts text. + tts_codec_start_token_id (`int`, *optional*, defaults to 8293): + The tts codec start token index to encode the start of tts codec. + tts_codec_end_token_id (`int`, *optional*, defaults to 8294): + The tts codec end token index to encode the end of tts codec. + tts_codec_pad_token_id (`int`, *optional*, defaults to 8292): + The tts codec pad token index to encode the pad of tts codec. + tts_codec_mask_token_id (`int`, *optional*, defaults to 8296): + The tts codec mask token index to encode the mask of tts codec. + vision_start_token_id (`int`, *optional*, defaults to 151652): + The tts vision start token index to encode the start of vision. + vision_end_token_id (`int`, *optional*, defaults to 151653): + The tts vision end token index to encode the end of vision. + embedding_size (`int`, *optional*, defaults to 3584): + Dimension of the embedding representations. + hidden_size (`int`, *optional*, defaults to 3584): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 18944): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 28): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 28): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 4): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 32768): + The maximum sequence length that this model might ever be used with. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + head_dim (`int`, *optional*, defaults to 128): + The dimension of each attention head. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 1000000.0): + The base period of the RoPE embeddings. + use_sliding_window (`bool`, *optional*, defaults to `False`): + Whether to use sliding window attention. + sliding_window (`int`, *optional*, defaults to 32768): + Sliding window attention (SWA) window size. If not specified, will default to `4096`. + max_window_layers (`int`, *optional*, defaults to 28): + The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + position_id_per_seconds (`int`, *optional*, defaults to 25): + The increment of position id per second. + seconds_per_chunk (`int`, *optional*, defaults to 2): + The duration in seconds of the chunk of audio and video data. + audio_start_token_id (`int`, *optional*, defaults to 151647): + The audio start token index to encode the audio prompt. + audio_end_token_id (`int`, *optional*, defaults to 151648): + The audio end token index to encode the audio prompt. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + spatial_merge_size (`int`, *optional*, defaults to 2): + The size used for merging spatial dimensions. + + Example: + + ```python + >>> from transformers import Qwen2_5OmniTalkerForConditionalGeneration, Qwen2_5OmniThinkerConfig, Qwen2_5OmniAudioEncoderConfig, Qwen2_5OmniVisionEncoderConfig + + >>> # Initializing a Qwen2_5OmniAudioEncoder config + >>> audio_config = Qwen2_5OmniAudioEncoderConfig() + + >>> # Initializing a Qwen2 config + >>> text_config = Qwen2Config() + + >>> # Initializing a Qwen2_5Omni configuration + >>> configuration = Qwen2_5OmniThinkerConfig(audio_config, text_config) + + >>> # Initializing a model from the qwen2-audio style configuration + >>> model = Qwen2_5OmniTalkerForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen2_5_omni_talker" + attribute_map = { + "image_token_id": "image_token_index", + "video_token_id": "video_token_index", + "audio_token_id": "audio_token_index", + } + + def __init__( + self, + audio_token_index=151646, + image_token_index=151655, + video_token_index=151656, + vocab_size=8448, + tts_text_start_token_id=151860, + tts_text_end_token_id=151861, + tts_text_pad_token_id=151859, + tts_codec_start_token_id=8293, + tts_codec_end_token_id=8294, + tts_codec_pad_token_id=8292, + tts_codec_mask_token_id=8296, + vision_start_token_id=151652, + vision_end_token_id=151653, + embedding_size=3584, + hidden_size=3584, + intermediate_size=18944, + num_hidden_layers=28, + num_attention_heads=28, + num_key_value_heads=4, + hidden_act="silu", + max_position_embeddings=32768, + rms_norm_eps=1e-06, + head_dim=128, + use_cache=True, + tie_word_embeddings=False, + rope_theta=1000000.0, + use_sliding_window=False, + sliding_window=32768, + max_window_layers=28, + attention_dropout=0.0, + rope_scaling=None, + position_id_per_seconds=25, + seconds_per_chunk=2, + audio_start_token_id=151647, + audio_end_token_id=151648, + initializer_range=0.02, + spatial_merge_size=2, + **kwargs, + ): + self.audio_token_index = audio_token_index + self.image_token_index = image_token_index + self.video_token_index = video_token_index + + self.tts_text_start_token_id = tts_text_start_token_id + self.tts_text_end_token_id = tts_text_end_token_id + self.tts_text_pad_token_id = tts_text_pad_token_id + self.tts_codec_start_token_id = tts_codec_start_token_id + self.tts_codec_end_token_id = tts_codec_end_token_id + self.tts_codec_pad_token_id = tts_codec_pad_token_id + + self.tts_codec_mask_token_id = tts_codec_mask_token_id + + self.vision_start_token_id = vision_start_token_id + self.vision_end_token_id = vision_end_token_id + + self.vocab_size = vocab_size + self.head_dim = head_dim + self.embedding_size = embedding_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.use_sliding_window = use_sliding_window + self.sliding_window = sliding_window + self.max_window_layers = max_window_layers + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + self.rope_scaling = rope_scaling + self.position_id_per_seconds = position_id_per_seconds # zf + self.seconds_per_chunk = seconds_per_chunk # zf + self.audio_start_token_id = audio_start_token_id # zf + self.audio_end_token_id = audio_end_token_id # zf + + self.initializer_range = initializer_range + self.spatial_merge_size = spatial_merge_size + + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + + +class Qwen2_5OmniDiTConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of the Qwen2_5OmniToken2WavDiT used in the Qwen2.5-Omni-Token2Wav model. + It defines the architecture of the DiT model, which is used for generating mel-spectrograms from tokens. + + Args: + hidden_size (`int`, *optional*, defaults to 1024): + The dimension of the model. + num_hidden_layers (`int`, *optional*, defaults to 22): + The number of transformer blocks in the DiT model. + num_attention_heads (`int`, *optional*, defaults to 16): + The number of attention heads in each transformer block. + ff_mult (`int`, *optional*, defaults to 2): + The multiplier for the feedforward layer in each transformer block. + emb_dim (`int`, *optional*, defaults to 512): + The dimension of the embedding layer. + head_dim (`int`, *optional*, defaults to 64): + The dimension of each attention head. + repeats (`int`, *optional*, defaults to 2): + The number of times the codec embeddings are repeated. + num_embeds (`int`, *optional*, defaults to 8193): + The number of unique embeddings in the codec. + mel_dim (`int`, *optional*, defaults to 80): + The dimension of the mel-spectrogram. + dropout (`float`, *optional*, defaults to 0.1): + The dropout rate for the transformer blocks. + + enc_emb_dim (`int`, *optional*, defaults to 192): + The dimension of the pre-trained speaker embedding. + enc_dim (`int`, *optional*, defaults to 128): + The dimension of the encoder output. + enc_channels (`List[int]`, *optional*, defaults to `[256, 256, 256, 256, 768]`): + A list of output channels for each TDNN/SERes2Net layer in the encoder. + enc_kernel_sizes (`List[int]`, *optional*, defaults to `[5, 3, 3, 3, 1]`): + A list of kernel sizes for each layer in the encoder. + enc_dilations (`List[int]`, *optional*, defaults to `[1, 2, 3, 4, 1]`): + A list of dilations for each layer in the encoder. + enc_attention_channels (`int`, *optional*, defaults to 64): + The number of attention channels in the SqueezeExcitationBlock. + enc_res2net_scale (`int`, *optional*, defaults to 2): + The scale of the Res2Net block in the encoder. + enc_se_channels (`int`, *optional*, defaults to 64): + The number of output channels after squeeze in the SqueezeExcitationBlock. + """ + + model_type = "qwen2_5_omni_dit" + + def __init__( + self, + hidden_size=1024, + num_hidden_layers=22, + num_attention_heads=16, + ff_mult=2, + emb_dim=512, + head_dim=64, + rope_theta=10000.0, + max_position_embeddings=32768, + block_size=24, + look_ahead_layers=[10], + look_backward_layers=[0, 20], + repeats=2, + num_embeds=8193, + mel_dim=80, + dropout=0.1, + enc_emb_dim=192, + enc_dim=128, + enc_channels=[256, 256, 256, 256, 768], + enc_kernel_sizes=[5, 3, 3, 3, 1], + enc_dilations=[1, 2, 3, 4, 1], + enc_attention_channels=64, + enc_res2net_scale=2, + enc_se_channels=64, + **kwargs, + ): + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.ff_mult = ff_mult + self.emb_dim = emb_dim + self.head_dim = head_dim + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + self.block_size = block_size + self.look_ahead_layers = look_ahead_layers + self.look_backward_layers = look_backward_layers + self.repeats = repeats + self.num_embeds = num_embeds + self.mel_dim = mel_dim + self.dropout = dropout + self.enc_emb_dim = enc_emb_dim + self.enc_dim = enc_dim + self.enc_channels = enc_channels + self.enc_kernel_sizes = enc_kernel_sizes + self.enc_dilations = enc_dilations + self.enc_attention_channels = enc_attention_channels + self.enc_res2net_scale = enc_res2net_scale + self.enc_se_channels = enc_se_channels + super().__init__(**kwargs) + + +class Qwen2_5OmniBigVGANConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of the Qwen2_5OmniToken2WavBigVGAN module used in the Qwen2.5-Omni-Token2Wav model. + It defines the architecture of the BigVGAN model, which is used for converting mel-spectrograms to waveforms. + + Args: + mel_dim (`int`, *optional*, defaults to 80): + The dimension of the mel-spectrogram. + upsample_initial_channel (`int`, *optional*, defaults to 1536): + The number of channels in the initial upsampling layer. + resblock_kernel_sizes (`List[int]`, *optional*, defaults to `[3, 7, 11]`): + A list of kernel sizes for each residual block. + resblock_dilation_sizes (`List[List[int]]`, *optional*, defaults to `[[1, 3, 5], [1, 3, 5], [1, 3, 5]]`): + A list of dilation sizes for each residual block. + upsample_rates (`List[int]`, *optional*, defaults to `[5, 3, 2, 2, 2, 2]`): + A list of upsampling rates for each upsampling layer. + upsample_kernel_sizes (`List[int]`, *optional*, defaults to `[11, 7, 4, 4, 4, 4]`): + A list of kernel sizes for each upsampling layer. + """ + + model_type = "qwen2_5_omni_bigvgan" + + def __init__( + self, + mel_dim=80, + upsample_initial_channel=1536, + resblock_kernel_sizes=[3, 7, 11], + resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], + upsample_rates=[5, 3, 2, 2, 2, 2], + upsample_kernel_sizes=[11, 7, 4, 4, 4, 4], + **kwargs, + ): + self.mel_dim = mel_dim + self.upsample_initial_channel = upsample_initial_channel + self.resblock_kernel_sizes = resblock_kernel_sizes + self.resblock_dilation_sizes = resblock_dilation_sizes + self.upsample_rates = upsample_rates + self.upsample_kernel_sizes = upsample_kernel_sizes + super().__init__(**kwargs) + + +class Qwen2_5OmniToken2WavConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen2_5OmniToken2WavModel`]. + It is used to instantiate the Qwen2.5-Omni-Token2Wav model which combines a Diffusion Transformer (DiT) for mel-spectrogram generation with a BigVGAN model for waveform synthesis. The configuration contains sub-configurations for both components. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + dit_config ([`DiT_Args`], *optional*): + Configuration class for the Diffusion Transformer (DiT) module responsible for generating mel-spectrograms. + bigvgan_config ([`BigVGAN_Args`], *optional*): + Configuration class for the BigVGAN module responsible for converting mel-spectrograms to waveforms. + Example: + + ```python + >>> from transformers import Qwen2_5OmniToken2WavModel, DiT_Args, BigVGAN_Args + + >>> # Initialize DiT configuration + >>> dit_config = DiT_Args( + ... dim=1024, + ... depth=22, + ... heads=16, + ... ff_mult=2 + ... ) + + >>> # Initialize BigVGAN configuration + >>> bigvgan_config = BigVGAN_Args( + ... mel_dim=80, + ... upsample_rates=[5,3,2,2,2,2] + ... ) + + >>> # Initialize main configuration + >>> config = Qwen2_5OmniToken2WavConfig(dit_config, bigvgan_config) + + >>> # Initialize model with config + >>> model = Qwen2_5OmniToken2Wav(config) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "qwen2_5_omni_token2wav" + sub_configs = { + "dit_config": Qwen2_5OmniDiTConfig, + "bigvgan_config": Qwen2_5OmniBigVGANConfig, + } + + def __init__(self, dit_config=None, bigvgan_config=None, **kwargs): + if dit_config is None: + dit_config = {} + if bigvgan_config is None: + bigvgan_config = {} + self.dit_config = Qwen2_5OmniDiTConfig(**dit_config) + self.bigvgan_config = Qwen2_5OmniBigVGANConfig(**bigvgan_config) + super().__init__(**kwargs) + + +class Qwen2_5OmniConfig(PretrainedConfig): + """ + This is the configuration class to store the configuration of a [`Qwen2_5OmniForConditionalGeneration`]. It is used to instantiate a Qwen2.5Omni + model according to the specified sub-models configurations, defining the model architecture. + + Instantiating a configuration with the defaults will yield a similar configuration to that of the + [Qwen/Qwen2.5-Omni-7B](https://huggingface.co/Qwen/Qwen2.5-Omni-7B) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + thinker_config (`dict`, *optional*): Configuration of the underlying thinker sub-model. + talker_config (`dict`, *optional*): Configuration of the underlying talker sub-model. + token2wav_config (`dict`, *optional*): Configuration of the underlying codec sub-model. + enable_audio_output (`bool`, *optional*, defaults to `True`): Whether enabel audio output and load talker and token2wav module. + + Example: + + ```python + >>> from transformers import ( + ... Qwen2_5OmniThinkerConfig, + ... Qwen2_5OmniTalkerConfig, + ... Qwen2_5OmniToken2WavConfig, + ... Qwen2_5OmniForConditionalGeneration, + ... Qwen2_5OmniConfig, + ... ) + + >>> # Initializing sub-modules configurations. + >>> thinker_config = Qwen2_5OmniThinkerConfig() + >>> talker_config = Qwen2_5OmniTalkerConfig() + >>> token2wav_config = Qwen2_5OmniToken2WavConfig() + + + >>> # Initializing a module style configuration + >>> configuration = Qwen2_5OmniConfig.from_sub_model_configs( + ... thinker_config, talker_config, token2wav_config + ... ) + + >>> # Initializing a model (with random weights) + >>> model = Qwen2_5OmniForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "qwen2_5_omni" + sub_configs = { + "thinker_config": Qwen2_5OmniThinkerConfig, + "talker_config": Qwen2_5OmniTalkerConfig, + "token2wav_config": Qwen2_5OmniToken2WavConfig, + } + + def __init__( + self, + thinker_config=None, + talker_config=None, + token2wav_config=None, + enable_audio_output: bool = True, + **kwargs, + ): + if thinker_config is None: + thinker_config = {} + logger.info("thinker_config is None. Initializing thinker model with default values") + + if talker_config is None: + talker_config = {} + logger.info("talker_config is None. Initializing talker model with default values") + + if token2wav_config is None: + token2wav_config = {} + logger.info("token2wav_config is None. Initializing token2wav model with default values") + + self.thinker_config = Qwen2_5OmniThinkerConfig(**thinker_config) + self.talker_config = Qwen2_5OmniTalkerConfig(**talker_config) + self.token2wav_config = Qwen2_5OmniToken2WavConfig(**token2wav_config) + self.enable_audio_output = enable_audio_output + + super().__init__(**kwargs) + + def get_text_config(self, decoder=False) -> "PretrainedConfig": + """ + Returns the config that is meant to be used with text IO. On most models, it is the original config instance + itself. On specific composite models, it is under a set of valid names. + + Args: + decoder (`Optional[bool]`, *optional*, defaults to `False`): + If set to `True`, then only search for decoder config names. + """ + # Overriden for deeply nested config like Qwen2-Omni. We don't have any omni model + # except for Qwen yet. This has to be generalized if more deeply nested configs are + # added. NOTE: currently method used only by vLLM + return self.thinker_config.get_text_config() + + +class Qwen2_5OmniPreTrainedModel(Qwen2_5_VLPreTrainedModel): + config_class = Qwen2_5OmniConfig + _supports_static_cache = True + + def _init_weights(self, module): + # important: this ported version of Qwen2.5OmniThinker isn't meant for training from scratch - only + # inference and fine-tuning - so the proper init weights code has been removed + std = self.config.initializer_range if hasattr(self.config, "initializer_range") else 0.02 + + if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv3d, nn.ConvTranspose1d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + elif isinstance(module, Qwen2RMSNorm): + module.weight.data.fill_(1.0) + + +class Qwen2_5OmniPreTrainedModelForConditionalGeneration(Qwen2_5OmniPreTrainedModel): + def _prepare_4d_causal_attention_mask_with_cache_position( + self, + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + min_dtype: float, + cache_position: torch.Tensor, + batch_size: int, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + min_dtype (`float`): + The minimum value representable with the dtype `dtype`. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + def get_llm_pos_ids_for_vision( + self, + start_idx: int, + vision_idx: int, + spatial_merge_size: int, + t_index: List[int], + grid_hs: List[int], + grid_ws: List[int], + ): + llm_pos_ids_list = [] + llm_grid_h = grid_hs[vision_idx] // spatial_merge_size + llm_grid_w = grid_ws[vision_idx] // spatial_merge_size + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(len(t_index), -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(len(t_index), llm_grid_h, -1).flatten() + t_index = torch.Tensor(t_index).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten().long() + _llm_pos_ids = torch.stack([t_index, h_index, w_index]) + llm_pos_ids_list.append(_llm_pos_ids + start_idx) # + 1 ) # 12.09 by malinhan + llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1) + return llm_pos_ids + + def get_chunked_index( + self, token_indices: torch.Tensor, tokens_per_chunk: int, remove_index: int + ) -> list[tuple[int, int]]: + """ + Splits token index list into chunks based on token value ranges. + + Given a list of token indices, returns a list of (start, end) index tuples representing + slices of the list where the token values fall within successive ranges of `t_ntoken_per_chunk`. + + For example, if `t_ntoken_per_chunk` is 1000, the function will create chunks such that: + - the first chunk contains token values < 1000, + - the second chunk contains values >= 1000 and < 2000, and so on. + + Parameters: + token_indices (`torch.Tensor` of shape `(seq_len, )`): A monotonically increasing list of + token index values. + t_ntoken_per_chunk (`int`): Number of tokens per chunk (used as the chunk size threshold). + remove_index (`int`) An index id to subtract from `token_indices` before chunking + + Returns: + `List[Tuple[int, int]]`: A list of tuples, each representing the start (inclusive) + and end (exclusive) indices of a chunk in `token_indices`. + """ + + def _iter(): + i, start_idx = 0, 0 # skip bos token + current_chunk = 1 + while i < len(token_indices): # skip eos token + if token_indices[i] - remove_index >= current_chunk * tokens_per_chunk: + yield (start_idx, i) + start_idx = i + current_chunk += 1 + i += 1 + yield (start_idx, len(token_indices)) + + return list(_iter()) + + def get_rope_index( + self, + input_ids: Optional[torch.LongTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + use_audio_in_video: bool = False, + audio_seqlens: Optional[torch.LongTensor] = None, + second_per_grids: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Calculate the 3D rope index based on image and video's temporal, height and width in LLM. + + Explanation: + Each embedding sequence contains vision embedding and text embedding or just contains text embedding. + + For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs. + Examples: + input_ids: [T T T T T], here T is for text. + temporal position_ids: [0, 1, 2, 3, 4] + height position_ids: [0, 1, 2, 3, 4] + width position_ids: [0, 1, 2, 3, 4] + + For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part + and 1D rotary position embedding for text part. + Examples: + Temporal (Time): 3 patches, representing different segments of the video in time. + Height: 2 patches, dividing each frame vertically. + Width: 2 patches, dividing each frame horizontally. + We also have some important parameters: + fps (Frames Per Second): The video's frame rate, set to 1. This means one frame is processed each second. + tokens_per_second: This is a crucial parameter. It dictates how many "time-steps" or "temporal tokens" are conceptually packed into a one-second interval of the video. In this case, we have 25 tokens per second. So each second of the video will be represented with 25 separate time points. It essentially defines the temporal granularity. + temporal_patch_size: The number of frames that compose one temporal patch. Here, it's 2 frames. + interval: The step size for the temporal position IDs, calculated as tokens_per_second * temporal_patch_size / fps. In this case, 25 * 2 / 1 = 50. This means that each temporal patch will be have a difference of 50 in the temporal position IDs. + input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision. + vision temporal position_ids: [0, 0, 0, 0, 50, 50, 50, 50, 100, 100, 100, 100] + vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1] + vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] + text temporal position_ids: [101, 102, 103, 104, 105] + text height position_ids: [101, 102, 103, 104, 105] + text width position_ids: [101, 102, 103, 104, 105] + Here we calculate the text start position_ids as the max vision position_ids plus 1. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + use_audio_in_video (`bool`, *optional*): + If set to `True`, use the audio in video. + audio_seqlens (`torch.LongTensor` of shape `(num_audios)`, *optional*): + The length of feature shape of each audio in LLM. + second_per_grids (`torch.LongTensor` of shape `(num_videos)`, *optional*): + The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs. + + Returns: + position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`) + mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`) + """ + spatial_merge_size = self.spatial_merge_size + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + audio_token_id = self.config.audio_token_id + vision_start_token_id = self.config.vision_start_token_id + audio_start_token_id = self.config.audio_start_token_id + position_id_per_seconds = self.config.position_id_per_seconds + seconds_per_chunk = self.config.seconds_per_chunk + + mrope_position_deltas = [] + if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): + total_input_ids = input_ids + if attention_mask is None: + attention_mask = torch.ones_like(total_input_ids) + position_ids = torch.ones( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ) + image_idx, video_idx, audio_idx = 0, 0, 0 + attention_mask = attention_mask.to(total_input_ids.device) + for i, input_ids in enumerate(total_input_ids): + input_ids = input_ids[attention_mask[i] == 1] + image_nums, video_nums, audio_nums = 0, 0, 0 + vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + audio_nums = torch.sum(input_ids == audio_start_token_id) + image_nums = (vision_tokens == image_token_id).sum() + video_nums = ( + (vision_tokens == audio_start_token_id).sum() + if use_audio_in_video + else (vision_tokens == video_token_id).sum() + ) + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos, remain_audios = image_nums, video_nums, audio_nums + multimodal_nums = ( + image_nums + audio_nums if use_audio_in_video else image_nums + video_nums + audio_nums + ) + for _ in range(multimodal_nums): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if audio_token_id in input_tokens and remain_audios > 0: + ed_audio = input_tokens.index(audio_token_id, st) + else: + ed_audio = len(input_tokens) + 1 + min_ed = min(ed_image, ed_video, ed_audio) + if min_ed == ed_audio: + text_len = min_ed - st - 1 + if text_len != 0: + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + bos_len = 1 + llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + audio_len = ((audio_seqlens[audio_idx] - 1) // 2 + 1 - 2) // 2 + 1 + llm_pos_ids = torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx + llm_pos_ids_list.append(llm_pos_ids) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + eos_len = 1 + llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx) + + st += text_len + bos_len + audio_len + eos_len + audio_idx += 1 + remain_audios -= 1 + + elif min_ed == ed_image: + text_len = min_ed - st - 1 + if text_len != 0: + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + bos_len = 1 + llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + grid_t = image_grid_thw[image_idx][0] + grid_hs = image_grid_thw[:, 1] + grid_ws = image_grid_thw[:, 2] + t_index = (torch.arange(grid_t) * 1 * position_id_per_seconds).long() + llm_pos_ids = self.get_llm_pos_ids_for_vision( + st_idx, image_idx, spatial_merge_size, t_index, grid_hs, grid_ws + ) + image_len = image_grid_thw[image_idx].prod() // (spatial_merge_size**2) + llm_pos_ids_list.append(llm_pos_ids) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + eos_len = 1 + llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx) + + st += text_len + bos_len + image_len + eos_len + image_idx += 1 + remain_images -= 1 + + elif min_ed == ed_video and not use_audio_in_video: + text_len = min_ed - st - 1 + if text_len != 0: + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + bos_len = 1 + llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + grid_t = video_grid_thw[video_idx][0] + grid_hs = video_grid_thw[:, 1] + grid_ws = video_grid_thw[:, 2] + t_index = ( + torch.arange(grid_t) * second_per_grids[video_idx].cpu().float() * position_id_per_seconds + ).long() + llm_pos_ids = self.get_llm_pos_ids_for_vision( + st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws + ) + video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2) + llm_pos_ids_list.append(llm_pos_ids) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + eos_len = 1 + llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx) + + st += text_len + bos_len + video_len + eos_len + video_idx += 1 + remain_videos -= 1 + + elif min_ed == ed_video and use_audio_in_video: + text_len = min_ed - st - 2 + if text_len != 0: + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + bos_len = 1 + llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx) + llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + audio_len = ((audio_seqlens[audio_idx] - 1) // 2 + 1 - 2) // 2 + 1 + audio_llm_pos_ids = torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx + grid_t = video_grid_thw[video_idx][0] + grid_hs = video_grid_thw[:, 1] + grid_ws = video_grid_thw[:, 2] + + t_index = ( + torch.arange(grid_t) * second_per_grids[video_idx].cpu().float() * position_id_per_seconds + ).long() + video_llm_pos_ids = self.get_llm_pos_ids_for_vision( + st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws + ) + + t_ntoken_per_chunk = int(position_id_per_seconds * seconds_per_chunk) + video_chunk_indexes = self.get_chunked_index(video_llm_pos_ids[0], t_ntoken_per_chunk, st_idx) + audio_chunk_indexes = self.get_chunked_index(audio_llm_pos_ids[0], t_ntoken_per_chunk, st_idx) + sub_len = 0 + for j in range(max(len(video_chunk_indexes), len(audio_chunk_indexes))): + video_chunk_index = video_chunk_indexes[j] if j < len(video_chunk_indexes) else None + audio_chunk_index = audio_chunk_indexes[j] if j < len(audio_chunk_indexes) else None + if video_chunk_index is not None: + sub_len += video_chunk_index[1] - video_chunk_index[0] + + llm_pos_ids_list.append( + video_llm_pos_ids[:, video_chunk_index[0] : video_chunk_index[1]] + ) + if audio_chunk_index is not None: + sub_len += audio_chunk_index[1] - audio_chunk_index[0] + + llm_pos_ids_list.append( + audio_llm_pos_ids[:, audio_chunk_index[0] : audio_chunk_index[1]] + ) + video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + eos_len = 1 + llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx) + llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx) + + st += text_len + bos_len * 2 + audio_len + video_len + eos_len * 2 + + audio_idx += 1 + video_idx += 1 + remain_videos -= 1 + remain_audios -= 1 + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) + mrope_position_deltas.append(llm_positions.max() + 1 - len(input_ids)) + mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) + + return position_ids, mrope_position_deltas + else: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) + max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] + mrope_position_deltas = max_position_ids + 1 - torch.sum(attention_mask, dim=-1, keepdim=True) + + return position_ids, mrope_position_deltas + + +############################ +# Start Thinker # +############################ + + +@dataclass +class Qwen2_5OmniThinkerCausalLMOutputWithPast(ModelOutput): + """ + Base class for Qwen2.5OmniThinker causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`, *optional*): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + past_key_values: Optional[List[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + rope_deltas: Optional[torch.LongTensor] = None + + +class Qwen2_5OmniAudioAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + config: Qwen2_5OmniAudioEncoderConfig, + ): + super().__init__() + self.embed_dim = config.d_model + self.num_heads = config.encoder_attention_heads + self.dropout = config.attention_dropout + self.head_dim = self.embed_dim // self.num_heads + self.config = config + + if (self.head_dim * self.num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {self.num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = False + self.is_causal = False + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + seq_length, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states).reshape(seq_length, self.num_heads, -1) + key_states = self.k_proj(hidden_states).reshape(seq_length, self.num_heads, -1) + value_states = self.v_proj(hidden_states).reshape(seq_length, self.num_heads, -1) + + query_states = query_states.transpose(0, 1) + key_states = key_states.transpose(0, 1) + value_states = value_states.transpose(0, 1) + attn_weights = torch.matmul(query_states, key_states.transpose(1, 2)) / math.sqrt(self.head_dim) + + attention_mask = torch.full( + [1, seq_length, key_states.shape[1]], + torch.finfo(query_states.dtype).min, + device=query_states.device, + dtype=query_states.dtype, + ) + for i in range(1, len(cu_seqlens)): + attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 + + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1).to(query_states.dtype) + + attn_output = torch.matmul(attn_weights, value_states).transpose(0, 1).reshape(seq_length, self.embed_dim) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + + attn_output = self.out_proj(attn_output) + + return attn_output + + +class Qwen2_5OmniAudioFlashAttention2(Qwen2_5OmniAudioAttention): + """ + Qwen2.5OmniThinker flash attention module. This module inherits from `Qwen2_5OmniAudioAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + seq_length, all_dim = hidden_states.size() + query_states = self.q_proj(hidden_states) + query_states = query_states.reshape(seq_length, self.num_heads, -1) + + key_states = self.k_proj(hidden_states) + key_states = key_states.reshape(seq_length, self.num_heads, -1) + value_states = self.v_proj(hidden_states) + value_states = value_states.reshape(seq_length, self.num_heads, -1) + + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + attn_output = flash_attn_varlen_func( + query_states, key_states, value_states, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, dropout_p=0.0 + ) + attn_output = attn_output.reshape(seq_length, all_dim) + attn_output = self.out_proj(attn_output) + + return attn_output + + +class Qwen2_5OmniAudioSdpaAttention(Qwen2_5OmniAudioAttention): + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + seq_length, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states).reshape(seq_length, self.num_heads, -1) + key_states = self.k_proj(hidden_states).reshape(seq_length, self.num_heads, -1) + value_states = self.v_proj(hidden_states).reshape(seq_length, self.num_heads, -1) + + attention_mask = torch.zeros( + [1, seq_length, key_states.shape[0]], device=query_states.device, dtype=torch.bool + ) + for i in range(1, len(cu_seqlens)): + attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. + query_states = query_states.transpose(0, 1) + key_states = key_states.transpose(0, 1) + value_states = value_states.transpose(0, 1) + + # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, + # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=self.dropout if self.training else 0.0, + ) + attn_output = attn_output.transpose(0, 1) + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(seq_length, self.embed_dim) + attn_output = self.out_proj(attn_output) + return attn_output + + +QWEN2_5_OMNI_AUDIO_ATTENTION_CLASSES = { + "eager": Qwen2_5OmniAudioAttention, + "flash_attention_2": Qwen2_5OmniAudioFlashAttention2, + "sdpa": Qwen2_5OmniAudioSdpaAttention, +} + + +class Qwen2_5OmniAudioEncoderLayer(Qwen2AudioEncoderLayer): + def __init__(self, config: Qwen2_5OmniAudioEncoderConfig): + super().__init__(config) + self.self_attn = QWEN2_5_OMNI_AUDIO_ATTENTION_CLASSES[config._attn_implementation](config) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states = self.self_attn( + hidden_states=hidden_states, + cu_seqlens=cu_seqlens, + ) + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + hidden_states = residual + hidden_states + + if hidden_states.dtype == torch.float16: + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + return outputs + + +class SinusoidsPositionEmbedding(nn.Module): + def __init__(self, length, channels, max_timescale=10000): + super().__init__() + if channels % 2 != 0: + raise ValueError("SinusoidsPositionEmbedding needs even channels input") + log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) + inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2)).float() + scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] + self.register_buffer( + "positional_embedding", + torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1), + persistent=False, + ) + + def forward(self, seqlen: int): + return self.positional_embedding[:seqlen, :] + + +QWEN_OMNI_AUDIO_INPUTS = r""" + Args: + input_features (`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`): + Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be + obtained by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a + `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into + `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding + and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] + + feature_lens: [B], torch.LongTensor , mel length + + aftercnn_lens : [B], torch.LongTensor , mel length after cnn + """ + + +class Qwen2_5OmniAudioEncoder(Qwen2_5OmniPreTrainedModel): + """ + Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a + [`Qwen2_5OmniAudioEncoderLayer`]. + + Args: + config: Qwen2_5OmniAudioEncoderConfig + """ + + config_class = Qwen2_5OmniAudioEncoderConfig + main_input_name = "input_features" + _no_split_modules = ["Qwen2_5OmniAudioEncoderLayer"] + _supports_sdpa = True + + def __init__(self, config: Qwen2_5OmniAudioEncoderConfig): + super().__init__(config) + self.dropout = config.dropout + + embed_dim = config.d_model + self.num_mel_bins = config.num_mel_bins + self.max_source_positions = config.max_source_positions + self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + self.n_window = config.n_window + self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1) + self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1) + self.positional_embedding = SinusoidsPositionEmbedding(self.max_source_positions, embed_dim) + self.audio_bos_eos_token = nn.Embedding(2, config.output_dim) + self.layers = nn.ModuleList([Qwen2_5OmniAudioEncoderLayer(config) for _ in range(config.encoder_layers)]) + self.ln_post = nn.LayerNorm(config.d_model) + self.avg_pooler = nn.AvgPool1d(2, stride=2) + self.proj = nn.Linear(config.d_model, config.output_dim) + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def _freeze_parameters(self): + for param in self.parameters(): + param.requires_grad = False + self._requires_grad = False + + def get_input_embeddings(self) -> nn.Module: + return self.conv1 + + def set_input_embeddings(self, value: nn.Module): + self.conv1 = value + + @add_start_docstrings_to_model_forward(QWEN_OMNI_AUDIO_INPUTS) + def forward( + self, + input_features, + feature_lens=None, + aftercnn_lens=None, + ): + chunk_num = torch.ceil(feature_lens / (self.n_window * 2)).long() + + chunk_lengths = torch.tensor( + [self.n_window * 2] * chunk_num.sum(), + dtype=torch.long, + device=feature_lens.device, + ) + tail_chunk_index = F.pad(chunk_num, (1, 0), value=-1).cumsum(0)[1:] + chunk_lengths[tail_chunk_index] = feature_lens % (self.n_window * 2) + chunk_lengths = torch.where(chunk_lengths == 0, self.n_window * 2, chunk_lengths) + + chunk_list = input_features.split(chunk_lengths.tolist(), dim=1) + padded_feature, padded_mask, padded_mask_after_cnn = self.padded_and_mask_function( + chunk_list, chunk_lengths, padding_value=0, padding_side="right" + ) + padded_embed = nn.functional.gelu(self.conv1(padded_feature)) * padded_mask + padded_embed = nn.functional.gelu(self.conv2(padded_embed)).transpose(1, 2) + + padded_embed = padded_embed + self.positional_embedding.positional_embedding[ + : padded_embed.shape[1], : + ].unsqueeze(0).to(padded_embed.dtype) + hidden_states = padded_embed[padded_mask_after_cnn] + cu_seqlens = torch.cat( + ( + torch.zeros(1, device=padded_mask_after_cnn.device, dtype=torch.int32), + padded_mask_after_cnn.sum(1).cumsum(0), + ) + ).to(torch.int32) + + for idx, encoder_layer in enumerate(self.layers): + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + cu_seqlens, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + cu_seqlens, + ) + + hidden_states = layer_outputs[0] + + hidden_states_list = hidden_states.split(aftercnn_lens.tolist(), dim=0) + token_audio_list = [] + for each_audio_states in hidden_states_list: + each_audio_states = self.avg_pooler(each_audio_states.transpose(0, 1)).transpose_(0, 1) + each_audio_states = self.ln_post(each_audio_states) + each_audio_states = self.proj(each_audio_states) + token_audio_list.append(each_audio_states) + token_audio = torch.cat(token_audio_list, dim=0) + return BaseModelOutput(last_hidden_state=token_audio) + + def padded_and_mask_function(self, tensor_list, tensor_len, padding_value=0, padding_side="right"): + """ + Pads a sequence of tensors to their maximum length on indicated `padding_side`. + Then prepares a mask so that pad tokens are not attended to. + """ + max_len = tensor_len.max() + dim = tensor_list[0].shape[0] + padded_tensor = torch.full( + size=(len(tensor_list), dim, max_len), + fill_value=padding_value, + dtype=self.dtype, + device=tensor_list[0].device, + ) + + batch_mask = torch.zeros( + (len(tensor_len), max_len), + dtype=torch.long, + device=padded_tensor.device, + ) + for i, length in enumerate(tensor_len): + batch_mask[i, :length] = 1 + padded_tensor[i, :, :length] = tensor_list[i] + + feature_lens_after_cnn = (tensor_len - 1) // 2 + 1 + max_len_after_cnn = feature_lens_after_cnn.max() + batch_mask_after_cnn = torch.zeros( + (len(tensor_len), max_len_after_cnn), + dtype=torch.long, + device=padded_tensor.device, + ) + for i, length in enumerate(feature_lens_after_cnn): + batch_mask_after_cnn[i, :length] = 1 + return ( + padded_tensor, + batch_mask.unsqueeze(1), + batch_mask_after_cnn.bool(), + ) + + # Ignore copy + def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): + """ + Computes the output length of the convolutional layers and the output length of the audio encoder + """ + input_lengths = (input_lengths - 1) // 2 + 1 + output_lengths = (input_lengths - 2) // 2 + 1 + return input_lengths, output_lengths + + +def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + orig_dtype = tensor.dtype + tensor = tensor.float() + cos = freqs.cos() + sin = freqs.sin() + cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() + sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() + output = (tensor * cos) + (rotate_half(tensor) * sin) + output = output.to(orig_dtype) + return output + + +class Qwen2_5OmniVisionAttention(nn.Module): + def __init__(self, dim: int, num_heads: int = 16) -> None: + super().__init__() + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.q = nn.Linear(dim, dim, bias=True) + self.k = nn.Linear(dim, dim, bias=True) + self.v = nn.Linear(dim, dim, bias=True) + self.proj = nn.Linear(dim, dim) + + def forward( + self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + q = self.q(hidden_states).reshape(seq_length, self.num_heads, -1) + k = self.k(hidden_states).reshape(seq_length, self.num_heads, -1) + v = self.v(hidden_states).reshape(seq_length, self.num_heads, -1) + q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) + k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) + + attention_mask = torch.full( + [1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype + ) + for i in range(1, len(cu_seqlens)): + attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 + + q = q.transpose(0, 1) + k = k.transpose(0, 1) + v = v.transpose(0, 1) + attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim) + attn_weights = attn_weights + attention_mask + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) + attn_output = torch.matmul(attn_weights, v) + attn_output = attn_output.transpose(0, 1) + attn_output = attn_output.reshape(seq_length, -1) + attn_output = self.proj(attn_output) + return attn_output + + +class Qwen2_5OmniVisionFlashAttention2(nn.Module): + def __init__(self, dim: int, num_heads: int = 16) -> None: + super().__init__() + self.num_heads = num_heads + self.q = nn.Linear(dim, dim, bias=True) + self.k = nn.Linear(dim, dim, bias=True) + self.v = nn.Linear(dim, dim, bias=True) + self.proj = nn.Linear(dim, dim) + + def _apply_rotary_pos_emb_flashatt(self, tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + tensor_ = tensor.float() + cos = freqs.cos() # .type_as(tensor_) + sin = freqs.sin() # .type_as(tensor_) + output = apply_rotary_emb(tensor_, cos, sin).type_as(tensor) + return output + + def forward( + self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + q = self.q(hidden_states).reshape(seq_length, self.num_heads, -1) + k = self.k(hidden_states).reshape(seq_length, self.num_heads, -1) + v = self.v(hidden_states).reshape(seq_length, self.num_heads, -1) + q = self._apply_rotary_pos_emb_flashatt(q.unsqueeze(0), rotary_pos_emb).squeeze(0) + k = self._apply_rotary_pos_emb_flashatt(k.unsqueeze(0), rotary_pos_emb).squeeze(0) + + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape( + seq_length, -1 + ) + attn_output = self.proj(attn_output) + return attn_output + + +class Qwen2_5OmniVisionSdpaAttention(nn.Module): + def __init__(self, dim: int, num_heads: int = 16) -> None: + super().__init__() + self.num_heads = num_heads + self.q = nn.Linear(dim, dim, bias=True) + self.k = nn.Linear(dim, dim, bias=True) + self.v = nn.Linear(dim, dim, bias=True) + self.proj = nn.Linear(dim, dim) + + def forward( + self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + q = self.q(hidden_states).reshape(seq_length, self.num_heads, -1) + k = self.k(hidden_states).reshape(seq_length, self.num_heads, -1) + v = self.v(hidden_states).reshape(seq_length, self.num_heads, -1) + q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) + k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) + + attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool) + for i in range(1, len(cu_seqlens)): + attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True + q = q.transpose(0, 1) + k = k.transpose(0, 1) + v = v.transpose(0, 1) + attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0) + attn_output = attn_output.transpose(0, 1) + attn_output = attn_output.reshape(seq_length, -1) + attn_output = self.proj(attn_output) + return attn_output + + +QWEN2_5_OMNI_VISION_ATTENTION_CLASSES = { + "eager": Qwen2_5OmniVisionAttention, + "flash_attention_2": Qwen2_5OmniVisionFlashAttention2, + "sdpa": Qwen2_5OmniVisionSdpaAttention, +} + + +class Qwen2_5OmniVisionBlock(Qwen2_5_VLVisionBlock): + def __init__(self, config: Qwen2_5OmniVisionEncoderConfig) -> None: + super().__init__(config, config._attn_implementation) + self.attn = QWEN2_5_OMNI_VISION_ATTENTION_CLASSES[config._attn_implementation]( + config.hidden_size, num_heads=config.num_heads + ) + + def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor: + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb + ) + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + return hidden_states + + +class Qwen2_5OmniVisionEncoder(Qwen2_5_VisionTransformerPretrainedModel): + config_class = Qwen2_5OmniVisionEncoderConfig + _no_split_modules = ["Qwen2_5OmniVisionBlock"] + + def __init__(self, config: Qwen2_5OmniVisionEncoderConfig, *inputs, **kwargs) -> None: + super().__init__(config, *inputs, **kwargs) + self.blocks = nn.ModuleList([Qwen2_5OmniVisionBlock(config) for _ in range(config.depth)]) + + def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor: + """ + Args: + hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`): + The final hidden states of the model. + grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): + The temporal, height and width of feature shape of each image in LLM. + + Returns: + `torch.Tensor`: hidden_states. + """ + hidden_states = self.patch_embed(hidden_states) + rotary_pos_emb = self.rot_pos_emb(grid_thw) + + window_index, cu_window_seqlens = self.get_window_index(grid_thw) + cu_window_seqlens = torch.tensor( + cu_window_seqlens, + device=hidden_states.device, + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) + + seq_len, _ = hidden_states.size() + hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + hidden_states = hidden_states[window_index, :, :] + hidden_states = hidden_states.reshape(seq_len, -1) + rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + rotary_pos_emb = rotary_pos_emb[window_index, :, :] + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + dim=0, + # Select dtype based on the following factors: + # - FA2 requires that cu_seqlens_q must have dtype int32 + # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw + # See https://github.com/huggingface/transformers/pull/34852 for more information + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + # Modification here + for layer_num, blk in enumerate(self.blocks): + if layer_num in self.fullatt_block_indexes: + cu_seqlens_now = cu_seqlens + else: + cu_seqlens_now = cu_window_seqlens + if self.gradient_checkpointing and self.training: + hidden_states = self._gradient_checkpointing_func( + blk.__call__, hidden_states, cu_seqlens_now, rotary_pos_emb + ) + else: + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens_now, + rotary_pos_emb=rotary_pos_emb, + ) + hidden_states = self.merger(hidden_states) + reverse_indices = torch.argsort(window_index) + hidden_states = hidden_states[reverse_indices, :] + + return hidden_states + + +class Qwen2_5OmniRotaryEmbedding(Qwen2VLRotaryEmbedding): + def __init__(self, config: Qwen2_5OmniThinkerConfig, device=None): + super().__init__(config, device) + + +# It's same as `Qwen2_5_VLAttention`, but talker model's hidden_size isn't divisible by num_heads. +# Removes the value error as a workaround. +class Qwen2_5OmniAttention(Qwen2_5_VLAttention, nn.Module): + def __init__(self, config: Qwen2_5OmniConfig, layer_idx: Optional[int] = None): + nn.Module.__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.is_causal = True + self.attention_dropout = config.attention_dropout + self.rope_scaling = config.rope_scaling + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + self.rotary_emb = Qwen2_5OmniRotaryEmbedding(config=config) + + +class Qwen2MLP(Qwen2_5_VLMLP): + pass + + +QWEN2_5OMNI_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`{config_class}`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Qwen2.5OmniThinker Model outputting raw hidden-states without any specific head on top.", + QWEN2_5OMNI_START_DOCSTRING.format(config_class="Qwen2_5OmniTextConfig"), +) +class Qwen2_5OmniThinkerTextModel(Qwen2_5_VLModel): + config_class = Qwen2_5OmniTextConfig + _no_split_modules = ["Qwen2_5OmniDecoderLayer"] + + def __init__(self, config: Qwen2_5OmniTextConfig): + super().__init__(config) + + +QWEN2_5OMNITHINKER_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + input_features (`torch.FloatTensor` of shape `(batch_size, feature_size, feature_sequence_length)`): + Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by + loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via + the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the + [`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a + tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size), *optional*): + The tensors corresponding to the input images. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`SiglipImageProcessor.__call__`] for details ([]`NewTaskModelProcessor`] uses + [`SiglipImageProcessor`] for processing images). + pixel_values_videos(`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size), *optional*): + The tensors corresponding to the input videos. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`SiglipImageProcessor.__call__`] for details ([]`NewTaskModelProcessor`] uses + [`SiglipImageProcessor`] for processing videos). + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + feature_attention_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`, *optional*): + Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + audio_feature_lengths (`torch.LongTensor` of shape `(num_audios)`, *optional*): + The length of feature shape of each audio in LLM. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + past_key_values (`list(torch.FloatTensor)`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + """The Qwen2.5OmniThinker model which consists of a audio backbone and a language model.""", + QWEN2_5OMNI_START_DOCSTRING.format(config_class="Qwen2_5OmniThinkerConfig"), +) +class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForConditionalGeneration, GenerationMixin): + config_class = Qwen2_5OmniThinkerConfig + base_model_prefix = "thinker" + _no_split_modules = ["Qwen2_5OmniAudioEncoder", "Qwen2_5OmniVisionEncoder"] + + def __init__(self, config: Qwen2_5OmniThinkerConfig): + super().__init__(config) + self.audio_tower = Qwen2_5OmniAudioEncoder._from_config( + config.audio_config, attn_implementation=config._attn_implementation + ) + + self.visual = Qwen2_5OmniVisionEncoder._from_config( + config.vision_config, attn_implementation=config._attn_implementation + ) + + self.vocab_size = config.text_config.vocab_size + self.model = Qwen2_5OmniThinkerTextModel._from_config( + config.text_config, attn_implementation=config._attn_implementation + ) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 + self.spatial_merge_size = config.vision_config.spatial_merge_size + self.rope_deltas = None + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + @add_start_docstrings_to_model_forward(QWEN2_5OMNITHINKER_INPUTS_DOCSTRING) + @replace_return_docstrings( + output_type=Qwen2_5OmniThinkerCausalLMOutputWithPast, config_class="Qwen2_5OmniThinkerConfig" + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + input_features: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + feature_attention_mask: Optional[torch.Tensor] = None, + audio_feature_lengths: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + use_audio_in_video: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + video_second_per_grid: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, Qwen2_5OmniThinkerCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from io import BytesIO + >>> from urllib.request import urlopen + >>> import librosa + >>> from qwen_vl_utils import process_vision_info + >>> from transformers import Qwen2_5OmniProcessor, Qwen2_5OmniThinkerForConditionalGeneration + + >>> thinker = Qwen2_5OmniThinkerForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-Omni-7B") + >>> processor = Qwen2_5OmniProcessor.from_pretrained("Qwen/Qwen2.5-Omni-7B") + + >>> conversations = [ + >>> {'role': 'system', 'content': 'You are a helpful voice chat bot, and please respond to me in a casual conversation manner using random voice.'}, + >>> {"role": "user", "content": [ + >>> {"type": "image", "image_url": "https://www.ilankelman.org/stopsigns/australia.jpg"}, + >>> {"type": "audio", "audio_url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/glass-breaking-151256.mp3"}, + >>> ]}, + >>> ] + + >>> text = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False) + >>> audios = [ librosa.load(BytesIO(urlopen( conversations[1]['content'][1]['audio_url'] ).read()), sr=self.processor.feature_extractor.sampling_rate) ] + >>> images, videos = process_vision_info(conversations) + >>> inputs = processor(text=text, audios=audios, images=images, videos=videos, return_tensors="pt", padding=True) + + >>> # Generate + >>> inputs['use_audio_in_video'] = `True` or `False` + >>> generation = thinker.generate(**inputs, max_new_tokens=2048) + >>> generate_ids = generation[:, inputs.input_ids.size(1):] + + >>> response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if feature_attention_mask is not None: + audio_feature_lengths = torch.sum(feature_attention_mask, dim=1) + input_features = input_features.permute(0, 2, 1)[feature_attention_mask.bool()].permute(1, 0) + else: + audio_feature_lengths = None + if attention_mask is not None and position_ids is None: + if ( + cache_position is None + or (cache_position is not None and cache_position[0] == 0) + or self.rope_deltas is None + ): + delta0 = (1 - attention_mask).sum(dim=-1).unsqueeze(1) + position_ids, rope_deltas = self.get_rope_index( + input_ids, + image_grid_thw, + video_grid_thw, + attention_mask, + use_audio_in_video, + audio_feature_lengths, + video_second_per_grid, + ) + rope_deltas = rope_deltas - delta0 + self.rope_deltas = rope_deltas + else: + batch_size, seq_length = input_ids.shape + delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0 + position_ids = torch.arange(seq_length, device=input_ids.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + + if inputs_embeds is None: + # 1. Extract the input embeddings + inputs_embeds = self.get_input_embeddings()(input_ids) + + # 2. Merge text , audios , image and video + if input_ids is not None and input_ids.shape[1] != 1: # Prefill stage + if input_features is not None: + audio_feat_lengths, audio_output_lengths = self.audio_tower._get_feat_extract_output_lengths( + audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1) + ) + feature_lens = ( + audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1) + ) + audio_outputs = self.audio_tower( + input_features, + feature_lens=feature_lens, + aftercnn_lens=audio_feat_lengths, + ) + audio_features = audio_outputs.last_hidden_state + if audio_features.shape[0] != sum(audio_output_lengths.tolist()): + raise ValueError("length of audio_features should match audio_output_lengths") + audio_mask = ( + (input_ids == self.config.audio_token_id) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(audio_mask, audio_features) + + if pixel_values is not None: + pixel_values = pixel_values.type(self.visual.dtype) + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + image_mask = ( + (input_ids == self.config.image_token_id) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if pixel_values_videos is not None: + pixel_values_videos = pixel_values_videos.type(self.visual.dtype) + video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) + video_mask = ( + (input_ids == self.config.video_token_id) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + if attention_mask is not None: + attention_mask = attention_mask.to(inputs_embeds.device) + + outputs = self.model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.get_text_config().vocab_size + ) + + if not return_dict: + output = (logits,) + outputs + return (loss,) + output if loss is not None else output + + return Qwen2_5OmniThinkerCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=self.rope_deltas, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + pixel_values_videos=None, + image_grid_thw=None, + video_grid_thw=None, + input_features=None, + feature_attention_mask=None, + use_audio_in_video=False, + video_second_per_grid=None, + **kwargs, + ): + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + position_ids=position_ids, + use_cache=use_cache, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + input_features=input_features, + feature_attention_mask=feature_attention_mask, + use_audio_in_video=use_audio_in_video, + video_second_per_grid=video_second_per_grid, + **kwargs, + ) + + model_inputs["position_ids"] = None + + if cache_position[0] != 0: + model_inputs["pixel_values"] = None + model_inputs["pixel_values_videos"] = None + + return model_inputs + + +############################ +# Start Talker # +############################ + + +@dataclass +class Qwen2_5OmniTalkerCausalLMOutputWithPast(ModelOutput): + """ + Base class for Qwen2.5OmniTalker causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[List[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + rope_deltas: Optional[torch.LongTensor] = None + thinker_reply_part: torch.FloatTensor = None + + +@add_start_docstrings( + "The bare Qwen2.5OmniTalker Model outputting raw hidden-states without any specific head on top.", + QWEN2_5OMNI_START_DOCSTRING.format(config_class="Qwen2_5OmniTalkerConfig"), +) +class Qwen2_5OmniTalkerModel(Qwen2_5_VLModel): + config_class = Qwen2_5OmniTalkerConfig + _no_split_modules = ["Qwen2_5OmniTalkerDecoderLayer"] + + def __init__(self, config: Qwen2_5OmniTalkerConfig): + super().__init__(config) + self.embed_tokens = nn.Embedding(config.vocab_size, config.embedding_size, self.padding_idx) + + +class Qwen2_5OmniTalkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForConditionalGeneration, GenerationMixin): + config_class = Qwen2_5OmniTalkerConfig + base_model_prefix = "talker" + + def __init__(self, config: Qwen2_5OmniTalkerConfig): + super().__init__(config) + + self.thinker_to_talker_proj = nn.Linear(config.embedding_size, config.hidden_size) + + self.model = Qwen2_5OmniTalkerModel(config) + self.codebook_size = config.vocab_size + self.codec_head = nn.Linear(config.hidden_size, self.codebook_size, bias=False) + + self.codec_bos_token = config.tts_codec_start_token_id + self.codec_eos_token = config.tts_codec_end_token_id + self.codec_pad_token = config.tts_codec_pad_token_id + self.codec_mask_token = config.tts_codec_mask_token_id + + self.text_bos_token = config.tts_text_start_token_id + self.text_eos_token = config.tts_text_end_token_id + self.text_pad_token = config.tts_text_pad_token_id + + self.spatial_merge_size = self.config.spatial_merge_size + self.rope_deltas = None + + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + @add_start_docstrings_to_model_forward(QWEN2_5OMNITHINKER_INPUTS_DOCSTRING) + @replace_return_docstrings( + output_type=Qwen2_5OmniTalkerCausalLMOutputWithPast, config_class="Qwen2_5OmniTalkerConfig" + ) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + thinker_reply_part: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + input_text_ids: Optional[torch.LongTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + use_audio_in_video: Optional[bool] = None, + audio_feature_lengths: Optional[torch.LongTensor] = None, + video_second_per_grid: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, Qwen2_5OmniTalkerCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from io import BytesIO + >>> from urllib.request import urlopen + >>> import librosa + >>> from transformers import AutoProcessor, Qwen2_5OmniTalkerForConditionalGeneration + + >>> model = Qwen2_5OmniTalkerForConditionalGeneration.from_pretrained("Qwen/Qwen2-Audio-7B") + >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2-Audio-7B") + + >>> prompt = "<|audio_bos|><|AUDIO|><|audio_eos|>Generate the caption in English:" + >>> url = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/glass-breaking-151256.mp3" + >>> audio, _ = librosa.load(BytesIO(urlopen(url).read()), sr=self.processor.feature_extractor.sampling_rate) + + >>> inputs = processor(text=prompt, audios=audio, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(**inputs, max_length=30) + >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Generate the caption in English: Glass is breaking." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if attention_mask is not None and position_ids is None: + if ( + cache_position is None + or (cache_position is not None and cache_position[0] == 0) + or self.rope_deltas is None + ): + position_ids, rope_deltas = self.get_rope_index( + input_text_ids, + image_grid_thw, + video_grid_thw, + attention_mask, + use_audio_in_video, + audio_feature_lengths, + video_second_per_grid, + ) + + inputs_embeds[:, -1, :] += self.get_input_embeddings()( + torch.tensor([self.codec_bos_token], dtype=torch.long, device=inputs_embeds.device) + ) + inputs_embeds[:, -2, :] += self.get_input_embeddings()( + torch.tensor([self.codec_pad_token], dtype=torch.long, device=inputs_embeds.device) + ) + self.rope_deltas = rope_deltas + + else: + batch_size, seq_length = input_ids.shape + delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0 + position_ids = torch.arange(seq_length, device=input_ids.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + + if inputs_embeds is None: + # 1. Inference tokens after second token + codec_embeds = self.get_input_embeddings()(input_ids) + inputs_embeds = codec_embeds + thinker_reply_part[:, :1, :] + if thinker_reply_part.shape[1] > 1: + thinker_reply_part = thinker_reply_part[:, 1:, :] + + talker_lm_input = self.thinker_to_talker_proj(inputs_embeds) + + if attention_mask is not None: + attention_mask = attention_mask.to(inputs_embeds.device) + + outputs = self.model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=talker_lm_input, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.codec_head(hidden_states) + logits = logits.float() + + loss = None + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return Qwen2_5OmniTalkerCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=hidden_states, + attentions=outputs.attentions, + rope_deltas=self.rope_deltas, + thinker_reply_part=thinker_reply_part, + ) + + def _get_initial_cache_position(self, input_ids, model_kwargs): + # Talker needs to calculate cache_position with input_ids, so pop inputs_embeds temporarily + inputs_embeds = model_kwargs.pop("inputs_embeds") + model_kwargs = super()._get_initial_cache_position(input_ids, model_kwargs) + model_kwargs["inputs_embeds"] = inputs_embeds + return model_kwargs + + # prepare inputs for talker lm generation + def prepare_inputs_for_generation( + self, + input_ids, + input_text_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + thinker_reply_part=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + pixel_values_videos=None, + image_grid_thw=None, + video_grid_thw=None, + input_audio_features=None, + audio_feature_attention_mask=None, + audio_feature_lengths=None, + use_audio_in_video=False, + video_second_per_grid=None, + **kwargs, + ): + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values, + attention_mask, + inputs_embeds, + cache_position, + use_cache=use_cache, + thinker_reply_part=thinker_reply_part, + input_text_ids=input_text_ids, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + use_audio_in_video=use_audio_in_video, + audio_feature_lengths=audio_feature_lengths, + video_second_per_grid=video_second_per_grid, + **kwargs, + ) + + model_inputs["position_ids"] = None + + return model_inputs + + def _update_model_kwargs_for_generation( + self, + outputs: ModelOutput, + model_kwargs: Dict[str, Any], + is_encoder_decoder: bool = False, + num_new_tokens: int = 1, + ) -> Dict[str, Any]: + model_kwargs = super()._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder, num_new_tokens + ) + + if getattr(outputs, "thinker_reply_part", None) is not None: + model_kwargs["thinker_reply_part"] = outputs.thinker_reply_part + + return model_kwargs + + +############################ +# Start Token2Wav # +############################ + + +# Using custom RoPE, will use LlamaRotaryEmbedding next version +class Qwen2_5OmniDiTRotaryEmbedding(nn.Module): + def __init__(self, dim, base=10000): + super().__init__() + + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) + + def forward(self, x): + batch_size, seq_len = x.shape[0], x.shape[1] + t = torch.arange(seq_len, device=x.device) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = t.unsqueeze(1).float() @ self.inv_freq.unsqueeze(0).float() + freqs = torch.stack((freqs, freqs), dim=-1) + freqs = freqs.reshape(*freqs.shape[:-2], -1) + freqs = freqs.repeat(batch_size, *([1] * freqs.dim())) + cos = freqs.cos() + sin = freqs.sin() + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# Modified from Llama with a different rotate function, will fixed in next release +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + + def rotate_half_codec(x): + # x = rearrange(x, "... (d r) -> ... d r", r=2) + x = x.reshape(*x.shape[:-1], -1, 2) + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + return x.reshape(*x.shape[:-2], -1) + + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half_codec(q) * sin) + k_embed = (k * cos) + (rotate_half_codec(k) * sin) + return q_embed, k_embed + + +class TimeDelayNetBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + dilation, + ): + super().__init__() + self.conv = nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + dilation=dilation, + padding="same", + padding_mode="reflect", + ) + self.activation = nn.ReLU() + + def forward(self, hidden_states: torch.Tensor): + return self.activation(self.conv(hidden_states)) + + +class Res2NetBlock(torch.nn.Module): + def __init__(self, in_channels, out_channels, scale=8, kernel_size=3, dilation=1): + super().__init__() + + in_channel = in_channels // scale + hidden_channel = out_channels // scale + + self.blocks = nn.ModuleList( + [ + TimeDelayNetBlock( + in_channel, + hidden_channel, + kernel_size=kernel_size, + dilation=dilation, + ) + for i in range(scale - 1) + ] + ) + self.scale = scale + + def forward(self, hidden_states): + outputs = [] + for i, hidden_part in enumerate(torch.chunk(hidden_states, self.scale, dim=1)): + if i == 0: + output_part = hidden_part + elif i == 1: + output_part = self.blocks[i - 1](hidden_part) + else: + output_part = self.blocks[i - 1](hidden_part + output_part) + outputs.append(output_part) + output = torch.cat(outputs, dim=1) + return output + + +class SqueezeExcitationBlock(nn.Module): + def __init__(self, in_channels, se_channels, out_channels): + super().__init__() + + self.conv1 = nn.Conv1d( + in_channels=in_channels, + out_channels=se_channels, + kernel_size=1, + padding="same", + padding_mode="reflect", + ) + self.relu = nn.ReLU(inplace=True) + self.conv2 = nn.Conv1d( + in_channels=se_channels, + out_channels=out_channels, + kernel_size=1, + padding="same", + padding_mode="reflect", + ) + self.sigmoid = nn.Sigmoid() + + def forward(self, hidden_states): + hidden_states_mean = hidden_states.mean(dim=2, keepdim=True) + + hidden_states_mean = self.relu(self.conv1(hidden_states_mean)) + hidden_states_mean = self.sigmoid(self.conv2(hidden_states_mean)) + + return hidden_states * hidden_states_mean + + +class AttentiveStatisticsPooling(nn.Module): + """This class implements an attentive statistic pooling layer for each channel. + It returns the concatenated mean and std of the input tensor. + """ + + def __init__(self, channels, attention_channels=128): + super().__init__() + + self.eps = 1e-12 + self.tdnn = TimeDelayNetBlock(channels * 3, attention_channels, 1, 1) + self.tanh = nn.Tanh() + self.conv = nn.Conv1d( + in_channels=attention_channels, + out_channels=channels, + kernel_size=1, + padding="same", + padding_mode="reflect", + ) + + def _length_to_mask(self, length, max_len=None, dtype=None, device=None): + """Creates a binary mask for each sequence. + + Reference: https://discuss.pytorch.org/t/how-to-generate-variable-length-mask/23397/3 + + Arguments + --------- + length : torch.LongTensor + Containing the length of each sequence in the batch. Must be 1D. + max_len : int + Max length for the mask, also the size of the second dimension. + dtype : torch.dtype, default: None + The dtype of the generated mask. + device: torch.device, default: None + The device to put the mask variable. + + Returns + ------- + mask : tensor + The binary mask. + """ + + if max_len is None: + max_len = length.max().long().item() # using arange to generate mask + mask = torch.arange(max_len, device=length.device, dtype=length.dtype).expand( + len(length), max_len + ) < length.unsqueeze(1) + + mask = torch.as_tensor(mask, dtype=dtype, device=device) + return mask + + def _compute_statistics(self, x, m, dim=2): + mean = (m * x).sum(dim) + std = torch.sqrt((m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(self.eps)) + return mean, std + + def forward(self, hidden_states): + seq_length = hidden_states.shape[-1] + lengths = torch.ones(hidden_states.shape[0], device=hidden_states.device) + + # Make binary mask of shape [N, 1, L] + mask = self._length_to_mask( + lengths * seq_length, max_len=seq_length, dtype=hidden_states.dtype, device=hidden_states.device + ) + mask = mask.unsqueeze(1) + + # Expand the temporal context of the pooling layer by allowing the + # self-attention to look at global properties of the utterance. + total = mask.sum(dim=2, keepdim=True) + + mean, std = self._compute_statistics(hidden_states, mask / total) + mean = mean.unsqueeze(2).repeat(1, 1, seq_length) + std = std.unsqueeze(2).repeat(1, 1, seq_length) + attention = torch.cat([hidden_states, mean, std], dim=1) + + # Apply layers + attention = self.conv(self.tanh(self.tdnn(attention))) + + # Filter out zero-paddings + attention = attention.masked_fill(mask == 0, float("-inf")) + + attention = F.softmax(attention, dim=2) + mean, std = self._compute_statistics(hidden_states, attention) + # Append mean and std of the batch + pooled_stats = torch.cat((mean, std), dim=1) + pooled_stats = pooled_stats.unsqueeze(2) + + return pooled_stats + + +class SqueezeExcitationRes2NetBlock(nn.Module): + """An implementation of building block in ECAPA-TDNN, i.e., + TDNN-Res2Net-TDNN-SqueezeExcitationBlock. + """ + + def __init__( + self, + in_channels, + out_channels, + res2net_scale=8, + se_channels=128, + kernel_size=1, + dilation=1, + ): + super().__init__() + self.out_channels = out_channels + self.tdnn1 = TimeDelayNetBlock( + in_channels, + out_channels, + kernel_size=1, + dilation=1, + ) + self.res2net_block = Res2NetBlock(out_channels, out_channels, res2net_scale, kernel_size, dilation) + self.tdnn2 = TimeDelayNetBlock( + out_channels, + out_channels, + kernel_size=1, + dilation=1, + ) + self.se_block = SqueezeExcitationBlock(out_channels, se_channels, out_channels) + + def forward(self, hidden_state): + residual = hidden_state + + hidden_state = self.tdnn1(hidden_state) + hidden_state = self.res2net_block(hidden_state) + hidden_state = self.tdnn2(hidden_state) + hidden_state = self.se_block(hidden_state) + + return hidden_state + residual + + +class ECAPA_TimeDelayNet(torch.nn.Module): + """An implementation of the speaker embedding model in a paper. + "ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in + TDNN Based Speaker Verification" (https://arxiv.org/abs/2005.07143). + """ + + def __init__(self, config: Qwen2_5OmniDiTConfig): + super().__init__() + if len(config.enc_channels) != len(config.enc_kernel_sizes) or len(config.enc_channels) != len( + config.enc_dilations + ): + raise ValueError("enc_channels, enc_kernel_sizes and enc_dilations should have same length") + self.channels = config.enc_channels + self.blocks = nn.ModuleList() + + # The initial TDNN layer + self.blocks.append( + TimeDelayNetBlock( + config.mel_dim, + config.enc_channels[0], + config.enc_kernel_sizes[0], + config.enc_dilations[0], + ) + ) + + # SE-Res2Net layers + for i in range(1, len(config.enc_channels) - 1): + self.blocks.append( + SqueezeExcitationRes2NetBlock( + config.enc_channels[i - 1], + config.enc_channels[i], + res2net_scale=config.enc_res2net_scale, + se_channels=config.enc_se_channels, + kernel_size=config.enc_kernel_sizes[i], + dilation=config.enc_dilations[i], + ) + ) + + # Multi-layer feature aggregation + self.mfa = TimeDelayNetBlock( + config.enc_channels[-1], + config.enc_channels[-1], + config.enc_kernel_sizes[-1], + config.enc_dilations[-1], + ) + + # Attentive Statistical Pooling + self.asp = AttentiveStatisticsPooling( + config.enc_channels[-1], + attention_channels=config.enc_attention_channels, + ) + + # Final linear transformation + self.fc = nn.Conv1d( + in_channels=config.enc_channels[-1] * 2, + out_channels=config.enc_dim, + kernel_size=1, + padding="same", + padding_mode="reflect", + ) + + def forward(self, hidden_states): + # Minimize transpose for efficiency + hidden_states = hidden_states.transpose(1, 2) + + hidden_states_list = [] + for layer in self.blocks: + hidden_states = layer(hidden_states) + hidden_states_list.append(hidden_states) + + # Multi-layer feature aggregation + hidden_states = torch.cat(hidden_states_list[1:], dim=1) + hidden_states = self.mfa(hidden_states) + + # Attentive Statistical Pooling + hidden_states = self.asp(hidden_states) + + # Final linear transformation + hidden_states = self.fc(hidden_states) + + hidden_states = hidden_states.squeeze(-1) + return hidden_states + + +class DiTInputEmbedding(nn.Module): + def __init__(self, config: Qwen2_5OmniDiTConfig): + super().__init__() + self.proj = nn.Linear( + config.mel_dim + config.enc_dim + config.enc_emb_dim + config.emb_dim, + config.hidden_size, + ) + self.spk_encoder = ECAPA_TimeDelayNet(config) + + def forward( + self, + hidden_states: torch.Tensor, + speaker_embedding: torch.Tensor, + condition_vector: torch.Tensor, + code_embed: torch.Tensor, + drop_audio_cond: Optional[bool] = False, + code_embed_uncond: Optional[bool] = None, + apply_cfg: Optional[bool] = True, + ): + if apply_cfg: + hidden_states = torch.cat([hidden_states, hidden_states], dim=0) + speaker_embedding = torch.cat([speaker_embedding, torch.zeros_like(speaker_embedding)], dim=0) + condition_vector = torch.cat([condition_vector, torch.zeros_like(condition_vector)], dim=0) + code_embed = torch.cat([code_embed, code_embed_uncond], dim=0) + elif drop_audio_cond: # cfg for cond audio + condition_vector = torch.zeros_like(condition_vector) + speaker_embedding = torch.zeros_like(speaker_embedding) + condition_vector = self.spk_encoder(condition_vector).unsqueeze(1).repeat(1, hidden_states.size(1), 1) + hidden_states = self.proj(torch.cat((hidden_states, condition_vector, code_embed, speaker_embedding), dim=-1)) + + return hidden_states + + +# Transformer backbone using DiT blocks +class DiTCodecEmbedding(nn.Module): + def __init__(self, codec_num_embeds, codec_dim, repeats): + super().__init__() + self.repeats = repeats + self.codec_embed = nn.Embedding(codec_num_embeds + 1, codec_dim) + + def forward(self, code, drop_code=False): + if drop_code: + code = torch.zeros_like(code) + code_embed = self.codec_embed(code) + + code_embed = torch.repeat_interleave(code_embed, repeats=self.repeats, dim=1) + return code_embed + + +# AdaLayerNormZero +# return with modulated x for attn input, and params for later mlp modulation +class Qwen2_5_OmniAdaLayerNormZero(nn.Module): + def __init__(self, dim): + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(dim, dim * 6) + + self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + + def forward(self, hidden_states, emb=None): + emb = self.linear(self.silu(emb)) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1) + + hidden_states = self.norm(hidden_states) * (1 + scale_msa[:, None]) + shift_msa[:, None] + return hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp + + +# AdaLayerNormZero for final layer +# return only with modulated x for attn input, cuz no more mlp modulation +class Qwen2_5_OmniAdaLayerNormZero_Final(nn.Module): + def __init__(self, dim): + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(dim, dim * 2) + + self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + + def forward(self, hidden_states, emb): + emb = self.linear(self.silu(emb)) + scale, shift = torch.chunk(emb, 2, dim=1) + + hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :] + return hidden_states + + +# FeedForward +class DiTMLP(nn.Module): + def __init__(self, dim, mult=4, dropout=0.0): + super().__init__() + inner_dim = int(dim * mult) + + self.ff = nn.ModuleList( + [ + nn.Linear(dim, inner_dim), + nn.GELU(approximate="tanh"), + nn.Dropout(dropout), + nn.Linear(inner_dim, dim), + ] + ) + + def forward(self, hidden_states): + for layer in self.ff: + hidden_states = layer(hidden_states) + return hidden_states + + +class DiTAttention(nn.Module): + def __init__(self, config: Qwen2_5OmniDiTConfig): + super().__init__() + + self.config = config + self.dim = config.hidden_size + self.heads = config.num_attention_heads + self.inner_dim = config.head_dim * config.num_attention_heads + self.dropout = config.dropout + self._attn_implementation = config._attn_implementation + self.is_causal = False + + self.to_q = nn.Linear(config.hidden_size, self.inner_dim) + self.to_k = nn.Linear(config.hidden_size, self.inner_dim) + self.to_v = nn.Linear(config.hidden_size, self.inner_dim) + + self.to_out = nn.ModuleList([nn.Linear(self.inner_dim, config.hidden_size), nn.Dropout(config.dropout)]) + + def forward( + self, + hidden_states, # noised input x + position_embeddings=None, # rotary position embedding for x + attention_mask=None, + ) -> torch.Tensor: + batch_size = hidden_states.shape[0] + + # `sample` projections. + query = self.to_q(hidden_states) + key = self.to_k(hidden_states) + value = self.to_v(hidden_states) + + # attention + inner_dim = key.shape[-1] + head_dim = inner_dim // self.heads + query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + + # apply rotary position embedding + # Due to training process, only first head is applied with RoPE, will be fixed at next release + cos, sin = position_embeddings + query[:, :1], key[:, :1] = apply_rotary_pos_emb(query[:, :1], key[:, :1], cos, sin) + + attention_interface = ALL_ATTENTION_FUNCTIONS[self._attn_implementation] + attention_weights, _ = attention_interface( + self, + query, + key, + value, + attention_mask=attention_mask, + is_causal=False, + ) + + # mask. e.g. inference got a batch with different target durations, mask out the padding + attention_weights = attention_weights.reshape(batch_size, -1, self.heads * head_dim) + attention_weights = attention_weights.to(query.dtype) + + # linear proj + attention_output = self.to_out[0](attention_weights) + attention_output = self.to_out[1](attention_output) + + return attention_output + + +# time step conditioning embedding +class SinusPositionEmbedding(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, hidden_states, scale=1000): + device = hidden_states.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb) + emb = scale * hidden_states.unsqueeze(1) * emb.unsqueeze(0) + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb.type_as(hidden_states) + + +class DiTTimestepEmbedding(nn.Module): + def __init__(self, dim, freq_embed_dim=256): + super().__init__() + self.time_embed = SinusPositionEmbedding(freq_embed_dim) + self.time_mlp = nn.ModuleList([nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim)]) + + def forward(self, timestep): # noqa: F821 + time_hidden = self.time_embed(timestep) + time_hidden = time_hidden.to(timestep.dtype) + for layer in self.time_mlp: + time_hidden = layer(time_hidden) # b d + return time_hidden + + +class DiTDecoderLayer(nn.Module): + def __init__(self, config: Qwen2_5OmniDiTConfig, look_ahead_block=0, look_backward_block=0): + super().__init__() + self.attn_norm = Qwen2_5_OmniAdaLayerNormZero(config.hidden_size) + + self.attn = DiTAttention(config) + self.look_ahead_block = look_ahead_block + self.look_backward_block = look_backward_block + self.ff_norm = nn.LayerNorm(config.hidden_size, elementwise_affine=False, eps=1e-6) + self.ff = DiTMLP(dim=config.hidden_size, mult=config.ff_mult, dropout=config.dropout) + + def forward( + self, hidden_states, timestep, position_embeddings=None, block_diff=None + ): # x: noised input, t: time embedding + # pre-norm & modulation for attention input + norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(hidden_states, emb=timestep) + + # attention + attn_output = self.attn( + hidden_states=norm, + position_embeddings=position_embeddings, + attention_mask=(block_diff >= -float(self.look_backward_block)) + & (block_diff <= float(self.look_ahead_block)), + ) + + # process attention output for input x + hidden_states = hidden_states + gate_msa.unsqueeze(1) * attn_output + + norm = self.ff_norm(hidden_states) * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + ff_output = self.ff(norm) + hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output + + return hidden_states + + +class SnakeBeta(nn.Module): + """ + A modified Snake function which uses separate parameters for the magnitude of the periodic components + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + References: + - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + """ + + def __init__(self, in_features, alpha=1.0): + super().__init__() + self.in_features = in_features + + # initialize alpha + self.alpha = Parameter(torch.zeros(in_features) * alpha) + self.beta = Parameter(torch.zeros(in_features) * alpha) + + self.no_div_by_zero = 0.000000001 + + def forward(self, hidden_states): + """ + Forward pass of the function. + Applies the function to the input elementwise. + SnakeBeta ∶= x + 1/b * sin^2 (xa) + """ + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] + beta = self.beta.unsqueeze(0).unsqueeze(-1) + alpha = torch.exp(alpha) + beta = torch.exp(beta) + hidden_states = hidden_states + (1.0 / (beta + self.no_div_by_zero)) * torch.pow( + torch.sin(hidden_states * alpha), 2 + ) + + return hidden_states + + +def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): + """Generates a 1D Kaiser-windowed sinc filter. + + Args: + cutoff (float): Normalized cutoff frequency (0 to 0.5). + half_width (float): Transition bandwidth. + kernel_size (int): Number of filter taps. + + Returns: + torch.Tensor: A tensor of shape (1, 1, kernel_size) representing the filter. + """ + is_even = kernel_size % 2 == 0 + half_size = kernel_size // 2 + + # Compute Kaiser window parameters + delta_f = 4 * half_width + attenuation = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 + + if attenuation > 50.0: + beta = 0.1102 * (attenuation - 8.7) + elif attenuation >= 21.0: + beta = 0.5842 * (attenuation - 21) ** 0.4 + 0.07886 * (attenuation - 21.0) + else: + beta = 0.0 + + kaiser_window = torch.kaiser_window(kernel_size, beta=beta, periodic=False, dtype=torch.float32) + + # Compute time indices + if is_even: + time_indices = torch.arange(-half_size, half_size) + 0.5 + else: + time_indices = torch.arange(kernel_size) - half_size + + # Compute sinc filter + if cutoff == 0: + return torch.zeros((1, 1, kernel_size), dtype=torch.float32) # Ensures correct shape + + sinc_filter = torch.sinc(2 * cutoff * time_indices) + normalized_filter = 2 * cutoff * kaiser_window * sinc_filter + + # Normalize to ensure sum = 1 (avoid leakage of constant component) + normalized_filter /= normalized_filter.sum() + + return normalized_filter.view(1, 1, kernel_size) + + +class UpSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + self.ratio = ratio + self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + self.stride = ratio + self.pad = self.kernel_size // ratio - 1 + self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2 + self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2 + + filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size) + self.register_buffer("filter", filter, persistent=False) + + def forward(self, hidden_states): + channels = hidden_states.shape[1] + + hidden_states = F.pad(hidden_states, (self.pad, self.pad), mode="replicate") + hidden_states = self.ratio * F.conv_transpose1d( + hidden_states, self.filter.expand(channels, -1, -1), stride=self.stride, groups=channels + ) + hidden_states = hidden_states[..., self.pad_left : -self.pad_right] + + return hidden_states + + +class DownSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + cutoff = 0.5 / ratio + half_width = 0.6 / ratio + + if cutoff < 0.0: + raise ValueError("Minimum cutoff must be larger than zero.") + if cutoff > 0.5: + raise ValueError("A cutoff above 0.5 does not make sense.") + + self.even = kernel_size % 2 == 0 + self.pad_left = kernel_size // 2 - int(self.even) + self.pad_right = kernel_size // 2 + self.stride = ratio + filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size) + self.register_buffer("filter", filter, persistent=False) + + def forward(self, hidden_states): + channels = hidden_states.shape[1] + hidden_states = F.pad(hidden_states, (self.pad_left, self.pad_right), mode="replicate") + out = F.conv1d(hidden_states, self.filter.expand(channels, -1, -1), stride=self.stride, groups=channels) + return out + + +class TorchActivation1d(nn.Module): + def __init__( + self, + activation, + up_ratio: int = 2, + down_ratio: int = 2, + up_kernel_size: int = 12, + down_kernel_size: int = 12, + ): + super().__init__() + if not callable(activation): + raise ValueError("Activation function must be callable") + self.act = activation + self.upsample = UpSample1d(up_ratio, up_kernel_size) + self.downsample = DownSample1d(down_ratio, down_kernel_size) + + def forward(self, hidden_states): + hidden_states = self.upsample(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.downsample(hidden_states) + + return hidden_states + + +class AMPBlock(torch.nn.Module): + def __init__( + self, + channels, + kernel_size=3, + dilation=(1, 3, 5), + ): + super().__init__() + + self.convs1 = nn.ModuleList( + [ + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=self._get_padding(kernel_size, dilation[0]), + ), + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=self._get_padding(kernel_size, dilation[1]), + ), + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=self._get_padding(kernel_size, dilation[2]), + ), + ] + ) + + self.convs2 = nn.ModuleList( + [ + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=self._get_padding(kernel_size, 1), + ), + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=self._get_padding(kernel_size, 1), + ), + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=self._get_padding(kernel_size, 1), + ), + ] + ) + + self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers + + self.activations = nn.ModuleList( + [TorchActivation1d(activation=SnakeBeta(channels)) for _ in range(self.num_layers)] + ) + + def _get_padding(self, kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + def forward(self, hidden_states): + acts1, acts2 = self.activations[::2], self.activations[1::2] + for conv1, conv2, act1, act2 in zip(self.convs1, self.convs2, acts1, acts2): + residual = hidden_states + hidden_states = act1(hidden_states) + hidden_states = conv1(hidden_states) + hidden_states = act2(hidden_states) + hidden_states = conv2(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +@add_start_docstrings( + "The full Qwen2.5Omni Token2WavBigVGAN model. Which take mel spectrogram as input and predict waveform.", + QWEN2_5OMNI_START_DOCSTRING.format(config_class="Qwen2_5OmniBigVGANConfig"), +) +class Qwen2_5OmniToken2WavBigVGANModel(Qwen2_5OmniPreTrainedModel): + config_class = Qwen2_5OmniBigVGANConfig + + def __init__(self, config: Qwen2_5OmniBigVGANConfig): + super().__init__(config) + self.num_residual_blocks = len(config.resblock_kernel_sizes) + self.num_upsample_layers = len(config.upsample_rates) + + self.conv_pre = nn.Conv1d(config.mel_dim, config.upsample_initial_channel, 7, 1, padding=3) + + # Removing extra ModuleList breaks official state dict + ups = [ + nn.ModuleList( + [ + nn.ConvTranspose1d( + config.upsample_initial_channel // (2**layer_idx), + config.upsample_initial_channel // (2 ** (layer_idx + 1)), + kernel_size, + stride, + padding=(kernel_size - stride) // 2, + ) + ] + ) + for layer_idx, (stride, kernel_size) in enumerate(zip(config.upsample_rates, config.upsample_kernel_sizes)) + ] + self.ups = nn.ModuleList(ups) + + self.resblocks = nn.ModuleList( + [ + AMPBlock(config.upsample_initial_channel // (2 ** (layer_idx + 1)), kernel_size, dilation) + for layer_idx in range(self.num_upsample_layers) + for kernel_size, dilation in zip(config.resblock_kernel_sizes, config.resblock_dilation_sizes) + ] + ) + + self.activation_post = TorchActivation1d( + activation=SnakeBeta(config.upsample_initial_channel // (2**self.num_upsample_layers)) + ) + self.conv_post = nn.Conv1d( + config.upsample_initial_channel // (2**self.num_upsample_layers), 1, 7, 1, padding=3, bias=False + ) + + def normalize_spectrogram(self, spectrogram, max_value, min_db): + return torch.clamp((2 * max_value) * ((spectrogram - min_db) / (-min_db)) - max_value, -max_value, max_value) + + def amplitude_to_db(self, amplitude, min_db_level): + min_level = torch.exp( + torch.tensor(min_db_level / 20.0 * np.log(10), device=amplitude.device, dtype=amplitude.dtype) + ) + return 20 * torch.log10(torch.clamp(amplitude, min=min_level)) + + def process_mel_spectrogram(self, mel_spectrogram): + amplitude_spectrum = torch.exp(mel_spectrogram) + decibel_spectrum = self.amplitude_to_db(amplitude_spectrum, -115) - 20 + return self.normalize_spectrogram(decibel_spectrum, 1, -115) + + def forward(self, mel_spectrogram): + processed_spectrogram = self.process_mel_spectrogram(mel_spectrogram) + hidden_representation = self.conv_pre(processed_spectrogram) + + for layer_index in range(self.num_upsample_layers): + hidden_representation = self.ups[layer_index][0](hidden_representation) + residual_output = sum( + self.resblocks[layer_index * self.num_residual_blocks + block_index](hidden_representation) + for block_index in range(self.num_residual_blocks) + ) + residual_output = residual_output / self.num_residual_blocks + hidden_representation = residual_output + + hidden_representation = self.activation_post(hidden_representation) + output_waveform = self.conv_post(hidden_representation) + return torch.clamp(output_waveform, min=-1.0, max=1.0).squeeze().cpu() + + +class RungeKutta4ODESolver: + def __init__(self, function, initial_value): + self.function = function + self.initial_value = initial_value + + self._one_third = 1 / 3 + self._two_thirds = 2 / 3 + + def _rk4_step(self, function, time_start, time_step, time_end, value_start, function_value_start=None): + k1 = function_value_start if function_value_start is not None else function(time_start, value_start) + k2 = function(time_start + time_step * self._one_third, value_start + time_step * k1 * self._one_third) + k3 = function(time_start + time_step * self._two_thirds, value_start + time_step * (k2 - k1 * self._one_third)) + k4 = function(time_end, value_start + time_step * (k1 - k2 + k3)) + return (k1 + 3 * (k2 + k3) + k4) * time_step / 8 + + def _compute_step(self, function, time_start, time_step, time_end, value_start): + function_value_start = function(time_start, value_start) + return self._rk4_step( + function, time_start, time_step, time_end, value_start, function_value_start=function_value_start + ), function_value_start + + def _linear_interpolation(self, time_start, time_end, value_start, value_end, time_point): + if time_point == time_start: + return value_start + if time_point == time_end: + return value_end + weight = (time_point - time_start) / (time_end - time_start) + return value_start + weight * (value_end - value_start) + + def integrate(self, time_points): + solution = torch.empty( + len(time_points), + *self.initial_value.shape, + dtype=self.initial_value.dtype, + device=self.initial_value.device, + ) + solution[0] = self.initial_value + + current_index = 1 + current_value = self.initial_value + for time_start, time_end in zip(time_points[:-1], time_points[1:]): + time_step = time_end - time_start + delta_value, _ = self._compute_step(self.function, time_start, time_step, time_end, current_value) + next_value = current_value + delta_value + + while current_index < len(time_points) and time_end >= time_points[current_index]: + solution[current_index] = self._linear_interpolation( + time_start, time_end, current_value, next_value, time_points[current_index] + ) + current_index += 1 + + current_value = next_value + + return solution + + +@add_start_docstrings( + "The full Qwen2.5Omni Token2WavDiT model. Which take speech tokens as input and predict mel spectrogram.", + QWEN2_5OMNI_START_DOCSTRING.format(config_class="Qwen2_5OmniDiTConfig"), +) +class Qwen2_5OmniToken2WavDiTModel(Qwen2_5OmniPreTrainedModel): + config_class = Qwen2_5OmniDiTConfig + _no_split_modules = ["DiTDecoderLayer"] + + def __init__(self, config: Qwen2_5OmniDiTConfig): + super().__init__(config) + self.mel_dim = config.mel_dim + self.repeats = config.repeats + self.time_embed = DiTTimestepEmbedding(config.hidden_size) + + self.text_embed = DiTCodecEmbedding(config.num_embeds, config.emb_dim, config.repeats) + self.input_embed = DiTInputEmbedding(config) + + self.rotary_embed = Qwen2_5OmniDiTRotaryEmbedding(config.head_dim) + + self.hidden_size = config.hidden_size + self.layers = config.num_hidden_layers + self.block_size = config.block_size + self.num_attention_heads = config.num_attention_heads + + self.transformer_blocks = nn.ModuleList() + for i in range(config.num_hidden_layers): + self.transformer_blocks.append( + DiTDecoderLayer( + config, + look_ahead_block=1 if i in config.look_ahead_layers else 0, + look_backward_block=1 if i in config.look_backward_layers else 0, + ) + ) + + self.norm_out = Qwen2_5_OmniAdaLayerNormZero_Final(config.hidden_size) # final modulation + self.proj_out = nn.Linear(config.hidden_size, config.mel_dim) + + def _create_block_diff(self, hidden_states): + batch, seq_len = hidden_states.shape[0], hidden_states.shape[1] + block_indices = torch.arange(seq_len, device=hidden_states.device) // self.block_size # [seq_length] + + block_i = block_indices.unsqueeze(1) # [seq_length, 1] + block_j = block_indices.unsqueeze(0) # [1, seq_length] + block_diff = block_j - block_i # (n, n) + + return block_diff.expand(batch, self.num_attention_heads, seq_len, seq_len) + + def forward( + self, + hidden_states, + condition_vector, + speaker_embedding, + quantized_code, + time_step, + drop_audio_conditioning=False, + drop_code=False, + apply_cfg=True, + ): + batch_size = hidden_states.shape[0] + if time_step.ndim == 0: + time_step = time_step.repeat(batch_size) + + # Compute embeddings + time_embedding = self.time_embed(time_step) + text_embedding = self.text_embed(quantized_code, drop_code=False if apply_cfg else drop_code) + text_embedding_unconditioned = self.text_embed(quantized_code, drop_code=True) if apply_cfg else None + + hidden_states = self.input_embed( + hidden_states, + speaker_embedding, + condition_vector, + text_embedding, + drop_audio_cond=drop_audio_conditioning, + code_embed_uncond=text_embedding_unconditioned, + apply_cfg=apply_cfg, + ) + + # Compute positional encodings + position_embeddings = self.rotary_embed(hidden_states) + blockwise_difference = self._create_block_diff(hidden_states) + + # Transformer blocks + for transformer_block in self.transformer_blocks: + hidden_states = transformer_block( + hidden_states, + time_embedding, + position_embeddings=position_embeddings, + block_diff=blockwise_difference, + ) + + hidden_states = self.norm_out(hidden_states, time_embedding) + output = self.proj_out(hidden_states) + + return output + + @torch.no_grad() + def sample( + self, + conditioning_vector, + reference_mel_spectrogram, + quantized_code, + num_steps=10, + guidance_scale=0.5, + sway_coefficient=-1.0, + ): + noise_initialization = torch.randn([1, 30000, self.mel_dim], dtype=reference_mel_spectrogram.dtype) + maximum_duration = quantized_code.shape[1] * self.repeats + initial_state = noise_initialization[:, :maximum_duration].to(quantized_code.device) + batch_size = reference_mel_spectrogram.shape[0] + conditioning_vector = conditioning_vector.unsqueeze(1).repeat(1, maximum_duration, 1) + + if batch_size != 1: + raise ValueError("Only batch size = 1 is currently supported") + + def ode_function(time_step, hidden_states): + if guidance_scale < 1e-5: + prediction = self( + hidden_states=hidden_states, + speaker_embedding=conditioning_vector, + condition_vector=reference_mel_spectrogram, + quantized_code=quantized_code, + time_step=time_step, + drop_audio_conditioning=False, + drop_code=False, + ) + return prediction + + model_output = self( + hidden_states=hidden_states, + quantized_code=quantized_code, + speaker_embedding=conditioning_vector, + condition_vector=reference_mel_spectrogram, + time_step=time_step, + apply_cfg=True, + ) + guided_prediction, null_prediction = torch.chunk(model_output, 2, dim=0) + return guided_prediction + (guided_prediction - null_prediction) * guidance_scale + + initial_time = 0 + time_embedding = torch.linspace( + initial_time, 1, num_steps, device=quantized_code.device, dtype=conditioning_vector.dtype + ) + + if sway_coefficient is not None: + time_embedding += sway_coefficient * (torch.cos(torch.pi / 2 * time_embedding) - 1 + time_embedding) + + ode_solver = RungeKutta4ODESolver(function=ode_function, initial_value=initial_state) + solution_trajectory = ode_solver.integrate(time_embedding) + + generated_waveform = solution_trajectory[-1] + generated_mel_spectrogram = generated_waveform.permute(0, 2, 1) + return generated_mel_spectrogram + + +@add_start_docstrings( + "The full Qwen2.5Omni Token2Wav model. Consists a DiT model take speech tokens as input and predict mel spectrogram and a BigVGAN vocoder take mel spectrogram as input and predict waveform.", + QWEN2_5OMNI_START_DOCSTRING.format(config_class="Qwen2_5OmniToken2WavConfig"), +) +class Qwen2_5OmniToken2WavModel(Qwen2_5OmniPreTrainedModel): + config_class = Qwen2_5OmniToken2WavConfig + base_model_prefix = "model" + _no_split_modules = ["Qwen2_5OmniToken2WavDiTModel", "Qwen2_5OmniToken2WavBigVGANModel"] + + def __init__(self, config: Qwen2_5OmniToken2WavConfig): + super().__init__(config) + attn_impl = config._attn_implementation + if config._attn_implementation == "flash_attention_2": + logger.warning_once( + "Qwen2_5OmniToken2WavModel must inference with fp32, but flash_attention_2 only supports fp16 and bf16, " + "attention implementation of Qwen2_5OmniToken2WavModel will fallback to sdpa." + ) + attn_impl = "sdpa" + elif config._attn_implementation == "eager": + logger.warning_once( + "Qwen2_5OmniToken2WavModel does not support eager attention implementation, fall back to sdpa" + ) + attn_impl = "sdpa" + self.code2wav_dit_model = Qwen2_5OmniToken2WavDiTModel._from_config( + config.dit_config, attn_implementation=attn_impl + ) + self.code2wav_bigvgan_model = Qwen2_5OmniToken2WavBigVGANModel._from_config( + config.bigvgan_config, attn_implementation=attn_impl + ) + + def forward( + self, + code, + conditioning, + reference_mel, + num_steps=10, + guidance_scale=0.5, + sway_coefficient=-1.0, + **kwargs, + ): + """Generates a waveform from input code and conditioning parameters.""" + + mel_spectrogram = self.code2wav_dit_model.sample( + conditioning, + reference_mel, + code, + num_steps=num_steps, + guidance_scale=guidance_scale, + sway_coefficient=sway_coefficient, + ) + + waveform = self.code2wav_bigvgan_model(mel_spectrogram) + + return waveform + + +############################ +# Start Qwen2.5Omni # +############################ + + +@add_start_docstrings( + """ + The full Qwen2.5Omni model, a multimodal model composed of 3 sub-models: + - [`Qwen2_5OmniThinkerForConditionalGeneration`]: + a causal auto-regressive transformer takes text, audio, image, video as input and predict text tokens. + - [`Qwen2_5OmniTalkerForConditionalGeneration`]: + a causal auto-regressive transformer takes thinker hidden states and response as input and predict speech tokens. + - [`Qwen2_5OmniToken2WavModel`]: + a DiT model take speech tokens as input and predict mel spectrogram and a BigVGAN vocoder take mel spectrogram as input and predict waveform. + """, + QWEN2_5OMNI_START_DOCSTRING.format(config_class=Qwen2_5OmniConfig), +) +class Qwen2_5OmniForConditionalGeneration(Qwen2_5OmniPreTrainedModel, GenerationMixin): + config_class = Qwen2_5OmniConfig + _no_split_modules = [ + "Qwen2_5OmniTalkerForConditionalGeneration", + "Qwen2_5OmniToken2WavModel", + ] + + def __init__(self, config): + super().__init__(config) + + self.thinker = Qwen2_5OmniThinkerForConditionalGeneration(config.thinker_config) + + self.has_talker = config.enable_audio_output + self.speaker_map = {} + if config.enable_audio_output: + self.enable_talker() + self.post_init() + + def enable_talker(self): + self.talker = Qwen2_5OmniTalkerForConditionalGeneration(self.config.talker_config) + self.token2wav = Qwen2_5OmniToken2WavModel(self.config.token2wav_config) + self.token2wav.float() + self.has_talker = True + + def load_speakers(self, path): + for key, value in torch.load(path).items(): + self.speaker_map[key] = value + logger.info("Speaker {} loaded".format(list(self.speaker_map.keys()))) + + def disable_talker(self): + if hasattr(self, "talker"): + del self.talker + if hasattr(self, "token2wav"): + del self.token2wav + self.has_talker = False + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path, + *model_args, + config=None, + cache_dir=None, + ignore_mismatched_sizes=False, + force_download=False, + local_files_only=False, + token=None, + revision="main", + use_safetensors=None, + weights_only=True, + **kwargs, + ): + model = super().from_pretrained( + pretrained_model_name_or_path, + *model_args, + config=config, + cache_dir=cache_dir, + ignore_mismatched_sizes=ignore_mismatched_sizes, + force_download=force_download, + local_files_only=local_files_only, + token=token, + revision=revision, + use_safetensors=use_safetensors, + weights_only=weights_only, + **kwargs, + ) + spk_path = cached_file( + pretrained_model_name_or_path, + "spk_dict.pt", + subfolder=kwargs.pop("subfolder", None), + cache_dir=kwargs.pop("cache_dir", None), + force_download=kwargs.pop("force_download", False), + proxies=kwargs.pop("proxies", None), + resume_download=kwargs.pop("resume_download", None), + local_files_only=kwargs.pop("local_files_only", False), + token=kwargs.pop("use_auth_token", None), + revision=kwargs.pop("revision", None), + ) + if spk_path is None: + raise ValueError(f"""{pretrained_model_name_or_path}/{spk_path} not exists""") + model.load_speakers(spk_path) + + return model + + @torch.no_grad() + # TODO: raushan, defaults should be saved in generation config + def generate( + self, + input_ids: Optional[torch.Tensor] = None, + speaker: str = "Chelsie", + use_audio_in_video: bool = False, + return_audio: Optional[bool] = None, + thinker_max_new_tokens: int = 1024, + talker_max_new_tokens: int = 4096, + talker_do_sample: bool = True, + talker_top_k: int = 40, + talker_top_p: float = 0.8, + talker_temperature: float = 0.9, + talker_eos_token_id: list[int] = [8292, 8294], + talker_repetition_penalty: float = 1.05, + **kwargs, + ): + r""" + Generate text response and audio from input. + + Args: + input_ids (`Optional[torch.Tensor]`, *optional*): + Input ids, should obtain from processor. + speaker (`str` , defaults to "Chelsie"): + Which speaker should be used in audio response. + use_audio_in_video (`bool`, defaults to False): + Whether or not use audio track in video, should same as the parameter in `process_audio_info`. + return_audio (`Optional[bool]`, *optional*): + Whether or not return response in audio format. When `return_audio=None`, this parameter is same as `config.enable_audio_output`. + kwargs (*optional*): + - Without a prefix, they will be entered as `**kwargs` for the `generate` method of each sub-model. + - With a *thinker_*, *talker_*, *token2wav_* prefix, they will be input for the `generate` method of the + thinker, talker and token2wav respectively. It has the priority over the keywords without a prefix. + Returns: + When `return_audio=False`: + - **Text** (`torch.Tensor`): Generated text token sequence. + When `return_audio=True`: + - **Text** (`torch.Tensor`): Generated text token sequence. + - **Audio waveform** (`torch.Tensor`): Generated audio waveform. + """ + if speaker not in self.speaker_map: + raise ValueError(f"{speaker} is not availible, availible speakers: {self.speaker_map.keys()}") + if return_audio and not self.has_talker: + raise ValueError( + "Cannot use talker when talker module not initalized. Use `enable_talker` method or set enable_talker in config to enable talker." + ) + if return_audio is None: + return_audio = self.has_talker + if input_ids.shape[0] != 1 and return_audio: + raise NotImplementedError("Qwen2.5-Omni currently does not support batched inference with audio output") + + shared_kwargs = {"use_audio_in_video": use_audio_in_video} + thinker_kwargs = { + "max_new_tokens": thinker_max_new_tokens, + } + talker_kwargs = { + "max_new_tokens": talker_max_new_tokens, + "do_sample": talker_do_sample, + "top_k": talker_top_k, + "top_p": talker_top_p, + "temperature": talker_temperature, + "eos_token_id": talker_eos_token_id, + "repetition_penalty": talker_repetition_penalty, + } + token2wav_kwargs = {} + + for key, value in kwargs.items(): + if key.startswith("thinker_"): + thinker_kwargs[key[len("thinker_") :]] = value + elif key.startswith("talker_"): + talker_kwargs[key[len("talker_") :]] = value + elif key.startswith("token2wav_"): + token2wav_kwargs[key[len("token2wav_") :]] = value + # Process special input values + elif key == "feature_attention_mask": + thinker_kwargs[key] = value + talker_kwargs["audio_feature_lengths"] = torch.sum(value, dim=1) + elif key == "input_features" or key == "attention_mask": + thinker_kwargs[key] = value + # Put other key to shared kwargs + else: + shared_kwargs[key] = value + + # Merge kwargs + for key, value in shared_kwargs.items(): + if key not in thinker_kwargs: + thinker_kwargs[key] = value + if key not in talker_kwargs: + talker_kwargs[key] = value + if key not in token2wav_kwargs: + token2wav_kwargs[key] = value + speaker_params = self.speaker_map[speaker] + + # 1. Generate from thinker module + generate_audio = return_audio and self.has_talker + if generate_audio: + thinker_kwargs["output_hidden_states"] = True + thinker_kwargs["return_dict_in_generate"] = True + + thinker_result = self.thinker.generate(input_ids=input_ids, **thinker_kwargs) + + if not generate_audio: + return thinker_result + + # 2. Generate speech tokens from talker module + embeds_to_talker = thinker_result.hidden_states[0][0].clone().to(self.talker.device) + if thinker_kwargs.get("input_features", None) is not None: + audio_ids_mask = input_ids == self.config.thinker_config.audio_token_index + audio_mask = audio_ids_mask.unsqueeze(-1).expand_as(embeds_to_talker).to(embeds_to_talker.device) + audio_mask_tensor = torch.zeros( + [audio_ids_mask.sum(), embeds_to_talker.shape[-1]], + dtype=embeds_to_talker.dtype, + device=self.talker.device, + ) + embeds_to_talker.masked_scatter_(audio_mask, audio_mask_tensor) + if thinker_kwargs.get("pixel_values", None) is not None: + image_ids_mask = input_ids == self.config.thinker_config.image_token_index + image_mask = image_ids_mask.unsqueeze(-1).expand_as(embeds_to_talker).to(embeds_to_talker.device) + image_mask_tensor = torch.zeros( + [image_ids_mask.sum(), embeds_to_talker.shape[-1]], + dtype=embeds_to_talker.dtype, + device=self.talker.device, + ) + embeds_to_talker.masked_scatter_(image_mask, image_mask_tensor) + if thinker_kwargs.get("pixel_values_videos", None) is not None: + video_ids_mask = input_ids == self.config.thinker_config.video_token_index + video_mask = video_ids_mask.unsqueeze(-1).expand_as(embeds_to_talker).to(embeds_to_talker.device) + video_mask_tensor = torch.zeros( + [video_ids_mask.sum(), embeds_to_talker.shape[-1]], + dtype=embeds_to_talker.dtype, + device=self.talker.device, + ) + embeds_to_talker.masked_scatter_(video_mask, video_mask_tensor) + + processed_thinker_hidden = ( + (embeds_to_talker,) + thinker_result.hidden_states[0][1:], + ) + thinker_result.hidden_states[1:] + thinker_generate_ids = thinker_result.sequences[:, input_ids.size(1) :].to(self.talker.device) + thinker_token_embeds = [ + token_hidden_states[0].to(self.talker.device) for token_hidden_states in processed_thinker_hidden + ] + thinker_hidden_states = [ + token_hidden_states[-1].to(self.talker.device) for token_hidden_states in processed_thinker_hidden + ] + + talker_text_bos_token = speaker_params["bos_token"] + talker_input_text_ids = torch.cat( + [ + input_ids.to(self.talker.device), + torch.tensor([[talker_text_bos_token]], dtype=torch.long, device=self.talker.device), + thinker_generate_ids[:, :1], + ], + dim=-1, + ) + + talker_input_ids = torch.cat( + [ + torch.full_like(input_ids, fill_value=self.talker.codec_mask_token, device=self.talker.device), + torch.tensor([[self.talker.codec_pad_token]], dtype=torch.long, device=self.talker.device), + torch.tensor([[self.talker.codec_bos_token]], dtype=torch.long, device=self.talker.device), + ], + dim=1, + ) + + thinker_embed_tokens = self.thinker.get_input_embeddings() + thinker_reply_part = torch.cat(thinker_hidden_states[1:], dim=1) + torch.cat(thinker_token_embeds[1:], dim=1) + talker_inputs_embeds = thinker_hidden_states[0] + thinker_token_embeds[0] + talker_text_bos_token = torch.tensor([[talker_text_bos_token]], dtype=torch.long, device=self.thinker.device) + talker_text_bos_embed = thinker_embed_tokens(talker_text_bos_token).to(self.talker.device) + talker_inputs_embeds = torch.cat( + [ + talker_inputs_embeds, + talker_text_bos_embed, + thinker_reply_part[:, :1, :], + ], + dim=1, + ) + + eos_embedding = thinker_embed_tokens( + torch.tensor([[self.talker.text_eos_token]], dtype=torch.long, device=self.thinker.device) + ).to(self.talker.device) + + pad_embedding = thinker_embed_tokens( + torch.tensor([[self.talker.text_pad_token]], dtype=torch.long, device=self.thinker.device) + ).to(self.talker.device) + + thinker_reply_part = torch.cat( + [ + thinker_reply_part[:, 1:, :], + eos_embedding, + pad_embedding, + ], + dim=1, + ) + + talker_attention_mask = None + if "attention_mask" in kwargs: + talker_attention_mask = torch.cat( + [kwargs["attention_mask"], kwargs["attention_mask"].new_ones((1, 2))], dim=1 + ).to(self.talker.device) + + talker_result = self.talker.generate( + input_ids=talker_input_ids, + input_text_ids=talker_input_text_ids, + thinker_reply_part=thinker_reply_part, + inputs_embeds=talker_inputs_embeds, + attention_mask=talker_attention_mask, + suppress_tokens=[self.talker.codec_bos_token], + **{k: (v.to(self.talker.device) if torch.is_tensor(v) else v) for k, v in talker_kwargs.items()}, + ) + talker_generate_codes = talker_result[:, talker_input_ids.shape[1] : -1] + + # 3. Generate wavs from code + if self.token2wav.dtype != torch.float: + self.token2wav.float() + + wav = self.token2wav( + talker_generate_codes.to(self.token2wav.device), + conditioning=speaker_params["cond"].to(self.token2wav.device).float(), + reference_mel=speaker_params["ref_mel"].to(self.token2wav.device).float(), + **token2wav_kwargs, + ) + + return thinker_result.sequences, wav.float() + + +__all__ = [ + "Qwen2_5OmniConfig", + "Qwen2_5OmniThinkerConfig", + "Qwen2_5OmniTalkerConfig", + "Qwen2_5OmniToken2WavConfig", + "Qwen2_5OmniForConditionalGeneration", + "Qwen2_5OmniThinkerTextModel", + "Qwen2_5OmniThinkerForConditionalGeneration", + "Qwen2_5OmniTalkerModel", + "Qwen2_5OmniTalkerForConditionalGeneration", + "Qwen2_5OmniToken2WavDiTModel", + "Qwen2_5OmniToken2WavBigVGANModel", + "Qwen2_5OmniToken2WavModel", + "Qwen2_5OmniPreTrainedModel", + "Qwen2_5OmniPreTrainedModelForConditionalGeneration", +] diff --git a/docs/transformers/src/transformers/models/qwen2_5_omni/processing_qwen2_5_omni.py b/docs/transformers/src/transformers/models/qwen2_5_omni/processing_qwen2_5_omni.py new file mode 100644 index 0000000000000000000000000000000000000000..57b2d43a3f5f6c26a4c49c343032a77fa2413f29 --- /dev/null +++ b/docs/transformers/src/transformers/models/qwen2_5_omni/processing_qwen2_5_omni.py @@ -0,0 +1,359 @@ +# coding=utf-8 +# Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Processor class for Qwen2.5Omni. +""" + +import logging +import re +from typing import List, Optional, Union + +import numpy as np + +from ...feature_extraction_utils import BatchFeature +from ...image_utils import ImageInput, VideoInput, make_batched_videos +from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack, VideosKwargs +from ...tokenization_utils_base import AudioInput, PreTokenizedInput, TextInput + + +class Qwen2_5_OmniVideosKwargs(VideosKwargs): + fps: Optional[List[int]] = None + use_audio_in_video: Optional[bool] = None + seconds_per_chunk: Optional[float] = None + position_id_per_seconds: Optional[int] = None + min_pixels: Optional[int] + max_pixels: Optional[int] + patch_size: Optional[int] + temporal_patch_size: Optional[int] + merge_size: Optional[int] + + +class Qwen2_5_OmniImagesKwargs(ImagesKwargs): + min_pixels: Optional[int] + max_pixels: Optional[int] + patch_size: Optional[int] + temporal_patch_size: Optional[int] + merge_size: Optional[int] + + +class Qwen2_5OmniProcessorKwargs(ProcessingKwargs, total=False): + videos_kwargs: Qwen2_5_OmniVideosKwargs + images_kwargs: Qwen2_5_OmniImagesKwargs + _defaults = { + "text_kwargs": { + "padding": False, + "padding_side": "left", + }, + "videos_kwargs": { + "seconds_per_chunk": 2.0, + "position_id_per_seconds": 25, + "use_audio_in_video": False, + "min_pixels": 128 * 28 * 28, + "max_pixels": 768 * 28 * 28, + }, + "audio_kwargs": { + "sampling_rate": 16000, + "padding": "max_length", + "return_attention_mask": True, + }, + } + + +class Qwen2_5OmniProcessor(ProcessorMixin): + r""" + Constructs a Qwen2.5Omni processor. + [`Qwen2_5OmniProcessor`] offers all the functionalities of [`Qwen2VLImageProcessor`], [`WhisperFeatureExtractor`], and [`Qwen2TokenizerFast`]. See the + [`~Qwen2_5OmniProcessor.__call__`] and [`~Qwen2_5OmniProcessor.decode`] for more information. + + Args: + image_processor ([`Qwen2VLImageProcessor`], *optional*): + The image processor. + feature_extractor ([`WhisperFeatureExtractor`], *optional*): + The audio feature extractor. + tokenizer ([`Qwen2TokenizerFast`], *optional*): + The text tokenizer. + chat_template (`Optional[str]`, *optional*): + The Jinja template to use for formatting the conversation. If not provided, the default chat template is used. + """ + + attributes = ["image_processor", "feature_extractor", "tokenizer"] + image_processor_class = "Qwen2VLImageProcessor" + feature_extractor_class = "WhisperFeatureExtractor" + tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast") + valid_kwargs = ["chat_template"] + + def __init__(self, image_processor=None, feature_extractor=None, tokenizer=None, chat_template=None): + super().__init__(image_processor, feature_extractor, tokenizer, chat_template=chat_template) + self.image_token = self.tokenizer.image_token + self.audio_token = self.tokenizer.audio_token + self.video_token = self.tokenizer.video_token + self.vision_bos_token = self.tokenizer.vision_bos_token + self.vision_eos_token = self.tokenizer.vision_eos_token + self.audio_bos_token = self.tokenizer.audio_bos_token + self.audio_eos_token = self.tokenizer.audio_eos_token + + def __call__( + self, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + images: ImageInput = None, + videos: VideoInput = None, + audio: AudioInput = None, + **kwargs: Unpack[Qwen2_5OmniProcessorKwargs], + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and audio(s). This method forwards the `text` + and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode + the text. To prepare the audio(s), this method forwards the `audio` and `kwargs` arguments to + WhisperFeatureExtractor's [`~WhisperFeatureExtractor.__call__`] if `audio` is not `None`. To prepare the vision inputs, + this method forwards the `vision_infos` and `kwargs` arguments to Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] + if `vision_infos` is not `None`. Please refer to the doctsring + of the above two methods for more information. + + Args: + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch + tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported. + audio (`np.ndarray`, `List[np.ndarray]`): + The audio or batch of audio to be prepared. Each audio can be a NumPy array. + """ + + if text is None: + raise ValueError("You need to specify either a `text` input to process.") + + output_kwargs = self._merge_kwargs( + Qwen2_5OmniProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + + seconds_per_chunk = output_kwargs["videos_kwargs"].pop("seconds_per_chunk") + position_id_per_seconds = output_kwargs["videos_kwargs"].pop("position_id_per_seconds") + use_audio_in_video = output_kwargs["videos_kwargs"].pop("use_audio_in_video") + fps = output_kwargs["videos_kwargs"].pop("fps", 2.0) + + if audio is not None: + output_kwargs["audio_kwargs"]["padding"] = "max_length" # Support "max_length" padding only here + audio_inputs = self.feature_extractor(audio, **output_kwargs["audio_kwargs"]) + audio_inputs["feature_attention_mask"] = audio_inputs.pop( + "attention_mask" + ) # rename feature_attention_mask to prevent conflicts later on + audio_inputs["input_features"] = audio_inputs.pop( + "input_features" + ) # rename input_features to prevent conflicts later on + input_lengths = (audio_inputs["feature_attention_mask"].sum(-1) - 1) // 2 + 1 + audio_lengths = iter((input_lengths - 2) // 2 + 1) + else: + audio_inputs = {} + audio_lengths = iter([]) + + if images is not None: + images_inputs = self.image_processor(images=images, videos=None, **output_kwargs["images_kwargs"]) + image_grid_thw = iter(images_inputs["image_grid_thw"]) + else: + images_inputs = {} + image_grid_thw = iter([]) + + if videos is not None: + videos = make_batched_videos(videos) + videos_inputs = self.image_processor(images=None, videos=videos, **output_kwargs["videos_kwargs"]) + fps = [fps] * len(videos) + videos_inputs["video_second_per_grid"] = [ + self.image_processor.temporal_patch_size / fps[i] for i in range(len(fps)) + ] + video_grid_thw = iter(videos_inputs["video_grid_thw"]) + video_second_per_grid = iter(videos_inputs["video_second_per_grid"]) + else: + videos_inputs = {} + video_grid_thw = iter([]) + video_second_per_grid = iter([]) + + if not isinstance(text, list): + text = [text] + + text = self.replace_multimodal_special_tokens( + text, + audio_lengths, + image_grid_thw, + video_grid_thw, + video_second_per_grid=video_second_per_grid, + use_audio_in_video=use_audio_in_video, + position_id_per_seconds=position_id_per_seconds, + seconds_per_chunk=seconds_per_chunk, + ) + + texts_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) + + return BatchFeature( + data={**texts_inputs, **images_inputs, **videos_inputs, **audio_inputs}, + tensor_type=kwargs.get("return_tensors"), + ) + + def replace_multimodal_special_tokens( + self, + text, + audio_lengths, + image_grid_thw, + video_grid_thw, + video_second_per_grid, + use_audio_in_video, + position_id_per_seconds, + seconds_per_chunk, + ): + # Extend mm token length + merge_length = self.image_processor.merge_size**2 + + processed_text = [] + for sample in text: + positions = [] + special_tokens = [re.escape(tok) for tok in [self.audio_token, self.image_token, self.video_token]] + pattern = "|".join(special_tokens) + positions = sorted([(match.start(), match.group()) for match in re.finditer(pattern, sample)]) + positions.sort(key=lambda x: x[0]) + + for _, special_token in positions: + if special_token == self.audio_token: + sample = sample.replace(self.audio_token, "<|audio_placeholder|>" * next(audio_lengths), 1) + elif special_token == self.image_token: + image_seq_length = next(image_grid_thw).prod() // merge_length + sample = sample.replace(self.image_token, "<|image_placeholder|>" * image_seq_length, 1) + elif special_token == self.video_token: + if not use_audio_in_video: + video_seq_length = next(video_grid_thw).prod() // merge_length + sample = sample.replace(self.video_token, "<|video_placeholder|>" * video_seq_length, 1) + else: + audio_token_indices = np.arange(next(audio_lengths)) + curr_video_grid_thw = next(video_grid_thw) + height = curr_video_grid_thw[1] // self.image_processor.merge_size + width = curr_video_grid_thw[2] // self.image_processor.merge_size + video_token_indices = np.arange(curr_video_grid_thw[0]).reshape(-1, 1, 1) + video_token_indices = np.broadcast_to( + video_token_indices, (video_token_indices.shape[0], height, width) + ).reshape(-1) + video_token_indices = ( + video_token_indices * next(video_second_per_grid) * position_id_per_seconds + ) + + tokens_per_chunk = int(position_id_per_seconds * seconds_per_chunk) + video_chunk_indexes = self.get_chunked_index(video_token_indices, tokens_per_chunk) + audio_chunk_indexes = self.get_chunked_index(audio_token_indices, tokens_per_chunk) + + placeholder_string = self.vision_bos_token + self.audio_bos_token + for j in range(max(len(video_chunk_indexes), len(audio_chunk_indexes))): + if j < len(video_chunk_indexes): + video_seq_length = video_chunk_indexes[j][1] - video_chunk_indexes[j][0] + placeholder_string += "<|video_placeholder|>" * video_seq_length + if j < len(audio_chunk_indexes): + audio_seq_length = audio_chunk_indexes[j][1] - audio_chunk_indexes[j][0] + placeholder_string += "<|audio_placeholder|>" * audio_seq_length + placeholder_string += self.audio_eos_token + self.vision_eos_token + sample = sample.replace( + self.vision_bos_token + self.video_token + self.vision_eos_token, + placeholder_string, + 1, + ) + + sample = sample.replace("<|audio_placeholder|>", self.audio_token) + sample = sample.replace("<|image_placeholder|>", self.image_token) + sample = sample.replace("<|video_placeholder|>", self.video_token) + processed_text.append(sample) + return processed_text + + def get_chunked_index(self, token_indices: np.ndarray, tokens_per_chunk: int) -> list[tuple[int, int]]: + """ + Splits token index list into chunks based on token value ranges. + + Given a list of token indices, returns a list of (start, end) index tuples representing + slices of the list where the token values fall within successive ranges of `t_ntoken_per_chunk`. + + For example, if `t_ntoken_per_chunk` is 1000, the function will create chunks such that: + - the first chunk contains token values < 1000, + - the second chunk contains values >= 1000 and < 2000, and so on. + + Parameters: + token_indices (`np.ndarray`): A monotonically increasing list of token index values. + t_ntoken_per_chunk (`int`): Number of tokens per chunk (used as the chunk size threshold). + + Returns: + `List[Tuple[int, int]]`: A list of tuples, each representing the start (inclusive) + and end (exclusive) indices of a chunk in `token_indices`. + """ + + def _iter(): + i, start_idx = 0, 0 # skip bos token + current_chunk = 1 + while i < len(token_indices): # skip eos token + if token_indices[i] >= current_chunk * tokens_per_chunk: + yield (start_idx, i) + start_idx = i + current_chunk += 1 + i += 1 + yield (start_idx, len(token_indices)) + + return list(_iter()) + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + def apply_chat_template(self, conversations, chat_template=None, **kwargs): + if isinstance(conversations[0], dict): + conversations = [conversations] + for conversation in conversations: + if ( + conversation[0]["role"] != "system" + or conversation[0]["content"][0]["text"] + != "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech." + ): + logging.warning( + "System prompt modified, audio output may not work as expected. " + + "Audio output mode only works when using default system prompt 'You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech.'" + ) + return super().apply_chat_template(conversations, chat_template, **kwargs) + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + feature_extractor_input_names = self.feature_extractor.model_input_names + image_processor_input_names = self.image_processor.model_input_names + return list( + dict.fromkeys( + tokenizer_input_names + + feature_extractor_input_names + + image_processor_input_names + + ["feature_attention_mask"] + + ["video_second_per_grid"] + ) + ) + + +__all__ = ["Qwen2_5OmniProcessor"] diff --git a/docs/transformers/src/transformers/models/qwen2_5_vl/__init__.py b/docs/transformers/src/transformers/models/qwen2_5_vl/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7a9f44a7a05fc65136552948be9acef53e5ea106 --- /dev/null +++ b/docs/transformers/src/transformers/models/qwen2_5_vl/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_qwen2_5_vl import * + from .modeling_qwen2_5_vl import * + from .processing_qwen2_5_vl import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py b/docs/transformers/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py new file mode 100644 index 0000000000000000000000000000000000000000..a700a9928861f32bc90181ea3e25bc4e19bd10a6 --- /dev/null +++ b/docs/transformers/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py @@ -0,0 +1,322 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_qwen2_5_vl.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from ...configuration_utils import PretrainedConfig +from ...modeling_rope_utils import rope_config_validation + + +class Qwen2_5_VLVisionConfig(PretrainedConfig): + model_type = "qwen2_5_vl" + base_config_key = "vision_config" + + def __init__( + self, + depth=32, + hidden_size=3584, + hidden_act="silu", + intermediate_size=3420, + num_heads=16, + in_channels=3, + patch_size=14, + spatial_merge_size=2, + temporal_patch_size=2, + tokens_per_second=4, + window_size=112, + out_hidden_size=3584, + fullatt_block_indexes=[7, 15, 23, 31], + initializer_range=0.02, + **kwargs, + ): + super().__init__(**kwargs) + + self.depth = depth + self.hidden_size = hidden_size + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.num_heads = num_heads + self.in_channels = in_channels + self.patch_size = patch_size + self.spatial_merge_size = spatial_merge_size + self.temporal_patch_size = temporal_patch_size + self.tokens_per_second = tokens_per_second + self.window_size = window_size + self.fullatt_block_indexes = fullatt_block_indexes + self.out_hidden_size = out_hidden_size + self.initializer_range = initializer_range + + +class Qwen2_5_VLTextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen2_5_VLTextModel`]. It is used to instantiate a + Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of + Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 152064): + Vocabulary size of the Qwen2_5_VL model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Qwen2_5_VLModel`] + hidden_size (`int`, *optional*, defaults to 8192): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 29568): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 80): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 64): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 32768): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 1000000.0): + The base period of the RoPE embeddings. + use_sliding_window (`bool`, *optional*, defaults to `False`): + Whether to use sliding window attention. + sliding_window (`int`, *optional*, defaults to 4096): + Sliding window attention (SWA) window size. If not specified, will default to `4096`. + max_window_layers (`int`, *optional*, defaults to 80): + The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + image_token_id (`int`, *optional*): + Token index used as placeholder for image embeddings. + video_token_id (`int`, *optional*): + Token index used as placeholder for video embeddings. + + ```python + >>> from transformers import Qwen2_5_VLTextModel, Qwen2_5_VLConfig + + >>> # Initializing a Qwen2_5_VL style configuration + >>> configuration = Qwen2_5_VLConfig() + + >>> # Initializing a model from the Qwen2-VL-7B style configuration + >>> model = Qwen2_5_VLTextModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen2_5_vl_text" + base_config_key = "text_config" + keys_to_ignore_at_inference = ["past_key_values"] + # Default tensor parallel plan for base model `Qwen2_5_VL` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=152064, + hidden_size=8192, + intermediate_size=29568, + num_hidden_layers=80, + num_attention_heads=64, + num_key_value_heads=8, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-05, + use_cache=True, + tie_word_embeddings=False, + rope_theta=1000000.0, + use_sliding_window=False, + sliding_window=4096, + max_window_layers=80, + attention_dropout=0.0, + rope_scaling=None, + image_token_id=None, + video_token_id=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.use_sliding_window = use_sliding_window + self.sliding_window = sliding_window + self.max_window_layers = max_window_layers + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + self.rope_scaling = rope_scaling + + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, move it to 'rope_type'. + # and change type from 'mrope' to 'default' because `mrope` does default RoPE calculations + # one can set it to "linear"/"dynamic" etc. to have scaled RoPE + # TODO: @raushan update config in the hub + if self.rope_scaling is not None and "type" in self.rope_scaling: + if self.rope_scaling["type"] == "mrope": + self.rope_scaling["type"] = "default" + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self, ignore_keys={"mrope_section"}) + self.image_token_id = image_token_id + self.video_token_id = video_token_id + + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + + +class Qwen2_5_VLConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen2_5_VLModel`]. It is used to instantiate a + Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of + Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen2_5_VLTextConfig`): + The config object or dictionary of the text backbone. + vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen2_5_VLVisionConfig`): + The config object or dictionary of the vision backbone. + image_token_id (`int`, *optional*, defaults to 151655): + The image token index to encode the image prompt. + video_token_id (`int`, *optional*, defaults to 151656): + The video token index to encode the image prompt. + + ```python + >>> from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLConfig + + >>> # Initializing a Qwen2_5_VL style configuration + >>> configuration = Qwen2_5_VLConfig() + + >>> # Initializing a model from the Qwen2-VL-7B style configuration + >>> model = Qwen2_5_VLForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen2_5_vl" + sub_configs = {"vision_config": Qwen2_5_VLVisionConfig, "text_config": Qwen2_5_VLTextConfig} + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + text_config=None, + vision_config=None, + image_token_id=151655, + video_token_id=151656, + **kwargs, + ): + if isinstance(vision_config, dict): + self.vision_config = self.sub_configs["vision_config"](**vision_config) + elif vision_config is None: + self.vision_config = self.sub_configs["vision_config"]() + + if isinstance(text_config, dict): + self.text_config = self.sub_configs["text_config"](**text_config) + elif text_config is None: + # For BC use all kwargs to init `TextConfig` + self.text_config = self.sub_configs["text_config"](**kwargs) + + self.image_token_id = image_token_id + self.video_token_id = video_token_id + + super().__init__(**kwargs) + + +__all__ = ["Qwen2_5_VLConfig", "Qwen2_5_VLTextConfig"] diff --git a/docs/transformers/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/docs/transformers/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py new file mode 100644 index 0000000000000000000000000000000000000000..4da0f59bf469af237bfca2b0c1a4d13a7f7a9683 --- /dev/null +++ b/docs/transformers/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -0,0 +1,2061 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_qwen2_5_vl.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import CrossEntropyLoss + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import PreTrainedModel +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_torch_flex_attn_available, + logging, + replace_return_docstrings, +) +from .configuration_qwen2_5_vl import Qwen2_5_VLConfig, Qwen2_5_VLTextConfig, Qwen2_5_VLVisionConfig + + +if is_flash_attn_available(): + from ...modeling_flash_attention_utils import apply_rotary_emb, flash_attn_varlen_func + + +if is_flash_attn_available(): + from ...modeling_flash_attention_utils import _flash_attention_forward + +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "Qwen2_5_VLConfig" + + +class Qwen2_5_VLMLP(nn.Module): + def __init__(self, config, bias: bool = False): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_state): + return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + + +class Qwen2_5_VisionPatchEmbed(nn.Module): + def __init__( + self, + patch_size: int = 14, + temporal_patch_size: int = 2, + in_channels: int = 3, + embed_dim: int = 1152, + ) -> None: + super().__init__() + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.in_channels = in_channels + self.embed_dim = embed_dim + + kernel_size = [temporal_patch_size, patch_size, patch_size] + self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + target_dtype = self.proj.weight.dtype + hidden_states = hidden_states.view( + -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size + ) + hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim) + return hidden_states + + +class Qwen2_5_VisionRotaryEmbedding(nn.Module): + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, seqlen: int) -> torch.Tensor: + seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = torch.outer(seq, self.inv_freq) + return freqs + + +class Qwen2RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Qwen2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class Qwen2_5_VLPatchMerger(nn.Module): + def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None: + super().__init__() + self.hidden_size = context_dim * (spatial_merge_size**2) + self.ln_q = Qwen2RMSNorm(context_dim, eps=1e-6) + self.mlp = nn.Sequential( + nn.Linear(self.hidden_size, self.hidden_size), + nn.GELU(), + nn.Linear(self.hidden_size, dim), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.mlp(self.ln_q(x).view(-1, self.hidden_size)) + return x + + +def apply_rotary_pos_emb_flashatt( + q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + cos = cos.chunk(2, dim=-1)[0].contiguous() + sin = sin.chunk(2, dim=-1)[0].contiguous() + q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q) + k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k) + return q_embed, k_embed + + +class Qwen2_5_VLVisionFlashAttention2(nn.Module): + def __init__(self, dim: int, num_heads: int = 16) -> None: + super().__init__() + self.num_heads = num_heads + self.qkv = nn.Linear(dim, dim * 3, bias=True) + self.proj = nn.Linear(dim, dim) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: Optional[torch.Tensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be " + "removed and `position_embeddings` will be mandatory." + ) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + cos = emb.cos() + sin = emb.sin() + else: + cos, sin = position_embeddings + q, k = apply_rotary_pos_emb_flashatt(q.unsqueeze(0), k.unsqueeze(0), cos, sin) + q = q.squeeze(0) + k = k.squeeze(0) + + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape( + seq_length, -1 + ) + attn_output = self.proj(attn_output) + return attn_output + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb_vision( + q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + orig_q_dtype = q.dtype + orig_k_dtype = k.dtype + q, k = q.float(), k.float() + cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float() + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + q_embed = q_embed.to(orig_q_dtype) + k_embed = k_embed.to(orig_k_dtype) + return q_embed, k_embed + + +class Qwen2_5_VLVisionAttention(nn.Module): + def __init__(self, dim: int, num_heads: int = 16) -> None: + super().__init__() + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.qkv = nn.Linear(dim, dim * 3, bias=True) + self.proj = nn.Linear(dim, dim) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: Optional[torch.Tensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be " + "removed and `position_embeddings` will be mandatory." + ) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + cos = emb.cos() + sin = emb.sin() + else: + cos, sin = position_embeddings + q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) + + attention_mask = torch.full( + [1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype + ) + for i in range(1, len(cu_seqlens)): + attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 + + q = q.transpose(0, 1) + k = k.transpose(0, 1) + v = v.transpose(0, 1) + attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim) + attn_weights = attn_weights + attention_mask + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) + attn_output = torch.matmul(attn_weights, v) + attn_output = attn_output.transpose(0, 1) + attn_output = attn_output.reshape(seq_length, -1) + attn_output = self.proj(attn_output) + return attn_output + + +class Qwen2_5_VLVisionSdpaAttention(nn.Module): + def __init__(self, dim: int, num_heads: int = 16) -> None: + super().__init__() + self.num_heads = num_heads + self.qkv = nn.Linear(dim, dim * 3, bias=True) + self.proj = nn.Linear(dim, dim) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: Optional[torch.Tensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be " + "removed and `position_embeddings` will be mandatory." + ) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + cos = emb.cos() + sin = emb.sin() + else: + cos, sin = position_embeddings + q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) + + attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool) + for i in range(1, len(cu_seqlens)): + attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True + q = q.transpose(0, 1) + k = k.transpose(0, 1) + v = v.transpose(0, 1) + attn_output = F.scaled_dot_product_attention( + q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0), attention_mask, dropout_p=0.0 + ) + attn_output = attn_output.squeeze(0).transpose(0, 1) + attn_output = attn_output.reshape(seq_length, -1) + attn_output = self.proj(attn_output) + return attn_output + + +QWEN2_5_VL_VISION_ATTENTION_CLASSES = { + "eager": Qwen2_5_VLVisionAttention, + "flash_attention_2": Qwen2_5_VLVisionFlashAttention2, + "sdpa": Qwen2_5_VLVisionSdpaAttention, +} + + +class Qwen2_5_VLVisionBlock(nn.Module): + def __init__(self, config, attn_implementation: str = "sdpa") -> None: + super().__init__() + self.norm1 = Qwen2RMSNorm(config.hidden_size, eps=1e-6) + self.norm2 = Qwen2RMSNorm(config.hidden_size, eps=1e-6) + self.attn = QWEN2_5_VL_VISION_ATTENTION_CLASSES[attn_implementation]( + config.hidden_size, num_heads=config.num_heads + ) + self.mlp = Qwen2_5_VLMLP(config, bias=True) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: Optional[torch.Tensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states), + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + position_embeddings=position_embeddings, + ) + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + return hidden_states + + +Qwen2_5_VL_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Qwen2_5_VLConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Qwen2_5_VL Model outputting raw hidden-states without any specific head on top.", + Qwen2_5_VL_START_DOCSTRING, +) +class Qwen2_5_VLPreTrainedModel(PreTrainedModel): + config_class = Qwen2_5_VLConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Qwen2_5_VLDecoderLayer", "Qwen2_5_VLVisionBlock"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_static_cache = False # TODO (joao): fix. torch.compile failing probably due to `cache_positions` + + def _init_weights(self, module): + std = self.config.get_text_config().initializer_range + if isinstance(module, (nn.Linear, nn.Conv3d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, Qwen2RMSNorm): + module.weight.data.fill_(1.0) + + +class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel): + config_class = Qwen2_5_VLVisionConfig + _no_split_modules = ["Qwen2_5_VLVisionBlock"] + + def __init__(self, config, *inputs, **kwargs) -> None: + super().__init__(config, *inputs, **kwargs) + self.spatial_merge_size = config.spatial_merge_size + self.patch_size = config.patch_size + self.fullatt_block_indexes = config.fullatt_block_indexes + self.window_size = config.window_size + self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size + + self.patch_embed = Qwen2_5_VisionPatchEmbed( + patch_size=config.patch_size, + temporal_patch_size=config.temporal_patch_size, + in_channels=config.in_channels, + embed_dim=config.hidden_size, + ) + + head_dim = config.hidden_size // config.num_heads + self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2) + + self.blocks = nn.ModuleList( + [Qwen2_5_VLVisionBlock(config, config._attn_implementation) for _ in range(config.depth)] + ) + self.merger = Qwen2_5_VLPatchMerger( + dim=config.out_hidden_size, + context_dim=config.hidden_size, + spatial_merge_size=config.spatial_merge_size, + ) + self.gradient_checkpointing = False + + def rot_pos_emb(self, grid_thw): + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + + def get_window_index(self, grid_thw): + window_index: list = [] + cu_window_seqlens: list = [0] + window_index_id = 0 + vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size + + for grid_t, grid_h, grid_w in grid_thw: + llm_grid_h, llm_grid_w = ( + grid_h // self.spatial_merge_size, + grid_w // self.spatial_merge_size, + ) + index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w) + pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size + pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size + num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size + num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size + index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100) + index_padded = index_padded.reshape( + grid_t, + num_windows_h, + vit_merger_window_size, + num_windows_w, + vit_merger_window_size, + ) + index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( + grid_t, + num_windows_h * num_windows_w, + vit_merger_window_size, + vit_merger_window_size, + ) + seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) + index_padded = index_padded.reshape(-1) + index_new = index_padded[index_padded != -100] + window_index.append(index_new + window_index_id) + cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1] + cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) + window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() + window_index = torch.cat(window_index, dim=0) + + return window_index, cu_window_seqlens + + def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor: + """ + Args: + hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`): + The final hidden states of the model. + grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): + The temporal, height and width of feature shape of each image in LLM. + + Returns: + `torch.Tensor`: hidden_states. + """ + hidden_states = self.patch_embed(hidden_states) + rotary_pos_emb = self.rot_pos_emb(grid_thw) + window_index, cu_window_seqlens = self.get_window_index(grid_thw) + cu_window_seqlens = torch.tensor( + cu_window_seqlens, + device=hidden_states.device, + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) + + seq_len, _ = hidden_states.size() + hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + hidden_states = hidden_states[window_index, :, :] + hidden_states = hidden_states.reshape(seq_len, -1) + rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + rotary_pos_emb = rotary_pos_emb[window_index, :, :] + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + position_embeddings = (emb.cos(), emb.sin()) + + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + dim=0, + # Select dtype based on the following factors: + # - FA2 requires that cu_seqlens_q must have dtype int32 + # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw + # See https://github.com/huggingface/transformers/pull/34852 for more information + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + for layer_num, blk in enumerate(self.blocks): + if layer_num in self.fullatt_block_indexes: + cu_seqlens_now = cu_seqlens + else: + cu_seqlens_now = cu_window_seqlens + if self.gradient_checkpointing and self.training: + hidden_states = self._gradient_checkpointing_func( + blk.__call__, hidden_states, cu_seqlens_now, None, position_embeddings + ) + else: + hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings) + + hidden_states = self.merger(hidden_states) + reverse_indices = torch.argsort(window_index) + hidden_states = hidden_states[reverse_indices, :] + + return hidden_states + + +class Qwen2_5_VLRotaryEmbedding(nn.Module): + def __init__(self, config: Qwen2_5_VLTextConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + # In contrast to other models, Qwen2_5_VL has different position ids for the grids + # So we expand the inv_freq to shape (3, ...) + inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) + position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class Qwen2MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1): + """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). + + Explanation: + Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding + sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For + vision embedding part, we apply rotary position embedding on temporal, height and width dimension separately. + Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding. + For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal, + height and width) of text embedding is always the same, so the text embedding rotary position embedding has no + difference with modern LLMs. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + mrope_section(`List(int)`): + Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + mrope_section = mrope_section * 2 + cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze( + unsqueeze_dim + ) + sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze( + unsqueeze_dim + ) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class Qwen2_5_VLAttention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + """ + + def __init__(self, config: Qwen2_5_VLTextConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.is_causal = True + self.attention_dropout = config.attention_dropout + self.rope_scaling = config.rope_scaling + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + self.rotary_emb = Qwen2_5_VLRotaryEmbedding(config=config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] + ) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # Fix precision issues in Qwen2-VL float16 inference + # Replace inf values with zeros in attention weights to prevent NaN propagation + if query_states.dtype == torch.float16: + attn_weights = torch.where(torch.isinf(attn_weights), torch.zeros_like(attn_weights), attn_weights) + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class Qwen2_5_VLFlashAttention2(Qwen2_5_VLAttention): + """ + Qwen2_5_VL flash attention module, following Qwen2_5_VL attention module. This module inherits from `Qwen2_5_VLAttention` + as the weights of the module stays untouched. The only required change would be on the forward pass + where it needs to correctly call the public API of flash attention and deal with padding tokens + in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom + config.max_window_layers layers. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + ): + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + # Because the input can be padded, the absolute sequence length depends on the max position id. + cos, sin = position_embeddings + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] + ) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + dropout_rate = 0.0 if not self.training else self.attention_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + if ( + self.config.use_sliding_window + and getattr(self.config, "sliding_window", None) is not None + and self.layer_idx >= self.config.max_window_layers + ): + sliding_window = self.config.sliding_window + else: + sliding_window = None + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + sliding_window=sliding_window, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class Qwen2_5_VLSdpaAttention(Qwen2_5_VLAttention): + """ + Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from Qwen2Attention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "Qwen2_5_VLModel is using Qwen2_5_VLSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] + ) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +QWEN2_5_VL_ATTENTION_CLASSES = { + "eager": Qwen2_5_VLAttention, + "flash_attention_2": Qwen2_5_VLFlashAttention2, + "sdpa": Qwen2_5_VLSdpaAttention, +} + + +class Qwen2_5_VLDecoderLayer(nn.Module): + def __init__(self, config: Qwen2_5_VLTextConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + if config.use_sliding_window and config._attn_implementation != "flash_attention_2": + logger.warning_once( + f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " + "unexpected results may be encountered." + ) + self.self_attn = QWEN2_5_VL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) + + self.mlp = Qwen2MLP(config) + self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +@add_start_docstrings( + "The bare Qwen2_5_VL Model outputting raw hidden-states without any specific head on top.", + Qwen2_5_VL_START_DOCSTRING, +) +class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel): + config_class = Qwen2_5_VLTextConfig + + def __init__(self, config: Qwen2_5_VLTextConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Qwen2_5_VLDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self._attn_implementation = config._attn_implementation + self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen2_5_VLRotaryEmbedding(config=config) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # torch.jit.trace() doesn't support cache objects in the output + if use_cache and past_key_values is None and not torch.jit.is_tracing(): + past_key_values = DynamicCache() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + # the hard coded `3` is for temporal, height and width. + if position_ids is None: + position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) + elif position_ids.dim() == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, "BlockMask"], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and past_key_values is not None: + is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Qwen2_5_VL. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if ( + self.config._attn_implementation == "sdpa" + and not (using_static_cache or using_sliding_window_cache) + and not output_attentions + ): + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + # SlidingWindowCache or StaticCache + if using_sliding_window_cache or using_static_cache: + target_length = past_key_values.get_max_cache_shape() + # DynamicCache or no cache + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + config=self.config, + past_key_values=past_key_values, + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + config: Qwen2_5_VLConfig, + past_key_values: Cache, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to place the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + config (`Qwen2_5_VLConfig`): + The model's configuration class + past_key_values (`Cache`): + The cache class that is being used currently to generate + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + if config.get_text_config().sliding_window is not None: + # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also + # the check is needed to verify is current checkpoint was trained with sliding window or not + if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: + sliding_attend_mask = torch.arange(target_length, device=device) <= ( + cache_position.reshape(-1, 1) - config.get_text_config().sliding_window + ) + diagonal_attend_mask.bitwise_or_(sliding_attend_mask) + causal_mask *= diagonal_attend_mask + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.shape[-1] > target_length: + attention_mask = attention_mask[:, :target_length] + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + return causal_mask + + +@dataclass +class Qwen2_5_VLCausalLMOutputWithPast(ModelOutput): + """ + Base class for Qwen2_5_VL causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + past_key_values: Optional[List[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + rope_deltas: Optional[torch.LongTensor] = None + + +QWEN2_5_VL_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + pixel_values (`torch.FloatTensor` of shape `(seq_length, num_channels * image_size * image_size)): + The tensors corresponding to the input images. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`Qwen2_5_VLImageProcessor.__call__`] for details. [`Qwen2_5_VLProcessor`] uses + [`Qwen2_5_VLImageProcessor`] for processing images. + pixel_values_videos (`torch.FloatTensor` of shape `(seq_length, num_channels * temporal_size * image_size * image_size)): + The tensors corresponding to the input videos. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`Qwen2_5_VLImageProcessor.__call__`] for details. [`Qwen2_5_VLProcessor`] uses + [`Qwen2_5_VLImageProcessor`] for processing videos. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. +""" + + +class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + config_class = Qwen2_5_VLConfig + _no_split_modules = ["Qwen2_5_VLDecoderLayer", "Qwen2_5_VLVisionBlock"] + + def __init__(self, config): + super().__init__(config) + self.visual = Qwen2_5_VisionTransformerPretrainedModel._from_config(config.vision_config) + + text_config = config.get_text_config() + self.model = Qwen2_5_VLModel._from_config(text_config) + self.vocab_size = text_config.vocab_size + self.lm_head = nn.Linear(text_config.hidden_size, text_config.vocab_size, bias=False) + self.rope_deltas = None # cache rope_deltas here + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def get_rope_index( + self, + input_ids: Optional[torch.LongTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Calculate the 3D rope index based on image and video's temporal, height and width in LLM. + + Explanation: + Each embedding sequence contains vision embedding and text embedding or just contains text embedding. + + For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs. + Examples: + input_ids: [T T T T T], here T is for text. + temporal position_ids: [0, 1, 2, 3, 4] + height position_ids: [0, 1, 2, 3, 4] + width position_ids: [0, 1, 2, 3, 4] + + For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part + and 1D rotary position embedding for text part. + Examples: + Temporal (Time): 3 patches, representing different segments of the video in time. + Height: 2 patches, dividing each frame vertically. + Width: 2 patches, dividing each frame horizontally. + We also have some important parameters: + fps (Frames Per Second): The video's frame rate, set to 1. This means one frame is processed each second. + tokens_per_second: This is a crucial parameter. It dictates how many "time-steps" or "temporal tokens" are conceptually packed into a one-second interval of the video. In this case, we have 25 tokens per second. So each second of the video will be represented with 25 separate time points. It essentially defines the temporal granularity. + temporal_patch_size: The number of frames that compose one temporal patch. Here, it's 2 frames. + interval: The step size for the temporal position IDs, calculated as tokens_per_second * temporal_patch_size / fps. In this case, 25 * 2 / 1 = 50. This means that each temporal patch will be have a difference of 50 in the temporal position IDs. + input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision. + vision temporal position_ids: [0, 0, 0, 0, 50, 50, 50, 50, 100, 100, 100, 100] + vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1] + vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] + text temporal position_ids: [101, 102, 103, 104, 105] + text height position_ids: [101, 102, 103, 104, 105] + text width position_ids: [101, 102, 103, 104, 105] + Here we calculate the text start position_ids as the max vision position_ids plus 1. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*): + The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + Returns: + position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`) + mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`) + """ + spatial_merge_size = self.config.vision_config.spatial_merge_size + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + mrope_position_deltas = [] + if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): + total_input_ids = input_ids + if attention_mask is None: + attention_mask = torch.ones_like(total_input_ids) + position_ids = torch.ones( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ) + image_index, video_index = 0, 0 + attention_mask = attention_mask.to(total_input_ids.device) + for i, input_ids in enumerate(total_input_ids): + input_ids = input_ids[attention_mask[i] == 1] + image_nums, video_nums = 0, 0 + vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + second_per_grid_t = 0 + image_index += 1 + remain_images -= 1 + ed = ed_image + + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + if second_per_grid_ts is not None: + second_per_grid_t = second_per_grid_ts[video_index] + else: + second_per_grid_t = 1.0 + video_index += 1 + remain_videos -= 1 + ed = ed_video + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + range_tensor = torch.arange(llm_grid_t).view(-1, 1) + expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w) + + ## normalize type, send to device. + second_per_grid_t = torch.as_tensor( + second_per_grid_t, dtype=range_tensor.dtype, device=range_tensor.device + ) + + time_tensor = expanded_range * second_per_grid_t * self.config.vision_config.tokens_per_second + + time_tensor_long = time_tensor.long() + t_index = time_tensor_long.flatten() + + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) + mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) + mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) + return position_ids, mrope_position_deltas + else: + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) + max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] + mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] + else: + position_ids = ( + torch.arange(input_ids.shape[1], device=input_ids.device) + .view(1, 1, -1) + .expand(3, input_ids.shape[0], -1) + ) + mrope_position_deltas = torch.zeros( + [input_ids.shape[0], 1], + device=input_ids.device, + dtype=input_ids.dtype, + ) + + return position_ids, mrope_position_deltas + + @add_start_docstrings_to_model_forward(QWEN2_5_VL_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Qwen2_5_VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + ) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration + + >>> model = Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") + >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") + + >>> messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What is shown in this image?"}, + ], + }, + ] + >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos]) + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if inputs_embeds is None: + inputs_embeds = self.model.embed_tokens(input_ids) + if pixel_values is not None: + pixel_values = pixel_values.type(self.visual.dtype) + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + n_image_tokens = (input_ids == self.config.image_token_id).sum().item() + n_image_features = image_embeds.shape[0] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + + mask = input_ids == self.config.image_token_id + mask_unsqueezed = mask.unsqueeze(-1) + mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) + image_mask = mask_expanded.to(inputs_embeds.device) + + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if pixel_values_videos is not None: + pixel_values_videos = pixel_values_videos.type(self.visual.dtype) + video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) + n_video_tokens = (input_ids == self.config.video_token_id).sum().item() + n_video_features = video_embeds.shape[0] + if n_video_tokens != n_video_features: + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" + ) + + mask = input_ids == self.config.video_token_id + mask_unsqueezed = mask.unsqueeze(-1) + mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) + video_mask = mask_expanded.to(inputs_embeds.device) + + video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + if attention_mask is not None: + attention_mask = attention_mask.to(inputs_embeds.device) + + # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme + if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): + # calculate RoPE index once per generation in the pre-fill stage only + if ( + (cache_position is not None and cache_position[0] == 0) + or self.rope_deltas is None + or (past_key_values is None or past_key_values.get_seq_length() == 0) + ): + position_ids, rope_deltas = self.get_rope_index( + input_ids, + image_grid_thw, + video_grid_thw, + second_per_grid_ts, + attention_mask, + ) + self.rope_deltas = rope_deltas + # then use the prev pre-calculated rope-deltas to get the correct position ids + else: + batch_size, seq_length, _ = inputs_embeds.shape + delta = ( + (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) + if cache_position is not None + else 0 + ) + position_ids = torch.arange(seq_length, device=inputs_embeds.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + if cache_position is not None: # otherwise `deltas` is an int `0` + delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + + outputs = self.model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return Qwen2_5_VLCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=self.rope_deltas, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + pixel_values_videos=None, + image_grid_thw=None, + video_grid_thw=None, + second_per_grid_ts=None, + **kwargs, + ): + # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + position_ids=position_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + use_cache=use_cache, + **kwargs, + ) + + # Qwen2-5-VL position_ids are prepareed with rope_deltas in forward + model_inputs["position_ids"] = None + + if cache_position[0] != 0: + model_inputs["pixel_values"] = None + model_inputs["pixel_values_videos"] = None + + return model_inputs + + def _get_image_nums_and_video_nums( + self, + input_ids: Optional[torch.LongTensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Get the number of images and videos for each sample to calculate the separation length of the sample tensor. + These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Returns: + image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`) + video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`) + """ + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + + vision_start_mask = input_ids == vision_start_token_id + vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1) + image_mask = input_ids == image_token_id + video_mask = input_ids == video_token_id + image_nums = torch.sum(vision_first_mask & image_mask, dim=1) + video_nums = torch.sum(vision_first_mask & video_mask, dim=1) + + return image_nums, video_nums + + def _expand_inputs_for_generation( + self, + expand_size: int = 1, + is_encoder_decoder: bool = False, + input_ids: Optional[torch.LongTensor] = None, + **model_kwargs, + ) -> Tuple[torch.LongTensor, Dict[str, Any]]: + # Overwritten -- Support for expanding tensors without a batch size dimension + # e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw, second_per_grid_t + # pixel_values.shape[0] is sum(seqlen_images for samples) + # image_grid_thw.shape[0] is sum(num_images for samples) + + if expand_size == 1: + return input_ids, model_kwargs + + visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw", "second_per_grid_ts"] + + def _expand_dict_for_generation_visual(dict_to_expand): + image_grid_thw = model_kwargs.get("image_grid_thw", None) + video_grid_thw = model_kwargs.get("video_grid_thw", None) + image_nums, video_nums = self._get_image_nums_and_video_nums(input_ids) + + def _repeat_interleave_samples(x, lengths, repeat_times): + samples = torch.split(x, lengths) + repeat_args = [repeat_times] + [1] * (x.dim() - 1) + result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0) + return result + + for key in dict_to_expand: + if key == "pixel_values": + # split images into samples + samples = torch.split(image_grid_thw, list(image_nums)) + # compute the sequence length of images for each sample + lengths = [torch.prod(sample, dim=1).sum() for sample in samples] + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "image_grid_thw": + # get the num of images for each sample + lengths = list(image_nums) + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "pixel_values_videos": + samples = torch.split(video_grid_thw, list(video_nums)) + lengths = [torch.prod(sample, dim=1).sum() for sample in samples] + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "video_grid_thw": + lengths = list(video_nums) + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "second_per_grid_ts": + if not isinstance(dict_to_expand[key], list): + raise TypeError( + f"Expected value for key '{key}' to be a list, but got {type(dict_to_expand[key])} instead." + ) + tensor = torch.tensor(dict_to_expand[key]) + lengths = list(video_nums) + tensor = _repeat_interleave_samples(tensor, lengths=lengths, repeat_times=expand_size) + dict_to_expand[key] = tensor.tolist() + return dict_to_expand + + def _expand_dict_for_generation(dict_to_expand): + for key in dict_to_expand: + if ( + key != "cache_position" + and dict_to_expand[key] is not None + and isinstance(dict_to_expand[key], torch.Tensor) + and key not in visual_keys + ): + dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0) + return dict_to_expand + + # input_ids is required for expanding visual inputs + # If input_ids is unavailable, visual inputs will not be used; therefore, there is no need to expand visual inputs. + if input_ids is not None and input_ids.numel() != 0: + model_kwargs = _expand_dict_for_generation_visual(model_kwargs) + + if input_ids is not None: + input_ids = input_ids.repeat_interleave(expand_size, dim=0) + + model_kwargs = _expand_dict_for_generation(model_kwargs) + + if is_encoder_decoder: + if model_kwargs.get("encoder_outputs") is None: + raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.") + model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"]) + + return input_ids, model_kwargs + + +__all__ = ["Qwen2_5_VLForConditionalGeneration", "Qwen2_5_VLModel", "Qwen2_5_VLPreTrainedModel"] diff --git a/docs/transformers/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/docs/transformers/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py new file mode 100644 index 0000000000000000000000000000000000000000..e34724c7906a4c76f7f216d65ef0c99314d841bb --- /dev/null +++ b/docs/transformers/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py @@ -0,0 +1,988 @@ +# coding=utf-8 +# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Qwen2.5-VL model.""" + +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint +from torch.nn import CrossEntropyLoss + +from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLConfig, Qwen2VLTextConfig +from transformers.models.qwen2_vl.modeling_qwen2_vl import ( + PatchEmbed, + PatchMerger, + Qwen2RMSNorm, + Qwen2VLCausalLMOutputWithPast, + Qwen2VLForConditionalGeneration, + Qwen2VLModel, + Qwen2VLPreTrainedModel, + VisionAttention, + VisionRotaryEmbedding, + VisionSdpaAttention, +) +from transformers.models.qwen2_vl.processing_qwen2_vl import Qwen2VLImagesKwargs, Qwen2VLProcessor + +from ...activations import ACT2FN +from ...configuration_utils import PretrainedConfig +from ...feature_extraction_utils import BatchFeature +from ...image_utils import ImageInput, VideoInput +from ...modeling_flash_attention_utils import is_flash_attn_available +from ...processing_utils import ProcessingKwargs, Unpack, VideosKwargs +from ...tokenization_utils_base import PreTokenizedInput, TextInput +from ...utils import logging + + +if is_flash_attn_available(): + from ...modeling_flash_attention_utils import apply_rotary_emb, flash_attn_varlen_func + + +logger = logging.get_logger(__name__) + + +def apply_rotary_pos_emb_flashatt( + q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + cos = cos.chunk(2, dim=-1)[0].contiguous() + sin = sin.chunk(2, dim=-1)[0].contiguous() + q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q) + k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k) + return q_embed, k_embed + + +class Qwen2_5_VLVisionConfig(PretrainedConfig): + model_type = "qwen2_5_vl" + base_config_key = "vision_config" + + def __init__( + self, + depth=32, + hidden_size=3584, + hidden_act="silu", + intermediate_size=3420, + num_heads=16, + in_channels=3, + patch_size=14, + spatial_merge_size=2, + temporal_patch_size=2, + tokens_per_second=4, + window_size=112, + out_hidden_size=3584, + fullatt_block_indexes=[7, 15, 23, 31], + initializer_range=0.02, + **kwargs, + ): + super().__init__(**kwargs) + + self.depth = depth + self.hidden_size = hidden_size + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.num_heads = num_heads + self.in_channels = in_channels + self.patch_size = patch_size + self.spatial_merge_size = spatial_merge_size + self.temporal_patch_size = temporal_patch_size + self.tokens_per_second = tokens_per_second + self.window_size = window_size + self.fullatt_block_indexes = fullatt_block_indexes + self.out_hidden_size = out_hidden_size + self.initializer_range = initializer_range + + +class Qwen2_5_VLTextConfig(Qwen2VLTextConfig): + model_type = "qwen2_5_vl_text" + + +class Qwen2_5_VLConfig(Qwen2VLConfig): + model_type = "qwen2_5_vl" + sub_configs = {"vision_config": Qwen2_5_VLVisionConfig, "text_config": Qwen2_5_VLTextConfig} + + +class Qwen2_5_VLMLP(nn.Module): + def __init__(self, config, bias: bool = False): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_state): + return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + + +class Qwen2_5_VisionPatchEmbed(PatchEmbed): + pass + + +class Qwen2_5_VisionRotaryEmbedding(VisionRotaryEmbedding): + pass + + +class Qwen2_5_VLPatchMerger(PatchMerger): + def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None: + super().__init__(dim, context_dim, spatial_merge_size) + self.ln_q = Qwen2RMSNorm(context_dim, eps=1e-6) + + +class Qwen2_5_VLVisionFlashAttention2(nn.Module): + def __init__(self, dim: int, num_heads: int = 16) -> None: + super().__init__() + self.num_heads = num_heads + self.qkv = nn.Linear(dim, dim * 3, bias=True) + self.proj = nn.Linear(dim, dim) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: Optional[torch.Tensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be " + "removed and `position_embeddings` will be mandatory." + ) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + cos = emb.cos() + sin = emb.sin() + else: + cos, sin = position_embeddings + q, k = apply_rotary_pos_emb_flashatt(q.unsqueeze(0), k.unsqueeze(0), cos, sin) + q = q.squeeze(0) + k = k.squeeze(0) + + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape( + seq_length, -1 + ) + attn_output = self.proj(attn_output) + return attn_output + + +class Qwen2_5_VLVisionAttention(VisionAttention): + pass + + +class Qwen2_5_VLVisionSdpaAttention(VisionSdpaAttention): + pass + + +QWEN2_5_VL_VISION_ATTENTION_CLASSES = { + "eager": Qwen2_5_VLVisionAttention, + "flash_attention_2": Qwen2_5_VLVisionFlashAttention2, + "sdpa": Qwen2_5_VLVisionSdpaAttention, +} + + +class Qwen2_5_VLVisionBlock(nn.Module): + def __init__(self, config, attn_implementation: str = "sdpa") -> None: + super().__init__() + self.norm1 = Qwen2RMSNorm(config.hidden_size, eps=1e-6) + self.norm2 = Qwen2RMSNorm(config.hidden_size, eps=1e-6) + self.attn = QWEN2_5_VL_VISION_ATTENTION_CLASSES[attn_implementation]( + config.hidden_size, num_heads=config.num_heads + ) + self.mlp = Qwen2_5_VLMLP(config, bias=True) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: Optional[torch.Tensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states), + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + position_embeddings=position_embeddings, + ) + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + return hidden_states + + +class Qwen2_5_VLPreTrainedModel(Qwen2VLPreTrainedModel): + def _init_weights(self, module): + std = self.config.get_text_config().initializer_range + if isinstance(module, (nn.Linear, nn.Conv3d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, Qwen2RMSNorm): + module.weight.data.fill_(1.0) + + +class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel): + config_class = Qwen2_5_VLVisionConfig + _no_split_modules = ["Qwen2_5_VLVisionBlock"] + + def __init__(self, config, *inputs, **kwargs) -> None: + super().__init__(config, *inputs, **kwargs) + self.spatial_merge_size = config.spatial_merge_size + self.patch_size = config.patch_size + self.fullatt_block_indexes = config.fullatt_block_indexes + self.window_size = config.window_size + self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size + + self.patch_embed = Qwen2_5_VisionPatchEmbed( + patch_size=config.patch_size, + temporal_patch_size=config.temporal_patch_size, + in_channels=config.in_channels, + embed_dim=config.hidden_size, + ) + + head_dim = config.hidden_size // config.num_heads + self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2) + + self.blocks = nn.ModuleList( + [Qwen2_5_VLVisionBlock(config, config._attn_implementation) for _ in range(config.depth)] + ) + self.merger = Qwen2_5_VLPatchMerger( + dim=config.out_hidden_size, + context_dim=config.hidden_size, + spatial_merge_size=config.spatial_merge_size, + ) + self.gradient_checkpointing = False + + def rot_pos_emb(self, grid_thw): + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + + def get_window_index(self, grid_thw): + window_index: list = [] + cu_window_seqlens: list = [0] + window_index_id = 0 + vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size + + for grid_t, grid_h, grid_w in grid_thw: + llm_grid_h, llm_grid_w = ( + grid_h // self.spatial_merge_size, + grid_w // self.spatial_merge_size, + ) + index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w) + pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size + pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size + num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size + num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size + index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100) + index_padded = index_padded.reshape( + grid_t, + num_windows_h, + vit_merger_window_size, + num_windows_w, + vit_merger_window_size, + ) + index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( + grid_t, + num_windows_h * num_windows_w, + vit_merger_window_size, + vit_merger_window_size, + ) + seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) + index_padded = index_padded.reshape(-1) + index_new = index_padded[index_padded != -100] + window_index.append(index_new + window_index_id) + cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1] + cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) + window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() + window_index = torch.cat(window_index, dim=0) + + return window_index, cu_window_seqlens + + def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor: + """ + Args: + hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`): + The final hidden states of the model. + grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): + The temporal, height and width of feature shape of each image in LLM. + + Returns: + `torch.Tensor`: hidden_states. + """ + hidden_states = self.patch_embed(hidden_states) + rotary_pos_emb = self.rot_pos_emb(grid_thw) + window_index, cu_window_seqlens = self.get_window_index(grid_thw) + cu_window_seqlens = torch.tensor( + cu_window_seqlens, + device=hidden_states.device, + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) + + seq_len, _ = hidden_states.size() + hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + hidden_states = hidden_states[window_index, :, :] + hidden_states = hidden_states.reshape(seq_len, -1) + rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + rotary_pos_emb = rotary_pos_emb[window_index, :, :] + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + position_embeddings = (emb.cos(), emb.sin()) + + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + dim=0, + # Select dtype based on the following factors: + # - FA2 requires that cu_seqlens_q must have dtype int32 + # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw + # See https://github.com/huggingface/transformers/pull/34852 for more information + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + for layer_num, blk in enumerate(self.blocks): + if layer_num in self.fullatt_block_indexes: + cu_seqlens_now = cu_seqlens + else: + cu_seqlens_now = cu_window_seqlens + if self.gradient_checkpointing and self.training: + hidden_states = self._gradient_checkpointing_func( + blk.__call__, hidden_states, cu_seqlens_now, None, position_embeddings + ) + else: + hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings) + + hidden_states = self.merger(hidden_states) + reverse_indices = torch.argsort(window_index) + hidden_states = hidden_states[reverse_indices, :] + + return hidden_states + + +class Qwen2_5_VLModel(Qwen2VLModel): + pass + + +@dataclass +class Qwen2_5_VLCausalLMOutputWithPast(Qwen2VLCausalLMOutputWithPast): + pass + + +class Qwen2_5_VLForConditionalGeneration(Qwen2VLForConditionalGeneration): + config_class = Qwen2_5_VLConfig + _no_split_modules = ["Qwen2_5_VLDecoderLayer", "Qwen2_5_VLVisionBlock"] + + def __init__(self, config): + super().__init__(config) + self.visual = Qwen2_5_VisionTransformerPretrainedModel._from_config(config.vision_config) + + def get_rope_index( + self, + input_ids: Optional[torch.LongTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Calculate the 3D rope index based on image and video's temporal, height and width in LLM. + + Explanation: + Each embedding sequence contains vision embedding and text embedding or just contains text embedding. + + For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs. + Examples: + input_ids: [T T T T T], here T is for text. + temporal position_ids: [0, 1, 2, 3, 4] + height position_ids: [0, 1, 2, 3, 4] + width position_ids: [0, 1, 2, 3, 4] + + For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part + and 1D rotary position embedding for text part. + Examples: + Temporal (Time): 3 patches, representing different segments of the video in time. + Height: 2 patches, dividing each frame vertically. + Width: 2 patches, dividing each frame horizontally. + We also have some important parameters: + fps (Frames Per Second): The video's frame rate, set to 1. This means one frame is processed each second. + tokens_per_second: This is a crucial parameter. It dictates how many "time-steps" or "temporal tokens" are conceptually packed into a one-second interval of the video. In this case, we have 25 tokens per second. So each second of the video will be represented with 25 separate time points. It essentially defines the temporal granularity. + temporal_patch_size: The number of frames that compose one temporal patch. Here, it's 2 frames. + interval: The step size for the temporal position IDs, calculated as tokens_per_second * temporal_patch_size / fps. In this case, 25 * 2 / 1 = 50. This means that each temporal patch will be have a difference of 50 in the temporal position IDs. + input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision. + vision temporal position_ids: [0, 0, 0, 0, 50, 50, 50, 50, 100, 100, 100, 100] + vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1] + vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] + text temporal position_ids: [101, 102, 103, 104, 105] + text height position_ids: [101, 102, 103, 104, 105] + text width position_ids: [101, 102, 103, 104, 105] + Here we calculate the text start position_ids as the max vision position_ids plus 1. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*): + The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + Returns: + position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`) + mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`) + """ + spatial_merge_size = self.config.vision_config.spatial_merge_size + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + mrope_position_deltas = [] + if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): + total_input_ids = input_ids + if attention_mask is None: + attention_mask = torch.ones_like(total_input_ids) + position_ids = torch.ones( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ) + image_index, video_index = 0, 0 + attention_mask = attention_mask.to(total_input_ids.device) + for i, input_ids in enumerate(total_input_ids): + input_ids = input_ids[attention_mask[i] == 1] + image_nums, video_nums = 0, 0 + vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + second_per_grid_t = 0 + image_index += 1 + remain_images -= 1 + ed = ed_image + + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + if second_per_grid_ts is not None: + second_per_grid_t = second_per_grid_ts[video_index] + else: + second_per_grid_t = 1.0 + video_index += 1 + remain_videos -= 1 + ed = ed_video + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + range_tensor = torch.arange(llm_grid_t).view(-1, 1) + expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w) + + ## normalize type, send to device. + second_per_grid_t = torch.as_tensor( + second_per_grid_t, dtype=range_tensor.dtype, device=range_tensor.device + ) + + time_tensor = expanded_range * second_per_grid_t * self.config.vision_config.tokens_per_second + + time_tensor_long = time_tensor.long() + t_index = time_tensor_long.flatten() + + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) + mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) + mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) + return position_ids, mrope_position_deltas + else: + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) + max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] + mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] + else: + position_ids = ( + torch.arange(input_ids.shape[1], device=input_ids.device) + .view(1, 1, -1) + .expand(3, input_ids.shape[0], -1) + ) + mrope_position_deltas = torch.zeros( + [input_ids.shape[0], 1], + device=input_ids.device, + dtype=input_ids.dtype, + ) + + return position_ids, mrope_position_deltas + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + ) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration + + >>> model = Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") + >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") + + >>> messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What is shown in this image?"}, + ], + }, + ] + >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos]) + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if inputs_embeds is None: + inputs_embeds = self.model.embed_tokens(input_ids) + if pixel_values is not None: + pixel_values = pixel_values.type(self.visual.dtype) + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + n_image_tokens = (input_ids == self.config.image_token_id).sum().item() + n_image_features = image_embeds.shape[0] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + + mask = input_ids == self.config.image_token_id + mask_unsqueezed = mask.unsqueeze(-1) + mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) + image_mask = mask_expanded.to(inputs_embeds.device) + + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if pixel_values_videos is not None: + pixel_values_videos = pixel_values_videos.type(self.visual.dtype) + video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) + n_video_tokens = (input_ids == self.config.video_token_id).sum().item() + n_video_features = video_embeds.shape[0] + if n_video_tokens != n_video_features: + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" + ) + + mask = input_ids == self.config.video_token_id + mask_unsqueezed = mask.unsqueeze(-1) + mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) + video_mask = mask_expanded.to(inputs_embeds.device) + + video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + if attention_mask is not None: + attention_mask = attention_mask.to(inputs_embeds.device) + + # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme + if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): + # calculate RoPE index once per generation in the pre-fill stage only + if ( + (cache_position is not None and cache_position[0] == 0) + or self.rope_deltas is None + or (past_key_values is None or past_key_values.get_seq_length() == 0) + ): + position_ids, rope_deltas = self.get_rope_index( + input_ids, + image_grid_thw, + video_grid_thw, + second_per_grid_ts, + attention_mask, + ) + self.rope_deltas = rope_deltas + # then use the prev pre-calculated rope-deltas to get the correct position ids + else: + batch_size, seq_length, _ = inputs_embeds.shape + delta = ( + (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) + if cache_position is not None + else 0 + ) + position_ids = torch.arange(seq_length, device=inputs_embeds.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + if cache_position is not None: # otherwise `deltas` is an int `0` + delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + + outputs = self.model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return Qwen2_5_VLCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=self.rope_deltas, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + pixel_values_videos=None, + image_grid_thw=None, + video_grid_thw=None, + second_per_grid_ts=None, + **kwargs, + ): + # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + position_ids=position_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + use_cache=use_cache, + **kwargs, + ) + + # Qwen2-5-VL position_ids are prepareed with rope_deltas in forward + model_inputs["position_ids"] = None + + if cache_position[0] != 0: + model_inputs["pixel_values"] = None + model_inputs["pixel_values_videos"] = None + + return model_inputs + + +class Qwen2_5_VLVideosProcessorKwargs(VideosKwargs, total=False): + fps: Union[List[float], float] + + +class Qwen2_5_VLImagesKwargs(Qwen2VLImagesKwargs): + pass + + +class Qwen2_5_VLProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs: Qwen2_5_VLImagesKwargs + videos_kwargs: Qwen2_5_VLVideosProcessorKwargs + _defaults = { + "text_kwargs": { + "padding": False, + }, + "videos_kwargs": {"fps": 2.0}, + } + + +class Qwen2_5_VLProcessor(Qwen2VLProcessor): + r""" + Constructs a Qwen2.5-VL processor which wraps a Qwen2.5-VL image processor and a Qwen2 tokenizer into a single processor. + [`Qwen2_5_VLProcessor`] offers all the functionalities of [`Qwen2VLImageProcessor`] and [`Qwen2TokenizerFast`]. See the + [`~Qwen2_5_VLProcessor.__call__`] and [`~Qwen2_5_VLProcessor.decode`] for more information. + Args: + image_processor ([`Qwen2VLImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`Qwen2TokenizerFast`], *optional*): + The tokenizer is a required input. + chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages + in a chat into a tokenizable string. + """ + + image_processor_class = "AutoImageProcessor" + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + names_from_processor = list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) + return names_from_processor + ["second_per_grid_ts"] + + def __call__( + self, + images: ImageInput = None, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + videos: VideoInput = None, + **kwargs: Unpack[Qwen2_5_VLProcessorKwargs], + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` + and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode + the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to + Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`. + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch + tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + - **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`. + - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`. + - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`. + - **second_per_grid_ts** -- List of video seconds per time grid. Returned when `videos` is not `None`. + """ + output_kwargs = self._merge_kwargs( + Qwen2_5_VLProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + if images is not None: + image_inputs = self.image_processor(images=images, videos=None, **output_kwargs["images_kwargs"]) + image_grid_thw = image_inputs["image_grid_thw"] + else: + image_inputs = {} + image_grid_thw = None + + if videos is not None: + videos_inputs = self.image_processor(images=None, videos=videos, **output_kwargs["images_kwargs"]) + video_grid_thw = videos_inputs["video_grid_thw"] + + fps = output_kwargs["videos_kwargs"].pop("fps", 2.0) + if isinstance(fps, (int, float)): + second_per_grid_ts = [self.image_processor.temporal_patch_size / fps] * len(video_grid_thw) + elif hasattr(fps, "__len__") and len(fps) == len(video_grid_thw): + second_per_grid_ts = [self.image_processor.temporal_patch_size / tmp for tmp in fps] + else: + raise ValueError( + f"The length of fps ({len(fps) if hasattr(fps, '__len__') else fps}) must be equal to the length of video_grid_thw ({len(video_grid_thw)}) or fps should be a single number." + ) + videos_inputs.update({"second_per_grid_ts": second_per_grid_ts}) + + else: + videos_inputs = {} + video_grid_thw = None + + if not isinstance(text, list): + text = [text] + + text = text.copy() # below lines change text in-place + if image_grid_thw is not None: + merge_length = self.image_processor.merge_size**2 + index = 0 + for i in range(len(text)): + while self.image_token in text[i]: + num_image_tokens = image_grid_thw[index].prod() // merge_length + text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1) + index += 1 + text[i] = text[i].replace("<|placeholder|>", self.image_token) + + if video_grid_thw is not None: + merge_length = self.image_processor.merge_size**2 + index = 0 + for i in range(len(text)): + while self.video_token in text[i]: + num_video_tokens = video_grid_thw[index].prod() // merge_length + text[i] = text[i].replace(self.video_token, "<|placeholder|>" * num_video_tokens, 1) + index += 1 + text[i] = text[i].replace("<|placeholder|>", self.video_token) + + return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) + text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) + self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"]) + + return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}, tensor_type=return_tensors) + + +__all__ = [ + "Qwen2_5_VLConfig", + "Qwen2_5_VLTextConfig", + "Qwen2_5_VLForConditionalGeneration", + "Qwen2_5_VLModel", + "Qwen2_5_VLPreTrainedModel", + "Qwen2_5_VLProcessor", +] diff --git a/docs/transformers/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py b/docs/transformers/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py new file mode 100644 index 0000000000000000000000000000000000000000..d0994323833c869eb34da85ffd4604a3768864a3 --- /dev/null +++ b/docs/transformers/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py @@ -0,0 +1,246 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_qwen2_5_vl.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Optional, Union + +from ...feature_extraction_utils import BatchFeature +from ...image_utils import ImageInput, VideoInput +from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack, VideosKwargs +from ...tokenization_utils_base import PreTokenizedInput, TextInput + + +class Qwen2_5_VLVideosProcessorKwargs(VideosKwargs, total=False): + fps: Union[List[float], float] + + +class Qwen2_5_VLImagesKwargs(ImagesKwargs): + min_pixels: Optional[int] + max_pixels: Optional[int] + patch_size: Optional[int] + temporal_patch_size: Optional[int] + merge_size: Optional[int] + + +class Qwen2_5_VLProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs: Qwen2_5_VLImagesKwargs + videos_kwargs: Qwen2_5_VLVideosProcessorKwargs + _defaults = { + "text_kwargs": { + "padding": False, + }, + "videos_kwargs": {"fps": 2.0}, + } + + +class Qwen2_5_VLProcessor(ProcessorMixin): + r""" + Constructs a Qwen2.5-VL processor which wraps a Qwen2.5-VL image processor and a Qwen2 tokenizer into a single processor. + [`Qwen2_5_VLProcessor`] offers all the functionalities of [`Qwen2VLImageProcessor`] and [`Qwen2TokenizerFast`]. See the + [`~Qwen2_5_VLProcessor.__call__`] and [`~Qwen2_5_VLProcessor.decode`] for more information. + Args: + image_processor ([`Qwen2VLImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`Qwen2TokenizerFast`], *optional*): + The tokenizer is a required input. + chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages + in a chat into a tokenizable string. + """ + + attributes = ["image_processor", "tokenizer"] + valid_kwargs = ["chat_template"] + + image_processor_class = "AutoImageProcessor" + tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast") + + def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs): + self.image_token = "<|image_pad|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token + self.video_token = "<|video_pad|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token + self.image_token_id = ( + tokenizer.image_token_id + if getattr(tokenizer, "image_token_id", None) + else tokenizer.convert_tokens_to_ids(self.image_token) + ) + self.video_token_id = ( + tokenizer.video_token_id + if getattr(tokenizer, "video_token_id", None) + else tokenizer.convert_tokens_to_ids(self.video_token) + ) + super().__init__(image_processor, tokenizer, chat_template=chat_template) + + def __call__( + self, + images: ImageInput = None, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + videos: VideoInput = None, + **kwargs: Unpack[Qwen2_5_VLProcessorKwargs], + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` + and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode + the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to + Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`. + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch + tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + - **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`. + - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`. + - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`. + - **second_per_grid_ts** -- List of video seconds per time grid. Returned when `videos` is not `None`. + """ + output_kwargs = self._merge_kwargs( + Qwen2_5_VLProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + if images is not None: + image_inputs = self.image_processor(images=images, videos=None, **output_kwargs["images_kwargs"]) + image_grid_thw = image_inputs["image_grid_thw"] + else: + image_inputs = {} + image_grid_thw = None + + if videos is not None: + videos_inputs = self.image_processor(images=None, videos=videos, **output_kwargs["images_kwargs"]) + video_grid_thw = videos_inputs["video_grid_thw"] + + fps = output_kwargs["videos_kwargs"].pop("fps", 2.0) + if isinstance(fps, (int, float)): + second_per_grid_ts = [self.image_processor.temporal_patch_size / fps] * len(video_grid_thw) + elif hasattr(fps, "__len__") and len(fps) == len(video_grid_thw): + second_per_grid_ts = [self.image_processor.temporal_patch_size / tmp for tmp in fps] + else: + raise ValueError( + f"The length of fps ({len(fps) if hasattr(fps, '__len__') else fps}) must be equal to the length of video_grid_thw ({len(video_grid_thw)}) or fps should be a single number." + ) + videos_inputs.update({"second_per_grid_ts": second_per_grid_ts}) + + else: + videos_inputs = {} + video_grid_thw = None + + if not isinstance(text, list): + text = [text] + + text = text.copy() # below lines change text in-place + if image_grid_thw is not None: + merge_length = self.image_processor.merge_size**2 + index = 0 + for i in range(len(text)): + while self.image_token in text[i]: + num_image_tokens = image_grid_thw[index].prod() // merge_length + text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1) + index += 1 + text[i] = text[i].replace("<|placeholder|>", self.image_token) + + if video_grid_thw is not None: + merge_length = self.image_processor.merge_size**2 + index = 0 + for i in range(len(text)): + while self.video_token in text[i]: + num_video_tokens = video_grid_thw[index].prod() // merge_length + text[i] = text[i].replace(self.video_token, "<|placeholder|>" * num_video_tokens, 1) + index += 1 + text[i] = text[i].replace("<|placeholder|>", self.video_token) + + return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) + text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) + self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"]) + + return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}, tensor_type=return_tensors) + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + def post_process_image_text_to_text( + self, generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False, **kwargs + ): + """ + Post-process the output of the model to decode the text. + + Args: + generated_outputs (`torch.Tensor` or `np.ndarray`): + The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)` + or `(sequence_length,)`. + skip_special_tokens (`bool`, *optional*, defaults to `True`): + Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method. + Clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`): + Whether or not to clean up the tokenization spaces. Argument passed to the tokenizer's `batch_decode` method. + **kwargs: + Additional arguments to be passed to the tokenizer's `batch_decode method`. + + Returns: + `List[str]`: The decoded text. + """ + return self.tokenizer.batch_decode( + generated_outputs, + skip_special_tokens=skip_special_tokens, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + **kwargs, + ) + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + names_from_processor = list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) + return names_from_processor + ["second_per_grid_ts"] + + +__all__ = ["Qwen2_5_VLProcessor"] diff --git a/docs/transformers/src/transformers/models/qwen2_audio/__init__.py b/docs/transformers/src/transformers/models/qwen2_audio/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b126d712bd469e57188316a0b0ba127f31ddfd2e --- /dev/null +++ b/docs/transformers/src/transformers/models/qwen2_audio/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_qwen2_audio import * + from .modeling_qwen2_audio import * + from .processing_qwen2_audio import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/qwen2_audio/configuration_qwen2_audio.py b/docs/transformers/src/transformers/models/qwen2_audio/configuration_qwen2_audio.py new file mode 100644 index 0000000000000000000000000000000000000000..23a1c699dd4c5e548bcf139095997d6fdf4fadca --- /dev/null +++ b/docs/transformers/src/transformers/models/qwen2_audio/configuration_qwen2_audio.py @@ -0,0 +1,205 @@ +# coding=utf-8 +# Copyright 2024 Microsoft Research & University of Wisconsin-Madison and the HuggingFace Inc. team. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Qwen2Audio model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ..auto import CONFIG_MAPPING, AutoConfig + + +logger = logging.get_logger(__name__) + + +class Qwen2AudioEncoderConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen2AudioEncoder`]. It is used to instantiate a + Qwen2-Audio audio encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the audio encoder of the Qwen2-Audio + architecture. + + e.g. [Qwen/Qwen2-Audio-7B](https://huggingface.co/Qwen/Qwen2-Audio-7B) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + num_mel_bins (`int`, *optional*, defaults to 128): + Number of mel features used per input features. Should correspond to the value used in the + `Qwen2AudioProcessor` class. + encoder_layers (`int`, *optional*, defaults to 32): + Number of encoder layers. + encoder_attention_heads (`int`, *optional*, defaults to 20): + Number of attention heads for each attention layer in the Transformer encoder. + encoder_ffn_dim (`int`, *optional*, defaults to 5120): + Dimensionality of the "intermediate" (often named feed-forward) layer in encoder. + encoder_layerdrop (`float`, *optional*, defaults to 0.0): + The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) + for more details. + d_model (`int`, *optional*, defaults to 1280): + Dimensionality of the layers. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + activation_function (`str`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + activation_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for activations inside the fully connected layer. + scale_embedding (`bool`, *optional*, defaults to `False`): + Scale embeddings by diving by sqrt(d_model). + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + max_source_positions (`int`, *optional*, defaults to 1500): + The maximum sequence length of log-mel filter-bank features that this model might ever be used with. + + Example: + + ```python + >>> from transformers import Qwen2AudioEncoderConfig, Qwen2AudioEncoder + + >>> # Initializing a Qwen2AudioEncoderConfig + >>> configuration = Qwen2AudioEncoderConfig() + + >>> # Initializing a Qwen2AudioEncoder (with random weights) + >>> model = Qwen2AudioEncoder(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen2_audio_encoder" + + def __init__( + self, + num_mel_bins=128, + encoder_layers=32, + encoder_attention_heads=20, + encoder_ffn_dim=5120, + encoder_layerdrop=0.0, + d_model=1280, + dropout=0.0, + attention_dropout=0.0, + activation_function="gelu", + activation_dropout=0.0, + scale_embedding=False, + initializer_range=0.02, + max_source_positions=1500, + **kwargs, + ): + super().__init__(**kwargs) + + self.num_mel_bins = num_mel_bins + self.d_model = d_model + self.encoder_layers = encoder_layers + self.encoder_attention_heads = encoder_attention_heads + self.encoder_ffn_dim = encoder_ffn_dim + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_function = activation_function + self.activation_dropout = activation_dropout + self.encoder_layerdrop = encoder_layerdrop + self.num_hidden_layers = encoder_layers + self.initializer_range = initializer_range + self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True + self.max_source_positions = max_source_positions + + +class Qwen2AudioConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen2AudioForConditionalGeneration`]. It is used to instantiate an + Qwen2-Audio model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the Qwen2-Audio. + + e.g. [Qwen/Qwen2-Audio-7B](https://huggingface.co/Qwen/Qwen2-Audio-7B) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + audio_config (`Union[AutoConfig, dict]`, *optional*, defaults to `CLIPVisionConfig`): + The config object or dictionary of the audio backbone. + text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`): + The config object or dictionary of the text backbone. + audio_token_index (`int`, *optional*, defaults to 151646): + The image token index to encode the image prompt. + + Example: + + ```python + >>> from transformers import Qwen2AudioForConditionalGeneration, Qwen2AudioConfig, Qwen2AudioEncoderConfig, Qwen2Config + + >>> # Initializing a Qwen2AudioEncoder config + >>> audio_config = Qwen2AudioEncoderConfig() + + >>> # Initializing a Qwen2 config + >>> text_config = Qwen2Config() + + >>> # Initializing a Qwen2Audio configuration + >>> configuration = Qwen2AudioConfig(audio_config, text_config) + + >>> # Initializing a model from the qwen2-audio style configuration + >>> model = Qwen2AudioForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen2_audio" + attribute_map = { + "audio_token_id": "audio_token_index", + } + sub_configs = {"text_config": AutoConfig, "audio_config": AutoConfig} + + def __init__( + self, + audio_config=None, + text_config=None, + audio_token_index=151646, + **kwargs, + ): + self.audio_token_index = audio_token_index + + if isinstance(audio_config, dict): + audio_config["model_type"] = ( + audio_config["model_type"] if "model_type" in audio_config else "qwen2_audio_encoder" + ) + audio_config = CONFIG_MAPPING[audio_config["model_type"]](**audio_config) + elif audio_config is None: + audio_config = CONFIG_MAPPING["qwen2_audio_encoder"]( + d_model=1280, + encoder_attention_heads=20, + encoder_ffn_dim=5120, + encoder_layerdrop=0.0, + encoder_layers=32, + num_mel_bins=128, + max_source_positions=1500, + scale_embedding=False, + activation_function="gelu", + ) + + self.audio_config = audio_config + + if isinstance(text_config, dict): + text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "qwen2" + text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) + elif text_config is None: + text_config = CONFIG_MAPPING["qwen2"]() + + self.text_config = text_config + + super().__init__(**kwargs) + + +__all__ = ["Qwen2AudioConfig", "Qwen2AudioEncoderConfig"] diff --git a/docs/transformers/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py b/docs/transformers/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py new file mode 100644 index 0000000000000000000000000000000000000000..4496ef73b8cf8b7ec5ca2335d5854f043c2292d9 --- /dev/null +++ b/docs/transformers/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py @@ -0,0 +1,1188 @@ +# coding=utf-8 +# Copyright 2024 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Qwen2Audio model.""" + +import math +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache +from ...generation import GenerationMixin +from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_outputs import BaseModelOutput, ModelOutput +from ...modeling_utils import PreTrainedModel +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from ...utils.deprecation import deprecate_kwarg +from ..auto import AutoModel, AutoModelForCausalLM +from .configuration_qwen2_audio import Qwen2AudioConfig, Qwen2AudioEncoderConfig + + +if is_flash_attn_available(): + from ...modeling_flash_attention_utils import _flash_attention_forward + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "Qwen2AudioConfig" + + +@dataclass +class Qwen2AudioCausalLMOutputWithPast(ModelOutput): + """ + Base class for Qwen2Audio causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Pre-computed hidden-states that can be used to speed up auto-regressive (sequential) decoding. There are + two sets of pre-computed hidden-states: key and values states in the self-attention blocks. + The `past_key_values` are returned when `use_cache=True` is passed or when `config.use_cache=True`. + It is a [`~cache_utils.Cache`] instance. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `input_ids` of shape `(batch_size, sequence_length)`. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + attention_mask (`torch.FloatTensor`, *optional*): + Attentions mask, used to update attention mask and position_ids. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + past_key_values: Optional[Cache] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + attention_mask: Optional[torch.FloatTensor] = None + + +class Qwen2AudioAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + # Copied from transformers.models.whisper.modeling_whisper.WhisperAttention.__init__ with Whisper->Qwen2Audio + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + is_causal: bool = False, + layer_idx: Optional[int] = None, + config: Optional[Qwen2AudioConfig] = None, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + self.config = config + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + self.is_causal = is_causal + + if layer_idx is None and is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + self.layer_idx = layer_idx + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + # Copied from transformers.models.bart.modeling_bart.BartAttention._shape with BART->whisper + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + @deprecate_kwarg("key_value_states", version="4.52") + @deprecate_kwarg("past_key_value", version="4.52") + @deprecate_kwarg("cache_position", version="4.52") + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Cache] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self._shape(self.q_proj(hidden_states) * self.scaling, tgt_len, bsz) + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attn_output = torch.matmul(attn_probs, value_states) + + if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2) + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights, None + + +class Qwen2AudioFlashAttention2(Qwen2AudioAttention): + """ + Qwen2Audio flash attention module. This module inherits from `Qwen2AudioAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + # Copied from transformers.models.whisper.modeling_whisper.WhisperFlashAttention2.__init__ with Whisper->Qwen2Audio + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() + + @deprecate_kwarg("key_value_states", version="4.52") + @deprecate_kwarg("past_key_value", version="4.52") + @deprecate_kwarg("cache_position", version="4.52") + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Cache] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # Qwen2AudioFlashAttention2 attention does not support output_attentions + if output_attentions: + raise ValueError("Qwen2AudioFlashAttention2 attention does not support output_attentions") + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = torch.reshape(self.q_proj(hidden_states), (bsz, tgt_len, self.num_heads, self.head_dim)) + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim] + # We would need to refactor the KV cache to be able to avoid many of these transpose/reshape/view. + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + causal_mask = attention_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, : key_states.shape[-2]] + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + causal_mask, + tgt_len, + dropout=self.dropout if self.training else 0.0, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) + + attn_output = attn_output.reshape(bsz, tgt_len, -1) + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, None + + +class Qwen2AudioSdpaAttention(Qwen2AudioAttention): + @deprecate_kwarg("key_value_states", version="4.52") + @deprecate_kwarg("past_key_value", version="4.52") + @deprecate_kwarg("cache_position", version="4.52") + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Cache] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + if output_attentions or layer_head_mask is not None: + # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "Qwen2AudioModel is using Qwen2AudioSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention" + ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self._shape(self.q_proj(hidden_states), tgt_len, bsz) + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + causal_mask = attention_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. + is_causal = True if self.is_causal and causal_mask is None and tgt_len > 1 else False + + # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, + # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.dropout if self.training else 0.0, + is_causal=is_causal, + ) + + if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, None, None + + +QWEN2AUDIO_ATTENTION_CLASSES = { + "eager": Qwen2AudioAttention, + "flash_attention_2": Qwen2AudioFlashAttention2, + "sdpa": Qwen2AudioSdpaAttention, +} + + +# Copied from transformers.models.whisper.modeling_whisper.WhisperEncoderLayer with Whisper->Qwen2Audio, WHISPER->QWEN2AUDIO +class Qwen2AudioEncoderLayer(nn.Module): + def __init__(self, config: Qwen2AudioConfig): + super().__init__() + self.embed_dim = config.d_model + + self.self_attn = QWEN2AUDIO_ATTENTION_CLASSES[config._attn_implementation]( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + config=config, + ) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + layer_head_mask: torch.Tensor, + output_attentions: bool = False, + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states, attn_weights, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + if hidden_states.dtype == torch.float16: + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +QWEN2AUDIO_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Qwen2AudioConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Qwen2Audio Model outputting raw hidden-states without any specific head on top.", + QWEN2AUDIO_START_DOCSTRING, +) +class Qwen2AudioPreTrainedModel(PreTrainedModel): + config_class = Qwen2AudioConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Qwen2AudioAttention"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + + def _init_weights(self, module): + # important: this ported version of Qwen2Audio isn't meant for training from scratch - only + # inference and fine-tuning - so the proper init weights code has been removed + std = ( + self.config.initializer_range + if hasattr(self.config, "initializer_range") + else self.config.audio_config.initializer_range + ) + + if isinstance(module, (nn.Linear, nn.Conv1d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +QWEN2AUDIOENCODER_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Qwen2AudioEncoderConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + """The audio model from Qwen2Audio without any head or projection on top.""", + QWEN2AUDIOENCODER_START_DOCSTRING, +) +# Copied from transformers.models.whisper.modeling_whisper.WhisperEncoder with Whisper->Qwen2Audio +class Qwen2AudioEncoder(Qwen2AudioPreTrainedModel): + """ + Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a + [`Qwen2AudioEncoderLayer`]. + + Args: + config: Qwen2AudioEncoderConfig + """ + + # Ignore copy + config_class = Qwen2AudioEncoderConfig + main_input_name = "input_features" + _no_split_modules = ["Qwen2AudioEncoderLayer"] + + def __init__(self, config: Qwen2AudioEncoderConfig): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.encoder_layerdrop + + embed_dim = config.d_model + self.num_mel_bins = config.num_mel_bins + self.padding_idx = config.pad_token_id + self.max_source_positions = config.max_source_positions + self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + + self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1) + self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1) + + self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim) + self.embed_positions.requires_grad_(False) + + self.layers = nn.ModuleList([Qwen2AudioEncoderLayer(config) for _ in range(config.encoder_layers)]) + self.layer_norm = nn.LayerNorm(config.d_model) + # Ignore copy + self.avg_pooler = nn.AvgPool1d(2, stride=2) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def _freeze_parameters(self): + for param in self.parameters(): + param.requires_grad = False + self._requires_grad = False + + def get_input_embeddings(self) -> nn.Module: + return self.conv1 + + def set_input_embeddings(self, value: nn.Module): + self.conv1 = value + + def forward( + self, + input_features, + attention_mask=None, + head_mask=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + input_features (`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`): + Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be + obtained by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a + `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into + `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding + and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] + attention_mask (`torch.Tensor`)`, *optional*): + Qwen2Audio does not support masking of the `input_features`, this argument is preserved for compatibility, + but it is not used. By default the silence in the input log mel spectrogram are ignored. + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + + expected_seq_length = self.config.max_source_positions * self.conv1.stride[0] * self.conv2.stride[0] + if input_features.shape[-1] != expected_seq_length: + raise ValueError( + f"Qwen2Audio expects the mel input features to be of length {expected_seq_length}, but found {input_features.shape[-1]}. Make sure to pad the input mel features to {expected_seq_length}." + ) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Ignore copy + input_features = input_features.to(dtype=self.conv1.weight.dtype, device=self.conv1.weight.device) + + inputs_embeds = nn.functional.gelu(self.conv1(input_features)) + inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) + + inputs_embeds = inputs_embeds.permute(0, 2, 1) + embed_pos = self.embed_positions.weight + + hidden_states = inputs_embeds + embed_pos + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + assert head_mask.size()[0] == (len(self.layers)), ( + f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." + ) + + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + to_drop = False + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: # skip the layer + to_drop = True + + # Ignore copy + if to_drop: + layer_outputs = (None, None) + else: + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + attention_mask, + (head_mask[idx] if head_mask is not None else None), + output_attentions, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + # Ignore copy + hidden_states = hidden_states.permute(0, 2, 1) + hidden_states = self.avg_pooler(hidden_states) + hidden_states = hidden_states.permute(0, 2, 1) + + hidden_states = self.layer_norm(hidden_states) + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + # Ignore copy + def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): + """ + Computes the output length of the convolutional layers and the output length of the audio encoder + """ + input_lengths = (input_lengths - 1) // 2 + 1 + output_lengths = (input_lengths - 2) // 2 + 1 + return input_lengths, output_lengths + + +class Qwen2AudioMultiModalProjector(nn.Module): + def __init__(self, config: Qwen2AudioConfig): + super().__init__() + self.linear = nn.Linear(config.audio_config.d_model, config.text_config.hidden_size, bias=True) + + def forward(self, audio_features): + hidden_states = self.linear(audio_features) + return hidden_states + + +QWEN2AUDIO_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + input_features (`torch.FloatTensor` of shape `(batch_size, feature_size, feature_sequence_length)`): + Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by + loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via + the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the + [`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a + tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + feature_attention_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`): + Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Pre-computed hidden-states that can be used to speed up auto-regressive (sequential) decoding. There are + two sets of pre-computed hidden-states: key and values states in the self-attention blocks. + The `past_key_values` are returned when `use_cache=True` is passed or when `config.use_cache=True`. + It is a [`~cache_utils.Cache`] instance. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `input_ids` of shape `(batch_size, sequence_length)`.shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + """The QWEN2AUDIO model which consists of a audio backbone and a language model.""", + QWEN2AUDIO_START_DOCSTRING, +) +class Qwen2AudioForConditionalGeneration(Qwen2AudioPreTrainedModel, GenerationMixin): + def __init__(self, config: Qwen2AudioConfig): + super().__init__(config) + self.audio_tower = AutoModel.from_config(config.audio_config) # Usually a `Qwen2AudioEncoder` instance + + self.multi_modal_projector = Qwen2AudioMultiModalProjector(config) + self.vocab_size = config.text_config.vocab_size + self.language_model = AutoModelForCausalLM.from_config(config.text_config) + if self.language_model._tied_weights_keys is not None: + self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys] + + self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 + self._padding_side = "left" # set it to left by default, user can use setter to change padding_sides + self.post_init() + + @property + def padding_side(self): + return self._padding_side + + @padding_side.setter + def padding_side(self, padding_side: str): + if padding_side not in ["left", "right"]: + raise ValueError(f"{padding_side} is not `left` or `right`.") + self._padding_side = padding_side + + # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_input_embeddings + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_input_embeddings + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_output_embeddings + def get_output_embeddings(self): + return self.language_model.get_output_embeddings() + + # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_output_embeddings + def set_output_embeddings(self, new_embeddings): + self.language_model.set_output_embeddings(new_embeddings) + + # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_decoder + def set_decoder(self, decoder): + self.language_model.set_decoder(decoder) + + # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_decoder + def get_decoder(self): + return self.language_model.get_decoder() + + def _merge_input_ids_with_audio_features( + self, audio_features, num_audio_tokens, inputs_embeds, input_ids, attention_mask, labels + ): + """ + Merge input_ids with with audio features into final embeddings + + Args: + audio_features (`torch.Tensor` of shape `(num_audios, max_audio_tokens, embed_dim)`): + All audio vectors of all audios in the batch + num_audio_tokens (`torch.LongTensor` of shape `(num_audios)`): + The length of audio embeddings of each audio as stacked in `audio_features` + inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, embed_dim)`): + Token embeddings before merging with audio embeddings + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Input_ids of tokens, possibly filled with audio token + attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Mask to avoid performing attention on padding token indices. + labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*) + labels need to be recalculated to support training (if provided) + Returns: + final_embedding, final_attention_mask, final_labels, position_ids, final_input_ids + + Explanation: + each audio has variable length embeddings, with length specified by num_audio_tokens + audio_features is concatenation of all audio embed vectors + task: fill each <|AUDIO|> with the correct number of audio embeddings + Example: + X (5 tokens), Y (3 tokens), Z (8 tokens) + X, Y are in the same sequence (in-context learning) + if right padding + input_ids: [ + a b c d e f X g h i j k Y l m + o p q r Z s t u v _ _ _ _ _ _ + ] + input_ids should be: [ + a b c d e f X X X X X g h i j k Y Y Y l m + o p q r Z Z Z Z Z Z Z Z s t u v _ _ _ _ _ + ] + labels should be: [ + a b c d e f _ _ _ _ _ g h i j k _ _ _ l m + o p q r _ _ _ _ _ _ _ _ s t u v _ _ _ _ _ + ] + elif left padding + input_ids: [ + a b c d e f X g h i j k Y l m + _ _ _ _ _ _ o p q r Z s t u v + ] + input_ids should be: [ + a b c d e f X X X X X g h i j k Y Y Y l m + _ _ _ _ _ o p q r Z Z Z Z Z Z Z Z s t u v + ] + labels should be: [ + a b c d e f _ _ _ _ _ g h i j k _ _ _ l m + _ _ _ _ _ o p q r _ _ _ _ _ _ _ _ s t u v + ] + Edge cases: + * If tokens are same but audio token sizes are different, then cannot infer left or right padding + ```python + url1 = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/glass-breaking-151256.mp3" + audio1, _ = librosa.load(BytesIO(urlopen(url1).read()), sr=processor.feature_extractor.sampling_rate) + url2 = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/f2641_0_throatclearing.wav" + audio2, _ = librosa.load(BytesIO(urlopen(url2).read()), sr=processor.feature_extractor.sampling_rate) + prompts = [ + "[INST] <|AUDIO|>\nWhat is that in this audio? [/INST]", + "[INST] <|AUDIO|>\nWhat is that in this audio? [/INST]", + ] + inputs = processor(text=prompts, audios=[audio1, audio2], return_tensors='pt', padding=True).to("cuda") + audio1 has 101 tokens, while audio2 has 72 tokens + ``` + + input_ids: [ + a b c d X g h + i j Y k l m n + ] + where X is 3 tokens while Y is 5, this mean after merge + if left-padding (batched generation) + input_ids should be: [ + _ _ a b c d X X X g h + i j Y Y Y Y Y k l m n + ] + elif (right padding) (training) + input_ids should be: [ + a b c d X X X g h _ _ + i j Y Y Y Y Y k l m n + ] + """ + num_audios, max_audio_tokens, embed_dim = audio_features.shape + audio_features_mask = torch.arange(max_audio_tokens).expand(num_audios, max_audio_tokens).to( + num_audio_tokens.device + ) < num_audio_tokens.unsqueeze(1) + masked_audio_features = audio_features[audio_features_mask].view(-1, embed_dim) + batch_size, sequence_length = input_ids.shape + _left_padding = torch.any(attention_mask[:, 0] == 0) + _right_padding = torch.any(attention_mask[:, -1] == 0) + + left_padding = True + if batch_size > 1: + if _left_padding and not _right_padding: + left_padding = True + elif not _left_padding and _right_padding: + left_padding = False + elif not _left_padding and not _right_padding: + # both side is 1, so cannot tell + left_padding = self.padding_side == "left" + else: + # invalid attention_mask + raise ValueError(f"both side of attention_mask has zero, invalid. {attention_mask}") + + # 1. Create a mask to know where special audio tokens are + special_audio_token_mask = input_ids == self.config.audio_token_id + num_special_audio_tokens = torch.sum(special_audio_token_mask, dim=-1) + + # In case the Audio model or the Language model has been offloaded to CPU, we need to manually + # set the corresponding tensors into their correct target device. + target_device = inputs_embeds.device + attention_mask = attention_mask.to(target_device) + input_ids = input_ids.to(target_device) + num_audio_tokens = num_audio_tokens.to(target_device) + batch_indices, non_audio_indices = torch.where( + (input_ids != self.config.audio_token_id) & (attention_mask == 1) + ) + + # 2. Compute the positions where text should be written + # Calculate new positions for text tokens in merged audio-text sequence. + # `special_audio_token_mask` identifies audio tokens. Each audio token will be replaced by `audio_feat_lengths - 1` text tokens. + # `torch.cumsum` computes how each audio token shifts subsequent text token positions. + token_placeholder_num = torch.zeros_like(input_ids) + token_placeholder_num[special_audio_token_mask] = num_audio_tokens.long() - 1 + token_placeholder_num = token_placeholder_num + 1 + new_token_positions = torch.cumsum(token_placeholder_num, -1) - 1 + max_token_num = token_placeholder_num.sum(-1).max() + nb_audio_pad = max_token_num - 1 - new_token_positions[:, -1] + if left_padding: + new_token_positions += nb_audio_pad[:, None] # offset for left padding + text_to_overwrite = new_token_positions[batch_indices, non_audio_indices] + batch_indices, non_audio_indices, text_to_overwrite = ( + batch_indices.to(target_device), + non_audio_indices.to(target_device), + text_to_overwrite.to(target_device), + ) + + # 3. Create the full embedding, already padded to the maximum position + final_embedding = torch.zeros( + batch_size, max_token_num, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device + ) + final_attention_mask = torch.zeros( + batch_size, max_token_num, dtype=attention_mask.dtype, device=inputs_embeds.device + ) + final_input_ids = torch.full( + (batch_size, max_token_num), self.pad_token_id, dtype=input_ids.dtype, device=inputs_embeds.device + ) + + # 4. Fill the embeddings based on the mask. If we have ["hey" "", + unk_token="", + additional_special_tokens=[], + sp_model_kwargs: Optional[Dict[str, Any]] = None, + **kwargs, + ) -> None: + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + + self.vocab_file = vocab_file + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(vocab_file) + + super().__init__( + eos_token=eos_token, + unk_token=unk_token, + additional_special_tokens=additional_special_tokens, + sp_model_kwargs=self.sp_model_kwargs, + **kwargs, + ) + + @property + def vocab_size(self): + return self.sp_model.get_piece_size() + + def get_vocab(self) -> Dict[str, int]: + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + def __getstate__(self): + state = self.__dict__.copy() + state["sp_model"] = None + return state + + def __setstate__(self, d): + self.__dict__ = d + + # for backward compatibility + if not hasattr(self, "sp_model_kwargs"): + self.sp_model_kwargs = {} + + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(self.vocab_file) + + def _tokenize(self, text: str) -> List[str]: + """Take as input a string and return a list of strings (tokens) for words/sub-words""" + return self.sp_model.encode(text, out_type=str) + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.sp_model.piece_to_id(token) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + if index < self.sp_model.get_piece_size(): + token = self.sp_model.IdToPiece(index) + return token + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + current_sub_tokens = [] + out_string = "" + for token in tokens: + # make sure that special tokens are not decoded using sentencepiece model + if token in self.all_special_tokens: + out_string += self.sp_model.decode(current_sub_tokens) + token + current_sub_tokens = [] + else: + current_sub_tokens.append(token) + out_string += self.sp_model.decode(current_sub_tokens) + return out_string.strip() + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): + copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + + return (out_vocab_file,) + + +__all__ = ["ReformerTokenizer"] diff --git a/docs/transformers/src/transformers/models/reformer/tokenization_reformer_fast.py b/docs/transformers/src/transformers/models/reformer/tokenization_reformer_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..a48441c55e5a787ece9cc7507f1c9746f72fded4 --- /dev/null +++ b/docs/transformers/src/transformers/models/reformer/tokenization_reformer_fast.py @@ -0,0 +1,118 @@ +# coding=utf-8 +# Copyright 2020 The Trax Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization class for model Reformer.""" + +import os +from shutil import copyfile +from typing import Optional, Tuple + +from ...tokenization_utils_fast import PreTrainedTokenizerFast +from ...utils import is_sentencepiece_available, logging + + +if is_sentencepiece_available(): + from .tokenization_reformer import ReformerTokenizer +else: + ReformerTokenizer = None + + +logger = logging.get_logger(__name__) + + +SPIECE_UNDERLINE = "▁" + +VOCAB_FILES_NAMES = {"vocab_file": "spiece.model", "tokenizer_file": "tokenizer.json"} + + +class ReformerTokenizerFast(PreTrainedTokenizerFast): + """ + Construct a "fast" Reformer tokenizer (backed by HuggingFace's *tokenizers* library). Based on + [Unigram](https://huggingface.co/docs/tokenizers/python/latest/components.html?highlight=unigram#models). + + This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should + refer to this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that + contains the vocabulary necessary to instantiate a tokenizer. + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + additional_special_tokens (`List[str]`, *optional*): + Additional special tokens used by the tokenizer. + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + slow_tokenizer_class = ReformerTokenizer + + def __init__( + self, + vocab_file=None, + tokenizer_file=None, + eos_token="", + unk_token="", + additional_special_tokens=[], + **kwargs, + ): + super().__init__( + vocab_file, + tokenizer_file=tokenizer_file, + eos_token=eos_token, + unk_token=unk_token, + additional_special_tokens=additional_special_tokens, + **kwargs, + ) + + self.vocab_file = vocab_file + + @property + def can_save_slow_tokenizer(self) -> bool: + return os.path.isfile(self.vocab_file) if self.vocab_file else False + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + if not self.can_save_slow_tokenizer: + raise ValueError( + "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow " + "tokenizer." + ) + + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): + copyfile(self.vocab_file, out_vocab_file) + + return (out_vocab_file,) + + +__all__ = ["ReformerTokenizerFast"] diff --git a/docs/transformers/src/transformers/models/regnet/__init__.py b/docs/transformers/src/transformers/models/regnet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cac770fdd0bcf669b4f723dbbd88afcbc098cd12 --- /dev/null +++ b/docs/transformers/src/transformers/models/regnet/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_regnet import * + from .modeling_flax_regnet import * + from .modeling_regnet import * + from .modeling_tf_regnet import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/regnet/configuration_regnet.py b/docs/transformers/src/transformers/models/regnet/configuration_regnet.py new file mode 100644 index 0000000000000000000000000000000000000000..c089f67297aa7764de4ccb54b2ba485d63f5a924 --- /dev/null +++ b/docs/transformers/src/transformers/models/regnet/configuration_regnet.py @@ -0,0 +1,94 @@ +# coding=utf-8 +# Copyright 2022 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""RegNet model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class RegNetConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`RegNetModel`]. It is used to instantiate a RegNet + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the RegNet + [facebook/regnet-y-040](https://huggingface.co/facebook/regnet-y-040) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + embedding_size (`int`, *optional*, defaults to 64): + Dimensionality (hidden size) for the embedding layer. + hidden_sizes (`List[int]`, *optional*, defaults to `[256, 512, 1024, 2048]`): + Dimensionality (hidden size) at each stage. + depths (`List[int]`, *optional*, defaults to `[3, 4, 6, 3]`): + Depth (number of layers) for each stage. + layer_type (`str`, *optional*, defaults to `"y"`): + The layer to use, it can be either `"x" or `"y"`. An `x` layer is a ResNet's BottleNeck layer with + `reduction` fixed to `1`. While a `y` layer is a `x` but with squeeze and excitation. Please refer to the + paper for a detailed explanation of how these layers were constructed. + hidden_act (`str`, *optional*, defaults to `"relu"`): + The non-linear activation function in each block. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` + are supported. + downsample_in_first_stage (`bool`, *optional*, defaults to `False`): + If `True`, the first stage will downsample the inputs using a `stride` of 2. + + Example: + ```python + >>> from transformers import RegNetConfig, RegNetModel + + >>> # Initializing a RegNet regnet-y-40 style configuration + >>> configuration = RegNetConfig() + >>> # Initializing a model from the regnet-y-40 style configuration + >>> model = RegNetModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "regnet" + layer_types = ["x", "y"] + + def __init__( + self, + num_channels=3, + embedding_size=32, + hidden_sizes=[128, 192, 512, 1088], + depths=[2, 6, 12, 2], + groups_width=64, + layer_type="y", + hidden_act="relu", + **kwargs, + ): + super().__init__(**kwargs) + if layer_type not in self.layer_types: + raise ValueError(f"layer_type={layer_type} is not one of {','.join(self.layer_types)}") + self.num_channels = num_channels + self.embedding_size = embedding_size + self.hidden_sizes = hidden_sizes + self.depths = depths + self.groups_width = groups_width + self.layer_type = layer_type + self.hidden_act = hidden_act + # always downsample in the first stage + self.downsample_in_first_stage = True + + +__all__ = ["RegNetConfig"] diff --git a/docs/transformers/src/transformers/models/regnet/convert_regnet_seer_10b_to_pytorch.py b/docs/transformers/src/transformers/models/regnet/convert_regnet_seer_10b_to_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..95bd5b854282e37f36677fc94c2fc27f89efa7bd --- /dev/null +++ b/docs/transformers/src/transformers/models/regnet/convert_regnet_seer_10b_to_pytorch.py @@ -0,0 +1,308 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert RegNet 10B checkpoints vissl.""" +# You need to install a specific version of classy vision +# pip install git+https://github.com/FrancescoSaverioZuppichini/ClassyVision.git@convert_weights + +import argparse +import json +import os +import re +from collections import OrderedDict +from dataclasses import dataclass, field +from functools import partial +from pathlib import Path +from pprint import pprint +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +from classy_vision.models.regnet import RegNet, RegNetParams +from huggingface_hub import hf_hub_download +from torch import Tensor +from vissl.models.model_helpers import get_trunk_forward_outputs + +from transformers import AutoImageProcessor, RegNetConfig, RegNetForImageClassification, RegNetModel +from transformers.modeling_utils import _load_state_dict_into_meta_model, load_state_dict +from transformers.utils import logging + + +logging.set_verbosity_info() +logger = logging.get_logger() + + +@dataclass +class Tracker: + module: nn.Module + traced: List[nn.Module] = field(default_factory=list) + handles: list = field(default_factory=list) + name2module: Dict[str, nn.Module] = field(default_factory=OrderedDict) + + def _forward_hook(self, m, inputs: Tensor, outputs: Tensor, name: str): + has_not_submodules = len(list(m.modules())) == 1 or isinstance(m, nn.Conv2d) or isinstance(m, nn.BatchNorm2d) + if has_not_submodules: + self.traced.append(m) + self.name2module[name] = m + + def __call__(self, x: Tensor): + for name, m in self.module.named_modules(): + self.handles.append(m.register_forward_hook(partial(self._forward_hook, name=name))) + self.module(x) + [x.remove() for x in self.handles] + return self + + @property + def parametrized(self): + # check the len of the state_dict keys to see if we have learnable params + return {k: v for k, v in self.name2module.items() if len(list(v.state_dict().keys())) > 0} + + +class FakeRegNetVisslWrapper(nn.Module): + """ + Fake wrapper for RegNet that mimics what vissl does without the need to pass a config file. + """ + + def __init__(self, model: nn.Module): + super().__init__() + + feature_blocks: List[Tuple[str, nn.Module]] = [] + # - get the stem + feature_blocks.append(("conv1", model.stem)) + # - get all the feature blocks + for k, v in model.trunk_output.named_children(): + assert k.startswith("block"), f"Unexpected layer name {k}" + block_index = len(feature_blocks) + 1 + feature_blocks.append((f"res{block_index}", v)) + + self._feature_blocks = nn.ModuleDict(feature_blocks) + + def forward(self, x: Tensor): + return get_trunk_forward_outputs( + x, + out_feat_keys=None, + feature_blocks=self._feature_blocks, + ) + + +class FakeRegNetParams(RegNetParams): + """ + Used to instantiace a RegNet model from classy vision with the same depth as the 10B one but with super small + parameters, so we can trace it in memory. + """ + + def get_expanded_params(self): + return [(8, 2, 2, 8, 1.0), (8, 2, 7, 8, 1.0), (8, 2, 17, 8, 1.0), (8, 2, 1, 8, 1.0)] + + +def get_from_to_our_keys(model_name: str) -> Dict[str, str]: + """ + Returns a dictionary that maps from original model's key -> our implementation's keys + """ + + # create our model (with small weights) + our_config = RegNetConfig(depths=[2, 7, 17, 1], hidden_sizes=[8, 8, 8, 8], groups_width=8) + if "in1k" in model_name: + our_model = RegNetForImageClassification(our_config) + else: + our_model = RegNetModel(our_config) + # create from model (with small weights) + from_model = FakeRegNetVisslWrapper( + RegNet(FakeRegNetParams(depth=27, group_width=1010, w_0=1744, w_a=620.83, w_m=2.52)) + ) + + with torch.no_grad(): + from_model = from_model.eval() + our_model = our_model.eval() + + x = torch.randn((1, 3, 32, 32)) + # trace both + dest_tracker = Tracker(our_model) + dest_traced = dest_tracker(x).parametrized + + pprint(dest_tracker.name2module) + src_tracker = Tracker(from_model) + src_traced = src_tracker(x).parametrized + + # convert the keys -> module dict to keys -> params + def to_params_dict(dict_with_modules): + params_dict = OrderedDict() + for name, module in dict_with_modules.items(): + for param_name, param in module.state_dict().items(): + params_dict[f"{name}.{param_name}"] = param + return params_dict + + from_to_ours_keys = {} + + src_state_dict = to_params_dict(src_traced) + dst_state_dict = to_params_dict(dest_traced) + + for (src_key, src_param), (dest_key, dest_param) in zip(src_state_dict.items(), dst_state_dict.items()): + from_to_ours_keys[src_key] = dest_key + logger.info(f"{src_key} -> {dest_key}") + # if "in1k" was in the model_name it means it must have a classification head (was finetuned) + if "in1k" in model_name: + from_to_ours_keys["0.clf.0.weight"] = "classifier.1.weight" + from_to_ours_keys["0.clf.0.bias"] = "classifier.1.bias" + + return from_to_ours_keys + + +def convert_weights_and_push(save_directory: Path, model_name: Optional[str] = None, push_to_hub: bool = True): + filename = "imagenet-1k-id2label.json" + num_labels = 1000 + + repo_id = "huggingface/label-files" + num_labels = num_labels + id2label = json.loads(Path(hf_hub_download(repo_id, filename, repo_type="dataset")).read_text()) + id2label = {int(k): v for k, v in id2label.items()} + + id2label = id2label + label2id = {v: k for k, v in id2label.items()} + + ImageNetPreTrainedConfig = partial(RegNetConfig, num_labels=num_labels, id2label=id2label, label2id=label2id) + + names_to_config = { + "regnet-y-10b-seer": ImageNetPreTrainedConfig( + depths=[2, 7, 17, 1], hidden_sizes=[2020, 4040, 11110, 28280], groups_width=1010 + ), + # finetuned on imagenet + "regnet-y-10b-seer-in1k": ImageNetPreTrainedConfig( + depths=[2, 7, 17, 1], hidden_sizes=[2020, 4040, 11110, 28280], groups_width=1010 + ), + } + + # add seer weights logic + def load_using_classy_vision(checkpoint_url: str) -> Tuple[Dict, Dict]: + files = torch.hub.load_state_dict_from_url(checkpoint_url, model_dir=str(save_directory), map_location="cpu") + # check if we have a head, if yes add it + model_state_dict = files["classy_state_dict"]["base_model"]["model"] + return model_state_dict["trunk"], model_state_dict["heads"] + + names_to_from_model = { + "regnet-y-10b-seer": partial( + load_using_classy_vision, + "https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_regnet10B/model_iteration124500_conso.torch", + ), + "regnet-y-10b-seer-in1k": partial( + load_using_classy_vision, + "https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_finetuned/seer_10b_finetuned_in1k_model_phase28_conso.torch", + ), + } + + from_to_ours_keys = get_from_to_our_keys(model_name) + + if not (save_directory / f"{model_name}.pth").exists(): + logger.info("Loading original state_dict.") + from_state_dict_trunk, from_state_dict_head = names_to_from_model[model_name]() + from_state_dict = from_state_dict_trunk + if "in1k" in model_name: + # add the head + from_state_dict = {**from_state_dict_trunk, **from_state_dict_head} + logger.info("Done!") + + converted_state_dict = {} + + not_used_keys = list(from_state_dict.keys()) + regex = r"\.block.-part." + # this is "interesting", so the original checkpoints have `block[0,1]-part` in each key name, we remove it + for key in from_state_dict.keys(): + # remove the weird "block[0,1]-part" from the key + src_key = re.sub(regex, "", key) + # now src_key from the model checkpoints is the one we got from the original model after tracing, so use it to get the correct destination key + dest_key = from_to_ours_keys[src_key] + # store the parameter with our key + converted_state_dict[dest_key] = from_state_dict[key] + not_used_keys.remove(key) + # check that all keys have been updated + assert len(not_used_keys) == 0, f"Some keys where not used {','.join(not_used_keys)}" + + logger.info(f"The following keys were not used: {','.join(not_used_keys)}") + + # save our state dict to disk + torch.save(converted_state_dict, save_directory / f"{model_name}.pth") + + del converted_state_dict + else: + logger.info("The state_dict was already stored on disk.") + if push_to_hub: + logger.info(f"Token is {os.environ['HF_TOKEN']}") + logger.info("Loading our model.") + # create our model + our_config = names_to_config[model_name] + our_model_func = RegNetModel + if "in1k" in model_name: + our_model_func = RegNetForImageClassification + with torch.device("meta"): + our_model = our_model_func(our_config) + logger.info("Loading state_dict in our model.") + # load state dict + state_dict_keys = our_model.state_dict().keys() + state_dict = load_state_dict(save_directory / f"{model_name}.pth", weights_only=True) + fixed_state_dict = state_dict = {our_model._fix_state_dict_key_on_load(k)[0]: v for k, v in state_dict.items()} + _load_state_dict_into_meta_model( + our_model, + fixed_state_dict, + start_prefix="", + expected_keys=state_dict_keys, + ) + logger.info("Finally, pushing!") + # push it to hub + our_model.push_to_hub( + repo_path_or_name=save_directory / model_name, + commit_message="Add model", + output_dir=save_directory / model_name, + ) + size = 384 + # we can use the convnext one + image_processor = AutoImageProcessor.from_pretrained("facebook/convnext-base-224-22k-1k", size=size) + image_processor.push_to_hub( + repo_path_or_name=save_directory / model_name, + commit_message="Add image processor", + output_dir=save_directory / model_name, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--model_name", + default=None, + type=str, + help=( + "The name of the model you wish to convert, it must be one of the supported regnet* architecture," + " currently: regnetx-*, regnety-*. If `None`, all of them will the converted." + ), + ) + parser.add_argument( + "--pytorch_dump_folder_path", + default=None, + type=Path, + required=True, + help="Path to the output PyTorch model directory.", + ) + parser.add_argument( + "--push_to_hub", + default=True, + type=bool, + required=False, + help="If True, push model and image processor to the hub.", + ) + + args = parser.parse_args() + + pytorch_dump_folder_path: Path = args.pytorch_dump_folder_path + pytorch_dump_folder_path.mkdir(exist_ok=True, parents=True) + convert_weights_and_push(pytorch_dump_folder_path, args.model_name, args.push_to_hub) diff --git a/docs/transformers/src/transformers/models/regnet/convert_regnet_to_pytorch.py b/docs/transformers/src/transformers/models/regnet/convert_regnet_to_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..9544400416bd97e5011cc6c7640069ce10b7e27f --- /dev/null +++ b/docs/transformers/src/transformers/models/regnet/convert_regnet_to_pytorch.py @@ -0,0 +1,458 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert RegNet checkpoints from timm and vissl.""" + +import argparse +import json +from dataclasses import dataclass, field +from functools import partial +from pathlib import Path +from typing import Callable, Dict, List, Optional, Tuple + +import timm +import torch +import torch.nn as nn +from classy_vision.models.regnet import RegNet, RegNetParams, RegNetY32gf, RegNetY64gf, RegNetY128gf +from huggingface_hub import hf_hub_download +from torch import Tensor +from vissl.models.model_helpers import get_trunk_forward_outputs + +from transformers import AutoImageProcessor, RegNetConfig, RegNetForImageClassification, RegNetModel +from transformers.utils import logging + + +logging.set_verbosity_info() +logger = logging.get_logger() + + +@dataclass +class Tracker: + module: nn.Module + traced: List[nn.Module] = field(default_factory=list) + handles: list = field(default_factory=list) + + def _forward_hook(self, m, inputs: Tensor, outputs: Tensor): + has_not_submodules = len(list(m.modules())) == 1 or isinstance(m, nn.Conv2d) or isinstance(m, nn.BatchNorm2d) + if has_not_submodules: + self.traced.append(m) + + def __call__(self, x: Tensor): + for m in self.module.modules(): + self.handles.append(m.register_forward_hook(self._forward_hook)) + self.module(x) + [x.remove() for x in self.handles] + return self + + @property + def parametrized(self): + # check the len of the state_dict keys to see if we have learnable params + return list(filter(lambda x: len(list(x.state_dict().keys())) > 0, self.traced)) + + +@dataclass +class ModuleTransfer: + src: nn.Module + dest: nn.Module + verbose: int = 1 + src_skip: List = field(default_factory=list) + dest_skip: List = field(default_factory=list) + raise_if_mismatch: bool = True + + def __call__(self, x: Tensor): + """ + Transfer the weights of `self.src` to `self.dest` by performing a forward pass using `x` as input. Under the + hood we tracked all the operations in both modules. + """ + dest_traced = Tracker(self.dest)(x).parametrized + src_traced = Tracker(self.src)(x).parametrized + + src_traced = list(filter(lambda x: type(x) not in self.src_skip, src_traced)) + dest_traced = list(filter(lambda x: type(x) not in self.dest_skip, dest_traced)) + + if len(dest_traced) != len(src_traced) and self.raise_if_mismatch: + raise Exception( + f"Numbers of operations are different. Source module has {len(src_traced)} operations while" + f" destination module has {len(dest_traced)}." + ) + + for dest_m, src_m in zip(dest_traced, src_traced): + dest_m.load_state_dict(src_m.state_dict()) + if self.verbose == 1: + print(f"Transfered from={src_m} to={dest_m}") + + +class FakeRegNetVisslWrapper(nn.Module): + """ + Fake wrapper for RegNet that mimics what vissl does without the need to pass a config file. + """ + + def __init__(self, model: nn.Module): + super().__init__() + + feature_blocks: List[Tuple[str, nn.Module]] = [] + # - get the stem + feature_blocks.append(("conv1", model.stem)) + # - get all the feature blocks + for k, v in model.trunk_output.named_children(): + assert k.startswith("block"), f"Unexpected layer name {k}" + block_index = len(feature_blocks) + 1 + feature_blocks.append((f"res{block_index}", v)) + + self._feature_blocks = nn.ModuleDict(feature_blocks) + + def forward(self, x: Tensor): + return get_trunk_forward_outputs( + x, + out_feat_keys=None, + feature_blocks=self._feature_blocks, + ) + + +class NameToFromModelFuncMap(dict): + """ + A Dictionary with some additional logic to return a function that creates the correct original model. + """ + + def convert_name_to_timm(self, x: str) -> str: + x_split = x.split("-") + return x_split[0] + x_split[1] + "_" + "".join(x_split[2:]) + + def __getitem__(self, x: str) -> Callable[[], Tuple[nn.Module, Dict]]: + # default to timm! + if x not in self: + x = self.convert_name_to_timm(x) + val = partial(lambda: (timm.create_model(x, pretrained=True).eval(), None)) + + else: + val = super().__getitem__(x) + + return val + + +class NameToOurModelFuncMap(dict): + """ + A Dictionary with some additional logic to return the correct hugging face RegNet class reference. + """ + + def __getitem__(self, x: str) -> Callable[[], nn.Module]: + if "seer" in x and "in1k" not in x: + val = RegNetModel + else: + val = RegNetForImageClassification + return val + + +def manually_copy_vissl_head(from_state_dict, to_state_dict, keys: List[Tuple[str, str]]): + for from_key, to_key in keys: + to_state_dict[to_key] = from_state_dict[from_key].clone() + print(f"Copied key={from_key} to={to_key}") + return to_state_dict + + +def convert_weight_and_push( + name: str, + from_model_func: Callable[[], nn.Module], + our_model_func: Callable[[], nn.Module], + config: RegNetConfig, + save_directory: Path, + push_to_hub: bool = True, +): + print(f"Converting {name}...") + with torch.no_grad(): + from_model, from_state_dict = from_model_func() + our_model = our_model_func(config).eval() + module_transfer = ModuleTransfer(src=from_model, dest=our_model, raise_if_mismatch=False) + x = torch.randn((1, 3, 224, 224)) + module_transfer(x) + + if from_state_dict is not None: + keys = [] + # for seer - in1k finetuned we have to manually copy the head + if "seer" in name and "in1k" in name: + keys = [("0.clf.0.weight", "classifier.1.weight"), ("0.clf.0.bias", "classifier.1.bias")] + to_state_dict = manually_copy_vissl_head(from_state_dict, our_model.state_dict(), keys) + our_model.load_state_dict(to_state_dict) + + our_outputs = our_model(x, output_hidden_states=True) + our_output = ( + our_outputs.logits if isinstance(our_model, RegNetForImageClassification) else our_outputs.last_hidden_state + ) + + from_output = from_model(x) + from_output = from_output[-1] if isinstance(from_output, list) else from_output + + # now since I don't want to use any config files, vissl seer model doesn't actually have an head, so let's just check the last hidden state + if "seer" in name and "in1k" in name: + our_output = our_outputs.hidden_states[-1] + + assert torch.allclose(from_output, our_output), "The model logits don't match the original one." + + if push_to_hub: + our_model.push_to_hub( + repo_path_or_name=save_directory / name, + commit_message="Add model", + use_temp_dir=True, + ) + + size = 224 if "seer" not in name else 384 + # we can use the convnext one + image_processor = AutoImageProcessor.from_pretrained("facebook/convnext-base-224-22k-1k", size=size) + image_processor.push_to_hub( + repo_path_or_name=save_directory / name, + commit_message="Add image processor", + use_temp_dir=True, + ) + + print(f"Pushed {name}") + + +def convert_weights_and_push(save_directory: Path, model_name: Optional[str] = None, push_to_hub: bool = True): + filename = "imagenet-1k-id2label.json" + num_labels = 1000 + expected_shape = (1, num_labels) + + repo_id = "huggingface/label-files" + num_labels = num_labels + id2label = json.loads(Path(hf_hub_download(repo_id, filename, repo_type="dataset")).read_text()) + id2label = {int(k): v for k, v in id2label.items()} + + id2label = id2label + label2id = {v: k for k, v in id2label.items()} + + ImageNetPreTrainedConfig = partial(RegNetConfig, num_labels=num_labels, id2label=id2label, label2id=label2id) + + names_to_config = { + "regnet-x-002": ImageNetPreTrainedConfig( + depths=[1, 1, 4, 7], hidden_sizes=[24, 56, 152, 368], groups_width=8, layer_type="x" + ), + "regnet-x-004": ImageNetPreTrainedConfig( + depths=[1, 2, 7, 12], hidden_sizes=[32, 64, 160, 384], groups_width=16, layer_type="x" + ), + "regnet-x-006": ImageNetPreTrainedConfig( + depths=[1, 3, 5, 7], hidden_sizes=[48, 96, 240, 528], groups_width=24, layer_type="x" + ), + "regnet-x-008": ImageNetPreTrainedConfig( + depths=[1, 3, 7, 5], hidden_sizes=[64, 128, 288, 672], groups_width=16, layer_type="x" + ), + "regnet-x-016": ImageNetPreTrainedConfig( + depths=[2, 4, 10, 2], hidden_sizes=[72, 168, 408, 912], groups_width=24, layer_type="x" + ), + "regnet-x-032": ImageNetPreTrainedConfig( + depths=[2, 6, 15, 2], hidden_sizes=[96, 192, 432, 1008], groups_width=48, layer_type="x" + ), + "regnet-x-040": ImageNetPreTrainedConfig( + depths=[2, 5, 14, 2], hidden_sizes=[80, 240, 560, 1360], groups_width=40, layer_type="x" + ), + "regnet-x-064": ImageNetPreTrainedConfig( + depths=[2, 4, 10, 1], hidden_sizes=[168, 392, 784, 1624], groups_width=56, layer_type="x" + ), + "regnet-x-080": ImageNetPreTrainedConfig( + depths=[2, 5, 15, 1], hidden_sizes=[80, 240, 720, 1920], groups_width=120, layer_type="x" + ), + "regnet-x-120": ImageNetPreTrainedConfig( + depths=[2, 5, 11, 1], hidden_sizes=[224, 448, 896, 2240], groups_width=112, layer_type="x" + ), + "regnet-x-160": ImageNetPreTrainedConfig( + depths=[2, 6, 13, 1], hidden_sizes=[256, 512, 896, 2048], groups_width=128, layer_type="x" + ), + "regnet-x-320": ImageNetPreTrainedConfig( + depths=[2, 7, 13, 1], hidden_sizes=[336, 672, 1344, 2520], groups_width=168, layer_type="x" + ), + # y variant + "regnet-y-002": ImageNetPreTrainedConfig(depths=[1, 1, 4, 7], hidden_sizes=[24, 56, 152, 368], groups_width=8), + "regnet-y-004": ImageNetPreTrainedConfig( + depths=[1, 3, 6, 6], hidden_sizes=[48, 104, 208, 440], groups_width=8 + ), + "regnet-y-006": ImageNetPreTrainedConfig( + depths=[1, 3, 7, 4], hidden_sizes=[48, 112, 256, 608], groups_width=16 + ), + "regnet-y-008": ImageNetPreTrainedConfig( + depths=[1, 3, 8, 2], hidden_sizes=[64, 128, 320, 768], groups_width=16 + ), + "regnet-y-016": ImageNetPreTrainedConfig( + depths=[2, 6, 17, 2], hidden_sizes=[48, 120, 336, 888], groups_width=24 + ), + "regnet-y-032": ImageNetPreTrainedConfig( + depths=[2, 5, 13, 1], hidden_sizes=[72, 216, 576, 1512], groups_width=24 + ), + "regnet-y-040": ImageNetPreTrainedConfig( + depths=[2, 6, 12, 2], hidden_sizes=[128, 192, 512, 1088], groups_width=64 + ), + "regnet-y-064": ImageNetPreTrainedConfig( + depths=[2, 7, 14, 2], hidden_sizes=[144, 288, 576, 1296], groups_width=72 + ), + "regnet-y-080": ImageNetPreTrainedConfig( + depths=[2, 4, 10, 1], hidden_sizes=[168, 448, 896, 2016], groups_width=56 + ), + "regnet-y-120": ImageNetPreTrainedConfig( + depths=[2, 5, 11, 1], hidden_sizes=[224, 448, 896, 2240], groups_width=112 + ), + "regnet-y-160": ImageNetPreTrainedConfig( + depths=[2, 4, 11, 1], hidden_sizes=[224, 448, 1232, 3024], groups_width=112 + ), + "regnet-y-320": ImageNetPreTrainedConfig( + depths=[2, 5, 12, 1], hidden_sizes=[232, 696, 1392, 3712], groups_width=232 + ), + # models created by SEER -> https://arxiv.org/abs/2202.08360 + "regnet-y-320-seer": RegNetConfig(depths=[2, 5, 12, 1], hidden_sizes=[232, 696, 1392, 3712], groups_width=232), + "regnet-y-640-seer": RegNetConfig(depths=[2, 5, 12, 1], hidden_sizes=[328, 984, 1968, 4920], groups_width=328), + "regnet-y-1280-seer": RegNetConfig( + depths=[2, 7, 17, 1], hidden_sizes=[528, 1056, 2904, 7392], groups_width=264 + ), + "regnet-y-2560-seer": RegNetConfig( + depths=[3, 7, 16, 1], hidden_sizes=[640, 1696, 2544, 5088], groups_width=640 + ), + "regnet-y-10b-seer": ImageNetPreTrainedConfig( + depths=[2, 7, 17, 1], hidden_sizes=[2020, 4040, 11110, 28280], groups_width=1010 + ), + # finetuned on imagenet + "regnet-y-320-seer-in1k": ImageNetPreTrainedConfig( + depths=[2, 5, 12, 1], hidden_sizes=[232, 696, 1392, 3712], groups_width=232 + ), + "regnet-y-640-seer-in1k": ImageNetPreTrainedConfig( + depths=[2, 5, 12, 1], hidden_sizes=[328, 984, 1968, 4920], groups_width=328 + ), + "regnet-y-1280-seer-in1k": ImageNetPreTrainedConfig( + depths=[2, 7, 17, 1], hidden_sizes=[528, 1056, 2904, 7392], groups_width=264 + ), + "regnet-y-2560-seer-in1k": ImageNetPreTrainedConfig( + depths=[3, 7, 16, 1], hidden_sizes=[640, 1696, 2544, 5088], groups_width=640 + ), + "regnet-y-10b-seer-in1k": ImageNetPreTrainedConfig( + depths=[2, 7, 17, 1], hidden_sizes=[2020, 4040, 11110, 28280], groups_width=1010 + ), + } + + names_to_ours_model_map = NameToOurModelFuncMap() + names_to_from_model_map = NameToFromModelFuncMap() + # add seer weights logic + + def load_using_classy_vision(checkpoint_url: str, model_func: Callable[[], nn.Module]) -> Tuple[nn.Module, Dict]: + files = torch.hub.load_state_dict_from_url(checkpoint_url, model_dir=str(save_directory), map_location="cpu") + model = model_func() + # check if we have a head, if yes add it + model_state_dict = files["classy_state_dict"]["base_model"]["model"] + state_dict = model_state_dict["trunk"] + model.load_state_dict(state_dict) + return model.eval(), model_state_dict["heads"] + + # pretrained + names_to_from_model_map["regnet-y-320-seer"] = partial( + load_using_classy_vision, + "https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_regnet32d/seer_regnet32gf_model_iteration244000.torch", + lambda: FakeRegNetVisslWrapper(RegNetY32gf()), + ) + + names_to_from_model_map["regnet-y-640-seer"] = partial( + load_using_classy_vision, + "https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_regnet64/seer_regnet64gf_model_final_checkpoint_phase0.torch", + lambda: FakeRegNetVisslWrapper(RegNetY64gf()), + ) + + names_to_from_model_map["regnet-y-1280-seer"] = partial( + load_using_classy_vision, + "https://dl.fbaipublicfiles.com/vissl/model_zoo/swav_ig1b_regnet128Gf_cnstant_bs32_node16_sinkhorn10_proto16k_syncBN64_warmup8k/model_final_checkpoint_phase0.torch", + lambda: FakeRegNetVisslWrapper(RegNetY128gf()), + ) + + names_to_from_model_map["regnet-y-10b-seer"] = partial( + load_using_classy_vision, + "https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_regnet10B/model_iteration124500_conso.torch", + lambda: FakeRegNetVisslWrapper( + RegNet(RegNetParams(depth=27, group_width=1010, w_0=1744, w_a=620.83, w_m=2.52)) + ), + ) + + # IN1K finetuned + names_to_from_model_map["regnet-y-320-seer-in1k"] = partial( + load_using_classy_vision, + "https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_finetuned/seer_regnet32_finetuned_in1k_model_final_checkpoint_phase78.torch", + lambda: FakeRegNetVisslWrapper(RegNetY32gf()), + ) + + names_to_from_model_map["regnet-y-640-seer-in1k"] = partial( + load_using_classy_vision, + "https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_finetuned/seer_regnet64_finetuned_in1k_model_final_checkpoint_phase78.torch", + lambda: FakeRegNetVisslWrapper(RegNetY64gf()), + ) + + names_to_from_model_map["regnet-y-1280-seer-in1k"] = partial( + load_using_classy_vision, + "https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_finetuned/seer_regnet128_finetuned_in1k_model_final_checkpoint_phase78.torch", + lambda: FakeRegNetVisslWrapper(RegNetY128gf()), + ) + + names_to_from_model_map["regnet-y-10b-seer-in1k"] = partial( + load_using_classy_vision, + "https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_finetuned/seer_10b_finetuned_in1k_model_phase28_conso.torch", + lambda: FakeRegNetVisslWrapper( + RegNet(RegNetParams(depth=27, group_width=1010, w_0=1744, w_a=620.83, w_m=2.52)) + ), + ) + + if model_name: + convert_weight_and_push( + model_name, + names_to_from_model_map[model_name], + names_to_ours_model_map[model_name], + names_to_config[model_name], + save_directory, + push_to_hub, + ) + else: + for model_name, config in names_to_config.items(): + convert_weight_and_push( + model_name, + names_to_from_model_map[model_name], + names_to_ours_model_map[model_name], + config, + save_directory, + push_to_hub, + ) + return config, expected_shape + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--model_name", + default=None, + type=str, + help=( + "The name of the model you wish to convert, it must be one of the supported regnet* architecture," + " currently: regnetx-*, regnety-*. If `None`, all of them will the converted." + ), + ) + parser.add_argument( + "--pytorch_dump_folder_path", + default=None, + type=Path, + required=True, + help="Path to the output PyTorch model directory.", + ) + parser.add_argument( + "--push_to_hub", + default=True, + type=bool, + required=False, + help="If True, push model and image processor to the hub.", + ) + + args = parser.parse_args() + + pytorch_dump_folder_path: Path = args.pytorch_dump_folder_path + pytorch_dump_folder_path.mkdir(exist_ok=True, parents=True) + convert_weights_and_push(pytorch_dump_folder_path, args.model_name, args.push_to_hub) diff --git a/docs/transformers/src/transformers/models/regnet/modeling_flax_regnet.py b/docs/transformers/src/transformers/models/regnet/modeling_flax_regnet.py new file mode 100644 index 0000000000000000000000000000000000000000..f575e33db9d74dd937c851df77b589421f1c8082 --- /dev/null +++ b/docs/transformers/src/transformers/models/regnet/modeling_flax_regnet.py @@ -0,0 +1,822 @@ +# coding=utf-8 +# Copyright 2023 The Google Flax Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from functools import partial +from typing import Optional, Tuple + +import flax.linen as nn +import jax +import jax.numpy as jnp +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze +from flax.traverse_util import flatten_dict, unflatten_dict + +from transformers import RegNetConfig +from transformers.modeling_flax_outputs import ( + FlaxBaseModelOutputWithNoAttention, + FlaxBaseModelOutputWithPooling, + FlaxBaseModelOutputWithPoolingAndNoAttention, + FlaxImageClassifierOutputWithNoAttention, +) +from transformers.modeling_flax_utils import ( + ACT2FN, + FlaxPreTrainedModel, + append_replace_return_docstrings, + overwrite_call_docstring, +) +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, +) + + +REGNET_START_DOCSTRING = r""" + + This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading, saving and converting weights from PyTorch models) + + This model is also a + [flax.linen.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) subclass. Use it as + a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and + behavior. + + Finally, this model supports inherent JAX features such as: + + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + + Parameters: + config ([`RegNetConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. + dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): + The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and + `jax.numpy.bfloat16` (on TPUs). + + This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If + specified all the computation will be performed with the given `dtype`. + + **Note that this only specifies the dtype of the computation and does not influence the dtype of model + parameters.** + + If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and + [`~FlaxPreTrainedModel.to_bf16`]. +""" + +REGNET_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`numpy.ndarray` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See + [`RegNetImageProcessor.__call__`] for details. + + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +# Copied from transformers.models.resnet.modeling_flax_resnet.Identity +class Identity(nn.Module): + """Identity function.""" + + @nn.compact + def __call__(self, x, **kwargs): + return x + + +class FlaxRegNetConvLayer(nn.Module): + out_channels: int + kernel_size: int = 3 + stride: int = 1 + groups: int = 1 + activation: Optional[str] = "relu" + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.convolution = nn.Conv( + self.out_channels, + kernel_size=(self.kernel_size, self.kernel_size), + strides=self.stride, + padding=self.kernel_size // 2, + feature_group_count=self.groups, + use_bias=False, + kernel_init=nn.initializers.variance_scaling(2.0, mode="fan_out", distribution="truncated_normal"), + dtype=self.dtype, + ) + self.normalization = nn.BatchNorm(momentum=0.9, epsilon=1e-05, dtype=self.dtype) + self.activation_func = ACT2FN[self.activation] if self.activation is not None else Identity() + + def __call__(self, hidden_state: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: + hidden_state = self.convolution(hidden_state) + hidden_state = self.normalization(hidden_state, use_running_average=deterministic) + hidden_state = self.activation_func(hidden_state) + return hidden_state + + +class FlaxRegNetEmbeddings(nn.Module): + config: RegNetConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.embedder = FlaxRegNetConvLayer( + self.config.embedding_size, + kernel_size=3, + stride=2, + activation=self.config.hidden_act, + dtype=self.dtype, + ) + + def __call__(self, pixel_values: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: + num_channels = pixel_values.shape[-1] + if num_channels != self.config.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + hidden_state = self.embedder(pixel_values, deterministic=deterministic) + return hidden_state + + +# Copied from transformers.models.resnet.modeling_flax_resnet.FlaxResNetShortCut with ResNet->RegNet +class FlaxRegNetShortCut(nn.Module): + """ + RegNet shortcut, used to project the residual features to the correct size. If needed, it is also used to + downsample the input using `stride=2`. + """ + + out_channels: int + stride: int = 2 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.convolution = nn.Conv( + self.out_channels, + kernel_size=(1, 1), + strides=self.stride, + use_bias=False, + kernel_init=nn.initializers.variance_scaling(2.0, mode="fan_out", distribution="truncated_normal"), + dtype=self.dtype, + ) + self.normalization = nn.BatchNorm(momentum=0.9, epsilon=1e-05, dtype=self.dtype) + + def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: + hidden_state = self.convolution(x) + hidden_state = self.normalization(hidden_state, use_running_average=deterministic) + return hidden_state + + +class FlaxRegNetSELayerCollection(nn.Module): + in_channels: int + reduced_channels: int + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.conv_1 = nn.Conv( + self.reduced_channels, + kernel_size=(1, 1), + kernel_init=nn.initializers.variance_scaling(2.0, mode="fan_out", distribution="truncated_normal"), + dtype=self.dtype, + name="0", + ) # 0 is the name used in corresponding pytorch implementation + self.conv_2 = nn.Conv( + self.in_channels, + kernel_size=(1, 1), + kernel_init=nn.initializers.variance_scaling(2.0, mode="fan_out", distribution="truncated_normal"), + dtype=self.dtype, + name="2", + ) # 2 is the name used in corresponding pytorch implementation + + def __call__(self, hidden_state: jnp.ndarray) -> jnp.ndarray: + hidden_state = self.conv_1(hidden_state) + hidden_state = nn.relu(hidden_state) + hidden_state = self.conv_2(hidden_state) + attention = nn.sigmoid(hidden_state) + + return attention + + +class FlaxRegNetSELayer(nn.Module): + """ + Squeeze and Excitation layer (SE) proposed in [Squeeze-and-Excitation Networks](https://arxiv.org/abs/1709.01507). + """ + + in_channels: int + reduced_channels: int + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.pooler = partial(nn.avg_pool, padding=((0, 0), (0, 0))) + self.attention = FlaxRegNetSELayerCollection(self.in_channels, self.reduced_channels, dtype=self.dtype) + + def __call__(self, hidden_state: jnp.ndarray) -> jnp.ndarray: + pooled = self.pooler( + hidden_state, + window_shape=(hidden_state.shape[1], hidden_state.shape[2]), + strides=(hidden_state.shape[1], hidden_state.shape[2]), + ) + attention = self.attention(pooled) + hidden_state = hidden_state * attention + return hidden_state + + +class FlaxRegNetXLayerCollection(nn.Module): + config: RegNetConfig + out_channels: int + stride: int = 1 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + groups = max(1, self.out_channels // self.config.groups_width) + + self.layer = [ + FlaxRegNetConvLayer( + self.out_channels, + kernel_size=1, + activation=self.config.hidden_act, + dtype=self.dtype, + name="0", + ), + FlaxRegNetConvLayer( + self.out_channels, + stride=self.stride, + groups=groups, + activation=self.config.hidden_act, + dtype=self.dtype, + name="1", + ), + FlaxRegNetConvLayer( + self.out_channels, + kernel_size=1, + activation=None, + dtype=self.dtype, + name="2", + ), + ] + + def __call__(self, hidden_state: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: + for layer in self.layer: + hidden_state = layer(hidden_state, deterministic=deterministic) + return hidden_state + + +class FlaxRegNetXLayer(nn.Module): + """ + RegNet's layer composed by three `3x3` convolutions, same as a ResNet bottleneck layer with reduction = 1. + """ + + config: RegNetConfig + in_channels: int + out_channels: int + stride: int = 1 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + should_apply_shortcut = self.in_channels != self.out_channels or self.stride != 1 + self.shortcut = ( + FlaxRegNetShortCut( + self.out_channels, + stride=self.stride, + dtype=self.dtype, + ) + if should_apply_shortcut + else Identity() + ) + self.layer = FlaxRegNetXLayerCollection( + self.config, + in_channels=self.in_channels, + out_channels=self.out_channels, + stride=self.stride, + dtype=self.dtype, + ) + self.activation_func = ACT2FN[self.config.hidden_act] + + def __call__(self, hidden_state: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: + residual = hidden_state + hidden_state = self.layer(hidden_state) + residual = self.shortcut(residual, deterministic=deterministic) + hidden_state += residual + hidden_state = self.activation_func(hidden_state) + return hidden_state + + +class FlaxRegNetYLayerCollection(nn.Module): + config: RegNetConfig + in_channels: int + out_channels: int + stride: int = 1 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + groups = max(1, self.out_channels // self.config.groups_width) + + self.layer = [ + FlaxRegNetConvLayer( + self.out_channels, + kernel_size=1, + activation=self.config.hidden_act, + dtype=self.dtype, + name="0", + ), + FlaxRegNetConvLayer( + self.out_channels, + stride=self.stride, + groups=groups, + activation=self.config.hidden_act, + dtype=self.dtype, + name="1", + ), + FlaxRegNetSELayer( + self.out_channels, + reduced_channels=int(round(self.in_channels / 4)), + dtype=self.dtype, + name="2", + ), + FlaxRegNetConvLayer( + self.out_channels, + kernel_size=1, + activation=None, + dtype=self.dtype, + name="3", + ), + ] + + def __call__(self, hidden_state: jnp.ndarray) -> jnp.ndarray: + for layer in self.layer: + hidden_state = layer(hidden_state) + return hidden_state + + +class FlaxRegNetYLayer(nn.Module): + """ + RegNet's Y layer: an X layer with Squeeze and Excitation. + """ + + config: RegNetConfig + in_channels: int + out_channels: int + stride: int = 1 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + should_apply_shortcut = self.in_channels != self.out_channels or self.stride != 1 + + self.shortcut = ( + FlaxRegNetShortCut( + self.out_channels, + stride=self.stride, + dtype=self.dtype, + ) + if should_apply_shortcut + else Identity() + ) + self.layer = FlaxRegNetYLayerCollection( + self.config, + in_channels=self.in_channels, + out_channels=self.out_channels, + stride=self.stride, + dtype=self.dtype, + ) + self.activation_func = ACT2FN[self.config.hidden_act] + + def __call__(self, hidden_state: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: + residual = hidden_state + hidden_state = self.layer(hidden_state) + residual = self.shortcut(residual, deterministic=deterministic) + hidden_state += residual + hidden_state = self.activation_func(hidden_state) + return hidden_state + + +class FlaxRegNetStageLayersCollection(nn.Module): + """ + A RegNet stage composed by stacked layers. + """ + + config: RegNetConfig + in_channels: int + out_channels: int + stride: int = 2 + depth: int = 2 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + layer = FlaxRegNetXLayer if self.config.layer_type == "x" else FlaxRegNetYLayer + + layers = [ + # downsampling is done in the first layer with stride of 2 + layer( + self.config, + self.in_channels, + self.out_channels, + stride=self.stride, + dtype=self.dtype, + name="0", + ) + ] + + for i in range(self.depth - 1): + layers.append( + layer( + self.config, + self.out_channels, + self.out_channels, + dtype=self.dtype, + name=str(i + 1), + ) + ) + + self.layers = layers + + def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: + hidden_state = x + for layer in self.layers: + hidden_state = layer(hidden_state, deterministic=deterministic) + return hidden_state + + +# Copied from transformers.models.resnet.modeling_flax_resnet.FlaxResNetStage with ResNet->RegNet +class FlaxRegNetStage(nn.Module): + """ + A RegNet stage composed by stacked layers. + """ + + config: RegNetConfig + in_channels: int + out_channels: int + stride: int = 2 + depth: int = 2 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.layers = FlaxRegNetStageLayersCollection( + self.config, + in_channels=self.in_channels, + out_channels=self.out_channels, + stride=self.stride, + depth=self.depth, + dtype=self.dtype, + ) + + def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: + return self.layers(x, deterministic=deterministic) + + +# Copied from transformers.models.resnet.modeling_flax_resnet.FlaxResNetStageCollection with ResNet->RegNet +class FlaxRegNetStageCollection(nn.Module): + config: RegNetConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + in_out_channels = zip(self.config.hidden_sizes, self.config.hidden_sizes[1:]) + stages = [ + FlaxRegNetStage( + self.config, + self.config.embedding_size, + self.config.hidden_sizes[0], + stride=2 if self.config.downsample_in_first_stage else 1, + depth=self.config.depths[0], + dtype=self.dtype, + name="0", + ) + ] + + for i, ((in_channels, out_channels), depth) in enumerate(zip(in_out_channels, self.config.depths[1:])): + stages.append( + FlaxRegNetStage(self.config, in_channels, out_channels, depth=depth, dtype=self.dtype, name=str(i + 1)) + ) + + self.stages = stages + + def __call__( + self, + hidden_state: jnp.ndarray, + output_hidden_states: bool = False, + deterministic: bool = True, + ) -> FlaxBaseModelOutputWithNoAttention: + hidden_states = () if output_hidden_states else None + + for stage_module in self.stages: + if output_hidden_states: + hidden_states = hidden_states + (hidden_state.transpose(0, 3, 1, 2),) + + hidden_state = stage_module(hidden_state, deterministic=deterministic) + + return hidden_state, hidden_states + + +# Copied from transformers.models.resnet.modeling_flax_resnet.FlaxResNetEncoder with ResNet->RegNet +class FlaxRegNetEncoder(nn.Module): + config: RegNetConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.stages = FlaxRegNetStageCollection(self.config, dtype=self.dtype) + + def __call__( + self, + hidden_state: jnp.ndarray, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + ) -> FlaxBaseModelOutputWithNoAttention: + hidden_state, hidden_states = self.stages( + hidden_state, output_hidden_states=output_hidden_states, deterministic=deterministic + ) + + if output_hidden_states: + hidden_states = hidden_states + (hidden_state.transpose(0, 3, 1, 2),) + + if not return_dict: + return tuple(v for v in [hidden_state, hidden_states] if v is not None) + + return FlaxBaseModelOutputWithNoAttention( + last_hidden_state=hidden_state, + hidden_states=hidden_states, + ) + + +# Copied from transformers.models.resnet.modeling_flax_resnet.FlaxResNetPreTrainedModel with ResNet->RegNet,resnet->regnet,RESNET->REGNET +class FlaxRegNetPreTrainedModel(FlaxPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = RegNetConfig + base_model_prefix = "regnet" + main_input_name = "pixel_values" + module_class: nn.Module = None + + def __init__( + self, + config: RegNetConfig, + input_shape=(1, 224, 224, 3), + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, + **kwargs, + ): + module = self.module_class(config=config, dtype=dtype, **kwargs) + if input_shape is None: + input_shape = (1, config.image_size, config.image_size, config.num_channels) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: + # init input tensors + pixel_values = jnp.zeros(input_shape, dtype=self.dtype) + + rngs = {"params": rng} + + random_params = self.module.init(rngs, pixel_values, return_dict=False) + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + + @add_start_docstrings_to_model_forward(REGNET_INPUTS_DOCSTRING) + def __call__( + self, + pixel_values, + params: dict = None, + train: bool = False, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1)) + + # Handle any PRNG if needed + rngs = {} + + return self.module.apply( + { + "params": params["params"] if params is not None else self.params["params"], + "batch_stats": params["batch_stats"] if params is not None else self.params["batch_stats"], + }, + jnp.array(pixel_values, dtype=jnp.float32), + not train, + output_hidden_states, + return_dict, + rngs=rngs, + mutable=["batch_stats"] if train else False, # Returing tuple with batch_stats only when train is True + ) + + +# Copied from transformers.models.resnet.modeling_flax_resnet.FlaxResNetModule with ResNet->RegNet +class FlaxRegNetModule(nn.Module): + config: RegNetConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.embedder = FlaxRegNetEmbeddings(self.config, dtype=self.dtype) + self.encoder = FlaxRegNetEncoder(self.config, dtype=self.dtype) + + # Adaptive average pooling used in resnet + self.pooler = partial( + nn.avg_pool, + padding=((0, 0), (0, 0)), + ) + + def __call__( + self, + pixel_values, + deterministic: bool = True, + output_hidden_states: bool = False, + return_dict: bool = True, + ) -> FlaxBaseModelOutputWithPoolingAndNoAttention: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + embedding_output = self.embedder(pixel_values, deterministic=deterministic) + + encoder_outputs = self.encoder( + embedding_output, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + last_hidden_state = encoder_outputs[0] + + pooled_output = self.pooler( + last_hidden_state, + window_shape=(last_hidden_state.shape[1], last_hidden_state.shape[2]), + strides=(last_hidden_state.shape[1], last_hidden_state.shape[2]), + ).transpose(0, 3, 1, 2) + + last_hidden_state = last_hidden_state.transpose(0, 3, 1, 2) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return FlaxBaseModelOutputWithPoolingAndNoAttention( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + ) + + +@add_start_docstrings( + "The bare RegNet model outputting raw features without any specific head on top.", + REGNET_START_DOCSTRING, +) +class FlaxRegNetModel(FlaxRegNetPreTrainedModel): + module_class = FlaxRegNetModule + + +FLAX_VISION_MODEL_DOCSTRING = """ + Returns: + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, FlaxRegNetModel + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("facebook/regnet-y-040") + >>> model = FlaxRegNetModel.from_pretrained("facebook/regnet-y-040") + + >>> inputs = image_processor(images=image, return_tensors="np") + >>> outputs = model(**inputs) + >>> last_hidden_states = outputs.last_hidden_state + ``` +""" + +overwrite_call_docstring(FlaxRegNetModel, FLAX_VISION_MODEL_DOCSTRING) +append_replace_return_docstrings( + FlaxRegNetModel, + output_type=FlaxBaseModelOutputWithPooling, + config_class=RegNetConfig, +) + + +# Copied from transformers.models.resnet.modeling_flax_resnet.FlaxResNetClassifierCollection with ResNet->RegNet +class FlaxRegNetClassifierCollection(nn.Module): + config: RegNetConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype, name="1") + + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + return self.classifier(x) + + +# Copied from transformers.models.resnet.modeling_flax_resnet.FlaxResNetForImageClassificationModule with ResNet->RegNet,resnet->regnet,RESNET->REGNET +class FlaxRegNetForImageClassificationModule(nn.Module): + config: RegNetConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.regnet = FlaxRegNetModule(config=self.config, dtype=self.dtype) + + if self.config.num_labels > 0: + self.classifier = FlaxRegNetClassifierCollection(self.config, dtype=self.dtype) + else: + self.classifier = Identity() + + def __call__( + self, + pixel_values=None, + deterministic: bool = True, + output_hidden_states=None, + return_dict=None, + ): + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.regnet( + pixel_values, + deterministic=deterministic, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs.pooler_output if return_dict else outputs[1] + + logits = self.classifier(pooled_output[:, :, 0, 0]) + + if not return_dict: + output = (logits,) + outputs[2:] + return output + + return FlaxImageClassifierOutputWithNoAttention(logits=logits, hidden_states=outputs.hidden_states) + + +@add_start_docstrings( + """ + RegNet Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for + ImageNet. + """, + REGNET_START_DOCSTRING, +) +class FlaxRegNetForImageClassification(FlaxRegNetPreTrainedModel): + module_class = FlaxRegNetForImageClassificationModule + + +FLAX_VISION_CLASSIF_DOCSTRING = """ + Returns: + + Example: + + ```python + >>> from transformers import AutoImageProcessor, FlaxRegNetForImageClassification + >>> from PIL import Image + >>> import jax + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("facebook/regnet-y-040") + >>> model = FlaxRegNetForImageClassification.from_pretrained("facebook/regnet-y-040") + + >>> inputs = image_processor(images=image, return_tensors="np") + >>> outputs = model(**inputs) + >>> logits = outputs.logits + + >>> # model predicts one of the 1000 ImageNet classes + >>> predicted_class_idx = jax.numpy.argmax(logits, axis=-1) + >>> print("Predicted class:", model.config.id2label[predicted_class_idx.item()]) + ``` +""" + +overwrite_call_docstring(FlaxRegNetForImageClassification, FLAX_VISION_CLASSIF_DOCSTRING) +append_replace_return_docstrings( + FlaxRegNetForImageClassification, + output_type=FlaxImageClassifierOutputWithNoAttention, + config_class=RegNetConfig, +) + + +__all__ = ["FlaxRegNetForImageClassification", "FlaxRegNetModel", "FlaxRegNetPreTrainedModel"] diff --git a/docs/transformers/src/transformers/models/regnet/modeling_regnet.py b/docs/transformers/src/transformers/models/regnet/modeling_regnet.py new file mode 100644 index 0000000000000000000000000000000000000000..9fd0a4c634d5921417fda15c3fd02e1c01d5e000 --- /dev/null +++ b/docs/transformers/src/transformers/models/regnet/modeling_regnet.py @@ -0,0 +1,454 @@ +# coding=utf-8 +# Copyright 2022 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch RegNet model.""" + +import math +from typing import Optional + +import torch +import torch.utils.checkpoint +from torch import Tensor, nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward +from ...modeling_outputs import ( + BaseModelOutputWithNoAttention, + BaseModelOutputWithPoolingAndNoAttention, + ImageClassifierOutputWithNoAttention, +) +from ...modeling_utils import PreTrainedModel +from ...utils import logging +from .configuration_regnet import RegNetConfig + + +logger = logging.get_logger(__name__) + +# General docstring +_CONFIG_FOR_DOC = "RegNetConfig" + +# Base docstring +_CHECKPOINT_FOR_DOC = "facebook/regnet-y-040" +_EXPECTED_OUTPUT_SHAPE = [1, 1088, 7, 7] + +# Image classification docstring +_IMAGE_CLASS_CHECKPOINT = "facebook/regnet-y-040" +_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat" + + +class RegNetConvLayer(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + stride: int = 1, + groups: int = 1, + activation: Optional[str] = "relu", + ): + super().__init__() + self.convolution = nn.Conv2d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=kernel_size // 2, + groups=groups, + bias=False, + ) + self.normalization = nn.BatchNorm2d(out_channels) + self.activation = ACT2FN[activation] if activation is not None else nn.Identity() + + def forward(self, hidden_state): + hidden_state = self.convolution(hidden_state) + hidden_state = self.normalization(hidden_state) + hidden_state = self.activation(hidden_state) + return hidden_state + + +class RegNetEmbeddings(nn.Module): + """ + RegNet Embeddings (stem) composed of a single aggressive convolution. + """ + + def __init__(self, config: RegNetConfig): + super().__init__() + self.embedder = RegNetConvLayer( + config.num_channels, config.embedding_size, kernel_size=3, stride=2, activation=config.hidden_act + ) + self.num_channels = config.num_channels + + def forward(self, pixel_values): + num_channels = pixel_values.shape[1] + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + hidden_state = self.embedder(pixel_values) + return hidden_state + + +# Copied from transformers.models.resnet.modeling_resnet.ResNetShortCut with ResNet->RegNet +class RegNetShortCut(nn.Module): + """ + RegNet shortcut, used to project the residual features to the correct size. If needed, it is also used to + downsample the input using `stride=2`. + """ + + def __init__(self, in_channels: int, out_channels: int, stride: int = 2): + super().__init__() + self.convolution = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False) + self.normalization = nn.BatchNorm2d(out_channels) + + def forward(self, input: Tensor) -> Tensor: + hidden_state = self.convolution(input) + hidden_state = self.normalization(hidden_state) + return hidden_state + + +class RegNetSELayer(nn.Module): + """ + Squeeze and Excitation layer (SE) proposed in [Squeeze-and-Excitation Networks](https://arxiv.org/abs/1709.01507). + """ + + def __init__(self, in_channels: int, reduced_channels: int): + super().__init__() + + self.pooler = nn.AdaptiveAvgPool2d((1, 1)) + self.attention = nn.Sequential( + nn.Conv2d(in_channels, reduced_channels, kernel_size=1), + nn.ReLU(), + nn.Conv2d(reduced_channels, in_channels, kernel_size=1), + nn.Sigmoid(), + ) + + def forward(self, hidden_state): + # b c h w -> b c 1 1 + pooled = self.pooler(hidden_state) + attention = self.attention(pooled) + hidden_state = hidden_state * attention + return hidden_state + + +class RegNetXLayer(nn.Module): + """ + RegNet's layer composed by three `3x3` convolutions, same as a ResNet bottleneck layer with reduction = 1. + """ + + def __init__(self, config: RegNetConfig, in_channels: int, out_channels: int, stride: int = 1): + super().__init__() + should_apply_shortcut = in_channels != out_channels or stride != 1 + groups = max(1, out_channels // config.groups_width) + self.shortcut = ( + RegNetShortCut(in_channels, out_channels, stride=stride) if should_apply_shortcut else nn.Identity() + ) + self.layer = nn.Sequential( + RegNetConvLayer(in_channels, out_channels, kernel_size=1, activation=config.hidden_act), + RegNetConvLayer(out_channels, out_channels, stride=stride, groups=groups, activation=config.hidden_act), + RegNetConvLayer(out_channels, out_channels, kernel_size=1, activation=None), + ) + self.activation = ACT2FN[config.hidden_act] + + def forward(self, hidden_state): + residual = hidden_state + hidden_state = self.layer(hidden_state) + residual = self.shortcut(residual) + hidden_state += residual + hidden_state = self.activation(hidden_state) + return hidden_state + + +class RegNetYLayer(nn.Module): + """ + RegNet's Y layer: an X layer with Squeeze and Excitation. + """ + + def __init__(self, config: RegNetConfig, in_channels: int, out_channels: int, stride: int = 1): + super().__init__() + should_apply_shortcut = in_channels != out_channels or stride != 1 + groups = max(1, out_channels // config.groups_width) + self.shortcut = ( + RegNetShortCut(in_channels, out_channels, stride=stride) if should_apply_shortcut else nn.Identity() + ) + self.layer = nn.Sequential( + RegNetConvLayer(in_channels, out_channels, kernel_size=1, activation=config.hidden_act), + RegNetConvLayer(out_channels, out_channels, stride=stride, groups=groups, activation=config.hidden_act), + RegNetSELayer(out_channels, reduced_channels=int(round(in_channels / 4))), + RegNetConvLayer(out_channels, out_channels, kernel_size=1, activation=None), + ) + self.activation = ACT2FN[config.hidden_act] + + def forward(self, hidden_state): + residual = hidden_state + hidden_state = self.layer(hidden_state) + residual = self.shortcut(residual) + hidden_state += residual + hidden_state = self.activation(hidden_state) + return hidden_state + + +class RegNetStage(nn.Module): + """ + A RegNet stage composed by stacked layers. + """ + + def __init__( + self, + config: RegNetConfig, + in_channels: int, + out_channels: int, + stride: int = 2, + depth: int = 2, + ): + super().__init__() + + layer = RegNetXLayer if config.layer_type == "x" else RegNetYLayer + + self.layers = nn.Sequential( + # downsampling is done in the first layer with stride of 2 + layer( + config, + in_channels, + out_channels, + stride=stride, + ), + *[layer(config, out_channels, out_channels) for _ in range(depth - 1)], + ) + + def forward(self, hidden_state): + hidden_state = self.layers(hidden_state) + return hidden_state + + +class RegNetEncoder(nn.Module): + def __init__(self, config: RegNetConfig): + super().__init__() + self.stages = nn.ModuleList([]) + # based on `downsample_in_first_stage`, the first layer of the first stage may or may not downsample the input + self.stages.append( + RegNetStage( + config, + config.embedding_size, + config.hidden_sizes[0], + stride=2 if config.downsample_in_first_stage else 1, + depth=config.depths[0], + ) + ) + in_out_channels = zip(config.hidden_sizes, config.hidden_sizes[1:]) + for (in_channels, out_channels), depth in zip(in_out_channels, config.depths[1:]): + self.stages.append(RegNetStage(config, in_channels, out_channels, depth=depth)) + + def forward( + self, hidden_state: Tensor, output_hidden_states: bool = False, return_dict: bool = True + ) -> BaseModelOutputWithNoAttention: + hidden_states = () if output_hidden_states else None + + for stage_module in self.stages: + if output_hidden_states: + hidden_states = hidden_states + (hidden_state,) + + hidden_state = stage_module(hidden_state) + + if output_hidden_states: + hidden_states = hidden_states + (hidden_state,) + + if not return_dict: + return tuple(v for v in [hidden_state, hidden_states] if v is not None) + + return BaseModelOutputWithNoAttention(last_hidden_state=hidden_state, hidden_states=hidden_states) + + +class RegNetPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = RegNetConfig + base_model_prefix = "regnet" + main_input_name = "pixel_values" + _no_split_modules = ["RegNetYLayer"] + + # Copied from transformers.models.resnet.modeling_resnet.ResNetPreTrainedModel._init_weights + def _init_weights(self, module): + if isinstance(module, nn.Conv2d): + nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") + # copied from the `reset_parameters` method of `class Linear(Module)` in `torch`. + elif isinstance(module, nn.Linear): + nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5)) + if module.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + nn.init.uniform_(module.bias, -bound, bound) + elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(module.weight, 1) + nn.init.constant_(module.bias, 0) + + +REGNET_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it + as a regular PyTorch Module and refer to the PyTorch documentation for all matters related to general usage and + behavior. + + Parameters: + config ([`RegNetConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +REGNET_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See + [`ConvNextImageProcessor.__call__`] for details. + + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare RegNet model outputting raw features without any specific head on top.", + REGNET_START_DOCSTRING, +) +# Copied from transformers.models.resnet.modeling_resnet.ResNetModel with RESNET->REGNET,ResNet->RegNet +class RegNetModel(RegNetPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.config = config + self.embedder = RegNetEmbeddings(config) + self.encoder = RegNetEncoder(config) + self.pooler = nn.AdaptiveAvgPool2d((1, 1)) + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(REGNET_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPoolingAndNoAttention, + config_class=_CONFIG_FOR_DOC, + modality="vision", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward( + self, pixel_values: Tensor, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None + ) -> BaseModelOutputWithPoolingAndNoAttention: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + embedding_output = self.embedder(pixel_values) + + encoder_outputs = self.encoder( + embedding_output, output_hidden_states=output_hidden_states, return_dict=return_dict + ) + + last_hidden_state = encoder_outputs[0] + + pooled_output = self.pooler(last_hidden_state) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndNoAttention( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + ) + + +@add_start_docstrings( + """ + RegNet Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for + ImageNet. + """, + REGNET_START_DOCSTRING, +) +# Copied from transformers.models.resnet.modeling_resnet.ResNetForImageClassification with RESNET->REGNET,ResNet->RegNet,resnet->regnet +class RegNetForImageClassification(RegNetPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.regnet = RegNetModel(config) + # classification head + self.classifier = nn.Sequential( + nn.Flatten(), + nn.Linear(config.hidden_sizes[-1], config.num_labels) if config.num_labels > 0 else nn.Identity(), + ) + # initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(REGNET_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_IMAGE_CLASS_CHECKPOINT, + output_type=ImageClassifierOutputWithNoAttention, + config_class=_CONFIG_FOR_DOC, + expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, + ) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> ImageClassifierOutputWithNoAttention: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.regnet(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict) + + pooled_output = outputs.pooler_output if return_dict else outputs[1] + + logits = self.classifier(pooled_output) + + loss = None + + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[2:] + return (loss,) + output if loss is not None else output + + return ImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states) + + +__all__ = ["RegNetForImageClassification", "RegNetModel", "RegNetPreTrainedModel"] diff --git a/docs/transformers/src/transformers/models/regnet/modeling_tf_regnet.py b/docs/transformers/src/transformers/models/regnet/modeling_tf_regnet.py new file mode 100644 index 0000000000000000000000000000000000000000..7e0cc5d562b97559878f86b7b6316fa2c403b675 --- /dev/null +++ b/docs/transformers/src/transformers/models/regnet/modeling_tf_regnet.py @@ -0,0 +1,611 @@ +# coding=utf-8 +# Copyright 2022 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TensorFlow RegNet model.""" + +from typing import Optional, Tuple, Union + +import tensorflow as tf + +from ...activations_tf import ACT2FN +from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward +from ...modeling_tf_outputs import ( + TFBaseModelOutputWithNoAttention, + TFBaseModelOutputWithPoolingAndNoAttention, + TFSequenceClassifierOutput, +) +from ...modeling_tf_utils import ( + TFPreTrainedModel, + TFSequenceClassificationLoss, + keras, + keras_serializable, + unpack_inputs, +) +from ...tf_utils import shape_list +from ...utils import logging +from .configuration_regnet import RegNetConfig + + +logger = logging.get_logger(__name__) + +# General docstring +_CONFIG_FOR_DOC = "RegNetConfig" + +# Base docstring +_CHECKPOINT_FOR_DOC = "facebook/regnet-y-040" +_EXPECTED_OUTPUT_SHAPE = [1, 1088, 7, 7] + +# Image classification docstring +_IMAGE_CLASS_CHECKPOINT = "facebook/regnet-y-040" +_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat" + + +class TFRegNetConvLayer(keras.layers.Layer): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + stride: int = 1, + groups: int = 1, + activation: Optional[str] = "relu", + **kwargs, + ): + super().__init__(**kwargs) + # The padding and conv has been verified in + # https://colab.research.google.com/gist/sayakpaul/854bc10eeaf21c9ee2119e0b9f3841a7/scratchpad.ipynb + self.padding = keras.layers.ZeroPadding2D(padding=kernel_size // 2) + self.convolution = keras.layers.Conv2D( + filters=out_channels, + kernel_size=kernel_size, + strides=stride, + padding="VALID", + groups=groups, + use_bias=False, + name="convolution", + ) + self.normalization = keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.9, name="normalization") + self.activation = ACT2FN[activation] if activation is not None else tf.identity + self.in_channels = in_channels + self.out_channels = out_channels + + def call(self, hidden_state): + hidden_state = self.convolution(self.padding(hidden_state)) + hidden_state = self.normalization(hidden_state) + hidden_state = self.activation(hidden_state) + return hidden_state + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "convolution", None) is not None: + with tf.name_scope(self.convolution.name): + self.convolution.build([None, None, None, self.in_channels]) + if getattr(self, "normalization", None) is not None: + with tf.name_scope(self.normalization.name): + self.normalization.build([None, None, None, self.out_channels]) + + +class TFRegNetEmbeddings(keras.layers.Layer): + """ + RegNet Embeddings (stem) composed of a single aggressive convolution. + """ + + def __init__(self, config: RegNetConfig, **kwargs): + super().__init__(**kwargs) + self.num_channels = config.num_channels + self.embedder = TFRegNetConvLayer( + in_channels=config.num_channels, + out_channels=config.embedding_size, + kernel_size=3, + stride=2, + activation=config.hidden_act, + name="embedder", + ) + + def call(self, pixel_values): + num_channels = shape_list(pixel_values)[1] + if tf.executing_eagerly() and num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + + # When running on CPU, `keras.layers.Conv2D` doesn't support `NCHW` format. + # So change the input format from `NCHW` to `NHWC`. + # shape = (batch_size, in_height, in_width, in_channels=num_channels) + pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1)) + hidden_state = self.embedder(pixel_values) + return hidden_state + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "embedder", None) is not None: + with tf.name_scope(self.embedder.name): + self.embedder.build(None) + + +class TFRegNetShortCut(keras.layers.Layer): + """ + RegNet shortcut, used to project the residual features to the correct size. If needed, it is also used to + downsample the input using `stride=2`. + """ + + def __init__(self, in_channels: int, out_channels: int, stride: int = 2, **kwargs): + super().__init__(**kwargs) + self.convolution = keras.layers.Conv2D( + filters=out_channels, kernel_size=1, strides=stride, use_bias=False, name="convolution" + ) + self.normalization = keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.9, name="normalization") + self.in_channels = in_channels + self.out_channels = out_channels + + def call(self, inputs: tf.Tensor, training: bool = False) -> tf.Tensor: + return self.normalization(self.convolution(inputs), training=training) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "convolution", None) is not None: + with tf.name_scope(self.convolution.name): + self.convolution.build([None, None, None, self.in_channels]) + if getattr(self, "normalization", None) is not None: + with tf.name_scope(self.normalization.name): + self.normalization.build([None, None, None, self.out_channels]) + + +class TFRegNetSELayer(keras.layers.Layer): + """ + Squeeze and Excitation layer (SE) proposed in [Squeeze-and-Excitation Networks](https://arxiv.org/abs/1709.01507). + """ + + def __init__(self, in_channels: int, reduced_channels: int, **kwargs): + super().__init__(**kwargs) + self.pooler = keras.layers.GlobalAveragePooling2D(keepdims=True, name="pooler") + self.attention = [ + keras.layers.Conv2D(filters=reduced_channels, kernel_size=1, activation="relu", name="attention.0"), + keras.layers.Conv2D(filters=in_channels, kernel_size=1, activation="sigmoid", name="attention.2"), + ] + self.in_channels = in_channels + self.reduced_channels = reduced_channels + + def call(self, hidden_state): + # [batch_size, h, w, num_channels] -> [batch_size, 1, 1, num_channels] + pooled = self.pooler(hidden_state) + for layer_module in self.attention: + pooled = layer_module(pooled) + hidden_state = hidden_state * pooled + return hidden_state + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "pooler", None) is not None: + with tf.name_scope(self.pooler.name): + self.pooler.build((None, None, None, None)) + if getattr(self, "attention", None) is not None: + with tf.name_scope(self.attention[0].name): + self.attention[0].build([None, None, None, self.in_channels]) + with tf.name_scope(self.attention[1].name): + self.attention[1].build([None, None, None, self.reduced_channels]) + + +class TFRegNetXLayer(keras.layers.Layer): + """ + RegNet's layer composed by three `3x3` convolutions, same as a ResNet bottleneck layer with reduction = 1. + """ + + def __init__(self, config: RegNetConfig, in_channels: int, out_channels: int, stride: int = 1, **kwargs): + super().__init__(**kwargs) + should_apply_shortcut = in_channels != out_channels or stride != 1 + groups = max(1, out_channels // config.groups_width) + self.shortcut = ( + TFRegNetShortCut(in_channels, out_channels, stride=stride, name="shortcut") + if should_apply_shortcut + else keras.layers.Activation("linear", name="shortcut") + ) + # `self.layers` instead of `self.layer` because that is a reserved argument. + self.layers = [ + TFRegNetConvLayer(in_channels, out_channels, kernel_size=1, activation=config.hidden_act, name="layer.0"), + TFRegNetConvLayer( + out_channels, out_channels, stride=stride, groups=groups, activation=config.hidden_act, name="layer.1" + ), + TFRegNetConvLayer(out_channels, out_channels, kernel_size=1, activation=None, name="layer.2"), + ] + self.activation = ACT2FN[config.hidden_act] + + def call(self, hidden_state): + residual = hidden_state + for layer_module in self.layers: + hidden_state = layer_module(hidden_state) + residual = self.shortcut(residual) + hidden_state += residual + hidden_state = self.activation(hidden_state) + return hidden_state + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "shortcut", None) is not None: + with tf.name_scope(self.shortcut.name): + self.shortcut.build(None) + if getattr(self, "layers", None) is not None: + for layer in self.layers: + with tf.name_scope(layer.name): + layer.build(None) + + +class TFRegNetYLayer(keras.layers.Layer): + """ + RegNet's Y layer: an X layer with Squeeze and Excitation. + """ + + def __init__(self, config: RegNetConfig, in_channels: int, out_channels: int, stride: int = 1, **kwargs): + super().__init__(**kwargs) + should_apply_shortcut = in_channels != out_channels or stride != 1 + groups = max(1, out_channels // config.groups_width) + self.shortcut = ( + TFRegNetShortCut(in_channels, out_channels, stride=stride, name="shortcut") + if should_apply_shortcut + else keras.layers.Activation("linear", name="shortcut") + ) + self.layers = [ + TFRegNetConvLayer(in_channels, out_channels, kernel_size=1, activation=config.hidden_act, name="layer.0"), + TFRegNetConvLayer( + out_channels, out_channels, stride=stride, groups=groups, activation=config.hidden_act, name="layer.1" + ), + TFRegNetSELayer(out_channels, reduced_channels=int(round(in_channels / 4)), name="layer.2"), + TFRegNetConvLayer(out_channels, out_channels, kernel_size=1, activation=None, name="layer.3"), + ] + self.activation = ACT2FN[config.hidden_act] + + def call(self, hidden_state): + residual = hidden_state + for layer_module in self.layers: + hidden_state = layer_module(hidden_state) + residual = self.shortcut(residual) + hidden_state += residual + hidden_state = self.activation(hidden_state) + return hidden_state + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "shortcut", None) is not None: + with tf.name_scope(self.shortcut.name): + self.shortcut.build(None) + if getattr(self, "layers", None) is not None: + for layer in self.layers: + with tf.name_scope(layer.name): + layer.build(None) + + +class TFRegNetStage(keras.layers.Layer): + """ + A RegNet stage composed by stacked layers. + """ + + def __init__( + self, config: RegNetConfig, in_channels: int, out_channels: int, stride: int = 2, depth: int = 2, **kwargs + ): + super().__init__(**kwargs) + + layer = TFRegNetXLayer if config.layer_type == "x" else TFRegNetYLayer + self.layers = [ + # downsampling is done in the first layer with stride of 2 + layer(config, in_channels, out_channels, stride=stride, name="layers.0"), + *[layer(config, out_channels, out_channels, name=f"layers.{i + 1}") for i in range(depth - 1)], + ] + + def call(self, hidden_state): + for layer_module in self.layers: + hidden_state = layer_module(hidden_state) + return hidden_state + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "layers", None) is not None: + for layer in self.layers: + with tf.name_scope(layer.name): + layer.build(None) + + +class TFRegNetEncoder(keras.layers.Layer): + def __init__(self, config: RegNetConfig, **kwargs): + super().__init__(**kwargs) + self.stages = [] + # based on `downsample_in_first_stage`, the first layer of the first stage may or may not downsample the input + self.stages.append( + TFRegNetStage( + config, + config.embedding_size, + config.hidden_sizes[0], + stride=2 if config.downsample_in_first_stage else 1, + depth=config.depths[0], + name="stages.0", + ) + ) + in_out_channels = zip(config.hidden_sizes, config.hidden_sizes[1:]) + for i, ((in_channels, out_channels), depth) in enumerate(zip(in_out_channels, config.depths[1:])): + self.stages.append(TFRegNetStage(config, in_channels, out_channels, depth=depth, name=f"stages.{i + 1}")) + + def call( + self, hidden_state: tf.Tensor, output_hidden_states: bool = False, return_dict: bool = True + ) -> TFBaseModelOutputWithNoAttention: + hidden_states = () if output_hidden_states else None + + for stage_module in self.stages: + if output_hidden_states: + hidden_states = hidden_states + (hidden_state,) + + hidden_state = stage_module(hidden_state) + + if output_hidden_states: + hidden_states = hidden_states + (hidden_state,) + + if not return_dict: + return tuple(v for v in [hidden_state, hidden_states] if v is not None) + + return TFBaseModelOutputWithNoAttention(last_hidden_state=hidden_state, hidden_states=hidden_states) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + for stage in self.stages: + with tf.name_scope(stage.name): + stage.build(None) + + +@keras_serializable +class TFRegNetMainLayer(keras.layers.Layer): + config_class = RegNetConfig + + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + self.config = config + self.embedder = TFRegNetEmbeddings(config, name="embedder") + self.encoder = TFRegNetEncoder(config, name="encoder") + self.pooler = keras.layers.GlobalAveragePooling2D(keepdims=True, name="pooler") + + @unpack_inputs + def call( + self, + pixel_values: tf.Tensor, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: bool = False, + ) -> TFBaseModelOutputWithPoolingAndNoAttention: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + embedding_output = self.embedder(pixel_values, training=training) + + encoder_outputs = self.encoder( + embedding_output, output_hidden_states=output_hidden_states, return_dict=return_dict, training=training + ) + + last_hidden_state = encoder_outputs[0] + pooled_output = self.pooler(last_hidden_state) + + # Change to NCHW output format have uniformity in the modules + pooled_output = tf.transpose(pooled_output, perm=(0, 3, 1, 2)) + last_hidden_state = tf.transpose(last_hidden_state, perm=(0, 3, 1, 2)) + + # Change the other hidden state outputs to NCHW as well + if output_hidden_states: + hidden_states = tuple([tf.transpose(h, perm=(0, 3, 1, 2)) for h in encoder_outputs[1]]) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return TFBaseModelOutputWithPoolingAndNoAttention( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=hidden_states if output_hidden_states else encoder_outputs.hidden_states, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "embedder", None) is not None: + with tf.name_scope(self.embedder.name): + self.embedder.build(None) + if getattr(self, "encoder", None) is not None: + with tf.name_scope(self.encoder.name): + self.encoder.build(None) + if getattr(self, "pooler", None) is not None: + with tf.name_scope(self.pooler.name): + self.pooler.build((None, None, None, None)) + + +class TFRegNetPreTrainedModel(TFPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = RegNetConfig + base_model_prefix = "regnet" + main_input_name = "pixel_values" + + @property + def input_signature(self): + return {"pixel_values": tf.TensorSpec(shape=(None, self.config.num_channels, 224, 224), dtype=tf.float32)} + + +REGNET_START_DOCSTRING = r""" + This model is a Tensorflow + [keras.layers.Layer](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer) sub-class. Use it as a + regular Tensorflow Module and refer to the Tensorflow documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`RegNetConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. +""" + +REGNET_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See + [`ConveNextImageProcessor.__call__`] for details. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare RegNet model outputting raw features without any specific head on top.", + REGNET_START_DOCSTRING, +) +class TFRegNetModel(TFRegNetPreTrainedModel): + def __init__(self, config: RegNetConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.regnet = TFRegNetMainLayer(config, name="regnet") + + @unpack_inputs + @add_start_docstrings_to_model_forward(REGNET_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFBaseModelOutputWithPoolingAndNoAttention, + config_class=_CONFIG_FOR_DOC, + modality="vision", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def call( + self, + pixel_values: tf.Tensor, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: bool = False, + ) -> Union[TFBaseModelOutputWithPoolingAndNoAttention, Tuple[tf.Tensor]]: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.regnet( + pixel_values=pixel_values, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + if not return_dict: + return (outputs[0],) + outputs[1:] + + return TFBaseModelOutputWithPoolingAndNoAttention( + last_hidden_state=outputs.last_hidden_state, + pooler_output=outputs.pooler_output, + hidden_states=outputs.hidden_states, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "regnet", None) is not None: + with tf.name_scope(self.regnet.name): + self.regnet.build(None) + + +@add_start_docstrings( + """ + RegNet Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for + ImageNet. + """, + REGNET_START_DOCSTRING, +) +class TFRegNetForImageClassification(TFRegNetPreTrainedModel, TFSequenceClassificationLoss): + def __init__(self, config: RegNetConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.num_labels = config.num_labels + self.regnet = TFRegNetMainLayer(config, name="regnet") + # classification head + self.classifier = [ + keras.layers.Flatten(), + keras.layers.Dense(config.num_labels, name="classifier.1") if config.num_labels > 0 else tf.identity, + ] + + @unpack_inputs + @add_start_docstrings_to_model_forward(REGNET_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_IMAGE_CLASS_CHECKPOINT, + output_type=TFSequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, + ) + def call( + self, + pixel_values: Optional[tf.Tensor] = None, + labels: Optional[tf.Tensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: bool = False, + ) -> Union[TFSequenceClassifierOutput, Tuple[tf.Tensor]]: + r""" + labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.regnet( + pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict, training=training + ) + + pooled_output = outputs.pooler_output if return_dict else outputs[1] + + flattened_output = self.classifier[0](pooled_output) + logits = self.classifier[1](flattened_output) + + loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFSequenceClassifierOutput(loss=loss, logits=logits, hidden_states=outputs.hidden_states) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "regnet", None) is not None: + with tf.name_scope(self.regnet.name): + self.regnet.build(None) + if getattr(self, "classifier", None) is not None: + with tf.name_scope(self.classifier[1].name): + self.classifier[1].build([None, None, None, self.config.hidden_sizes[-1]]) + + +__all__ = ["TFRegNetForImageClassification", "TFRegNetModel", "TFRegNetPreTrainedModel"] diff --git a/docs/transformers/src/transformers/models/rembert/__init__.py b/docs/transformers/src/transformers/models/rembert/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..38566f502ad0ccbe3ecf04c003b16faf1c3af112 --- /dev/null +++ b/docs/transformers/src/transformers/models/rembert/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_rembert import * + from .modeling_rembert import * + from .modeling_tf_rembert import * + from .tokenization_rembert import * + from .tokenization_rembert_fast import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/rembert/configuration_rembert.py b/docs/transformers/src/transformers/models/rembert/configuration_rembert.py new file mode 100644 index 0000000000000000000000000000000000000000..b4b1fbc325c244ddbb7c5f2c76592702d6daabee --- /dev/null +++ b/docs/transformers/src/transformers/models/rembert/configuration_rembert.py @@ -0,0 +1,162 @@ +# coding=utf-8 +# Copyright The HuggingFace Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""RemBERT model configuration""" + +from collections import OrderedDict +from typing import Mapping + +from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class RemBertConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`RemBertModel`]. It is used to instantiate an + RemBERT model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the RemBERT + [google/rembert](https://huggingface.co/google/rembert) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 250300): + Vocabulary size of the RemBERT model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`RemBertModel`] or [`TFRemBertModel`]. Vocabulary size of the model. + Defines the different tokens that can be represented by the *inputs_ids* passed to the forward method of + [`RemBertModel`]. + hidden_size (`int`, *optional*, defaults to 1152): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 18): + Number of attention heads for each attention layer in the Transformer encoder. + input_embedding_size (`int`, *optional*, defaults to 256): + Dimensionality of the input embeddings. + output_embedding_size (`int`, *optional*, defaults to 1664): + Dimensionality of the output embeddings. + intermediate_size (`int`, *optional*, defaults to 4608): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0): + The dropout ratio for the attention probabilities. + classifier_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the classifier layer when fine-tuning. + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + type_vocab_size (`int`, *optional*, defaults to 2): + The vocabulary size of the `token_type_ids` passed when calling [`RemBertModel`] or [`TFRemBertModel`]. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + is_decoder (`bool`, *optional*, defaults to `False`): + Whether the model is used as a decoder or not. If `False`, the model is used as an encoder. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + + Example: + + ```python + >>> from transformers import RemBertModel, RemBertConfig + + >>> # Initializing a RemBERT rembert style configuration + >>> configuration = RemBertConfig() + + >>> # Initializing a model from the rembert style configuration + >>> model = RemBertModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "rembert" + + def __init__( + self, + vocab_size=250300, + hidden_size=1152, + num_hidden_layers=32, + num_attention_heads=18, + input_embedding_size=256, + output_embedding_size=1664, + intermediate_size=4608, + hidden_act="gelu", + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + classifier_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + layer_norm_eps=1e-12, + use_cache=True, + pad_token_id=0, + bos_token_id=312, + eos_token_id=313, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + self.vocab_size = vocab_size + self.input_embedding_size = input_embedding_size + self.output_embedding_size = output_embedding_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.classifier_dropout_prob = classifier_dropout_prob + self.initializer_range = initializer_range + self.type_vocab_size = type_vocab_size + self.layer_norm_eps = layer_norm_eps + self.use_cache = use_cache + self.tie_word_embeddings = False + + +class RemBertOnnxConfig(OnnxConfig): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + if self.task == "multiple-choice": + dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"} + else: + dynamic_axis = {0: "batch", 1: "sequence"} + return OrderedDict( + [ + ("input_ids", dynamic_axis), + ("attention_mask", dynamic_axis), + ("token_type_ids", dynamic_axis), + ] + ) + + @property + def atol_for_validation(self) -> float: + return 1e-4 + + +__all__ = ["RemBertConfig", "RemBertOnnxConfig"] diff --git a/docs/transformers/src/transformers/models/rembert/convert_rembert_tf_checkpoint_to_pytorch.py b/docs/transformers/src/transformers/models/rembert/convert_rembert_tf_checkpoint_to_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..622d507080e4460a068493dba8d26d3b747657a0 --- /dev/null +++ b/docs/transformers/src/transformers/models/rembert/convert_rembert_tf_checkpoint_to_pytorch.py @@ -0,0 +1,62 @@ +# coding=utf-8 +# Copyright 2018 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert RemBERT checkpoint.""" + +import argparse + +import torch + +from transformers import RemBertConfig, RemBertModel, load_tf_weights_in_rembert +from transformers.utils import logging + + +logging.set_verbosity_info() + + +def convert_rembert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path): + # Initialise PyTorch model + config = RemBertConfig.from_json_file(bert_config_file) + print("Building PyTorch model from configuration: {}".format(str(config))) + model = RemBertModel(config) + + # Load weights from tf checkpoint + load_tf_weights_in_rembert(model, config, tf_checkpoint_path) + + # Save pytorch-model + print("Save PyTorch model to {}".format(pytorch_dump_path)) + torch.save(model.state_dict(), pytorch_dump_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." + ) + parser.add_argument( + "--rembert_config_file", + default=None, + type=str, + required=True, + help=( + "The config json file corresponding to the pre-trained RemBERT model. \n" + "This specifies the model architecture." + ), + ) + parser.add_argument( + "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." + ) + args = parser.parse_args() + convert_rembert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.rembert_config_file, args.pytorch_dump_path) diff --git a/docs/transformers/src/transformers/models/rembert/modeling_rembert.py b/docs/transformers/src/transformers/models/rembert/modeling_rembert.py new file mode 100644 index 0000000000000000000000000000000000000000..ef3250a71257d4e2e0ef8c26bc810b65ace46466 --- /dev/null +++ b/docs/transformers/src/transformers/models/rembert/modeling_rembert.py @@ -0,0 +1,1525 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Team The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch RemBERT model.""" + +import math +import os +from typing import Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...generation import GenerationMixin +from ...modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + MaskedLMOutput, + MultipleChoiceModelOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from .configuration_rembert import RemBertConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "RemBertConfig" +_CHECKPOINT_FOR_DOC = "google/rembert" + + +def load_tf_weights_in_rembert(model, config, tf_checkpoint_path): + """Load tf checkpoints in a pytorch model.""" + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + # Checkpoint is 12Gb, save memory by not loading useless variables + # Output embedding and cls are reset at classification time + if any(deny in name for deny in ("adam_v", "adam_m", "output_embedding", "cls")): + # logger.info("Skipping loading of %s", name) + continue + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array) + + for name, array in zip(names, arrays): + # Replace prefix with right one + name = name.replace("bert/", "rembert/") + # The pooler is a linear layer + # name = name.replace("pooler/dense", "pooler") + + name = name.split("/") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any( + n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] + for n in name + ): + logger.info(f"Skipping {'/'.join(name)}") + continue + pointer = model + for m_name in name: + if re.fullmatch(r"[A-Za-z]+_\d+", m_name): + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "kernel" or scope_names[0] == "gamma": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "output_bias" or scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "output_weights": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "squad": + pointer = getattr(pointer, "classifier") + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info("Skipping {}".format("/".join(name))) + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if m_name[-11:] == "_embeddings": + pointer = getattr(pointer, "weight") + elif m_name == "kernel": + array = np.transpose(array) + try: + if pointer.shape != array.shape: + raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array) + return model + + +class RemBertEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding( + config.vocab_size, config.input_embedding_size, padding_idx=config.pad_token_id + ) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.input_embedding_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.input_embedding_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.input_embedding_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values_length: int = 0, + ) -> torch.Tensor: + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] + + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +# Copied from transformers.models.bert.modeling_bert.BertPooler with Bert->RemBert +class RemBertPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class RemBertSelfAttention(nn.Module): + def __init__(self, config): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + self.is_decoder = config.is_decoder + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Tuple[Tuple[torch.FloatTensor]] = None, + output_attentions: bool = False, + ) -> Tuple: + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_layer = past_key_value[0] + value_layer = past_key_value[1] + attention_mask = encoder_attention_mask + elif is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in RemBertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + +# Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->RemBert +class RemBertSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class RemBertAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.self = RemBertSelfAttention(config) + self.output = RemBertSelfOutput(config) + self.pruned_heads = set() + + # Copied from transformers.models.bert.modeling_bert.BertAttention.prune_heads + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + # Copied from transformers.models.bert.modeling_bert.BertAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->RemBert +class RemBertIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->RemBert +class RemBertOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class RemBertLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = RemBertAttention(config) + self.is_decoder = config.is_decoder + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + if not self.is_decoder: + raise ValueError(f"{self} should be used as a decoder model if cross attention is added") + self.crossattention = RemBertAttention(config) + self.intermediate = RemBertIntermediate(config) + self.output = RemBertOutput(config) + + # Copied from transformers.models.bert.modeling_bert.BertLayer.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + + # if decoder, the last output is tuple of self-attn cache + if self.is_decoder: + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + else: + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + cross_attn_present_key_value = None + if self.is_decoder and encoder_hidden_states is not None: + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" + " by setting `config.add_cross_attention=True`" + ) + + # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + cross_attn_past_key_value, + output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + + # add cross-attn cache to positions 3,4 of present_key_value tuple + cross_attn_present_key_value = cross_attention_outputs[-1] + present_key_value = present_key_value + cross_attn_present_key_value + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output + ) + outputs = (layer_output,) + outputs + + # if decoder, return the attn key/values as the last output + if self.is_decoder: + outputs = outputs + (present_key_value,) + + return outputs + + # Copied from transformers.models.bert.modeling_bert.BertLayer.feed_forward_chunk + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class RemBertEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + + self.embedding_hidden_mapping_in = nn.Linear(config.input_embedding_size, config.hidden_size) + self.layer = nn.ModuleList([RemBertLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + hidden_states = self.embedding_hidden_mapping_in(hidden_states) + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + next_decoder_cache = () if use_cache else None + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.__call__, + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +# Copied from transformers.models.bert.modeling_bert.BertPredictionHeadTransform with Bert->RemBert +class RemBertPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class RemBertLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.output_embedding_size) + self.decoder = nn.Linear(config.output_embedding_size, config.vocab_size) + self.activation = ACT2FN[config.hidden_act] + self.LayerNorm = nn.LayerNorm(config.output_embedding_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.activation(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertOnlyMLMHead with Bert->RemBert +class RemBertOnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = RemBertLMPredictionHead(config) + + def forward(self, sequence_output: torch.Tensor) -> torch.Tensor: + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class RemBertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = RemBertConfig + load_tf_weights = load_tf_weights_in_rembert + base_model_prefix = "rembert" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +REMBERT_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use + it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`RemBertConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +REMBERT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert *input_ids* indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare RemBERT Model transformer outputting raw hidden-states without any specific head on top.", + REMBERT_START_DOCSTRING, +) +class RemBertModel(RemBertPreTrainedModel): + """ + + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in [Attention is + all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + + To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set + to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and + `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass. + """ + + def __init__(self, config, add_pooling_layer=True): + super().__init__(config) + self.config = config + + self.embeddings = RemBertEmbeddings(config) + self.encoder = RemBertEncoder(config) + + self.pooler = RemBertPooler(config) if add_pooling_layer else None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint="google/rembert", + output_type=BaseModelOutputWithPastAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPoolingAndCrossAttentions]: + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.config.is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +@add_start_docstrings("""RemBERT Model with a `language modeling` head on top.""", REMBERT_START_DOCSTRING) +class RemBertForMaskedLM(RemBertPreTrainedModel): + _tied_weights_keys = ["cls.predictions.decoder.weight"] + + def __init__(self, config): + super().__init__(config) + + if config.is_decoder: + logger.warning( + "If you want to use `RemBertForMaskedLM` make sure `config.is_decoder=False` for " + "bi-directional self-attention." + ) + + self.rembert = RemBertModel(config, add_pooling_layer=False) + self.cls = RemBertOnlyMLMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + @add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint="google/rembert", + output_type=MaskedLMOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, MaskedLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.rembert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs): + input_shape = input_ids.shape + effective_batch_size = input_shape[0] + + # add a dummy token + assert self.config.pad_token_id is not None, "The PAD token should be defined for generation" + attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1) + dummy_token = torch.full( + (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device + ) + input_ids = torch.cat([input_ids, dummy_token], dim=1) + + return {"input_ids": input_ids, "attention_mask": attention_mask} + + @classmethod + def can_generate(cls) -> bool: + """ + Legacy correction: RemBertForMaskedLM can't call `generate()` from `GenerationMixin`, even though it has a + `prepare_inputs_for_generation` method. + """ + return False + + +@add_start_docstrings( + """RemBERT Model with a `language modeling` head on top for CLM fine-tuning.""", REMBERT_START_DOCSTRING +) +class RemBertForCausalLM(RemBertPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["cls.predictions.decoder.weight"] + + def __init__(self, config): + super().__init__(config) + + if not config.is_decoder: + logger.warning("If you want to use `RemBertForCausalLM` as a standalone, add `is_decoder=True.`") + + self.rembert = RemBertModel(config, add_pooling_layer=False) + self.cls = RemBertOnlyMLMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + @add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are + ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, RemBertForCausalLM, RemBertConfig + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("google/rembert") + >>> config = RemBertConfig.from_pretrained("google/rembert") + >>> config.is_decoder = True + >>> model = RemBertForCausalLM.from_pretrained("google/rembert", config=config) + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> prediction_logits = outputs.logits + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.rembert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + lm_loss = None + if labels is not None: + lm_loss = self.loss_function( + prediction_scores, + labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def _reorder_cache(self, past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) + + layer_past[2:], + ) + return reordered_past + + +@add_start_docstrings( + """ + RemBERT Model transformer with a sequence classification/regression head on top (a linear layer on top of the + pooled output) e.g. for GLUE tasks. + """, + REMBERT_START_DOCSTRING, +) +class RemBertForSequenceClassification(RemBertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.rembert = RemBertModel(config) + self.dropout = nn.Dropout(config.classifier_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint="google/rembert", + output_type=SequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.rembert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + RemBERT Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a + softmax) e.g. for RocStories/SWAG tasks. + """, + REMBERT_START_DOCSTRING, +) +class RemBertForMultipleChoice(RemBertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.rembert = RemBertModel(config) + self.dropout = nn.Dropout(config.classifier_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, 1) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) + @add_code_sample_docstrings( + checkpoint="google/rembert", + output_type=MultipleChoiceModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, MultipleChoiceModelOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., + num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See + `input_ids` above) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] + + input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None + attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None + position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) + + outputs = self.rembert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + reshaped_logits = logits.view(-1, num_choices) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + + if not return_dict: + output = (reshaped_logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return MultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + RemBERT Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + REMBERT_START_DOCSTRING, +) +class RemBertForTokenClassification(RemBertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.rembert = RemBertModel(config, add_pooling_layer=False) + self.dropout = nn.Dropout(config.classifier_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint="google/rembert", + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.rembert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + RemBERT Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear + layers on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + REMBERT_START_DOCSTRING, +) +class RemBertForQuestionAnswering(RemBertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.num_labels = config.num_labels + + self.rembert = RemBertModel(config, add_pooling_layer=False) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint="google/rembert", + output_type=QuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.rembert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1) + end_logits = end_logits.squeeze(-1) + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions.clamp_(0, ignored_index) + end_positions.clamp_(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "RemBertForCausalLM", + "RemBertForMaskedLM", + "RemBertForMultipleChoice", + "RemBertForQuestionAnswering", + "RemBertForSequenceClassification", + "RemBertForTokenClassification", + "RemBertLayer", + "RemBertModel", + "RemBertPreTrainedModel", + "load_tf_weights_in_rembert", +] diff --git a/docs/transformers/src/transformers/models/rembert/modeling_tf_rembert.py b/docs/transformers/src/transformers/models/rembert/modeling_tf_rembert.py new file mode 100644 index 0000000000000000000000000000000000000000..4a21ee48d39fd1ca5fc0aef97b5cbb8d25ad479b --- /dev/null +++ b/docs/transformers/src/transformers/models/rembert/modeling_tf_rembert.py @@ -0,0 +1,1721 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TF 2.0 RemBERT model.""" + +from __future__ import annotations + +import math +from typing import Dict, Optional, Tuple, Union + +import numpy as np +import tensorflow as tf + +from ...activations_tf import get_tf_activation +from ...modeling_tf_outputs import ( + TFBaseModelOutputWithPastAndCrossAttentions, + TFBaseModelOutputWithPoolingAndCrossAttentions, + TFCausalLMOutputWithCrossAttentions, + TFMaskedLMOutput, + TFMultipleChoiceModelOutput, + TFQuestionAnsweringModelOutput, + TFSequenceClassifierOutput, + TFTokenClassifierOutput, +) +from ...modeling_tf_utils import ( + TFCausalLanguageModelingLoss, + TFMaskedLanguageModelingLoss, + TFModelInputType, + TFMultipleChoiceLoss, + TFPreTrainedModel, + TFQuestionAnsweringLoss, + TFSequenceClassificationLoss, + TFTokenClassificationLoss, + get_initializer, + keras, + keras_serializable, + unpack_inputs, +) +from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax +from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, +) +from .configuration_rembert import RemBertConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "RemBertConfig" + + +class TFRemBertEmbeddings(keras.layers.Layer): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config: RemBertConfig, **kwargs): + super().__init__(**kwargs) + + self.config = config + self.input_embedding_size = config.input_embedding_size + self.max_position_embeddings = config.max_position_embeddings + self.initializer_range = config.initializer_range + self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") + self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) + + def build(self, input_shape=None): + with tf.name_scope("word_embeddings"): + self.weight = self.add_weight( + name="weight", + shape=[self.config.vocab_size, self.input_embedding_size], + initializer=get_initializer(self.initializer_range), + ) + + with tf.name_scope("token_type_embeddings"): + self.token_type_embeddings = self.add_weight( + name="embeddings", + shape=[self.config.type_vocab_size, self.input_embedding_size], + initializer=get_initializer(self.initializer_range), + ) + + with tf.name_scope("position_embeddings"): + self.position_embeddings = self.add_weight( + name="embeddings", + shape=[self.max_position_embeddings, self.input_embedding_size], + initializer=get_initializer(self.initializer_range), + ) + + if self.built: + return + self.built = True + if getattr(self, "LayerNorm", None) is not None: + with tf.name_scope(self.LayerNorm.name): + self.LayerNorm.build([None, None, self.config.input_embedding_size]) + + def call( + self, + input_ids: Optional[tf.Tensor] = None, + position_ids: Optional[tf.Tensor] = None, + token_type_ids: Optional[tf.Tensor] = None, + inputs_embeds: Optional[tf.Tensor] = None, + past_key_values_length=0, + training: bool = False, + ) -> tf.Tensor: + """ + Applies embedding based on inputs tensor. + + Returns: + final_embeddings (`tf.Tensor`): output embedding tensor. + """ + assert not (input_ids is None and inputs_embeds is None) + + if input_ids is not None: + check_embeddings_within_bounds(input_ids, self.config.vocab_size) + inputs_embeds = tf.gather(params=self.weight, indices=input_ids) + + input_shape = shape_list(inputs_embeds)[:-1] + + if token_type_ids is None: + token_type_ids = tf.fill(dims=input_shape, value=0) + + if position_ids is None: + position_ids = tf.expand_dims( + tf.range(start=past_key_values_length, limit=input_shape[1] + past_key_values_length), axis=0 + ) + + position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids) + token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids) + final_embeddings = inputs_embeds + position_embeds + token_type_embeds + final_embeddings = self.LayerNorm(inputs=final_embeddings) + final_embeddings = self.dropout(inputs=final_embeddings, training=training) + + return final_embeddings + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfAttention with Bert->RemBert +class TFRemBertSelfAttention(keras.layers.Layer): + def __init__(self, config: RemBertConfig, **kwargs): + super().__init__(**kwargs) + + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number " + f"of attention heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.sqrt_att_head_size = math.sqrt(self.attention_head_size) + + self.query = keras.layers.Dense( + units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query" + ) + self.key = keras.layers.Dense( + units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key" + ) + self.value = keras.layers.Dense( + units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value" + ) + self.dropout = keras.layers.Dropout(rate=config.attention_probs_dropout_prob) + + self.is_decoder = config.is_decoder + self.config = config + + def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor: + # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size] + tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size)) + + # Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size] + return tf.transpose(tensor, perm=[0, 2, 1, 3]) + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: tf.Tensor, + head_mask: tf.Tensor, + encoder_hidden_states: tf.Tensor, + encoder_attention_mask: tf.Tensor, + past_key_value: Tuple[tf.Tensor], + output_attentions: bool, + training: bool = False, + ) -> Tuple[tf.Tensor]: + batch_size = shape_list(hidden_states)[0] + mixed_query_layer = self.query(inputs=hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_layer = past_key_value[0] + value_layer = past_key_value[1] + attention_mask = encoder_attention_mask + elif is_cross_attention: + key_layer = self.transpose_for_scores(self.key(inputs=encoder_hidden_states), batch_size) + value_layer = self.transpose_for_scores(self.value(inputs=encoder_hidden_states), batch_size) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size) + value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size) + key_layer = tf.concat([past_key_value[0], key_layer], axis=2) + value_layer = tf.concat([past_key_value[1], value_layer], axis=2) + else: + key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size) + value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size) + + query_layer = self.transpose_for_scores(mixed_query_layer, batch_size) + + if self.is_decoder: + # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + # (batch size, num_heads, seq_len_q, seq_len_k) + attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) + dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype) + attention_scores = tf.divide(attention_scores, dk) + + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in TFRemBertModel call() function) + attention_scores = tf.add(attention_scores, attention_mask) + + # Normalize the attention scores to probabilities. + attention_probs = stable_softmax(logits=attention_scores, axis=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(inputs=attention_probs, training=training) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = tf.multiply(attention_probs, head_mask) + + attention_output = tf.matmul(attention_probs, value_layer) + attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3]) + + # (batch_size, seq_len_q, all_head_size) + attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size)) + outputs = (attention_output, attention_probs) if output_attentions else (attention_output,) + + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "query", None) is not None: + with tf.name_scope(self.query.name): + self.query.build([None, None, self.config.hidden_size]) + if getattr(self, "key", None) is not None: + with tf.name_scope(self.key.name): + self.key.build([None, None, self.config.hidden_size]) + if getattr(self, "value", None) is not None: + with tf.name_scope(self.value.name): + self.value.build([None, None, self.config.hidden_size]) + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfOutput with Bert->RemBert +class TFRemBertSelfOutput(keras.layers.Layer): + def __init__(self, config: RemBertConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = keras.layers.Dense( + units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") + self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) + self.config = config + + def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: + hidden_states = self.dense(inputs=hidden_states) + hidden_states = self.dropout(inputs=hidden_states, training=training) + hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor) + + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + if getattr(self, "LayerNorm", None) is not None: + with tf.name_scope(self.LayerNorm.name): + self.LayerNorm.build([None, None, self.config.hidden_size]) + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertAttention with Bert->RemBert +class TFRemBertAttention(keras.layers.Layer): + def __init__(self, config: RemBertConfig, **kwargs): + super().__init__(**kwargs) + + self.self_attention = TFRemBertSelfAttention(config, name="self") + self.dense_output = TFRemBertSelfOutput(config, name="output") + + def prune_heads(self, heads): + raise NotImplementedError + + def call( + self, + input_tensor: tf.Tensor, + attention_mask: tf.Tensor, + head_mask: tf.Tensor, + encoder_hidden_states: tf.Tensor, + encoder_attention_mask: tf.Tensor, + past_key_value: Tuple[tf.Tensor], + output_attentions: bool, + training: bool = False, + ) -> Tuple[tf.Tensor]: + self_outputs = self.self_attention( + hidden_states=input_tensor, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + training=training, + ) + attention_output = self.dense_output( + hidden_states=self_outputs[0], input_tensor=input_tensor, training=training + ) + # add attentions (possibly with past_key_value) if we output them + outputs = (attention_output,) + self_outputs[1:] + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "self_attention", None) is not None: + with tf.name_scope(self.self_attention.name): + self.self_attention.build(None) + if getattr(self, "dense_output", None) is not None: + with tf.name_scope(self.dense_output.name): + self.dense_output.build(None) + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertIntermediate with Bert->RemBert +class TFRemBertIntermediate(keras.layers.Layer): + def __init__(self, config: RemBertConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = keras.layers.Dense( + units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = get_tf_activation(config.hidden_act) + else: + self.intermediate_act_fn = config.hidden_act + self.config = config + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + hidden_states = self.dense(inputs=hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertOutput with Bert->RemBert +class TFRemBertOutput(keras.layers.Layer): + def __init__(self, config: RemBertConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = keras.layers.Dense( + units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") + self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) + self.config = config + + def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: + hidden_states = self.dense(inputs=hidden_states) + hidden_states = self.dropout(inputs=hidden_states, training=training) + hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor) + + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.intermediate_size]) + if getattr(self, "LayerNorm", None) is not None: + with tf.name_scope(self.LayerNorm.name): + self.LayerNorm.build([None, None, self.config.hidden_size]) + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertLayer with Bert->RemBert +class TFRemBertLayer(keras.layers.Layer): + def __init__(self, config: RemBertConfig, **kwargs): + super().__init__(**kwargs) + + self.attention = TFRemBertAttention(config, name="attention") + self.is_decoder = config.is_decoder + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + if not self.is_decoder: + raise ValueError(f"{self} should be used as a decoder model if cross attention is added") + self.crossattention = TFRemBertAttention(config, name="crossattention") + self.intermediate = TFRemBertIntermediate(config, name="intermediate") + self.bert_output = TFRemBertOutput(config, name="output") + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: tf.Tensor, + head_mask: tf.Tensor, + encoder_hidden_states: tf.Tensor | None, + encoder_attention_mask: tf.Tensor | None, + past_key_value: Tuple[tf.Tensor] | None, + output_attentions: bool, + training: bool = False, + ) -> Tuple[tf.Tensor]: + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attention_outputs = self.attention( + input_tensor=hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=self_attn_past_key_value, + output_attentions=output_attentions, + training=training, + ) + attention_output = self_attention_outputs[0] + + # if decoder, the last output is tuple of self-attn cache + if self.is_decoder: + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + else: + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + cross_attn_present_key_value = None + if self.is_decoder and encoder_hidden_states is not None: + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" + " by setting `config.add_cross_attention=True`" + ) + + # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + cross_attention_outputs = self.crossattention( + input_tensor=attention_output, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=cross_attn_past_key_value, + output_attentions=output_attentions, + training=training, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + + # add cross-attn cache to positions 3,4 of present_key_value tuple + cross_attn_present_key_value = cross_attention_outputs[-1] + present_key_value = present_key_value + cross_attn_present_key_value + + intermediate_output = self.intermediate(hidden_states=attention_output) + layer_output = self.bert_output( + hidden_states=intermediate_output, input_tensor=attention_output, training=training + ) + outputs = (layer_output,) + outputs # add attentions if we output them + + # if decoder, return the attn key/values as the last output + if self.is_decoder: + outputs = outputs + (present_key_value,) + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "attention", None) is not None: + with tf.name_scope(self.attention.name): + self.attention.build(None) + if getattr(self, "intermediate", None) is not None: + with tf.name_scope(self.intermediate.name): + self.intermediate.build(None) + if getattr(self, "bert_output", None) is not None: + with tf.name_scope(self.bert_output.name): + self.bert_output.build(None) + if getattr(self, "crossattention", None) is not None: + with tf.name_scope(self.crossattention.name): + self.crossattention.build(None) + + +class TFRemBertEncoder(keras.layers.Layer): + def __init__(self, config: RemBertConfig, **kwargs): + super().__init__(**kwargs) + self.config = config + + self.embedding_hidden_mapping_in = keras.layers.Dense( + units=config.hidden_size, + kernel_initializer=get_initializer(config.initializer_range), + name="embedding_hidden_mapping_in", + ) + self.layer = [TFRemBertLayer(config, name="layer_._{}".format(i)) for i in range(config.num_hidden_layers)] + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: tf.Tensor, + head_mask: tf.Tensor, + encoder_hidden_states: tf.Tensor, + encoder_attention_mask: tf.Tensor, + past_key_values: Tuple[Tuple[tf.Tensor]], + use_cache: bool, + output_attentions: bool, + output_hidden_states: bool, + return_dict: bool, + training: bool = False, + ) -> Union[TFBaseModelOutputWithPastAndCrossAttentions, Tuple[tf.Tensor]]: + hidden_states = self.embedding_hidden_mapping_in(inputs=hidden_states) + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + next_decoder_cache = () if use_cache else None + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + past_key_value = past_key_values[i] if past_key_values is not None else None + + layer_outputs = layer_module( + hidden_states=hidden_states, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + training=training, + ) + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + if self.config.add_cross_attention and encoder_hidden_states is not None: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v for v in [hidden_states, all_hidden_states, all_attentions, all_cross_attentions] if v is not None + ) + + return TFBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "embedding_hidden_mapping_in", None) is not None: + with tf.name_scope(self.embedding_hidden_mapping_in.name): + self.embedding_hidden_mapping_in.build([None, None, self.config.input_embedding_size]) + if getattr(self, "layer", None) is not None: + for layer in self.layer: + with tf.name_scope(layer.name): + layer.build(None) + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertPooler with Bert->RemBert +class TFRemBertPooler(keras.layers.Layer): + def __init__(self, config: RemBertConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = keras.layers.Dense( + units=config.hidden_size, + kernel_initializer=get_initializer(config.initializer_range), + activation="tanh", + name="dense", + ) + self.config = config + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(inputs=first_token_tensor) + + return pooled_output + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + + +class TFRemBertLMPredictionHead(keras.layers.Layer): + def __init__(self, config: RemBertConfig, input_embeddings: keras.layers.Layer, **kwargs): + super().__init__(**kwargs) + + self.config = config + self.initializer_range = config.initializer_range + self.output_embedding_size = config.output_embedding_size + self.dense = keras.layers.Dense( + config.output_embedding_size, kernel_initializer=get_initializer(self.initializer_range), name="dense" + ) + if isinstance(config.hidden_act, str): + self.activation = get_tf_activation(config.hidden_act) + else: + self.activation = config.hidden_act + self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") + + def build(self, input_shape=None): + self.decoder = self.add_weight( + name="decoder/weight", + shape=[self.config.vocab_size, self.output_embedding_size], + initializer=get_initializer(self.initializer_range), + ) + self.decoder_bias = self.add_weight( + shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="decoder/bias" + ) + + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + if getattr(self, "LayerNorm", None) is not None: + with tf.name_scope(self.LayerNorm.name): + self.LayerNorm.build([None, self.config.output_embedding_size]) + + def get_output_embeddings(self) -> keras.layers.Layer: + return self + + def set_output_embeddings(self, value): + self.decoder = value + self.decoder.vocab_size = shape_list(value)[0] + + def get_bias(self) -> Dict[str, tf.Variable]: + return {"decoder_bias": self.decoder_bias} + + def set_bias(self, value: tf.Variable): + self.decoder_bias = value["decoder_bias"] + self.config.vocab_size = shape_list(value["decoder_bias"])[0] + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + hidden_states = self.dense(inputs=hidden_states) + hidden_states = self.activation(hidden_states) + seq_length = shape_list(tensor=hidden_states)[1] + hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.output_embedding_size]) + hidden_states = self.LayerNorm(hidden_states) + hidden_states = tf.matmul(a=hidden_states, b=self.decoder, transpose_b=True) + hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.config.vocab_size]) + hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.decoder_bias) + return hidden_states + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertMLMHead with Bert->RemBert +class TFRemBertMLMHead(keras.layers.Layer): + def __init__(self, config: RemBertConfig, input_embeddings: keras.layers.Layer, **kwargs): + super().__init__(**kwargs) + + self.predictions = TFRemBertLMPredictionHead(config, input_embeddings, name="predictions") + + def call(self, sequence_output: tf.Tensor) -> tf.Tensor: + prediction_scores = self.predictions(hidden_states=sequence_output) + + return prediction_scores + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "predictions", None) is not None: + with tf.name_scope(self.predictions.name): + self.predictions.build(None) + + +@keras_serializable +class TFRemBertMainLayer(keras.layers.Layer): + config_class = RemBertConfig + + def __init__(self, config: RemBertConfig, add_pooling_layer: bool = True, **kwargs): + super().__init__(**kwargs) + + self.config = config + self.is_decoder = config.is_decoder + + self.embeddings = TFRemBertEmbeddings(config, name="embeddings") + self.encoder = TFRemBertEncoder(config, name="encoder") + self.pooler = TFRemBertPooler(config, name="pooler") if add_pooling_layer else None + + def get_input_embeddings(self) -> keras.layers.Layer: + return self.embeddings + + def set_input_embeddings(self, value: tf.Variable): + self.embeddings.weight = value + self.embeddings.vocab_size = shape_list(value)[0] + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + raise NotImplementedError + + @unpack_inputs + # Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.call + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + encoder_hidden_states: np.ndarray | tf.Tensor | None = None, + encoder_attention_mask: np.ndarray | tf.Tensor | None = None, + past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: bool = False, + ) -> Union[TFBaseModelOutputWithPoolingAndCrossAttentions, Tuple[tf.Tensor]]: + if not self.config.is_decoder: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = shape_list(input_ids) + elif inputs_embeds is not None: + input_shape = shape_list(inputs_embeds)[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + + if past_key_values is None: + past_key_values_length = 0 + past_key_values = [None] * len(self.encoder.layer) + else: + past_key_values_length = shape_list(past_key_values[0][0])[-2] + + if attention_mask is None: + attention_mask = tf.fill(dims=(batch_size, seq_length + past_key_values_length), value=1) + + if token_type_ids is None: + token_type_ids = tf.fill(dims=input_shape, value=0) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + training=training, + ) + + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask_shape = shape_list(attention_mask) + + mask_seq_length = seq_length + past_key_values_length + # Copied from `modeling_tf_t5.py` + # Provided a padding mask of dimensions [batch_size, mask_seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] + if self.is_decoder: + seq_ids = tf.range(mask_seq_length) + causal_mask = tf.less_equal( + tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)), + seq_ids[None, :, None], + ) + causal_mask = tf.cast(causal_mask, dtype=attention_mask.dtype) + extended_attention_mask = causal_mask * attention_mask[:, None, :] + attention_mask_shape = shape_list(extended_attention_mask) + extended_attention_mask = tf.reshape( + extended_attention_mask, (attention_mask_shape[0], 1, attention_mask_shape[1], attention_mask_shape[2]) + ) + if past_key_values[0] is not None: + # attention_mask needs to be sliced to the shape `[batch_size, 1, from_seq_length - cached_seq_length, to_seq_length] + extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :] + else: + extended_attention_mask = tf.reshape( + attention_mask, (attention_mask_shape[0], 1, 1, attention_mask_shape[1]) + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = tf.cast(extended_attention_mask, dtype=embedding_output.dtype) + one_cst = tf.constant(1.0, dtype=embedding_output.dtype) + ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype) + extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst) + + # Copied from `modeling_tf_t5.py` with -1e9 -> -10000 + if self.is_decoder and encoder_attention_mask is not None: + # If a 2D ou 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + encoder_attention_mask = tf.cast(encoder_attention_mask, dtype=extended_attention_mask.dtype) + num_dims_encoder_attention_mask = len(shape_list(encoder_attention_mask)) + if num_dims_encoder_attention_mask == 3: + encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] + if num_dims_encoder_attention_mask == 2: + encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] + + # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition + # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270 + # encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask, + # tf.transpose(encoder_extended_attention_mask, perm=(-1, -2))) + + encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0 + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + if head_mask is not None: + raise NotImplementedError + else: + head_mask = [None] * self.config.num_hidden_layers + + encoder_outputs = self.encoder( + hidden_states=embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(hidden_states=sequence_output) if self.pooler is not None else None + + if not return_dict: + return ( + sequence_output, + pooled_output, + ) + encoder_outputs[1:] + + return TFBaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "embeddings", None) is not None: + with tf.name_scope(self.embeddings.name): + self.embeddings.build(None) + if getattr(self, "encoder", None) is not None: + with tf.name_scope(self.encoder.name): + self.encoder.build(None) + if getattr(self, "pooler", None) is not None: + with tf.name_scope(self.pooler.name): + self.pooler.build(None) + + +class TFRemBertPreTrainedModel(TFPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = RemBertConfig + base_model_prefix = "rembert" + + +REMBERT_START_DOCSTRING = r""" + + This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it + as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and + behavior. + + + + TensorFlow models and layers in `transformers` accept two formats as input: + + - having all inputs as keyword arguments (like PyTorch models), or + - having all inputs as a list, tuple or dict in the first positional argument. + + The reason the second format is supported is that Keras methods prefer this format when passing inputs to models + and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just + pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second + format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with + the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first + positional argument: + + - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` + - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: + `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` + - a dictionary with one or several input Tensors associated to the input names given in the docstring: + `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` + + Note that when creating models and layers with + [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry + about any of this, as you can just pass inputs like you would to any other Python function! + + + + Args: + config ([`RemBertConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +REMBERT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`np.ndarray`, `tf.Tensor`, `List[tf.Tensor]` ``Dict[str, tf.Tensor]` or `Dict[str, np.ndarray]` and each example must have the shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and + [`PreTrainedTokenizer.encode`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`np.ndarray` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`np.ndarray` or `tf.Tensor` of shape `({0}, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the + config will be used instead. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. This argument can be used only in eager mode, in graph mode the value in the config will be + used instead. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in + eager mode, in graph mode the value will always be set to True. + training (`bool`, *optional*, defaults to `False``): + Whether or not to use the model in training mode (some modules like dropout modules have different + behaviors between training and evaluation). +""" + + +@add_start_docstrings( + "The bare RemBERT Model transformer outputing raw hidden-states without any specific head on top.", + REMBERT_START_DOCSTRING, +) +class TFRemBertModel(TFRemBertPreTrainedModel): + def __init__(self, config: RemBertConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.rembert = TFRemBertMainLayer(config, name="rembert") + + @unpack_inputs + @add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint="google/rembert", + output_type=TFBaseModelOutputWithPoolingAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + encoder_hidden_states: np.ndarray | tf.Tensor | None = None, + encoder_attention_mask: np.ndarray | tf.Tensor | None = None, + past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: Optional[bool] = False, + ) -> Union[TFBaseModelOutputWithPoolingAndCrossAttentions, Tuple[tf.Tensor]]: + r""" + encoder_hidden_states (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + past_key_values (`Tuple[Tuple[tf.Tensor]]` of length `config.n_layers`) + contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*, defaults to `True`): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). Set to `False` during training, `True` during generation + """ + outputs = self.rembert( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "rembert", None) is not None: + with tf.name_scope(self.rembert.name): + self.rembert.build(None) + + +@add_start_docstrings("""RemBERT Model with a `language modeling` head on top.""", REMBERT_START_DOCSTRING) +class TFRemBertForMaskedLM(TFRemBertPreTrainedModel, TFMaskedLanguageModelingLoss): + def __init__(self, config: RemBertConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + if config.is_decoder: + logger.warning( + "If you want to use `TFRemBertForMaskedLM` make sure `config.is_decoder=False` for " + "bi-directional self-attention." + ) + + self.rembert = TFRemBertMainLayer(config, name="rembert", add_pooling_layer=False) + self.mlm = TFRemBertMLMHead(config, input_embeddings=self.rembert.embeddings, name="mlm___cls") + + def get_lm_head(self) -> keras.layers.Layer: + return self.mlm.predictions + + @unpack_inputs + @add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint="google/rembert", + output_type=TFMaskedLMOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: np.ndarray | tf.Tensor | None = None, + training: Optional[bool] = False, + ) -> Union[TFMaskedLMOutput, Tuple[tf.Tensor]]: + r""" + labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + """ + outputs = self.rembert( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + sequence_output = outputs[0] + prediction_scores = self.mlm(sequence_output=sequence_output, training=training) + loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=prediction_scores) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFMaskedLMOutput( + loss=loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "rembert", None) is not None: + with tf.name_scope(self.rembert.name): + self.rembert.build(None) + if getattr(self, "mlm", None) is not None: + with tf.name_scope(self.mlm.name): + self.mlm.build(None) + + +@add_start_docstrings( + """RemBERT Model with a `language modeling` head on top for CLM fine-tuning.""", REMBERT_START_DOCSTRING +) +class TFRemBertForCausalLM(TFRemBertPreTrainedModel, TFCausalLanguageModelingLoss): + def __init__(self, config: RemBertConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + if not config.is_decoder: + logger.warning("If you want to use `TFRemBertForCausalLM` as a standalone, add `is_decoder=True.`") + + self.rembert = TFRemBertMainLayer(config, name="rembert", add_pooling_layer=False) + self.mlm = TFRemBertMLMHead(config, input_embeddings=self.rembert.embeddings, name="mlm___cls") + + def get_lm_head(self) -> keras.layers.Layer: + return self.mlm.predictions + + # Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel.prepare_inputs_for_generation + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs): + input_shape = input_ids.shape + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = tf.ones(input_shape) + + # cut decoder_input_ids if past is used + if past_key_values is not None: + input_ids = input_ids[:, -1:] + + return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} + + @unpack_inputs + @add_code_sample_docstrings( + checkpoint="google/rembert", + output_type=TFCausalLMOutputWithCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + encoder_hidden_states: np.ndarray | tf.Tensor | None = None, + encoder_attention_mask: np.ndarray | tf.Tensor | None = None, + past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: np.ndarray | tf.Tensor | None = None, + training: Optional[bool] = False, + ) -> Union[TFCausalLMOutputWithCrossAttentions, Tuple[tf.Tensor]]: + r""" + encoder_hidden_states (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + past_key_values (`Tuple[Tuple[tf.Tensor]]` of length `config.n_layers`) + contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*, defaults to `True`): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). Set to `False` during training, `True` during generation + labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the cross entropy classification loss. Indices should be in `[0, ..., + config.vocab_size - 1]`. + """ + outputs = self.rembert( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + sequence_output = outputs[0] + logits = self.mlm(sequence_output=sequence_output, training=training) + loss = None + + if labels is not None: + # shift labels to the left and cut last logit token + shifted_logits = logits[:, :-1] + labels = labels[:, 1:] + loss = self.hf_compute_loss(labels=labels, logits=shifted_logits) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFCausalLMOutputWithCrossAttentions( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "rembert", None) is not None: + with tf.name_scope(self.rembert.name): + self.rembert.build(None) + if getattr(self, "mlm", None) is not None: + with tf.name_scope(self.mlm.name): + self.mlm.build(None) + + +@add_start_docstrings( + """ + RemBERT Model transformer with a sequence classification/regression head on top e.g., for GLUE tasks. + """, + REMBERT_START_DOCSTRING, +) +class TFRemBertForSequenceClassification(TFRemBertPreTrainedModel, TFSequenceClassificationLoss): + def __init__(self, config: RemBertConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.num_labels = config.num_labels + + self.rembert = TFRemBertMainLayer(config, name="rembert") + self.dropout = keras.layers.Dropout(rate=config.classifier_dropout_prob) + self.classifier = keras.layers.Dense( + units=config.num_labels, + kernel_initializer=get_initializer(config.initializer_range), + name="classifier", + ) + self.config = config + + @unpack_inputs + @add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint="google/rembert", + output_type=TFSequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: np.ndarray | tf.Tensor | None = None, + training: Optional[bool] = False, + ) -> Union[TFSequenceClassifierOutput, Tuple[tf.Tensor]]: + r""" + labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + outputs = self.rembert( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + pooled_output = outputs[1] + pooled_output = self.dropout(inputs=pooled_output, training=training) + logits = self.classifier(inputs=pooled_output) + loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFSequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "rembert", None) is not None: + with tf.name_scope(self.rembert.name): + self.rembert.build(None) + if getattr(self, "classifier", None) is not None: + with tf.name_scope(self.classifier.name): + self.classifier.build([None, None, self.config.hidden_size]) + + +@add_start_docstrings( + """ + RemBERT Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a + softmax) e.g. for RocStories/SWAG tasks. + """, + REMBERT_START_DOCSTRING, +) +class TFRemBertForMultipleChoice(TFRemBertPreTrainedModel, TFMultipleChoiceLoss): + def __init__(self, config: RemBertConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.rembert = TFRemBertMainLayer(config, name="rembert") + self.dropout = keras.layers.Dropout(rate=config.classifier_dropout_prob) + self.classifier = keras.layers.Dense( + units=1, kernel_initializer=get_initializer(config.initializer_range), name="classifier" + ) + self.config = config + + @unpack_inputs + @add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) + @add_code_sample_docstrings( + checkpoint="google/rembert", + output_type=TFMultipleChoiceModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: np.ndarray | tf.Tensor | None = None, + training: Optional[bool] = False, + ) -> Union[TFMultipleChoiceModelOutput, Tuple[tf.Tensor]]: + r""" + labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` + where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above) + """ + + if input_ids is not None: + num_choices = shape_list(input_ids)[1] + seq_length = shape_list(input_ids)[2] + else: + num_choices = shape_list(inputs_embeds)[1] + seq_length = shape_list(inputs_embeds)[2] + + flat_input_ids = tf.reshape(tensor=input_ids, shape=(-1, seq_length)) if input_ids is not None else None + flat_attention_mask = ( + tf.reshape(tensor=attention_mask, shape=(-1, seq_length)) if attention_mask is not None else None + ) + flat_token_type_ids = ( + tf.reshape(tensor=token_type_ids, shape=(-1, seq_length)) if token_type_ids is not None else None + ) + flat_position_ids = ( + tf.reshape(tensor=position_ids, shape=(-1, seq_length)) if position_ids is not None else None + ) + flat_inputs_embeds = ( + tf.reshape(tensor=inputs_embeds, shape=(-1, seq_length, shape_list(inputs_embeds)[3])) + if inputs_embeds is not None + else None + ) + outputs = self.rembert( + input_ids=flat_input_ids, + attention_mask=flat_attention_mask, + token_type_ids=flat_token_type_ids, + position_ids=flat_position_ids, + head_mask=head_mask, + inputs_embeds=flat_inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + pooled_output = outputs[1] + pooled_output = self.dropout(inputs=pooled_output, training=training) + logits = self.classifier(inputs=pooled_output) + reshaped_logits = tf.reshape(tensor=logits, shape=(-1, num_choices)) + loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=reshaped_logits) + + if not return_dict: + output = (reshaped_logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFMultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "rembert", None) is not None: + with tf.name_scope(self.rembert.name): + self.rembert.build(None) + if getattr(self, "classifier", None) is not None: + with tf.name_scope(self.classifier.name): + self.classifier.build([None, None, self.config.hidden_size]) + + +@add_start_docstrings( + """ + RemBERT Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + REMBERT_START_DOCSTRING, +) +class TFRemBertForTokenClassification(TFRemBertPreTrainedModel, TFTokenClassificationLoss): + def __init__(self, config: RemBertConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.num_labels = config.num_labels + + self.rembert = TFRemBertMainLayer(config, name="rembert", add_pooling_layer=False) + self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) + self.classifier = keras.layers.Dense( + units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" + ) + self.config = config + + @unpack_inputs + @add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint="google/rembert", + output_type=TFTokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: np.ndarray | tf.Tensor | None = None, + training: Optional[bool] = False, + ) -> Union[TFTokenClassifierOutput, Tuple[tf.Tensor]]: + r""" + labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + outputs = self.rembert( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(inputs=sequence_output, training=training) + logits = self.classifier(inputs=sequence_output) + loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return TFTokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "rembert", None) is not None: + with tf.name_scope(self.rembert.name): + self.rembert.build(None) + if getattr(self, "classifier", None) is not None: + with tf.name_scope(self.classifier.name): + self.classifier.build([None, None, self.config.hidden_size]) + + +@add_start_docstrings( + """ + RemBERT Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear + layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + REMBERT_START_DOCSTRING, +) +class TFRemBertForQuestionAnswering(TFRemBertPreTrainedModel, TFQuestionAnsweringLoss): + def __init__(self, config: RemBertConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.num_labels = config.num_labels + + self.rembert = TFRemBertMainLayer(config, add_pooling_layer=False, name="rembert") + self.qa_outputs = keras.layers.Dense( + units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs" + ) + self.config = config + + @unpack_inputs + @add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint="google/rembert", + output_type=TFQuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + start_positions: np.ndarray | tf.Tensor | None = None, + end_positions: np.ndarray | tf.Tensor | None = None, + training: Optional[bool] = False, + ) -> Union[TFQuestionAnsweringModelOutput, Tuple[tf.Tensor]]: + r""" + start_positions (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + outputs = self.rembert( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + sequence_output = outputs[0] + logits = self.qa_outputs(inputs=sequence_output) + start_logits, end_logits = tf.split(value=logits, num_or_size_splits=2, axis=-1) + start_logits = tf.squeeze(input=start_logits, axis=-1) + end_logits = tf.squeeze(input=end_logits, axis=-1) + loss = None + + if start_positions is not None and end_positions is not None: + labels = {"start_position": start_positions} + labels["end_position"] = end_positions + loss = self.hf_compute_loss(labels=labels, logits=(start_logits, end_logits)) + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFQuestionAnsweringModelOutput( + loss=loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "rembert", None) is not None: + with tf.name_scope(self.rembert.name): + self.rembert.build(None) + if getattr(self, "qa_outputs", None) is not None: + with tf.name_scope(self.qa_outputs.name): + self.qa_outputs.build([None, None, self.config.hidden_size]) + + +__all__ = [ + "TFRemBertForCausalLM", + "TFRemBertForMaskedLM", + "TFRemBertForMultipleChoice", + "TFRemBertForQuestionAnswering", + "TFRemBertForSequenceClassification", + "TFRemBertForTokenClassification", + "TFRemBertLayer", + "TFRemBertModel", + "TFRemBertPreTrainedModel", +] diff --git a/docs/transformers/src/transformers/models/rembert/tokenization_rembert.py b/docs/transformers/src/transformers/models/rembert/tokenization_rembert.py new file mode 100644 index 0000000000000000000000000000000000000000..2f2053425d727d3ef72a2dc219c0cd81392f96cc --- /dev/null +++ b/docs/transformers/src/transformers/models/rembert/tokenization_rembert.py @@ -0,0 +1,267 @@ +# coding=utf-8 +# Copyright The HuggingFace Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for RemBERT.""" + +import os +from shutil import copyfile +from typing import List, Optional, Tuple + +import sentencepiece as spm + +from ...tokenization_utils import AddedToken, PreTrainedTokenizer +from ...utils import logging +from ...utils.import_utils import requires + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.model"} + + +@requires(backends=("sentencepiece",)) +class RemBertTokenizer(PreTrainedTokenizer): + """ + Construct a RemBERT tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece). + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that + contains the vocabulary necessary to instantiate a tokenizer. + bos_token (`str`, *optional*, defaults to `"[CLS]"`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + + + + When building a sequence using special tokens, this is not the token that is used for the beginning of + sequence. The token used is the `cls_token`. + + + + eos_token (`str`, *optional*, defaults to `"[SEP]"`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + sep_token (`str`, *optional*, defaults to `"[SEP]"`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + cls_token (`str`, *optional*, defaults to `"[CLS]"`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + mask_token (`str`, *optional*, defaults to `"[MASK]"`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + + Attributes: + sp_model (`SentencePieceProcessor`): + The *SentencePiece* processor that is used for every conversion (string, tokens and IDs). + """ + + vocab_files_names = VOCAB_FILES_NAMES + + def __init__( + self, + vocab_file, + do_lower_case=False, + remove_space=True, + keep_accents=True, + bos_token="[CLS]", + eos_token="[SEP]", + unk_token="[UNK]", + sep_token="[SEP]", + pad_token="[PAD]", + cls_token="[CLS]", + mask_token="[MASK]", + **kwargs, + ): + # Mask token behave like a normal word, i.e. include the space before it + mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token + + self.do_lower_case = do_lower_case + self.remove_space = remove_space + self.keep_accents = keep_accents + self.vocab_file = vocab_file + + self.sp_model = spm.SentencePieceProcessor() + self.sp_model.Load(vocab_file) + super().__init__( + do_lower_case=do_lower_case, + remove_space=remove_space, + keep_accents=keep_accents, + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + sep_token=sep_token, + pad_token=pad_token, + cls_token=cls_token, + mask_token=mask_token, + **kwargs, + ) + + @property + def vocab_size(self): + return len(self.sp_model) + + def get_vocab(self): + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + def __getstate__(self): + state = self.__dict__.copy() + state["sp_model"] = None + return state + + def __setstate__(self, d): + self.__dict__ = d + self.sp_model = spm.SentencePieceProcessor() + self.sp_model.Load(self.vocab_file) + + def _tokenize(self, text, sample=False): + """Tokenize a string.""" + pieces = self.sp_model.EncodeAsPieces(text) + return pieces + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.sp_model.PieceToId(token) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.sp_model.IdToPiece(index) + + def convert_tokens_to_string(self, tokens): + out_string = self.sp_model.decode_pieces(tokens) + return out_string + + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A REMBERT sequence has the following format: + + - single sequence: `[CLS] X [SEP]` + - pair of sequences: `[CLS] A [SEP] B [SEP]` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + if token_ids_1 is None: + return cls + token_ids_0 + sep + return cls + token_ids_0 + sep + token_ids_1 + sep + + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + + if already_has_special_tokens: + if token_ids_1 is not None: + raise ValueError( + "You should not supply a second sequence if the provided sequence of " + "ids is already formatted with special tokens for the model." + ) + return [1 if x in [self.sep_token_id, self.cls_token_id] else 0 for x in token_ids_0] + + if token_ids_1 is not None: + return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1] + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. A RemBERT + sequence pair mask has the following format: + + ``` + 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 + | first sequence | second sequence | + ``` + + If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s). + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s). + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + if not os.path.isdir(save_directory): + logger.error("Vocabulary path ({}) should be a directory".format(save_directory)) + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): + copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + + return (out_vocab_file,) + + +__all__ = ["RemBertTokenizer"] diff --git a/docs/transformers/src/transformers/models/rembert/tokenization_rembert_fast.py b/docs/transformers/src/transformers/models/rembert/tokenization_rembert_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..7eed4f80a7878828f10713bc69a389eb239d5d5b --- /dev/null +++ b/docs/transformers/src/transformers/models/rembert/tokenization_rembert_fast.py @@ -0,0 +1,232 @@ +# coding=utf-8 +# Copyright 2018 Google AI, Google Brain and the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for RemBERT model.""" + +import os +from shutil import copyfile +from typing import List, Optional, Tuple + +from ...tokenization_utils import AddedToken +from ...tokenization_utils_fast import PreTrainedTokenizerFast +from ...utils import is_sentencepiece_available, logging + + +if is_sentencepiece_available(): + from .tokenization_rembert import RemBertTokenizer +else: + RemBertTokenizer = None + +logger = logging.get_logger(__name__) +VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.model", "tokenizer_file": "tokenizer.json"} + + +SPIECE_UNDERLINE = "▁" + + +class RemBertTokenizerFast(PreTrainedTokenizerFast): + """ + Construct a "fast" RemBert tokenizer (backed by HuggingFace's *tokenizers* library). Based on + [Unigram](https://huggingface.co/docs/tokenizers/python/latest/components.html?highlight=unigram#models). This + tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods + + Args: + vocab_file (`str`): + [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that + contains the vocabulary necessary to instantiate a tokenizer. + do_lower_case (`bool`, *optional*, defaults to `True`): + Whether or not to lowercase the input when tokenizing. + remove_space (`bool`, *optional*, defaults to `True`): + Whether or not to strip the text when tokenizing (removing excess spaces before and after the string). + keep_accents (`bool`, *optional*, defaults to `False`): + Whether or not to keep accents when tokenizing. + bos_token (`str`, *optional*, defaults to `"[CLS]"`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + + + + When building a sequence using special tokens, this is not the token that is used for the beginning of + sequence. The token used is the `cls_token`. + + + + eos_token (`str`, *optional*, defaults to `"[SEP]"`): + The end of sequence token. .. note:: When building a sequence using special tokens, this is not the token + that is used for the end of sequence. The token used is the `sep_token`. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + sep_token (`str`, *optional*, defaults to `"[SEP]"`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + cls_token (`str`, *optional*, defaults to `"[CLS]"`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + mask_token (`str`, *optional*, defaults to `"[MASK]"`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + """ + + vocab_files_names = VOCAB_FILES_NAMES + slow_tokenizer_class = RemBertTokenizer + + def __init__( + self, + vocab_file=None, + tokenizer_file=None, + do_lower_case=True, + remove_space=True, + keep_accents=False, + bos_token="[CLS]", + eos_token="[SEP]", + unk_token="", + sep_token="[SEP]", + pad_token="", + cls_token="[CLS]", + mask_token="[MASK]", + **kwargs, + ): + # Mask token behave like a normal word, i.e. include the space before it + mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token + + super().__init__( + vocab_file, + tokenizer_file=tokenizer_file, + do_lower_case=do_lower_case, + remove_space=remove_space, + keep_accents=keep_accents, + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + sep_token=sep_token, + pad_token=pad_token, + cls_token=cls_token, + mask_token=mask_token, + **kwargs, + ) + + self.do_lower_case = do_lower_case + self.remove_space = remove_space + self.keep_accents = keep_accents + self.vocab_file = vocab_file + + @property + def can_save_slow_tokenizer(self) -> bool: + return os.path.isfile(self.vocab_file) if self.vocab_file else False + + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A RemBERT sequence has the following format: + + - single sequence: `[CLS] X [SEP]` + - pair of sequences: `[CLS] A [SEP] B [SEP]` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added + token_ids_1 (`List[int]`, *optional*, defaults to `None`): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: list of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + if token_ids_1 is None: + return cls + token_ids_0 + sep + return cls + token_ids_0 + sep + token_ids_1 + sep + + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of ids. + token_ids_1 (`List[int]`, *optional*, defaults to `None`): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Set to True if the token list is already formatted with special tokens for the model + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + + if already_has_special_tokens: + if token_ids_1 is not None: + raise ValueError( + "You should not supply a second sequence if the provided sequence of " + "ids is already formatted with special tokens for the model." + ) + return [1 if x in [self.sep_token_id, self.cls_token_id] else 0 for x in token_ids_0] + + if token_ids_1 is not None: + return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1] + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Creates a mask from the two sequences passed to be used in a sequence-pair classification task. A RemBERT + sequence pair mask has the following format: + + ``` + 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 + | first sequence | second sequence | + ``` + + if token_ids_1 is None, only returns the first portion of the mask (0s). + + Args: + token_ids_0 (`List[int]`): + List of ids. + token_ids_1 (`List[int]`, *optional*, defaults to `None`): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s). + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + if not os.path.isdir(save_directory): + logger.error("Vocabulary path ({}) should be a directory".format(save_directory)) + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): + copyfile(self.vocab_file, out_vocab_file) + + return (out_vocab_file,) + + +__all__ = ["RemBertTokenizerFast"] diff --git a/docs/transformers/src/transformers/models/resnet/__init__.py b/docs/transformers/src/transformers/models/resnet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..625e93a25543701db300db039d736589f2148e6d --- /dev/null +++ b/docs/transformers/src/transformers/models/resnet/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_resnet import * + from .modeling_flax_resnet import * + from .modeling_resnet import * + from .modeling_tf_resnet import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/resnet/configuration_resnet.py b/docs/transformers/src/transformers/models/resnet/configuration_resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..42bc19a8bca5b621cdcad893c71630ed15e9f316 --- /dev/null +++ b/docs/transformers/src/transformers/models/resnet/configuration_resnet.py @@ -0,0 +1,136 @@ +# coding=utf-8 +# Copyright 2022 Microsoft Research, Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""ResNet model configuration""" + +from collections import OrderedDict +from typing import Mapping + +from packaging import version + +from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig +from ...utils import logging +from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices + + +logger = logging.get_logger(__name__) + + +class ResNetConfig(BackboneConfigMixin, PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`ResNetModel`]. It is used to instantiate an + ResNet model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the ResNet + [microsoft/resnet-50](https://huggingface.co/microsoft/resnet-50) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + embedding_size (`int`, *optional*, defaults to 64): + Dimensionality (hidden size) for the embedding layer. + hidden_sizes (`List[int]`, *optional*, defaults to `[256, 512, 1024, 2048]`): + Dimensionality (hidden size) at each stage. + depths (`List[int]`, *optional*, defaults to `[3, 4, 6, 3]`): + Depth (number of layers) for each stage. + layer_type (`str`, *optional*, defaults to `"bottleneck"`): + The layer to use, it can be either `"basic"` (used for smaller models, like resnet-18 or resnet-34) or + `"bottleneck"` (used for larger models like resnet-50 and above). + hidden_act (`str`, *optional*, defaults to `"relu"`): + The non-linear activation function in each block. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` + are supported. + downsample_in_first_stage (`bool`, *optional*, defaults to `False`): + If `True`, the first stage will downsample the inputs using a `stride` of 2. + downsample_in_bottleneck (`bool`, *optional*, defaults to `False`): + If `True`, the first conv 1x1 in ResNetBottleNeckLayer will downsample the inputs using a `stride` of 2. + out_features (`List[str]`, *optional*): + If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc. + (depending on how many stages the model has). If unset and `out_indices` is set, will default to the + corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the + same order as defined in the `stage_names` attribute. + out_indices (`List[int]`, *optional*): + If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how + many stages the model has). If unset and `out_features` is set, will default to the corresponding stages. + If unset and `out_features` is unset, will default to the last stage. Must be in the + same order as defined in the `stage_names` attribute. + + Example: + ```python + >>> from transformers import ResNetConfig, ResNetModel + + >>> # Initializing a ResNet resnet-50 style configuration + >>> configuration = ResNetConfig() + + >>> # Initializing a model (with random weights) from the resnet-50 style configuration + >>> model = ResNetModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "resnet" + layer_types = ["basic", "bottleneck"] + + def __init__( + self, + num_channels=3, + embedding_size=64, + hidden_sizes=[256, 512, 1024, 2048], + depths=[3, 4, 6, 3], + layer_type="bottleneck", + hidden_act="relu", + downsample_in_first_stage=False, + downsample_in_bottleneck=False, + out_features=None, + out_indices=None, + **kwargs, + ): + super().__init__(**kwargs) + if layer_type not in self.layer_types: + raise ValueError(f"layer_type={layer_type} is not one of {','.join(self.layer_types)}") + self.num_channels = num_channels + self.embedding_size = embedding_size + self.hidden_sizes = hidden_sizes + self.depths = depths + self.layer_type = layer_type + self.hidden_act = hidden_act + self.downsample_in_first_stage = downsample_in_first_stage + self.downsample_in_bottleneck = downsample_in_bottleneck + self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)] + self._out_features, self._out_indices = get_aligned_output_features_output_indices( + out_features=out_features, out_indices=out_indices, stage_names=self.stage_names + ) + + +class ResNetOnnxConfig(OnnxConfig): + torch_onnx_minimum_version = version.parse("1.11") + + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict( + [ + ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}), + ] + ) + + @property + def atol_for_validation(self) -> float: + return 1e-3 + + +__all__ = ["ResNetConfig", "ResNetOnnxConfig"] diff --git a/docs/transformers/src/transformers/models/resnet/convert_resnet_to_pytorch.py b/docs/transformers/src/transformers/models/resnet/convert_resnet_to_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..4909f1dc670358c8da12e09e5ce53a5770a7809f --- /dev/null +++ b/docs/transformers/src/transformers/models/resnet/convert_resnet_to_pytorch.py @@ -0,0 +1,199 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert ResNet checkpoints from timm.""" + +import argparse +import json +from dataclasses import dataclass, field +from functools import partial +from pathlib import Path +from typing import List, Optional + +import timm +import torch +import torch.nn as nn +from huggingface_hub import hf_hub_download +from torch import Tensor + +from transformers import AutoImageProcessor, ResNetConfig, ResNetForImageClassification +from transformers.utils import logging + + +logging.set_verbosity_info() +logger = logging.get_logger() + + +@dataclass +class Tracker: + module: nn.Module + traced: List[nn.Module] = field(default_factory=list) + handles: list = field(default_factory=list) + + def _forward_hook(self, m, inputs: Tensor, outputs: Tensor): + has_not_submodules = len(list(m.modules())) == 1 or isinstance(m, nn.Conv2d) or isinstance(m, nn.BatchNorm2d) + if has_not_submodules: + self.traced.append(m) + + def __call__(self, x: Tensor): + for m in self.module.modules(): + self.handles.append(m.register_forward_hook(self._forward_hook)) + self.module(x) + [x.remove() for x in self.handles] + return self + + @property + def parametrized(self): + # check the len of the state_dict keys to see if we have learnable params + return list(filter(lambda x: len(list(x.state_dict().keys())) > 0, self.traced)) + + +@dataclass +class ModuleTransfer: + src: nn.Module + dest: nn.Module + verbose: int = 0 + src_skip: List = field(default_factory=list) + dest_skip: List = field(default_factory=list) + + def __call__(self, x: Tensor): + """ + Transfer the weights of `self.src` to `self.dest` by performing a forward pass using `x` as input. Under the + hood we tracked all the operations in both modules. + """ + dest_traced = Tracker(self.dest)(x).parametrized + src_traced = Tracker(self.src)(x).parametrized + + src_traced = list(filter(lambda x: type(x) not in self.src_skip, src_traced)) + dest_traced = list(filter(lambda x: type(x) not in self.dest_skip, dest_traced)) + + if len(dest_traced) != len(src_traced): + raise Exception( + f"Numbers of operations are different. Source module has {len(src_traced)} operations while" + f" destination module has {len(dest_traced)}." + ) + + for dest_m, src_m in zip(dest_traced, src_traced): + dest_m.load_state_dict(src_m.state_dict()) + if self.verbose == 1: + print(f"Transfered from={src_m} to={dest_m}") + + +def convert_weight_and_push(name: str, config: ResNetConfig, save_directory: Path, push_to_hub: bool = True): + print(f"Converting {name}...") + with torch.no_grad(): + from_model = timm.create_model(name, pretrained=True).eval() + our_model = ResNetForImageClassification(config).eval() + module_transfer = ModuleTransfer(src=from_model, dest=our_model) + x = torch.randn((1, 3, 224, 224)) + module_transfer(x) + + assert torch.allclose(from_model(x), our_model(x).logits), "The model logits don't match the original one." + + checkpoint_name = f"resnet{'-'.join(name.split('resnet'))}" + print(checkpoint_name) + + if push_to_hub: + our_model.push_to_hub( + repo_path_or_name=save_directory / checkpoint_name, + commit_message="Add model", + use_temp_dir=True, + ) + + # we can use the convnext one + image_processor = AutoImageProcessor.from_pretrained("facebook/convnext-base-224-22k-1k") + image_processor.push_to_hub( + repo_path_or_name=save_directory / checkpoint_name, + commit_message="Add image processor", + use_temp_dir=True, + ) + + print(f"Pushed {checkpoint_name}") + + +def convert_weights_and_push(save_directory: Path, model_name: Optional[str] = None, push_to_hub: bool = True): + filename = "imagenet-1k-id2label.json" + num_labels = 1000 + expected_shape = (1, num_labels) + + repo_id = "huggingface/label-files" + num_labels = num_labels + id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r")) + id2label = {int(k): v for k, v in id2label.items()} + + id2label = id2label + label2id = {v: k for k, v in id2label.items()} + + ImageNetPreTrainedConfig = partial(ResNetConfig, num_labels=num_labels, id2label=id2label, label2id=label2id) + + names_to_config = { + "resnet18": ImageNetPreTrainedConfig( + depths=[2, 2, 2, 2], hidden_sizes=[64, 128, 256, 512], layer_type="basic" + ), + "resnet26": ImageNetPreTrainedConfig( + depths=[2, 2, 2, 2], hidden_sizes=[256, 512, 1024, 2048], layer_type="bottleneck" + ), + "resnet34": ImageNetPreTrainedConfig( + depths=[3, 4, 6, 3], hidden_sizes=[64, 128, 256, 512], layer_type="basic" + ), + "resnet50": ImageNetPreTrainedConfig( + depths=[3, 4, 6, 3], hidden_sizes=[256, 512, 1024, 2048], layer_type="bottleneck" + ), + "resnet101": ImageNetPreTrainedConfig( + depths=[3, 4, 23, 3], hidden_sizes=[256, 512, 1024, 2048], layer_type="bottleneck" + ), + "resnet152": ImageNetPreTrainedConfig( + depths=[3, 8, 36, 3], hidden_sizes=[256, 512, 1024, 2048], layer_type="bottleneck" + ), + } + + if model_name: + convert_weight_and_push(model_name, names_to_config[model_name], save_directory, push_to_hub) + else: + for model_name, config in names_to_config.items(): + convert_weight_and_push(model_name, config, save_directory, push_to_hub) + return config, expected_shape + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--model_name", + default=None, + type=str, + help=( + "The name of the model you wish to convert, it must be one of the supported resnet* architecture," + " currently: resnet18,26,34,50,101,152. If `None`, all of them will the converted." + ), + ) + parser.add_argument( + "--pytorch_dump_folder_path", + default=None, + type=Path, + required=True, + help="Path to the output PyTorch model directory.", + ) + parser.add_argument( + "--push_to_hub", + default=True, + type=bool, + required=False, + help="If True, push model and image processor to the hub.", + ) + + args = parser.parse_args() + pytorch_dump_folder_path: Path = args.pytorch_dump_folder_path + pytorch_dump_folder_path.mkdir(exist_ok=True, parents=True) + convert_weights_and_push(pytorch_dump_folder_path, args.model_name, args.push_to_hub) diff --git a/docs/transformers/src/transformers/models/resnet/modeling_flax_resnet.py b/docs/transformers/src/transformers/models/resnet/modeling_flax_resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..d7fc2cc1be55462a9e2d417479ebcb457c98b516 --- /dev/null +++ b/docs/transformers/src/transformers/models/resnet/modeling_flax_resnet.py @@ -0,0 +1,704 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +from typing import Optional, Tuple + +import flax.linen as nn +import jax +import jax.numpy as jnp +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze +from flax.traverse_util import flatten_dict, unflatten_dict + +from ...modeling_flax_outputs import ( + FlaxBaseModelOutputWithNoAttention, + FlaxBaseModelOutputWithPoolingAndNoAttention, + FlaxImageClassifierOutputWithNoAttention, +) +from ...modeling_flax_utils import ( + ACT2FN, + FlaxPreTrainedModel, + append_replace_return_docstrings, + overwrite_call_docstring, +) +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward +from .configuration_resnet import ResNetConfig + + +RESNET_START_DOCSTRING = r""" + + This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading, saving and converting weights from PyTorch models) + + This model is also a + [flax.linen.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) subclass. Use it as + a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and + behavior. + + Finally, this model supports inherent JAX features such as: + + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + + Parameters: + config ([`ResNetConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. + dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): + The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and + `jax.numpy.bfloat16` (on TPUs). + + This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If + specified all the computation will be performed with the given `dtype`. + + **Note that this only specifies the dtype of the computation and does not influence the dtype of model + parameters.** + + If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and + [`~FlaxPreTrainedModel.to_bf16`]. +""" + + +RESNET_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`jax.numpy.float32` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See + [`AutoImageProcessor.__call__`] for details. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class Identity(nn.Module): + """Identity function.""" + + @nn.compact + def __call__(self, x, **kwargs): + return x + + +class FlaxResNetConvLayer(nn.Module): + out_channels: int + kernel_size: int = 3 + stride: int = 1 + activation: Optional[str] = "relu" + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.convolution = nn.Conv( + self.out_channels, + kernel_size=(self.kernel_size, self.kernel_size), + strides=self.stride, + padding=self.kernel_size // 2, + dtype=self.dtype, + use_bias=False, + kernel_init=nn.initializers.variance_scaling(2.0, mode="fan_out", distribution="normal", dtype=self.dtype), + ) + self.normalization = nn.BatchNorm(momentum=0.9, epsilon=1e-05, dtype=self.dtype) + self.activation_func = ACT2FN[self.activation] if self.activation is not None else Identity() + + def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: + hidden_state = self.convolution(x) + hidden_state = self.normalization(hidden_state, use_running_average=deterministic) + hidden_state = self.activation_func(hidden_state) + return hidden_state + + +class FlaxResNetEmbeddings(nn.Module): + """ + ResNet Embeddings (stem) composed of a single aggressive convolution. + """ + + config: ResNetConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.embedder = FlaxResNetConvLayer( + self.config.embedding_size, + kernel_size=7, + stride=2, + activation=self.config.hidden_act, + dtype=self.dtype, + ) + + self.max_pool = partial(nn.max_pool, window_shape=(3, 3), strides=(2, 2), padding=((1, 1), (1, 1))) + + def __call__(self, pixel_values: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: + num_channels = pixel_values.shape[-1] + if num_channels != self.config.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + embedding = self.embedder(pixel_values, deterministic=deterministic) + embedding = self.max_pool(embedding) + return embedding + + +class FlaxResNetShortCut(nn.Module): + """ + ResNet shortcut, used to project the residual features to the correct size. If needed, it is also used to + downsample the input using `stride=2`. + """ + + out_channels: int + stride: int = 2 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.convolution = nn.Conv( + self.out_channels, + kernel_size=(1, 1), + strides=self.stride, + use_bias=False, + kernel_init=nn.initializers.variance_scaling(2.0, mode="fan_out", distribution="truncated_normal"), + dtype=self.dtype, + ) + self.normalization = nn.BatchNorm(momentum=0.9, epsilon=1e-05, dtype=self.dtype) + + def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: + hidden_state = self.convolution(x) + hidden_state = self.normalization(hidden_state, use_running_average=deterministic) + return hidden_state + + +class FlaxResNetBasicLayerCollection(nn.Module): + out_channels: int + stride: int = 1 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.layer = [ + FlaxResNetConvLayer(self.out_channels, stride=self.stride, dtype=self.dtype), + FlaxResNetConvLayer(self.out_channels, activation=None, dtype=self.dtype), + ] + + def __call__(self, hidden_state: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: + for layer in self.layer: + hidden_state = layer(hidden_state, deterministic=deterministic) + return hidden_state + + +class FlaxResNetBasicLayer(nn.Module): + """ + A classic ResNet's residual layer composed by two `3x3` convolutions. + """ + + in_channels: int + out_channels: int + stride: int = 1 + activation: Optional[str] = "relu" + dtype: jnp.dtype = jnp.float32 + + def setup(self): + should_apply_shortcut = self.in_channels != self.out_channels or self.stride != 1 + self.shortcut = ( + FlaxResNetShortCut(self.out_channels, stride=self.stride, dtype=self.dtype) + if should_apply_shortcut + else None + ) + self.layer = FlaxResNetBasicLayerCollection( + out_channels=self.out_channels, + stride=self.stride, + dtype=self.dtype, + ) + self.activation_func = ACT2FN[self.activation] + + def __call__(self, hidden_state, deterministic: bool = True): + residual = hidden_state + hidden_state = self.layer(hidden_state, deterministic=deterministic) + + if self.shortcut is not None: + residual = self.shortcut(residual, deterministic=deterministic) + hidden_state += residual + + hidden_state = self.activation_func(hidden_state) + return hidden_state + + +class FlaxResNetBottleNeckLayerCollection(nn.Module): + out_channels: int + stride: int = 1 + activation: Optional[str] = "relu" + reduction: int = 4 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + reduces_channels = self.out_channels // self.reduction + + self.layer = [ + FlaxResNetConvLayer(reduces_channels, kernel_size=1, dtype=self.dtype, name="0"), + FlaxResNetConvLayer(reduces_channels, stride=self.stride, dtype=self.dtype, name="1"), + FlaxResNetConvLayer(self.out_channels, kernel_size=1, activation=None, dtype=self.dtype, name="2"), + ] + + def __call__(self, hidden_state: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: + for layer in self.layer: + hidden_state = layer(hidden_state, deterministic=deterministic) + return hidden_state + + +class FlaxResNetBottleNeckLayer(nn.Module): + """ + A classic ResNet's bottleneck layer composed by three `3x3` convolutions. The first `1x1` convolution reduces the + input by a factor of `reduction` in order to make the second `3x3` convolution faster. The last `1x1` convolution + remaps the reduced features to `out_channels`. + """ + + in_channels: int + out_channels: int + stride: int = 1 + activation: Optional[str] = "relu" + reduction: int = 4 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + should_apply_shortcut = self.in_channels != self.out_channels or self.stride != 1 + self.shortcut = ( + FlaxResNetShortCut(self.out_channels, stride=self.stride, dtype=self.dtype) + if should_apply_shortcut + else None + ) + + self.layer = FlaxResNetBottleNeckLayerCollection( + self.out_channels, + stride=self.stride, + activation=self.activation, + reduction=self.reduction, + dtype=self.dtype, + ) + + self.activation_func = ACT2FN[self.activation] + + def __call__(self, hidden_state: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: + residual = hidden_state + + if self.shortcut is not None: + residual = self.shortcut(residual, deterministic=deterministic) + hidden_state = self.layer(hidden_state, deterministic) + hidden_state += residual + hidden_state = self.activation_func(hidden_state) + return hidden_state + + +class FlaxResNetStageLayersCollection(nn.Module): + """ + A ResNet stage composed by stacked layers. + """ + + config: ResNetConfig + in_channels: int + out_channels: int + stride: int = 2 + depth: int = 2 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + layer = FlaxResNetBottleNeckLayer if self.config.layer_type == "bottleneck" else FlaxResNetBasicLayer + + layers = [ + # downsampling is done in the first layer with stride of 2 + layer( + self.in_channels, + self.out_channels, + stride=self.stride, + activation=self.config.hidden_act, + dtype=self.dtype, + name="0", + ), + ] + + for i in range(self.depth - 1): + layers.append( + layer( + self.out_channels, + self.out_channels, + activation=self.config.hidden_act, + dtype=self.dtype, + name=str(i + 1), + ) + ) + + self.layers = layers + + def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: + hidden_state = x + for layer in self.layers: + hidden_state = layer(hidden_state, deterministic=deterministic) + return hidden_state + + +class FlaxResNetStage(nn.Module): + """ + A ResNet stage composed by stacked layers. + """ + + config: ResNetConfig + in_channels: int + out_channels: int + stride: int = 2 + depth: int = 2 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.layers = FlaxResNetStageLayersCollection( + self.config, + in_channels=self.in_channels, + out_channels=self.out_channels, + stride=self.stride, + depth=self.depth, + dtype=self.dtype, + ) + + def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: + return self.layers(x, deterministic=deterministic) + + +class FlaxResNetStageCollection(nn.Module): + config: ResNetConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + in_out_channels = zip(self.config.hidden_sizes, self.config.hidden_sizes[1:]) + stages = [ + FlaxResNetStage( + self.config, + self.config.embedding_size, + self.config.hidden_sizes[0], + stride=2 if self.config.downsample_in_first_stage else 1, + depth=self.config.depths[0], + dtype=self.dtype, + name="0", + ) + ] + + for i, ((in_channels, out_channels), depth) in enumerate(zip(in_out_channels, self.config.depths[1:])): + stages.append( + FlaxResNetStage(self.config, in_channels, out_channels, depth=depth, dtype=self.dtype, name=str(i + 1)) + ) + + self.stages = stages + + def __call__( + self, + hidden_state: jnp.ndarray, + output_hidden_states: bool = False, + deterministic: bool = True, + ) -> FlaxBaseModelOutputWithNoAttention: + hidden_states = () if output_hidden_states else None + + for stage_module in self.stages: + if output_hidden_states: + hidden_states = hidden_states + (hidden_state.transpose(0, 3, 1, 2),) + + hidden_state = stage_module(hidden_state, deterministic=deterministic) + + return hidden_state, hidden_states + + +class FlaxResNetEncoder(nn.Module): + config: ResNetConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.stages = FlaxResNetStageCollection(self.config, dtype=self.dtype) + + def __call__( + self, + hidden_state: jnp.ndarray, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + ) -> FlaxBaseModelOutputWithNoAttention: + hidden_state, hidden_states = self.stages( + hidden_state, output_hidden_states=output_hidden_states, deterministic=deterministic + ) + + if output_hidden_states: + hidden_states = hidden_states + (hidden_state.transpose(0, 3, 1, 2),) + + if not return_dict: + return tuple(v for v in [hidden_state, hidden_states] if v is not None) + + return FlaxBaseModelOutputWithNoAttention( + last_hidden_state=hidden_state, + hidden_states=hidden_states, + ) + + +class FlaxResNetPreTrainedModel(FlaxPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = ResNetConfig + base_model_prefix = "resnet" + main_input_name = "pixel_values" + module_class: nn.Module = None + + def __init__( + self, + config: ResNetConfig, + input_shape=(1, 224, 224, 3), + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, + **kwargs, + ): + module = self.module_class(config=config, dtype=dtype, **kwargs) + if input_shape is None: + input_shape = (1, config.image_size, config.image_size, config.num_channels) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: + # init input tensors + pixel_values = jnp.zeros(input_shape, dtype=self.dtype) + + rngs = {"params": rng} + + random_params = self.module.init(rngs, pixel_values, return_dict=False) + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + + @add_start_docstrings_to_model_forward(RESNET_INPUTS_DOCSTRING) + def __call__( + self, + pixel_values, + params: dict = None, + train: bool = False, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1)) + + # Handle any PRNG if needed + rngs = {} + + return self.module.apply( + { + "params": params["params"] if params is not None else self.params["params"], + "batch_stats": params["batch_stats"] if params is not None else self.params["batch_stats"], + }, + jnp.array(pixel_values, dtype=jnp.float32), + not train, + output_hidden_states, + return_dict, + rngs=rngs, + mutable=["batch_stats"] if train else False, # Returing tuple with batch_stats only when train is True + ) + + +class FlaxResNetModule(nn.Module): + config: ResNetConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.embedder = FlaxResNetEmbeddings(self.config, dtype=self.dtype) + self.encoder = FlaxResNetEncoder(self.config, dtype=self.dtype) + + # Adaptive average pooling used in resnet + self.pooler = partial( + nn.avg_pool, + padding=((0, 0), (0, 0)), + ) + + def __call__( + self, + pixel_values, + deterministic: bool = True, + output_hidden_states: bool = False, + return_dict: bool = True, + ) -> FlaxBaseModelOutputWithPoolingAndNoAttention: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + embedding_output = self.embedder(pixel_values, deterministic=deterministic) + + encoder_outputs = self.encoder( + embedding_output, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + last_hidden_state = encoder_outputs[0] + + pooled_output = self.pooler( + last_hidden_state, + window_shape=(last_hidden_state.shape[1], last_hidden_state.shape[2]), + strides=(last_hidden_state.shape[1], last_hidden_state.shape[2]), + ).transpose(0, 3, 1, 2) + + last_hidden_state = last_hidden_state.transpose(0, 3, 1, 2) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return FlaxBaseModelOutputWithPoolingAndNoAttention( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + ) + + +@add_start_docstrings( + "The bare ResNet model outputting raw features without any specific head on top.", + RESNET_START_DOCSTRING, +) +class FlaxResNetModel(FlaxResNetPreTrainedModel): + module_class = FlaxResNetModule + + +FLAX_VISION_MODEL_DOCSTRING = """ + Returns: + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, FlaxResNetModel + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50") + >>> model = FlaxResNetModel.from_pretrained("microsoft/resnet-50") + >>> inputs = image_processor(images=image, return_tensors="np") + >>> outputs = model(**inputs) + >>> last_hidden_states = outputs.last_hidden_state + ``` +""" + +overwrite_call_docstring(FlaxResNetModel, FLAX_VISION_MODEL_DOCSTRING) +append_replace_return_docstrings( + FlaxResNetModel, output_type=FlaxBaseModelOutputWithPoolingAndNoAttention, config_class=ResNetConfig +) + + +class FlaxResNetClassifierCollection(nn.Module): + config: ResNetConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype, name="1") + + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + return self.classifier(x) + + +class FlaxResNetForImageClassificationModule(nn.Module): + config: ResNetConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.resnet = FlaxResNetModule(config=self.config, dtype=self.dtype) + + if self.config.num_labels > 0: + self.classifier = FlaxResNetClassifierCollection(self.config, dtype=self.dtype) + else: + self.classifier = Identity() + + def __call__( + self, + pixel_values=None, + deterministic: bool = True, + output_hidden_states=None, + return_dict=None, + ): + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.resnet( + pixel_values, + deterministic=deterministic, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs.pooler_output if return_dict else outputs[1] + + logits = self.classifier(pooled_output[:, :, 0, 0]) + + if not return_dict: + output = (logits,) + outputs[2:] + return output + + return FlaxImageClassifierOutputWithNoAttention(logits=logits, hidden_states=outputs.hidden_states) + + +@add_start_docstrings( + """ + ResNet Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for + ImageNet. + """, + RESNET_START_DOCSTRING, +) +class FlaxResNetForImageClassification(FlaxResNetPreTrainedModel): + module_class = FlaxResNetForImageClassificationModule + + +FLAX_VISION_CLASSIF_DOCSTRING = """ + Returns: + + Example: + + ```python + >>> from transformers import AutoImageProcessor, FlaxResNetForImageClassification + >>> from PIL import Image + >>> import jax + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50") + >>> model = FlaxResNetForImageClassification.from_pretrained("microsoft/resnet-50") + + >>> inputs = image_processor(images=image, return_tensors="np") + >>> outputs = model(**inputs) + >>> logits = outputs.logits + + >>> # model predicts one of the 1000 ImageNet classes + >>> predicted_class_idx = jax.numpy.argmax(logits, axis=-1) + >>> print("Predicted class:", model.config.id2label[predicted_class_idx.item()]) + ``` +""" + +overwrite_call_docstring(FlaxResNetForImageClassification, FLAX_VISION_CLASSIF_DOCSTRING) +append_replace_return_docstrings( + FlaxResNetForImageClassification, output_type=FlaxImageClassifierOutputWithNoAttention, config_class=ResNetConfig +) + + +__all__ = ["FlaxResNetForImageClassification", "FlaxResNetModel", "FlaxResNetPreTrainedModel"] diff --git a/docs/transformers/src/transformers/models/resnet/modeling_resnet.py b/docs/transformers/src/transformers/models/resnet/modeling_resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..d3b757414892d05597b1d2320b08adf87c1d69ca --- /dev/null +++ b/docs/transformers/src/transformers/models/resnet/modeling_resnet.py @@ -0,0 +1,520 @@ +# coding=utf-8 +# Copyright 2022 Microsoft Research, Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch ResNet model.""" + +import math +from typing import Optional + +import torch +import torch.utils.checkpoint +from torch import Tensor, nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...modeling_outputs import ( + BackboneOutput, + BaseModelOutputWithNoAttention, + BaseModelOutputWithPoolingAndNoAttention, + ImageClassifierOutputWithNoAttention, +) +from ...modeling_utils import PreTrainedModel +from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from ...utils.backbone_utils import BackboneMixin +from .configuration_resnet import ResNetConfig + + +logger = logging.get_logger(__name__) + +# General docstring +_CONFIG_FOR_DOC = "ResNetConfig" + +# Base docstring +_CHECKPOINT_FOR_DOC = "microsoft/resnet-50" +_EXPECTED_OUTPUT_SHAPE = [1, 2048, 7, 7] + +# Image classification docstring +_IMAGE_CLASS_CHECKPOINT = "microsoft/resnet-50" +_IMAGE_CLASS_EXPECTED_OUTPUT = "tiger cat" + + +class ResNetConvLayer(nn.Module): + def __init__( + self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1, activation: str = "relu" + ): + super().__init__() + self.convolution = nn.Conv2d( + in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=kernel_size // 2, bias=False + ) + self.normalization = nn.BatchNorm2d(out_channels) + self.activation = ACT2FN[activation] if activation is not None else nn.Identity() + + def forward(self, input: Tensor) -> Tensor: + hidden_state = self.convolution(input) + hidden_state = self.normalization(hidden_state) + hidden_state = self.activation(hidden_state) + return hidden_state + + +class ResNetEmbeddings(nn.Module): + """ + ResNet Embeddings (stem) composed of a single aggressive convolution. + """ + + def __init__(self, config: ResNetConfig): + super().__init__() + self.embedder = ResNetConvLayer( + config.num_channels, config.embedding_size, kernel_size=7, stride=2, activation=config.hidden_act + ) + self.pooler = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.num_channels = config.num_channels + + def forward(self, pixel_values: Tensor) -> Tensor: + num_channels = pixel_values.shape[1] + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + embedding = self.embedder(pixel_values) + embedding = self.pooler(embedding) + return embedding + + +class ResNetShortCut(nn.Module): + """ + ResNet shortcut, used to project the residual features to the correct size. If needed, it is also used to + downsample the input using `stride=2`. + """ + + def __init__(self, in_channels: int, out_channels: int, stride: int = 2): + super().__init__() + self.convolution = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False) + self.normalization = nn.BatchNorm2d(out_channels) + + def forward(self, input: Tensor) -> Tensor: + hidden_state = self.convolution(input) + hidden_state = self.normalization(hidden_state) + return hidden_state + + +class ResNetBasicLayer(nn.Module): + """ + A classic ResNet's residual layer composed by two `3x3` convolutions. + """ + + def __init__(self, in_channels: int, out_channels: int, stride: int = 1, activation: str = "relu"): + super().__init__() + should_apply_shortcut = in_channels != out_channels or stride != 1 + self.shortcut = ( + ResNetShortCut(in_channels, out_channels, stride=stride) if should_apply_shortcut else nn.Identity() + ) + self.layer = nn.Sequential( + ResNetConvLayer(in_channels, out_channels, stride=stride), + ResNetConvLayer(out_channels, out_channels, activation=None), + ) + self.activation = ACT2FN[activation] + + def forward(self, hidden_state): + residual = hidden_state + hidden_state = self.layer(hidden_state) + residual = self.shortcut(residual) + hidden_state += residual + hidden_state = self.activation(hidden_state) + return hidden_state + + +class ResNetBottleNeckLayer(nn.Module): + """ + A classic ResNet's bottleneck layer composed by three `3x3` convolutions. + + The first `1x1` convolution reduces the input by a factor of `reduction` in order to make the second `3x3` + convolution faster. The last `1x1` convolution remaps the reduced features to `out_channels`. If + `downsample_in_bottleneck` is true, downsample will be in the first layer instead of the second layer. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + stride: int = 1, + activation: str = "relu", + reduction: int = 4, + downsample_in_bottleneck: bool = False, + ): + super().__init__() + should_apply_shortcut = in_channels != out_channels or stride != 1 + reduces_channels = out_channels // reduction + self.shortcut = ( + ResNetShortCut(in_channels, out_channels, stride=stride) if should_apply_shortcut else nn.Identity() + ) + self.layer = nn.Sequential( + ResNetConvLayer( + in_channels, reduces_channels, kernel_size=1, stride=stride if downsample_in_bottleneck else 1 + ), + ResNetConvLayer(reduces_channels, reduces_channels, stride=stride if not downsample_in_bottleneck else 1), + ResNetConvLayer(reduces_channels, out_channels, kernel_size=1, activation=None), + ) + self.activation = ACT2FN[activation] + + def forward(self, hidden_state): + residual = hidden_state + hidden_state = self.layer(hidden_state) + residual = self.shortcut(residual) + hidden_state += residual + hidden_state = self.activation(hidden_state) + return hidden_state + + +class ResNetStage(nn.Module): + """ + A ResNet stage composed by stacked layers. + """ + + def __init__( + self, + config: ResNetConfig, + in_channels: int, + out_channels: int, + stride: int = 2, + depth: int = 2, + ): + super().__init__() + + layer = ResNetBottleNeckLayer if config.layer_type == "bottleneck" else ResNetBasicLayer + + if config.layer_type == "bottleneck": + first_layer = layer( + in_channels, + out_channels, + stride=stride, + activation=config.hidden_act, + downsample_in_bottleneck=config.downsample_in_bottleneck, + ) + else: + first_layer = layer(in_channels, out_channels, stride=stride, activation=config.hidden_act) + self.layers = nn.Sequential( + first_layer, *[layer(out_channels, out_channels, activation=config.hidden_act) for _ in range(depth - 1)] + ) + + def forward(self, input: Tensor) -> Tensor: + hidden_state = input + for layer in self.layers: + hidden_state = layer(hidden_state) + return hidden_state + + +class ResNetEncoder(nn.Module): + def __init__(self, config: ResNetConfig): + super().__init__() + self.stages = nn.ModuleList([]) + # based on `downsample_in_first_stage` the first layer of the first stage may or may not downsample the input + self.stages.append( + ResNetStage( + config, + config.embedding_size, + config.hidden_sizes[0], + stride=2 if config.downsample_in_first_stage else 1, + depth=config.depths[0], + ) + ) + in_out_channels = zip(config.hidden_sizes, config.hidden_sizes[1:]) + for (in_channels, out_channels), depth in zip(in_out_channels, config.depths[1:]): + self.stages.append(ResNetStage(config, in_channels, out_channels, depth=depth)) + + def forward( + self, hidden_state: Tensor, output_hidden_states: bool = False, return_dict: bool = True + ) -> BaseModelOutputWithNoAttention: + hidden_states = () if output_hidden_states else None + + for stage_module in self.stages: + if output_hidden_states: + hidden_states = hidden_states + (hidden_state,) + + hidden_state = stage_module(hidden_state) + + if output_hidden_states: + hidden_states = hidden_states + (hidden_state,) + + if not return_dict: + return tuple(v for v in [hidden_state, hidden_states] if v is not None) + + return BaseModelOutputWithNoAttention( + last_hidden_state=hidden_state, + hidden_states=hidden_states, + ) + + +class ResNetPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = ResNetConfig + base_model_prefix = "resnet" + main_input_name = "pixel_values" + _no_split_modules = ["ResNetConvLayer", "ResNetShortCut"] + + def _init_weights(self, module): + if isinstance(module, nn.Conv2d): + nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") + # copied from the `reset_parameters` method of `class Linear(Module)` in `torch`. + elif isinstance(module, nn.Linear): + nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5)) + if module.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + nn.init.uniform_(module.bias, -bound, bound) + elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(module.weight, 1) + nn.init.constant_(module.bias, 0) + + +RESNET_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it + as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`ResNetConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +RESNET_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See + [`ConvNextImageProcessor.__call__`] for details. + + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare ResNet model outputting raw features without any specific head on top.", + RESNET_START_DOCSTRING, +) +class ResNetModel(ResNetPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.config = config + self.embedder = ResNetEmbeddings(config) + self.encoder = ResNetEncoder(config) + self.pooler = nn.AdaptiveAvgPool2d((1, 1)) + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(RESNET_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPoolingAndNoAttention, + config_class=_CONFIG_FOR_DOC, + modality="vision", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward( + self, pixel_values: Tensor, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None + ) -> BaseModelOutputWithPoolingAndNoAttention: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + embedding_output = self.embedder(pixel_values) + + encoder_outputs = self.encoder( + embedding_output, output_hidden_states=output_hidden_states, return_dict=return_dict + ) + + last_hidden_state = encoder_outputs[0] + + pooled_output = self.pooler(last_hidden_state) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndNoAttention( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + ) + + +@add_start_docstrings( + """ + ResNet Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for + ImageNet. + """, + RESNET_START_DOCSTRING, +) +class ResNetForImageClassification(ResNetPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.resnet = ResNetModel(config) + # classification head + self.classifier = nn.Sequential( + nn.Flatten(), + nn.Linear(config.hidden_sizes[-1], config.num_labels) if config.num_labels > 0 else nn.Identity(), + ) + # initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(RESNET_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_IMAGE_CLASS_CHECKPOINT, + output_type=ImageClassifierOutputWithNoAttention, + config_class=_CONFIG_FOR_DOC, + expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, + ) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> ImageClassifierOutputWithNoAttention: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.resnet(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict) + + pooled_output = outputs.pooler_output if return_dict else outputs[1] + + logits = self.classifier(pooled_output) + + loss = None + + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[2:] + return (loss,) + output if loss is not None else output + + return ImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states) + + +@add_start_docstrings( + """ + ResNet backbone, to be used with frameworks like DETR and MaskFormer. + """, + RESNET_START_DOCSTRING, +) +class ResNetBackbone(ResNetPreTrainedModel, BackboneMixin): + def __init__(self, config): + super().__init__(config) + super()._init_backbone(config) + + self.num_features = [config.embedding_size] + config.hidden_sizes + self.embedder = ResNetEmbeddings(config) + self.encoder = ResNetEncoder(config) + + # initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(RESNET_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, pixel_values: Tensor, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None + ) -> BackboneOutput: + """ + Returns: + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, AutoBackbone + >>> import torch + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50") + >>> model = AutoBackbone.from_pretrained( + ... "microsoft/resnet-50", out_features=["stage1", "stage2", "stage3", "stage4"] + ... ) + + >>> inputs = processor(image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> feature_maps = outputs.feature_maps + >>> list(feature_maps[-1].shape) + [1, 2048, 7, 7] + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + embedding_output = self.embedder(pixel_values) + + outputs = self.encoder(embedding_output, output_hidden_states=True, return_dict=True) + + hidden_states = outputs.hidden_states + + feature_maps = () + for idx, stage in enumerate(self.stage_names): + if stage in self.out_features: + feature_maps += (hidden_states[idx],) + + if not return_dict: + output = (feature_maps,) + if output_hidden_states: + output += (outputs.hidden_states,) + return output + + return BackboneOutput( + feature_maps=feature_maps, + hidden_states=outputs.hidden_states if output_hidden_states else None, + attentions=None, + ) + + +__all__ = ["ResNetForImageClassification", "ResNetModel", "ResNetPreTrainedModel", "ResNetBackbone"] diff --git a/docs/transformers/src/transformers/models/resnet/modeling_tf_resnet.py b/docs/transformers/src/transformers/models/resnet/modeling_tf_resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..0f32c04f459c2762a1ec8bdd06e7f2744976cd7e --- /dev/null +++ b/docs/transformers/src/transformers/models/resnet/modeling_tf_resnet.py @@ -0,0 +1,596 @@ +# coding=utf-8 +# Copyright 2022 Microsoft Research, Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TensorFlow ResNet model.""" + +from typing import Optional, Tuple, Union + +import tensorflow as tf + +from ...activations_tf import ACT2FN +from ...modeling_tf_outputs import ( + TFBaseModelOutputWithNoAttention, + TFBaseModelOutputWithPoolingAndNoAttention, + TFImageClassifierOutputWithNoAttention, +) +from ...modeling_tf_utils import ( + TFPreTrainedModel, + TFSequenceClassificationLoss, + keras, + keras_serializable, + unpack_inputs, +) +from ...tf_utils import shape_list +from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging +from .configuration_resnet import ResNetConfig + + +logger = logging.get_logger(__name__) + +# General docstring +_CONFIG_FOR_DOC = "ResNetConfig" + +# Base docstring +_CHECKPOINT_FOR_DOC = "microsoft/resnet-50" +_EXPECTED_OUTPUT_SHAPE = [1, 2048, 7, 7] + +# Image classification docstring +_IMAGE_CLASS_CHECKPOINT = "microsoft/resnet-50" +_IMAGE_CLASS_EXPECTED_OUTPUT = "tiger cat" + + +class TFResNetConvLayer(keras.layers.Layer): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + stride: int = 1, + activation: str = "relu", + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.pad_value = kernel_size // 2 + self.conv = keras.layers.Conv2D( + out_channels, kernel_size=kernel_size, strides=stride, padding="valid", use_bias=False, name="convolution" + ) + # Use same default momentum and epsilon as PyTorch equivalent + self.normalization = keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.9, name="normalization") + self.activation = ACT2FN[activation] if activation is not None else keras.layers.Activation("linear") + self.in_channels = in_channels + self.out_channels = out_channels + + def convolution(self, hidden_state: tf.Tensor) -> tf.Tensor: + # Pad to match that done in the PyTorch Conv2D model + height_pad = width_pad = (self.pad_value, self.pad_value) + hidden_state = tf.pad(hidden_state, [(0, 0), height_pad, width_pad, (0, 0)]) + hidden_state = self.conv(hidden_state) + return hidden_state + + def call(self, hidden_state: tf.Tensor, training: bool = False) -> tf.Tensor: + hidden_state = self.convolution(hidden_state) + hidden_state = self.normalization(hidden_state, training=training) + hidden_state = self.activation(hidden_state) + return hidden_state + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "conv", None) is not None: + with tf.name_scope(self.conv.name): + self.conv.build([None, None, None, self.in_channels]) + if getattr(self, "normalization", None) is not None: + with tf.name_scope(self.normalization.name): + self.normalization.build([None, None, None, self.out_channels]) + + +class TFResNetEmbeddings(keras.layers.Layer): + """ + ResNet Embeddings (stem) composed of a single aggressive convolution. + """ + + def __init__(self, config: ResNetConfig, **kwargs) -> None: + super().__init__(**kwargs) + self.embedder = TFResNetConvLayer( + config.num_channels, + config.embedding_size, + kernel_size=7, + stride=2, + activation=config.hidden_act, + name="embedder", + ) + self.pooler = keras.layers.MaxPool2D(pool_size=3, strides=2, padding="valid", name="pooler") + self.num_channels = config.num_channels + + def call(self, pixel_values: tf.Tensor, training: bool = False) -> tf.Tensor: + _, _, _, num_channels = shape_list(pixel_values) + if tf.executing_eagerly() and num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + hidden_state = pixel_values + hidden_state = self.embedder(hidden_state) + hidden_state = tf.pad(hidden_state, [[0, 0], [1, 1], [1, 1], [0, 0]]) + hidden_state = self.pooler(hidden_state) + return hidden_state + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "embedder", None) is not None: + with tf.name_scope(self.embedder.name): + self.embedder.build(None) + if getattr(self, "pooler", None) is not None: + with tf.name_scope(self.pooler.name): + self.pooler.build(None) + + +class TFResNetShortCut(keras.layers.Layer): + """ + ResNet shortcut, used to project the residual features to the correct size. If needed, it is also used to + downsample the input using `stride=2`. + """ + + def __init__(self, in_channels: int, out_channels: int, stride: int = 2, **kwargs) -> None: + super().__init__(**kwargs) + self.convolution = keras.layers.Conv2D( + out_channels, kernel_size=1, strides=stride, use_bias=False, name="convolution" + ) + # Use same default momentum and epsilon as PyTorch equivalent + self.normalization = keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.9, name="normalization") + self.in_channels = in_channels + self.out_channels = out_channels + + def call(self, x: tf.Tensor, training: bool = False) -> tf.Tensor: + hidden_state = x + hidden_state = self.convolution(hidden_state) + hidden_state = self.normalization(hidden_state, training=training) + return hidden_state + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "convolution", None) is not None: + with tf.name_scope(self.convolution.name): + self.convolution.build([None, None, None, self.in_channels]) + if getattr(self, "normalization", None) is not None: + with tf.name_scope(self.normalization.name): + self.normalization.build([None, None, None, self.out_channels]) + + +class TFResNetBasicLayer(keras.layers.Layer): + """ + A classic ResNet's residual layer composed by two `3x3` convolutions. + """ + + def __init__( + self, in_channels: int, out_channels: int, stride: int = 1, activation: str = "relu", **kwargs + ) -> None: + super().__init__(**kwargs) + should_apply_shortcut = in_channels != out_channels or stride != 1 + self.conv1 = TFResNetConvLayer(in_channels, out_channels, stride=stride, name="layer.0") + self.conv2 = TFResNetConvLayer(out_channels, out_channels, activation=None, name="layer.1") + self.shortcut = ( + TFResNetShortCut(in_channels, out_channels, stride=stride, name="shortcut") + if should_apply_shortcut + else keras.layers.Activation("linear", name="shortcut") + ) + self.activation = ACT2FN[activation] + + def call(self, hidden_state: tf.Tensor, training: bool = False) -> tf.Tensor: + residual = hidden_state + hidden_state = self.conv1(hidden_state, training=training) + hidden_state = self.conv2(hidden_state, training=training) + residual = self.shortcut(residual, training=training) + hidden_state += residual + hidden_state = self.activation(hidden_state) + return hidden_state + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "conv1", None) is not None: + with tf.name_scope(self.conv1.name): + self.conv1.build(None) + if getattr(self, "conv2", None) is not None: + with tf.name_scope(self.conv2.name): + self.conv2.build(None) + if getattr(self, "shortcut", None) is not None: + with tf.name_scope(self.shortcut.name): + self.shortcut.build(None) + + +class TFResNetBottleNeckLayer(keras.layers.Layer): + """ + A classic ResNet's bottleneck layer composed by three `3x3` convolutions. + + The first `1x1` convolution reduces the input by a factor of `reduction` in order to make the second `3x3` + convolution faster. The last `1x1` convolution remaps the reduced features to `out_channels`. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + stride: int = 1, + activation: str = "relu", + reduction: int = 4, + **kwargs, + ) -> None: + super().__init__(**kwargs) + should_apply_shortcut = in_channels != out_channels or stride != 1 + reduces_channels = out_channels // reduction + self.conv0 = TFResNetConvLayer(in_channels, reduces_channels, kernel_size=1, name="layer.0") + self.conv1 = TFResNetConvLayer(reduces_channels, reduces_channels, stride=stride, name="layer.1") + self.conv2 = TFResNetConvLayer(reduces_channels, out_channels, kernel_size=1, activation=None, name="layer.2") + self.shortcut = ( + TFResNetShortCut(in_channels, out_channels, stride=stride, name="shortcut") + if should_apply_shortcut + else keras.layers.Activation("linear", name="shortcut") + ) + self.activation = ACT2FN[activation] + + def call(self, hidden_state: tf.Tensor, training: bool = False) -> tf.Tensor: + residual = hidden_state + hidden_state = self.conv0(hidden_state, training=training) + hidden_state = self.conv1(hidden_state, training=training) + hidden_state = self.conv2(hidden_state, training=training) + residual = self.shortcut(residual, training=training) + hidden_state += residual + hidden_state = self.activation(hidden_state) + return hidden_state + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "conv0", None) is not None: + with tf.name_scope(self.conv0.name): + self.conv0.build(None) + if getattr(self, "conv1", None) is not None: + with tf.name_scope(self.conv1.name): + self.conv1.build(None) + if getattr(self, "conv2", None) is not None: + with tf.name_scope(self.conv2.name): + self.conv2.build(None) + if getattr(self, "shortcut", None) is not None: + with tf.name_scope(self.shortcut.name): + self.shortcut.build(None) + + +class TFResNetStage(keras.layers.Layer): + """ + A ResNet stage composed of stacked layers. + """ + + def __init__( + self, config: ResNetConfig, in_channels: int, out_channels: int, stride: int = 2, depth: int = 2, **kwargs + ) -> None: + super().__init__(**kwargs) + + layer = TFResNetBottleNeckLayer if config.layer_type == "bottleneck" else TFResNetBasicLayer + + layers = [layer(in_channels, out_channels, stride=stride, activation=config.hidden_act, name="layers.0")] + layers += [ + layer(out_channels, out_channels, activation=config.hidden_act, name=f"layers.{i + 1}") + for i in range(depth - 1) + ] + self.stage_layers = layers + + def call(self, hidden_state: tf.Tensor, training: bool = False) -> tf.Tensor: + for layer in self.stage_layers: + hidden_state = layer(hidden_state, training=training) + return hidden_state + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "stage_layers", None) is not None: + for layer in self.stage_layers: + with tf.name_scope(layer.name): + layer.build(None) + + +class TFResNetEncoder(keras.layers.Layer): + def __init__(self, config: ResNetConfig, **kwargs) -> None: + super().__init__(**kwargs) + # based on `downsample_in_first_stage` the first layer of the first stage may or may not downsample the input + self.stages = [ + TFResNetStage( + config, + config.embedding_size, + config.hidden_sizes[0], + stride=2 if config.downsample_in_first_stage else 1, + depth=config.depths[0], + name="stages.0", + ) + ] + for i, (in_channels, out_channels, depth) in enumerate( + zip(config.hidden_sizes, config.hidden_sizes[1:], config.depths[1:]) + ): + self.stages.append(TFResNetStage(config, in_channels, out_channels, depth=depth, name=f"stages.{i + 1}")) + + def call( + self, + hidden_state: tf.Tensor, + output_hidden_states: bool = False, + return_dict: bool = True, + training: bool = False, + ) -> TFBaseModelOutputWithNoAttention: + hidden_states = () if output_hidden_states else None + + for stage_module in self.stages: + if output_hidden_states: + hidden_states = hidden_states + (hidden_state,) + + hidden_state = stage_module(hidden_state, training=training) + + if output_hidden_states: + hidden_states = hidden_states + (hidden_state,) + + if not return_dict: + return tuple(v for v in [hidden_state, hidden_states] if v is not None) + + return TFBaseModelOutputWithNoAttention(last_hidden_state=hidden_state, hidden_states=hidden_states) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "stages", None) is not None: + for layer in self.stages: + with tf.name_scope(layer.name): + layer.build(None) + + +class TFResNetPreTrainedModel(TFPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = ResNetConfig + base_model_prefix = "resnet" + main_input_name = "pixel_values" + + @property + def input_signature(self): + return {"pixel_values": tf.TensorSpec(shape=(None, self.config.num_channels, 224, 224), dtype=tf.float32)} + + +RESNET_START_DOCSTRING = r""" + This model is a TensorFlow + [keras.layers.Layer](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer) sub-class. Use it as a + regular TensorFlow Module and refer to the TensorFlow documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`ResNetConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +RESNET_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See + [`ConvNextImageProcessor.__call__`] for details. + + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@keras_serializable +class TFResNetMainLayer(keras.layers.Layer): + config_class = ResNetConfig + + def __init__(self, config: ResNetConfig, **kwargs) -> None: + super().__init__(**kwargs) + self.config = config + self.embedder = TFResNetEmbeddings(config, name="embedder") + self.encoder = TFResNetEncoder(config, name="encoder") + self.pooler = keras.layers.GlobalAveragePooling2D(keepdims=True) + + @unpack_inputs + def call( + self, + pixel_values: tf.Tensor, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: bool = False, + ) -> Union[Tuple[tf.Tensor], TFBaseModelOutputWithPoolingAndNoAttention]: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # TF 2.0 image layers can't use NCHW format when running on CPU. + # We transpose to NHWC format and then transpose back after the full forward pass. + # (batch_size, num_channels, height, width) -> (batch_size, height, width, num_channels) + pixel_values = tf.transpose(pixel_values, perm=[0, 2, 3, 1]) + embedding_output = self.embedder(pixel_values, training=training) + + encoder_outputs = self.encoder( + embedding_output, output_hidden_states=output_hidden_states, return_dict=return_dict, training=training + ) + + last_hidden_state = encoder_outputs[0] + + pooled_output = self.pooler(last_hidden_state) + + # Transpose all the outputs to the NCHW format + # (batch_size, height, width, num_channels) -> (batch_size, num_channels, height, width) + last_hidden_state = tf.transpose(last_hidden_state, (0, 3, 1, 2)) + pooled_output = tf.transpose(pooled_output, (0, 3, 1, 2)) + hidden_states = () + for hidden_state in encoder_outputs[1:]: + hidden_states = hidden_states + tuple(tf.transpose(h, (0, 3, 1, 2)) for h in hidden_state) + + if not return_dict: + return (last_hidden_state, pooled_output) + hidden_states + + hidden_states = hidden_states if output_hidden_states else None + + return TFBaseModelOutputWithPoolingAndNoAttention( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=hidden_states, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "embedder", None) is not None: + with tf.name_scope(self.embedder.name): + self.embedder.build(None) + if getattr(self, "encoder", None) is not None: + with tf.name_scope(self.encoder.name): + self.encoder.build(None) + + +@add_start_docstrings( + "The bare ResNet model outputting raw features without any specific head on top.", + RESNET_START_DOCSTRING, +) +class TFResNetModel(TFResNetPreTrainedModel): + def __init__(self, config: ResNetConfig, **kwargs) -> None: + super().__init__(config, **kwargs) + self.resnet = TFResNetMainLayer(config=config, name="resnet") + + @add_start_docstrings_to_model_forward(RESNET_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFBaseModelOutputWithPoolingAndNoAttention, + config_class=_CONFIG_FOR_DOC, + modality="vision", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + @unpack_inputs + def call( + self, + pixel_values: tf.Tensor, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: bool = False, + ) -> Union[Tuple[tf.Tensor], TFBaseModelOutputWithPoolingAndNoAttention]: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + resnet_outputs = self.resnet( + pixel_values=pixel_values, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + return resnet_outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "resnet", None) is not None: + with tf.name_scope(self.resnet.name): + self.resnet.build(None) + + +@add_start_docstrings( + """ + ResNet Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for + ImageNet. + """, + RESNET_START_DOCSTRING, +) +class TFResNetForImageClassification(TFResNetPreTrainedModel, TFSequenceClassificationLoss): + def __init__(self, config: ResNetConfig, **kwargs) -> None: + super().__init__(config, **kwargs) + self.num_labels = config.num_labels + self.resnet = TFResNetMainLayer(config, name="resnet") + # classification head + self.classifier_layer = ( + keras.layers.Dense(config.num_labels, name="classifier.1") + if config.num_labels > 0 + else keras.layers.Activation("linear", name="classifier.1") + ) + self.config = config + + def classifier(self, x: tf.Tensor) -> tf.Tensor: + x = keras.layers.Flatten()(x) + logits = self.classifier_layer(x) + return logits + + @add_start_docstrings_to_model_forward(RESNET_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_IMAGE_CLASS_CHECKPOINT, + output_type=TFImageClassifierOutputWithNoAttention, + config_class=_CONFIG_FOR_DOC, + expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, + ) + @unpack_inputs + def call( + self, + pixel_values: Optional[tf.Tensor] = None, + labels: Optional[tf.Tensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: bool = False, + ) -> Union[Tuple[tf.Tensor], TFImageClassifierOutputWithNoAttention]: + r""" + labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.resnet( + pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict, training=training + ) + + pooled_output = outputs.pooler_output if return_dict else outputs[1] + + logits = self.classifier(pooled_output) + + loss = None if labels is None else self.hf_compute_loss(labels, logits) + + if not return_dict: + output = (logits,) + outputs[2:] + return (loss,) + output if loss is not None else output + + return TFImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "resnet", None) is not None: + with tf.name_scope(self.resnet.name): + self.resnet.build(None) + if getattr(self, "classifier_layer", None) is not None: + with tf.name_scope(self.classifier_layer.name): + self.classifier_layer.build([None, None, self.config.hidden_sizes[-1]]) + + +__all__ = ["TFResNetForImageClassification", "TFResNetModel", "TFResNetPreTrainedModel"] diff --git a/docs/transformers/src/transformers/models/roberta/__init__.py b/docs/transformers/src/transformers/models/roberta/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9f9418d33d354ab6192d749d06ee32e39f5de670 --- /dev/null +++ b/docs/transformers/src/transformers/models/roberta/__init__.py @@ -0,0 +1,31 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_roberta import * + from .modeling_flax_roberta import * + from .modeling_roberta import * + from .modeling_tf_roberta import * + from .tokenization_roberta import * + from .tokenization_roberta_fast import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/roberta/configuration_roberta.py b/docs/transformers/src/transformers/models/roberta/configuration_roberta.py new file mode 100644 index 0000000000000000000000000000000000000000..35ff80115f1d021017300b8058de853ff828e5e9 --- /dev/null +++ b/docs/transformers/src/transformers/models/roberta/configuration_roberta.py @@ -0,0 +1,155 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""RoBERTa configuration""" + +from collections import OrderedDict +from typing import Mapping + +from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class RobertaConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`RobertaModel`] or a [`TFRobertaModel`]. It is + used to instantiate a RoBERTa model according to the specified arguments, defining the model architecture. + Instantiating a configuration with the defaults will yield a similar configuration to that of the RoBERTa + [FacebookAI/roberta-base](https://huggingface.co/FacebookAI/roberta-base) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50265): + Vocabulary size of the RoBERTa model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`RobertaModel`] or [`TFRobertaModel`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + type_vocab_size (`int`, *optional*, defaults to 2): + The vocabulary size of the `token_type_ids` passed when calling [`RobertaModel`] or [`TFRobertaModel`]. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + position_embedding_type (`str`, *optional*, defaults to `"absolute"`): + Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For + positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to + [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155). + For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models + with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658). + is_decoder (`bool`, *optional*, defaults to `False`): + Whether the model is used as a decoder or not. If `False`, the model is used as an encoder. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + classifier_dropout (`float`, *optional*): + The dropout ratio for the classification head. + + Examples: + + ```python + >>> from transformers import RobertaConfig, RobertaModel + + >>> # Initializing a RoBERTa configuration + >>> configuration = RobertaConfig() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = RobertaModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "roberta" + + def __init__( + self, + vocab_size=50265, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + layer_norm_eps=1e-12, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + position_embedding_type="absolute", + use_cache=True, + classifier_dropout=None, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.position_embedding_type = position_embedding_type + self.use_cache = use_cache + self.classifier_dropout = classifier_dropout + + +class RobertaOnnxConfig(OnnxConfig): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + if self.task == "multiple-choice": + dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"} + else: + dynamic_axis = {0: "batch", 1: "sequence"} + return OrderedDict( + [ + ("input_ids", dynamic_axis), + ("attention_mask", dynamic_axis), + ] + ) + + +__all__ = ["RobertaConfig", "RobertaOnnxConfig"] diff --git a/docs/transformers/src/transformers/models/roberta/convert_roberta_original_pytorch_checkpoint_to_pytorch.py b/docs/transformers/src/transformers/models/roberta/convert_roberta_original_pytorch_checkpoint_to_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..c0e6bf94d2eb7241a19ccd197ca0cbd95be61b2a --- /dev/null +++ b/docs/transformers/src/transformers/models/roberta/convert_roberta_original_pytorch_checkpoint_to_pytorch.py @@ -0,0 +1,177 @@ +# coding=utf-8 +# Copyright 2018 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert RoBERTa checkpoint.""" + +import argparse +import pathlib + +import fairseq +import torch +from fairseq.models.roberta import RobertaModel as FairseqRobertaModel +from fairseq.modules import TransformerSentenceEncoderLayer +from packaging import version + +from transformers import RobertaConfig, RobertaForMaskedLM, RobertaForSequenceClassification +from transformers.models.bert.modeling_bert import ( + BertIntermediate, + BertLayer, + BertOutput, + BertSelfAttention, + BertSelfOutput, +) +from transformers.utils import logging + + +if version.parse(fairseq.__version__) < version.parse("0.9.0"): + raise Exception("requires fairseq >= 0.9.0") + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + +SAMPLE_TEXT = "Hello world! cécé herlolip" + + +def convert_roberta_checkpoint_to_pytorch( + roberta_checkpoint_path: str, pytorch_dump_folder_path: str, classification_head: bool +): + """ + Copy/paste/tweak roberta's weights to our BERT structure. + """ + roberta = FairseqRobertaModel.from_pretrained(roberta_checkpoint_path) + roberta.eval() # disable dropout + roberta_sent_encoder = roberta.model.encoder.sentence_encoder + config = RobertaConfig( + vocab_size=roberta_sent_encoder.embed_tokens.num_embeddings, + hidden_size=roberta.args.encoder_embed_dim, + num_hidden_layers=roberta.args.encoder_layers, + num_attention_heads=roberta.args.encoder_attention_heads, + intermediate_size=roberta.args.encoder_ffn_embed_dim, + max_position_embeddings=514, + type_vocab_size=1, + layer_norm_eps=1e-5, # PyTorch default used in fairseq + ) + if classification_head: + config.num_labels = roberta.model.classification_heads["mnli"].out_proj.weight.shape[0] + print("Our BERT config:", config) + + model = RobertaForSequenceClassification(config) if classification_head else RobertaForMaskedLM(config) + model.eval() + + # Now let's copy all the weights. + # Embeddings + model.roberta.embeddings.word_embeddings.weight = roberta_sent_encoder.embed_tokens.weight + model.roberta.embeddings.position_embeddings.weight = roberta_sent_encoder.embed_positions.weight + model.roberta.embeddings.token_type_embeddings.weight.data = torch.zeros_like( + model.roberta.embeddings.token_type_embeddings.weight + ) # just zero them out b/c RoBERTa doesn't use them. + model.roberta.embeddings.LayerNorm.weight = roberta_sent_encoder.emb_layer_norm.weight + model.roberta.embeddings.LayerNorm.bias = roberta_sent_encoder.emb_layer_norm.bias + + for i in range(config.num_hidden_layers): + # Encoder: start of layer + layer: BertLayer = model.roberta.encoder.layer[i] + roberta_layer: TransformerSentenceEncoderLayer = roberta_sent_encoder.layers[i] + + # self attention + self_attn: BertSelfAttention = layer.attention.self + assert ( + roberta_layer.self_attn.k_proj.weight.data.shape + == roberta_layer.self_attn.q_proj.weight.data.shape + == roberta_layer.self_attn.v_proj.weight.data.shape + == torch.Size((config.hidden_size, config.hidden_size)) + ) + + self_attn.query.weight.data = roberta_layer.self_attn.q_proj.weight + self_attn.query.bias.data = roberta_layer.self_attn.q_proj.bias + self_attn.key.weight.data = roberta_layer.self_attn.k_proj.weight + self_attn.key.bias.data = roberta_layer.self_attn.k_proj.bias + self_attn.value.weight.data = roberta_layer.self_attn.v_proj.weight + self_attn.value.bias.data = roberta_layer.self_attn.v_proj.bias + + # self-attention output + self_output: BertSelfOutput = layer.attention.output + assert self_output.dense.weight.shape == roberta_layer.self_attn.out_proj.weight.shape + self_output.dense.weight = roberta_layer.self_attn.out_proj.weight + self_output.dense.bias = roberta_layer.self_attn.out_proj.bias + self_output.LayerNorm.weight = roberta_layer.self_attn_layer_norm.weight + self_output.LayerNorm.bias = roberta_layer.self_attn_layer_norm.bias + + # intermediate + intermediate: BertIntermediate = layer.intermediate + assert intermediate.dense.weight.shape == roberta_layer.fc1.weight.shape + intermediate.dense.weight = roberta_layer.fc1.weight + intermediate.dense.bias = roberta_layer.fc1.bias + + # output + bert_output: BertOutput = layer.output + assert bert_output.dense.weight.shape == roberta_layer.fc2.weight.shape + bert_output.dense.weight = roberta_layer.fc2.weight + bert_output.dense.bias = roberta_layer.fc2.bias + bert_output.LayerNorm.weight = roberta_layer.final_layer_norm.weight + bert_output.LayerNorm.bias = roberta_layer.final_layer_norm.bias + # end of layer + + if classification_head: + model.classifier.dense.weight = roberta.model.classification_heads["mnli"].dense.weight + model.classifier.dense.bias = roberta.model.classification_heads["mnli"].dense.bias + model.classifier.out_proj.weight = roberta.model.classification_heads["mnli"].out_proj.weight + model.classifier.out_proj.bias = roberta.model.classification_heads["mnli"].out_proj.bias + else: + # LM Head + model.lm_head.dense.weight = roberta.model.encoder.lm_head.dense.weight + model.lm_head.dense.bias = roberta.model.encoder.lm_head.dense.bias + model.lm_head.layer_norm.weight = roberta.model.encoder.lm_head.layer_norm.weight + model.lm_head.layer_norm.bias = roberta.model.encoder.lm_head.layer_norm.bias + model.lm_head.decoder.weight = roberta.model.encoder.lm_head.weight + model.lm_head.decoder.bias = roberta.model.encoder.lm_head.bias + + # Let's check that we get the same results. + input_ids: torch.Tensor = roberta.encode(SAMPLE_TEXT).unsqueeze(0) # batch of size 1 + + our_output = model(input_ids)[0] + if classification_head: + their_output = roberta.model.classification_heads["mnli"](roberta.extract_features(input_ids)) + else: + their_output = roberta.model(input_ids)[0] + print(our_output.shape, their_output.shape) + max_absolute_diff = torch.max(torch.abs(our_output - their_output)).item() + print(f"max_absolute_diff = {max_absolute_diff}") # ~ 1e-7 + success = torch.allclose(our_output, their_output, atol=1e-3) + print("Do both models output the same tensors?", "🔥" if success else "💩") + if not success: + raise Exception("Something went wRoNg") + + pathlib.Path(pytorch_dump_folder_path).mkdir(parents=True, exist_ok=True) + print(f"Saving model to {pytorch_dump_folder_path}") + model.save_pretrained(pytorch_dump_folder_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--roberta_checkpoint_path", default=None, type=str, required=True, help="Path the official PyTorch dump." + ) + parser.add_argument( + "--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model." + ) + parser.add_argument( + "--classification_head", action="store_true", help="Whether to convert a final classification head." + ) + args = parser.parse_args() + convert_roberta_checkpoint_to_pytorch( + args.roberta_checkpoint_path, args.pytorch_dump_folder_path, args.classification_head + ) diff --git a/docs/transformers/src/transformers/models/roberta/modeling_flax_roberta.py b/docs/transformers/src/transformers/models/roberta/modeling_flax_roberta.py new file mode 100644 index 0000000000000000000000000000000000000000..2beb0a06b8d7c1c693f150e9d7201417c469cd62 --- /dev/null +++ b/docs/transformers/src/transformers/models/roberta/modeling_flax_roberta.py @@ -0,0 +1,1500 @@ +# coding=utf-8 +# Copyright 2021 The Google Flax Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable, Optional, Tuple + +import flax.linen as nn +import jax +import jax.numpy as jnp +import numpy as np +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze +from flax.linen import combine_masks, make_causal_mask +from flax.linen import partitioning as nn_partitioning +from flax.linen.attention import dot_product_attention_weights +from flax.traverse_util import flatten_dict, unflatten_dict +from jax import lax + +from ...modeling_flax_outputs import ( + FlaxBaseModelOutputWithPastAndCrossAttentions, + FlaxBaseModelOutputWithPooling, + FlaxBaseModelOutputWithPoolingAndCrossAttentions, + FlaxCausalLMOutputWithCrossAttentions, + FlaxMaskedLMOutput, + FlaxMultipleChoiceModelOutput, + FlaxQuestionAnsweringModelOutput, + FlaxSequenceClassifierOutput, + FlaxTokenClassifierOutput, +) +from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring, overwrite_call_docstring +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging +from .configuration_roberta import RobertaConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "FacebookAI/roberta-base" +_CONFIG_FOR_DOC = "RobertaConfig" + +remat = nn_partitioning.remat + + +def create_position_ids_from_input_ids(input_ids, padding_idx): + """ + Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols + are ignored. This is modified from fairseq's `utils.make_positions`. + + Args: + input_ids: jnp.ndarray + padding_idx: int + + Returns: jnp.ndarray + """ + # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. + mask = (input_ids != padding_idx).astype("i4") + + if mask.ndim > 2: + mask = mask.reshape((-1, mask.shape[-1])) + incremental_indices = jnp.cumsum(mask, axis=1).astype("i4") * mask + incremental_indices = incremental_indices.reshape(input_ids.shape) + else: + incremental_indices = jnp.cumsum(mask, axis=1).astype("i4") * mask + + return incremental_indices.astype("i4") + padding_idx + + +ROBERTA_START_DOCSTRING = r""" + + This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading, saving and converting weights from PyTorch models) + + This model is also a + [flax.linen.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) subclass. Use it as + a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and + behavior. + + Finally, this model supports inherent JAX features such as: + + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + + Parameters: + config ([`RobertaConfig`]): Model configuration class with all the parameters of the + model. Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. +""" + +ROBERTA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`numpy.ndarray` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`numpy.ndarray` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`numpy.ndarray` of shape `({0})`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`numpy.ndarray` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + head_mask (`numpy.ndarray` of shape `({0})`, `optional): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEmbeddings with Bert->Roberta +class FlaxRobertaEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + config: RobertaConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.word_embeddings = nn.Embed( + self.config.vocab_size, + self.config.hidden_size, + embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), + dtype=self.dtype, + ) + self.position_embeddings = nn.Embed( + self.config.max_position_embeddings, + self.config.hidden_size, + embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), + dtype=self.dtype, + ) + self.token_type_embeddings = nn.Embed( + self.config.type_vocab_size, + self.config.hidden_size, + embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), + dtype=self.dtype, + ) + self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + + def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True): + # Embed + inputs_embeds = self.word_embeddings(input_ids.astype("i4")) + position_embeds = self.position_embeddings(position_ids.astype("i4")) + token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4")) + + # Sum all embeddings + hidden_states = inputs_embeds + token_type_embeddings + position_embeds + + # Layer Norm + hidden_states = self.LayerNorm(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + return hidden_states + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfAttention with Bert->Roberta +class FlaxRobertaSelfAttention(nn.Module): + config: RobertaConfig + causal: bool = False + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.head_dim = self.config.hidden_size // self.config.num_attention_heads + if self.config.hidden_size % self.config.num_attention_heads != 0: + raise ValueError( + "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads` " + " : {self.config.num_attention_heads}" + ) + + self.query = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + ) + self.key = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + ) + self.value = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + ) + + if self.causal: + self.causal_mask = make_causal_mask( + jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool" + ) + + def _split_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.config.num_attention_heads, self.head_dim)) + + def _merge_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.config.hidden_size,)) + + @nn.compact + # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention._concatenate_to_cache + def _concatenate_to_cache(self, key, value, query, attention_mask): + """ + This function takes projected key, value states from a single input token and concatenates the states to cached + states from previous steps. This function is slightly adapted from the official Flax repository: + https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 + """ + # detect if we're initializing by absence of existing cache data. + is_initialized = self.has_variable("cache", "cached_key") + cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) + cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) + cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) + + if is_initialized: + *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape + # update key, value caches with our new 1d spatial slices + cur_index = cache_index.value + indices = (0,) * len(batch_dims) + (cur_index, 0, 0) + key = lax.dynamic_update_slice(cached_key.value, key, indices) + value = lax.dynamic_update_slice(cached_value.value, value, indices) + cached_key.value = key + cached_value.value = value + num_updated_cache_vectors = query.shape[1] + cache_index.value = cache_index.value + num_updated_cache_vectors + # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements. + pad_mask = jnp.broadcast_to( + jnp.arange(max_length) < cur_index + num_updated_cache_vectors, + tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), + ) + attention_mask = combine_masks(pad_mask, attention_mask) + return key, value, attention_mask + + def __call__( + self, + hidden_states, + attention_mask, + layer_head_mask, + key_value_states: Optional[jnp.ndarray] = None, + init_cache: bool = False, + deterministic=True, + output_attentions: bool = False, + ): + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + batch_size = hidden_states.shape[0] + + # get query proj + query_states = self.query(hidden_states) + # get key, value proj + if is_cross_attention: + # cross_attentions + key_states = self.key(key_value_states) + value_states = self.value(key_value_states) + else: + # self_attention + key_states = self.key(hidden_states) + value_states = self.value(hidden_states) + + query_states = self._split_heads(query_states) + key_states = self._split_heads(key_states) + value_states = self._split_heads(value_states) + + # handle cache prepare causal attention mask + if self.causal: + query_length, key_length = query_states.shape[1], key_states.shape[1] + if self.has_variable("cache", "cached_key"): + mask_shift = self.variables["cache"]["cache_index"] + max_decoder_length = self.variables["cache"]["cached_key"].shape[1] + causal_mask = lax.dynamic_slice( + self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length) + ) + else: + causal_mask = self.causal_mask[:, :, :query_length, :key_length] + causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) + + # combine masks if needed + if attention_mask is not None and self.causal: + attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) + attention_mask = combine_masks(attention_mask, causal_mask) + elif self.causal: + attention_mask = causal_mask + elif attention_mask is not None: + attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) + + # During fast autoregressive decoding, we feed one position at a time, + # and cache the keys and values step by step. + if self.causal and (self.has_variable("cache", "cached_key") or init_cache): + key_states, value_states, attention_mask = self._concatenate_to_cache( + key_states, value_states, query_states, attention_mask + ) + + # Convert the boolean attention mask to an attention bias. + if attention_mask is not None: + # attention mask in the form of attention bias + attention_bias = lax.select( + attention_mask > 0, + jnp.full(attention_mask.shape, 0.0).astype(self.dtype), + jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), + ) + else: + attention_bias = None + + dropout_rng = None + if not deterministic and self.config.attention_probs_dropout_prob > 0.0: + dropout_rng = self.make_rng("dropout") + + attn_weights = dot_product_attention_weights( + query_states, + key_states, + bias=attention_bias, + dropout_rng=dropout_rng, + dropout_rate=self.config.attention_probs_dropout_prob, + broadcast_dropout=True, + deterministic=deterministic, + dtype=self.dtype, + precision=None, + ) + + # Mask heads if we want to + if layer_head_mask is not None: + attn_weights = jnp.einsum("...hqk,h->...hqk", attn_weights, layer_head_mask) + + attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) + attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,)) + + outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) + return outputs + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfOutput with Bert->Roberta +class FlaxRobertaSelfOutput(nn.Module): + config: RobertaConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dense = nn.Dense( + self.config.hidden_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + ) + self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + + def __call__(self, hidden_states, input_tensor, deterministic: bool = True): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertAttention with Bert->Roberta +class FlaxRobertaAttention(nn.Module): + config: RobertaConfig + causal: bool = False + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.self = FlaxRobertaSelfAttention(self.config, causal=self.causal, dtype=self.dtype) + self.output = FlaxRobertaSelfOutput(self.config, dtype=self.dtype) + + def __call__( + self, + hidden_states, + attention_mask, + layer_head_mask, + key_value_states=None, + init_cache=False, + deterministic=True, + output_attentions: bool = False, + ): + # Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length) + # FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable + # with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length) + attn_outputs = self.self( + hidden_states, + attention_mask, + layer_head_mask=layer_head_mask, + key_value_states=key_value_states, + init_cache=init_cache, + deterministic=deterministic, + output_attentions=output_attentions, + ) + attn_output = attn_outputs[0] + hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_outputs[1],) + + return outputs + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertIntermediate with Bert->Roberta +class FlaxRobertaIntermediate(nn.Module): + config: RobertaConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dense = nn.Dense( + self.config.intermediate_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + ) + self.activation = ACT2FN[self.config.hidden_act] + + def __call__(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.activation(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertOutput with Bert->Roberta +class FlaxRobertaOutput(nn.Module): + config: RobertaConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dense = nn.Dense( + self.config.hidden_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + ) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + + def __call__(self, hidden_states, attention_output, deterministic: bool = True): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + hidden_states = self.LayerNorm(hidden_states + attention_output) + return hidden_states + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayer with Bert->Roberta +class FlaxRobertaLayer(nn.Module): + config: RobertaConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.attention = FlaxRobertaAttention(self.config, causal=self.config.is_decoder, dtype=self.dtype) + self.intermediate = FlaxRobertaIntermediate(self.config, dtype=self.dtype) + self.output = FlaxRobertaOutput(self.config, dtype=self.dtype) + if self.config.add_cross_attention: + self.crossattention = FlaxRobertaAttention(self.config, causal=False, dtype=self.dtype) + + def __call__( + self, + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + deterministic: bool = True, + output_attentions: bool = False, + ): + # Self Attention + attention_outputs = self.attention( + hidden_states, + attention_mask, + layer_head_mask=layer_head_mask, + init_cache=init_cache, + deterministic=deterministic, + output_attentions=output_attentions, + ) + attention_output = attention_outputs[0] + + # Cross-Attention Block + if encoder_hidden_states is not None: + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask=encoder_attention_mask, + layer_head_mask=layer_head_mask, + key_value_states=encoder_hidden_states, + deterministic=deterministic, + output_attentions=output_attentions, + ) + attention_output = cross_attention_outputs[0] + + hidden_states = self.intermediate(attention_output) + hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attention_outputs[1],) + if encoder_hidden_states is not None: + outputs += (cross_attention_outputs[1],) + return outputs + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayerCollection with Bert->Roberta +class FlaxRobertaLayerCollection(nn.Module): + config: RobertaConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False + + def setup(self): + if self.gradient_checkpointing: + FlaxRobertaCheckpointLayer = remat(FlaxRobertaLayer, static_argnums=(5, 6, 7)) + self.layers = [ + FlaxRobertaCheckpointLayer(self.config, name=str(i), dtype=self.dtype) + for i in range(self.config.num_hidden_layers) + ] + else: + self.layers = [ + FlaxRobertaLayer(self.config, name=str(i), dtype=self.dtype) + for i in range(self.config.num_hidden_layers) + ] + + def __call__( + self, + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + all_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + + # Check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + if head_mask.shape[0] != (len(self.layers)): + raise ValueError( + f"The head_mask should be specified for {len(self.layers)} layers, but it is for " + f" {head_mask.shape[0]}." + ) + + for i, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = layer( + hidden_states, + attention_mask, + head_mask[i] if head_mask is not None else None, + encoder_hidden_states, + encoder_attention_mask, + init_cache, + deterministic, + output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + outputs = (hidden_states, all_hidden_states, all_attentions, all_cross_attentions) + + if not return_dict: + return tuple(v for v in outputs if v is not None) + + return FlaxBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ) + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEncoder with Bert->Roberta +class FlaxRobertaEncoder(nn.Module): + config: RobertaConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False + + def setup(self): + self.layer = FlaxRobertaLayerCollection( + self.config, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) + + def __call__( + self, + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + return self.layer( + hidden_states, + attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + init_cache=init_cache, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPooler with Bert->Roberta +class FlaxRobertaPooler(nn.Module): + config: RobertaConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dense = nn.Dense( + self.config.hidden_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + ) + + def __call__(self, hidden_states): + cls_hidden_state = hidden_states[:, 0] + cls_hidden_state = self.dense(cls_hidden_state) + return nn.tanh(cls_hidden_state) + + +class FlaxRobertaLMHead(nn.Module): + config: RobertaConfig + dtype: jnp.dtype = jnp.float32 + bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros + + def setup(self): + self.dense = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + ) + self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + self.decoder = nn.Dense( + self.config.vocab_size, + dtype=self.dtype, + use_bias=False, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + ) + self.bias = self.param("bias", self.bias_init, (self.config.vocab_size,)) + + def __call__(self, hidden_states, shared_embedding=None): + hidden_states = self.dense(hidden_states) + hidden_states = ACT2FN["gelu"](hidden_states) + hidden_states = self.layer_norm(hidden_states) + + if shared_embedding is not None: + hidden_states = self.decoder.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) + else: + hidden_states = self.decoder(hidden_states) + + bias = jnp.asarray(self.bias, self.dtype) + hidden_states += bias + return hidden_states + + +class FlaxRobertaClassificationHead(nn.Module): + config: RobertaConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.dense = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + ) + classifier_dropout = ( + self.config.classifier_dropout + if self.config.classifier_dropout is not None + else self.config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(rate=classifier_dropout) + self.out_proj = nn.Dense( + self.config.num_labels, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + ) + + def __call__(self, hidden_states, deterministic=True): + hidden_states = hidden_states[:, 0, :] # take token (equiv. to [CLS]) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + hidden_states = self.dense(hidden_states) + hidden_states = nn.tanh(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + hidden_states = self.out_proj(hidden_states) + return hidden_states + + +class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = RobertaConfig + base_model_prefix = "roberta" + + module_class: nn.Module = None + + def __init__( + self, + config: RobertaConfig, + input_shape: Tuple = (1, 1), + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, + gradient_checkpointing: bool = False, + **kwargs, + ): + module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + + # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.enable_gradient_checkpointing + def enable_gradient_checkpointing(self): + self._module = self.module_class( + config=self.config, + dtype=self.dtype, + gradient_checkpointing=True, + ) + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: + # init input tensors + input_ids = jnp.zeros(input_shape, dtype="i4") + token_type_ids = jnp.ones_like(input_ids) + position_ids = create_position_ids_from_input_ids(input_ids, self.config.pad_token_id) + attention_mask = jnp.ones_like(input_ids) + head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads)) + + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + if self.config.add_cross_attention: + encoder_hidden_states = jnp.zeros(input_shape + (self.config.hidden_size,)) + encoder_attention_mask = attention_mask + module_init_outputs = self.module.init( + rngs, + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + return_dict=False, + ) + else: + module_init_outputs = self.module.init( + rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False + ) + + random_params = module_init_outputs["params"] + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + + # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartDecoderPreTrainedModel.init_cache + def init_cache(self, batch_size, max_length): + r""" + Args: + batch_size (`int`): + batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. + max_length (`int`): + maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized + cache. + """ + # init input variables to retrieve cache + input_ids = jnp.ones((batch_size, max_length), dtype="i4") + attention_mask = jnp.ones_like(input_ids, dtype="i4") + position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) + + init_variables = self.module.init( + jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True + ) + return unfreeze(init_variables["cache"]) + + @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + def __call__( + self, + input_ids, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + params: dict = None, + dropout_rng: jax.random.PRNGKey = None, + train: bool = False, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + past_key_values: dict = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + # init input tensors if not passed + if token_type_ids is None: + token_type_ids = jnp.zeros_like(input_ids) + + if position_ids is None: + position_ids = create_position_ids_from_input_ids(input_ids, self.config.pad_token_id) + + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + + if head_mask is None: + head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads)) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + inputs = {"params": params or self.params} + + if self.config.add_cross_attention: + # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed + # down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be + # changed by FlaxRobertaAttention module + if past_key_values: + inputs["cache"] = past_key_values + mutable = ["cache"] + else: + mutable = False + + outputs = self.module.apply( + inputs, + jnp.array(input_ids, dtype="i4"), + jnp.array(attention_mask, dtype="i4"), + token_type_ids=jnp.array(token_type_ids, dtype="i4"), + position_ids=jnp.array(position_ids, dtype="i4"), + head_mask=jnp.array(head_mask, dtype="i4"), + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + deterministic=not train, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + rngs=rngs, + mutable=mutable, + ) + + # add updated cache to model output + if past_key_values is not None and return_dict: + outputs, past_key_values = outputs + outputs["past_key_values"] = unfreeze(past_key_values["cache"]) + return outputs + elif past_key_values is not None and not return_dict: + outputs, past_key_values = outputs + outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:] + + else: + outputs = self.module.apply( + inputs, + jnp.array(input_ids, dtype="i4"), + jnp.array(attention_mask, dtype="i4"), + token_type_ids=jnp.array(token_type_ids, dtype="i4"), + position_ids=jnp.array(position_ids, dtype="i4"), + head_mask=jnp.array(head_mask, dtype="i4"), + deterministic=not train, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + rngs=rngs, + ) + + return outputs + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertModule with Bert->Roberta +class FlaxRobertaModule(nn.Module): + config: RobertaConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + add_pooling_layer: bool = True + gradient_checkpointing: bool = False + + def setup(self): + self.embeddings = FlaxRobertaEmbeddings(self.config, dtype=self.dtype) + self.encoder = FlaxRobertaEncoder( + self.config, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) + self.pooler = FlaxRobertaPooler(self.config, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids: Optional[jnp.ndarray] = None, + position_ids: Optional[jnp.ndarray] = None, + head_mask: Optional[jnp.ndarray] = None, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # make sure `token_type_ids` is correctly initialized when not passed + if token_type_ids is None: + token_type_ids = jnp.zeros_like(input_ids) + + # make sure `position_ids` is correctly initialized when not passed + if position_ids is None: + position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) + + hidden_states = self.embeddings( + input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic + ) + outputs = self.encoder( + hidden_states, + attention_mask, + head_mask=head_mask, + deterministic=deterministic, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + init_cache=init_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = outputs[0] + pooled = self.pooler(hidden_states) if self.add_pooling_layer else None + + if not return_dict: + # if pooled is None, don't return it + if pooled is None: + return (hidden_states,) + outputs[1:] + return (hidden_states, pooled) + outputs[1:] + + return FlaxBaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=hidden_states, + pooler_output=pooled, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +@add_start_docstrings( + "The bare RoBERTa Model transformer outputting raw hidden-states without any specific head on top.", + ROBERTA_START_DOCSTRING, +) +class FlaxRobertaModel(FlaxRobertaPreTrainedModel): + module_class = FlaxRobertaModule + + +append_call_sample_docstring(FlaxRobertaModel, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutputWithPooling, _CONFIG_FOR_DOC) + + +class FlaxRobertaForMaskedLMModule(nn.Module): + config: RobertaConfig + dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False + + def setup(self): + self.roberta = FlaxRobertaModule( + config=self.config, + add_pooling_layer=False, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) + self.lm_head = FlaxRobertaLMHead(config=self.config, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # Model + outputs = self.roberta( + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + if self.config.tie_word_embeddings: + shared_embedding = self.roberta.variables["params"]["embeddings"]["word_embeddings"]["embedding"] + else: + shared_embedding = None + + # Compute the prediction scores + logits = self.lm_head(hidden_states, shared_embedding=shared_embedding) + + if not return_dict: + return (logits,) + outputs[1:] + + return FlaxMaskedLMOutput( + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings("""RoBERTa Model with a `language modeling` head on top.""", ROBERTA_START_DOCSTRING) +class FlaxRobertaForMaskedLM(FlaxRobertaPreTrainedModel): + module_class = FlaxRobertaForMaskedLMModule + + +append_call_sample_docstring( + FlaxRobertaForMaskedLM, + _CHECKPOINT_FOR_DOC, + FlaxBaseModelOutputWithPooling, + _CONFIG_FOR_DOC, + mask="", +) + + +class FlaxRobertaForSequenceClassificationModule(nn.Module): + config: RobertaConfig + dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False + + def setup(self): + self.roberta = FlaxRobertaModule( + config=self.config, + dtype=self.dtype, + add_pooling_layer=False, + gradient_checkpointing=self.gradient_checkpointing, + ) + self.classifier = FlaxRobertaClassificationHead(config=self.config, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # Model + outputs = self.roberta( + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + logits = self.classifier(sequence_output, deterministic=deterministic) + + if not return_dict: + return (logits,) + outputs[1:] + + return FlaxSequenceClassifierOutput( + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Roberta Model transformer with a sequence classification/regression head on top (a linear layer on top of the + pooled output) e.g. for GLUE tasks. + """, + ROBERTA_START_DOCSTRING, +) +class FlaxRobertaForSequenceClassification(FlaxRobertaPreTrainedModel): + module_class = FlaxRobertaForSequenceClassificationModule + + +append_call_sample_docstring( + FlaxRobertaForSequenceClassification, + _CHECKPOINT_FOR_DOC, + FlaxSequenceClassifierOutput, + _CONFIG_FOR_DOC, +) + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForMultipleChoiceModule with Bert->Roberta, with self.bert->self.roberta +class FlaxRobertaForMultipleChoiceModule(nn.Module): + config: RobertaConfig + dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False + + def setup(self): + self.roberta = FlaxRobertaModule( + config=self.config, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + self.classifier = nn.Dense(1, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + num_choices = input_ids.shape[1] + input_ids = input_ids.reshape(-1, input_ids.shape[-1]) if input_ids is not None else None + attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1]) if attention_mask is not None else None + token_type_ids = token_type_ids.reshape(-1, token_type_ids.shape[-1]) if token_type_ids is not None else None + position_ids = position_ids.reshape(-1, position_ids.shape[-1]) if position_ids is not None else None + + # Model + outputs = self.roberta( + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + pooled_output = self.dropout(pooled_output, deterministic=deterministic) + logits = self.classifier(pooled_output) + + reshaped_logits = logits.reshape(-1, num_choices) + + if not return_dict: + return (reshaped_logits,) + outputs[2:] + + return FlaxMultipleChoiceModelOutput( + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Roberta Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a + softmax) e.g. for RocStories/SWAG tasks. + """, + ROBERTA_START_DOCSTRING, +) +class FlaxRobertaForMultipleChoice(FlaxRobertaPreTrainedModel): + module_class = FlaxRobertaForMultipleChoiceModule + + +overwrite_call_docstring( + FlaxRobertaForMultipleChoice, ROBERTA_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length") +) +append_call_sample_docstring( + FlaxRobertaForMultipleChoice, + _CHECKPOINT_FOR_DOC, + FlaxMultipleChoiceModelOutput, + _CONFIG_FOR_DOC, +) + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForTokenClassificationModule with Bert->Roberta, with self.bert->self.roberta +class FlaxRobertaForTokenClassificationModule(nn.Module): + config: RobertaConfig + dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False + + def setup(self): + self.roberta = FlaxRobertaModule( + config=self.config, + dtype=self.dtype, + add_pooling_layer=False, + gradient_checkpointing=self.gradient_checkpointing, + ) + classifier_dropout = ( + self.config.classifier_dropout + if self.config.classifier_dropout is not None + else self.config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(rate=classifier_dropout) + self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # Model + outputs = self.roberta( + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + logits = self.classifier(hidden_states) + + if not return_dict: + return (logits,) + outputs[1:] + + return FlaxTokenClassifierOutput( + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Roberta Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + ROBERTA_START_DOCSTRING, +) +class FlaxRobertaForTokenClassification(FlaxRobertaPreTrainedModel): + module_class = FlaxRobertaForTokenClassificationModule + + +append_call_sample_docstring( + FlaxRobertaForTokenClassification, + _CHECKPOINT_FOR_DOC, + FlaxTokenClassifierOutput, + _CONFIG_FOR_DOC, +) + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForQuestionAnsweringModule with Bert->Roberta, with self.bert->self.roberta +class FlaxRobertaForQuestionAnsweringModule(nn.Module): + config: RobertaConfig + dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False + + def setup(self): + self.roberta = FlaxRobertaModule( + config=self.config, + dtype=self.dtype, + add_pooling_layer=False, + gradient_checkpointing=self.gradient_checkpointing, + ) + self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # Model + outputs = self.roberta( + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + + logits = self.qa_outputs(hidden_states) + start_logits, end_logits = jnp.split(logits, self.config.num_labels, axis=-1) + start_logits = start_logits.squeeze(-1) + end_logits = end_logits.squeeze(-1) + + if not return_dict: + return (start_logits, end_logits) + outputs[1:] + + return FlaxQuestionAnsweringModelOutput( + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Roberta Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear + layers on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + ROBERTA_START_DOCSTRING, +) +class FlaxRobertaForQuestionAnswering(FlaxRobertaPreTrainedModel): + module_class = FlaxRobertaForQuestionAnsweringModule + + +append_call_sample_docstring( + FlaxRobertaForQuestionAnswering, + _CHECKPOINT_FOR_DOC, + FlaxQuestionAnsweringModelOutput, + _CONFIG_FOR_DOC, +) + + +class FlaxRobertaForCausalLMModule(nn.Module): + config: RobertaConfig + dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False + + def setup(self): + self.roberta = FlaxRobertaModule( + config=self.config, + add_pooling_layer=False, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) + self.lm_head = FlaxRobertaLMHead(config=self.config, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + position_ids, + token_type_ids: Optional[jnp.ndarray] = None, + head_mask: Optional[jnp.ndarray] = None, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # Model + outputs = self.roberta( + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + init_cache=init_cache, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + if self.config.tie_word_embeddings: + shared_embedding = self.roberta.variables["params"]["embeddings"]["word_embeddings"]["embedding"] + else: + shared_embedding = None + + # Compute the prediction scores + logits = self.lm_head(hidden_states, shared_embedding=shared_embedding) + + if not return_dict: + return (logits,) + outputs[1:] + + return FlaxCausalLMOutputWithCrossAttentions( + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +@add_start_docstrings( + """ + Roberta Model with a language modeling head on top (a linear layer on top of the hidden-states output) e.g for + autoregressive tasks. + """, + ROBERTA_START_DOCSTRING, +) +class FlaxRobertaForCausalLM(FlaxRobertaPreTrainedModel): + module_class = FlaxRobertaForCausalLMModule + + def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None): + # initializing the cache + batch_size, seq_length = input_ids.shape + + past_key_values = self.init_cache(batch_size, max_length) + # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. + # But since the decoder uses a causal mask, those positions are masked anyway. + # Thus, we can create a single static attention_mask here, which is more efficient for compilation + extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") + if attention_mask is not None: + position_ids = attention_mask.cumsum(axis=-1) - 1 + extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0)) + else: + position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) + + return { + "past_key_values": past_key_values, + "attention_mask": extended_attention_mask, + "position_ids": position_ids, + } + + def update_inputs_for_generation(self, model_outputs, model_kwargs): + model_kwargs["past_key_values"] = model_outputs.past_key_values + model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1 + return model_kwargs + + +append_call_sample_docstring( + FlaxRobertaForCausalLM, + _CHECKPOINT_FOR_DOC, + FlaxCausalLMOutputWithCrossAttentions, + _CONFIG_FOR_DOC, +) + + +__all__ = [ + "FlaxRobertaForCausalLM", + "FlaxRobertaForMaskedLM", + "FlaxRobertaForMultipleChoice", + "FlaxRobertaForQuestionAnswering", + "FlaxRobertaForSequenceClassification", + "FlaxRobertaForTokenClassification", + "FlaxRobertaModel", + "FlaxRobertaPreTrainedModel", +] diff --git a/docs/transformers/src/transformers/models/roberta/modeling_roberta.py b/docs/transformers/src/transformers/models/roberta/modeling_roberta.py new file mode 100644 index 0000000000000000000000000000000000000000..f2dfa19a6a50957e43825c9c723fe31c53375a37 --- /dev/null +++ b/docs/transformers/src/transformers/models/roberta/modeling_roberta.py @@ -0,0 +1,1698 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch RoBERTa model.""" + +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from packaging import version +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN, gelu +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import ( + _prepare_4d_attention_mask_for_sdpa, + _prepare_4d_causal_attention_mask_for_sdpa, +) +from ...modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + MaskedLMOutput, + MultipleChoiceModelOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + get_torch_version, + logging, + replace_return_docstrings, +) +from .configuration_roberta import RobertaConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "FacebookAI/roberta-base" +_CONFIG_FOR_DOC = "RobertaConfig" + + +class RobertaEmbeddings(nn.Module): + """ + Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. + """ + + # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.__init__ + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + self.register_buffer( + "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False + ) + + # End copy + self.padding_idx = config.pad_token_id + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx + ) + + def forward( + self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 + ): + if position_ids is None: + if input_ids is not None: + # Create the position ids from the input token ids. Any padded tokens remain padded. + position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length) + else: + position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds) + + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs + # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves + # issue #5664 + if token_type_ids is None: + if hasattr(self, "token_type_ids"): + buffered_token_type_ids = self.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + def create_position_ids_from_inputs_embeds(self, inputs_embeds): + """ + We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids. + + Args: + inputs_embeds: torch.Tensor + + Returns: torch.Tensor + """ + input_shape = inputs_embeds.size()[:-1] + sequence_length = input_shape[1] + + position_ids = torch.arange( + self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device + ) + return position_ids.unsqueeze(0).expand(input_shape) + + +# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->Roberta +class RobertaSelfAttention(nn.Module): + def __init__(self, config, position_embedding_type=None): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = position_embedding_type or getattr( + config, "position_embedding_type", "absolute" + ) + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) + + self.is_decoder = config.is_decoder + + def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_layer = past_key_value[0] + value_layer = past_key_value[1] + attention_mask = encoder_attention_mask + elif is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + use_cache = past_key_value is not None + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + query_length, key_length = query_layer.shape[2], key_layer.shape[2] + if use_cache: + position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( + -1, 1 + ) + else: + position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in RobertaModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + +# Copied from transformers.models.bert.modeling_bert.BertSdpaSelfAttention with Bert->Roberta +class RobertaSdpaSelfAttention(RobertaSelfAttention): + def __init__(self, config, position_embedding_type=None): + super().__init__(config, position_embedding_type=position_embedding_type) + self.dropout_prob = config.attention_probs_dropout_prob + self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0") + + # Adapted from RobertaSelfAttention + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None: + # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented. + logger.warning_once( + "RobertaSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support " + "non-absolute `position_embedding_type` or `output_attentions=True` or `head_mask`. Falling back to " + "the manual attention implementation, but specifying the manual implementation will be required from " + "Transformers version v5.0.0 onwards. This warning can be removed using the argument " + '`attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + + bsz, tgt_len, _ = hidden_states.size() + + query_layer = self.transpose_for_scores(self.query(hidden_states)) + + # If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention + # mask needs to be such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + attention_mask = encoder_attention_mask if is_cross_attention else attention_mask + + # Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning + if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]: + key_layer, value_layer = past_key_value + else: + key_layer = self.transpose_for_scores(self.key(current_states)) + value_layer = self.transpose_for_scores(self.value(current_states)) + if past_key_value is not None and not is_cross_attention: + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom + # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0. + # Reference: https://github.com/pytorch/pytorch/issues/112577 + if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None: + query_layer = query_layer.contiguous() + key_layer = key_layer.contiguous() + value_layer = value_layer.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create + # a causal mask in case tgt_len == 1. + is_causal = ( + True if self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1 else False + ) + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + attn_mask=attention_mask, + dropout_p=self.dropout_prob if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size) + + outputs = (attn_output,) + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + +# Copied from transformers.models.bert.modeling_bert.BertSelfOutput +class RobertaSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +ROBERTA_SELF_ATTENTION_CLASSES = { + "eager": RobertaSelfAttention, + "sdpa": RobertaSdpaSelfAttention, +} + + +# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Roberta,BERT->ROBERTA +class RobertaAttention(nn.Module): + def __init__(self, config, position_embedding_type=None): + super().__init__() + self.self = ROBERTA_SELF_ATTENTION_CLASSES[config._attn_implementation]( + config, position_embedding_type=position_embedding_type + ) + self.output = RobertaSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.bert.modeling_bert.BertIntermediate +class RobertaIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertOutput +class RobertaOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Roberta +class RobertaLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = RobertaAttention(config) + self.is_decoder = config.is_decoder + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + if not self.is_decoder: + raise ValueError(f"{self} should be used as a decoder model if cross attention is added") + self.crossattention = RobertaAttention(config, position_embedding_type="absolute") + self.intermediate = RobertaIntermediate(config) + self.output = RobertaOutput(config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + + # if decoder, the last output is tuple of self-attn cache + if self.is_decoder: + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + else: + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + cross_attn_present_key_value = None + if self.is_decoder and encoder_hidden_states is not None: + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" + " by setting `config.add_cross_attention=True`" + ) + + # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + cross_attn_past_key_value, + output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + + # add cross-attn cache to positions 3,4 of present_key_value tuple + cross_attn_present_key_value = cross_attention_outputs[-1] + present_key_value = present_key_value + cross_attn_present_key_value + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output + ) + outputs = (layer_output,) + outputs + + # if decoder, return the attn key/values as the last output + if self.is_decoder: + outputs = outputs + (present_key_value,) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +# Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->Roberta +class RobertaEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList([RobertaLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + next_decoder_cache = () if use_cache else None + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.__call__, + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +# Copied from transformers.models.bert.modeling_bert.BertPooler +class RobertaPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class RobertaPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = RobertaConfig + base_model_prefix = "roberta" + supports_gradient_checkpointing = True + _no_split_modules = ["RobertaEmbeddings", "RobertaSelfAttention", "RobertaSdpaSelfAttention"] + _supports_sdpa = True + + # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights with BertLMPredictionHead->RobertaLMHead + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, RobertaLMHead): + module.bias.data.zero_() + + +ROBERTA_START_DOCSTRING = r""" + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`RobertaConfig`]): Model configuration class with all the parameters of the + model. Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +ROBERTA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + This parameter can only be used when the model is initialized with `type_vocab_size` parameter with value + >= 2. All the value in this tensor should be always < type_vocab_size. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare RoBERTa Model transformer outputting raw hidden-states without any specific head on top.", + ROBERTA_START_DOCSTRING, +) +# Copied from transformers.models.bert.modeling_bert.BertModel with Bert->Roberta, BERT->ROBERTA +class RobertaModel(RobertaPreTrainedModel): + """ + + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in [Attention is + all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + + To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set + to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and + `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass. + """ + + _no_split_modules = ["RobertaEmbeddings", "RobertaLayer"] + + def __init__(self, config, add_pooling_layer=True): + super().__init__(config) + self.config = config + + self.embeddings = RobertaEmbeddings(config) + self.encoder = RobertaEncoder(config) + + self.pooler = RobertaPooler(config) if add_pooling_layer else None + + self.attn_implementation = config._attn_implementation + self.position_embedding_type = config.position_embedding_type + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPoolingAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, target_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.config.is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if token_type_ids is None: + if hasattr(self.embeddings, "token_type_ids"): + buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=device) + + use_sdpa_attention_masks = ( + self.attn_implementation == "sdpa" + and self.position_embedding_type == "absolute" + and head_mask is None + and not output_attentions + ) + + # Expand the attention mask + if use_sdpa_attention_masks and attention_mask.dim() == 2: + # Expand the attention mask for SDPA. + # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len] + if self.config.is_decoder: + extended_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + input_shape, + embedding_output, + past_key_values_length, + ) + else: + extended_attention_mask = _prepare_4d_attention_mask_for_sdpa( + attention_mask, embedding_output.dtype, tgt_len=seq_length + ) + else: + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + + if use_sdpa_attention_masks and encoder_attention_mask.dim() == 2: + # Expand the attention mask for SDPA. + # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len] + encoder_extended_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, embedding_output.dtype, tgt_len=seq_length + ) + else: + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +@add_start_docstrings( + """RoBERTa Model with a `language modeling` head on top for CLM fine-tuning.""", ROBERTA_START_DOCSTRING +) +class RobertaForCausalLM(RobertaPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + if not config.is_decoder: + logger.warning("If you want to use `RobertaLMHeadModel` as a standalone, add `is_decoder=True.`") + + self.roberta = RobertaModel(config, add_pooling_layer=False) + self.lm_head = RobertaLMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.lm_head.decoder + + def set_output_embeddings(self, new_embeddings): + self.lm_head.decoder = new_embeddings + + @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + past_key_values: Tuple[Tuple[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are + ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, RobertaForCausalLM, AutoConfig + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("FacebookAI/roberta-base") + >>> config = AutoConfig.from_pretrained("FacebookAI/roberta-base") + >>> config.is_decoder = True + >>> model = RobertaForCausalLM.from_pretrained("FacebookAI/roberta-base", config=config) + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> prediction_logits = outputs.logits + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if labels is not None: + use_cache = False + + outputs = self.roberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + prediction_scores = self.lm_head(sequence_output) + + lm_loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(prediction_scores.device) + lm_loss = self.loss_function( + prediction_scores, + labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def _reorder_cache(self, past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + + +@add_start_docstrings("""RoBERTa Model with a `language modeling` head on top.""", ROBERTA_START_DOCSTRING) +class RobertaForMaskedLM(RobertaPreTrainedModel): + _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + if config.is_decoder: + logger.warning( + "If you want to use `RobertaForMaskedLM` make sure `config.is_decoder=False` for " + "bi-directional self-attention." + ) + + self.roberta = RobertaModel(config, add_pooling_layer=False) + self.lm_head = RobertaLMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.lm_head.decoder + + def set_output_embeddings(self, new_embeddings): + self.lm_head.decoder = new_embeddings + + @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MaskedLMOutput, + config_class=_CONFIG_FOR_DOC, + mask="", + expected_output="' Paris'", + expected_loss=0.1, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + kwargs (`Dict[str, any]`, *optional*, defaults to `{}`): + Used to hide legacy arguments that have been deprecated. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.roberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + prediction_scores = self.lm_head(sequence_output) + + masked_lm_loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(prediction_scores.device) + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class RobertaLMHead(nn.Module): + """Roberta Head for masked language modeling.""" + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.decoder = nn.Linear(config.hidden_size, config.vocab_size) + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + self.decoder.bias = self.bias + + def forward(self, features, **kwargs): + x = self.dense(features) + x = gelu(x) + x = self.layer_norm(x) + + # project back to size of vocabulary with bias + x = self.decoder(x) + + return x + + def _tie_weights(self): + # To tie those two weights if they get disconnected (on TPU or when the bias is resized) + # For accelerate compatibility and to not break backward compatibility + if self.decoder.bias.device.type == "meta": + self.decoder.bias = self.bias + else: + self.bias = self.decoder.bias + + +@add_start_docstrings( + """ + RoBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the + pooled output) e.g. for GLUE tasks. + """, + ROBERTA_START_DOCSTRING, +) +class RobertaForSequenceClassification(RobertaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + + self.roberta = RobertaModel(config, add_pooling_layer=False) + self.classifier = RobertaClassificationHead(config) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint="cardiffnlp/twitter-roberta-base-emotion", + output_type=SequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_output="'optimism'", + expected_loss=0.08, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.roberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Roberta Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a + softmax) e.g. for RocStories/SWAG tasks. + """, + ROBERTA_START_DOCSTRING, +) +class RobertaForMultipleChoice(RobertaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.roberta = RobertaModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, 1) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MultipleChoiceModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., + num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See + `input_ids` above) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] + + flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None + flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None + flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + flat_inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) + + outputs = self.roberta( + flat_input_ids, + position_ids=flat_position_ids, + token_type_ids=flat_token_type_ids, + attention_mask=flat_attention_mask, + head_mask=head_mask, + inputs_embeds=flat_inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + reshaped_logits = logits.view(-1, num_choices) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(reshaped_logits.device) + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + + if not return_dict: + output = (reshaped_logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return MultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Roberta Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + ROBERTA_START_DOCSTRING, +) +class RobertaForTokenClassification(RobertaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.roberta = RobertaModel(config, add_pooling_layer=False) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint="Jean-Baptiste/roberta-large-ner-english", + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_output="['O', 'ORG', 'ORG', 'O', 'O', 'O', 'O', 'O', 'LOC', 'O', 'LOC', 'LOC']", + expected_loss=0.01, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.roberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class RobertaClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) + self.out_proj = nn.Linear(config.hidden_size, config.num_labels) + + def forward(self, features, **kwargs): + x = features[:, 0, :] # take token (equiv. to [CLS]) + x = self.dropout(x) + x = self.dense(x) + x = torch.tanh(x) + x = self.dropout(x) + x = self.out_proj(x) + return x + + +@add_start_docstrings( + """ + Roberta Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear + layers on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + ROBERTA_START_DOCSTRING, +) +class RobertaForQuestionAnswering(RobertaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.roberta = RobertaModel(config, add_pooling_layer=False) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint="deepset/roberta-base-squad2", + output_type=QuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + expected_output="' puppet'", + expected_loss=0.86, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.roberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0): + """ + Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols + are ignored. This is modified from fairseq's `utils.make_positions`. + + Args: + x: torch.Tensor x: + + Returns: torch.Tensor + """ + # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. + mask = input_ids.ne(padding_idx).int() + incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask + return incremental_indices.long() + padding_idx + + +__all__ = [ + "RobertaForCausalLM", + "RobertaForMaskedLM", + "RobertaForMultipleChoice", + "RobertaForQuestionAnswering", + "RobertaForSequenceClassification", + "RobertaForTokenClassification", + "RobertaModel", + "RobertaPreTrainedModel", +] diff --git a/docs/transformers/src/transformers/models/roberta/modeling_tf_roberta.py b/docs/transformers/src/transformers/models/roberta/modeling_tf_roberta.py new file mode 100644 index 0000000000000000000000000000000000000000..55361cdcc7132b9cf30735fb1ece039ce84d8b5e --- /dev/null +++ b/docs/transformers/src/transformers/models/roberta/modeling_tf_roberta.py @@ -0,0 +1,1783 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TF 2.0 RoBERTa model.""" + +from __future__ import annotations + +import math +import warnings +from typing import Optional, Tuple, Union + +import numpy as np +import tensorflow as tf + +from ...activations_tf import get_tf_activation +from ...modeling_tf_outputs import ( + TFBaseModelOutputWithPastAndCrossAttentions, + TFBaseModelOutputWithPoolingAndCrossAttentions, + TFCausalLMOutputWithCrossAttentions, + TFMaskedLMOutput, + TFMultipleChoiceModelOutput, + TFQuestionAnsweringModelOutput, + TFSequenceClassifierOutput, + TFTokenClassifierOutput, +) +from ...modeling_tf_utils import ( + TFCausalLanguageModelingLoss, + TFMaskedLanguageModelingLoss, + TFModelInputType, + TFMultipleChoiceLoss, + TFPreTrainedModel, + TFQuestionAnsweringLoss, + TFSequenceClassificationLoss, + TFTokenClassificationLoss, + get_initializer, + keras, + keras_serializable, + unpack_inputs, +) +from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax +from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, +) +from .configuration_roberta import RobertaConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "FacebookAI/roberta-base" +_CONFIG_FOR_DOC = "RobertaConfig" + + +class TFRobertaEmbeddings(keras.layers.Layer): + """ + Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. + """ + + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + + self.padding_idx = 1 + self.config = config + self.hidden_size = config.hidden_size + self.max_position_embeddings = config.max_position_embeddings + self.initializer_range = config.initializer_range + self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") + self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) + + def build(self, input_shape=None): + with tf.name_scope("word_embeddings"): + self.weight = self.add_weight( + name="weight", + shape=[self.config.vocab_size, self.hidden_size], + initializer=get_initializer(self.initializer_range), + ) + + with tf.name_scope("token_type_embeddings"): + self.token_type_embeddings = self.add_weight( + name="embeddings", + shape=[self.config.type_vocab_size, self.hidden_size], + initializer=get_initializer(self.initializer_range), + ) + + with tf.name_scope("position_embeddings"): + self.position_embeddings = self.add_weight( + name="embeddings", + shape=[self.max_position_embeddings, self.hidden_size], + initializer=get_initializer(self.initializer_range), + ) + + if self.built: + return + self.built = True + if getattr(self, "LayerNorm", None) is not None: + with tf.name_scope(self.LayerNorm.name): + self.LayerNorm.build([None, None, self.config.hidden_size]) + + def create_position_ids_from_input_ids(self, input_ids, past_key_values_length=0): + """ + Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding + symbols are ignored. This is modified from fairseq's `utils.make_positions`. + + Args: + input_ids: tf.Tensor + Returns: tf.Tensor + """ + mask = tf.cast(tf.math.not_equal(input_ids, self.padding_idx), dtype=input_ids.dtype) + incremental_indices = (tf.math.cumsum(mask, axis=1) + past_key_values_length) * mask + + return incremental_indices + self.padding_idx + + def call( + self, + input_ids=None, + position_ids=None, + token_type_ids=None, + inputs_embeds=None, + past_key_values_length=0, + training=False, + ): + """ + Applies embedding based on inputs tensor. + + Returns: + final_embeddings (`tf.Tensor`): output embedding tensor. + """ + assert not (input_ids is None and inputs_embeds is None) + + if input_ids is not None: + check_embeddings_within_bounds(input_ids, self.config.vocab_size) + inputs_embeds = tf.gather(params=self.weight, indices=input_ids) + + input_shape = shape_list(inputs_embeds)[:-1] + + if token_type_ids is None: + token_type_ids = tf.fill(dims=input_shape, value=0) + + if position_ids is None: + if input_ids is not None: + # Create the position ids from the input token ids. Any padded tokens remain padded. + position_ids = self.create_position_ids_from_input_ids( + input_ids=input_ids, past_key_values_length=past_key_values_length + ) + else: + position_ids = tf.expand_dims( + tf.range(start=self.padding_idx + 1, limit=input_shape[-1] + self.padding_idx + 1), axis=0 + ) + + position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids) + token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids) + final_embeddings = inputs_embeds + position_embeds + token_type_embeds + final_embeddings = self.LayerNorm(inputs=final_embeddings) + final_embeddings = self.dropout(inputs=final_embeddings, training=training) + + return final_embeddings + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertPooler with Bert->Roberta +class TFRobertaPooler(keras.layers.Layer): + def __init__(self, config: RobertaConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = keras.layers.Dense( + units=config.hidden_size, + kernel_initializer=get_initializer(config.initializer_range), + activation="tanh", + name="dense", + ) + self.config = config + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(inputs=first_token_tensor) + + return pooled_output + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfAttention with Bert->Roberta +class TFRobertaSelfAttention(keras.layers.Layer): + def __init__(self, config: RobertaConfig, **kwargs): + super().__init__(**kwargs) + + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number " + f"of attention heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.sqrt_att_head_size = math.sqrt(self.attention_head_size) + + self.query = keras.layers.Dense( + units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query" + ) + self.key = keras.layers.Dense( + units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key" + ) + self.value = keras.layers.Dense( + units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value" + ) + self.dropout = keras.layers.Dropout(rate=config.attention_probs_dropout_prob) + + self.is_decoder = config.is_decoder + self.config = config + + def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor: + # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size] + tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size)) + + # Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size] + return tf.transpose(tensor, perm=[0, 2, 1, 3]) + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: tf.Tensor, + head_mask: tf.Tensor, + encoder_hidden_states: tf.Tensor, + encoder_attention_mask: tf.Tensor, + past_key_value: Tuple[tf.Tensor], + output_attentions: bool, + training: bool = False, + ) -> Tuple[tf.Tensor]: + batch_size = shape_list(hidden_states)[0] + mixed_query_layer = self.query(inputs=hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_layer = past_key_value[0] + value_layer = past_key_value[1] + attention_mask = encoder_attention_mask + elif is_cross_attention: + key_layer = self.transpose_for_scores(self.key(inputs=encoder_hidden_states), batch_size) + value_layer = self.transpose_for_scores(self.value(inputs=encoder_hidden_states), batch_size) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size) + value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size) + key_layer = tf.concat([past_key_value[0], key_layer], axis=2) + value_layer = tf.concat([past_key_value[1], value_layer], axis=2) + else: + key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size) + value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size) + + query_layer = self.transpose_for_scores(mixed_query_layer, batch_size) + + if self.is_decoder: + # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + # (batch size, num_heads, seq_len_q, seq_len_k) + attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) + dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype) + attention_scores = tf.divide(attention_scores, dk) + + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in TFRobertaModel call() function) + attention_scores = tf.add(attention_scores, attention_mask) + + # Normalize the attention scores to probabilities. + attention_probs = stable_softmax(logits=attention_scores, axis=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(inputs=attention_probs, training=training) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = tf.multiply(attention_probs, head_mask) + + attention_output = tf.matmul(attention_probs, value_layer) + attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3]) + + # (batch_size, seq_len_q, all_head_size) + attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size)) + outputs = (attention_output, attention_probs) if output_attentions else (attention_output,) + + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "query", None) is not None: + with tf.name_scope(self.query.name): + self.query.build([None, None, self.config.hidden_size]) + if getattr(self, "key", None) is not None: + with tf.name_scope(self.key.name): + self.key.build([None, None, self.config.hidden_size]) + if getattr(self, "value", None) is not None: + with tf.name_scope(self.value.name): + self.value.build([None, None, self.config.hidden_size]) + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfOutput with Bert->Roberta +class TFRobertaSelfOutput(keras.layers.Layer): + def __init__(self, config: RobertaConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = keras.layers.Dense( + units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") + self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) + self.config = config + + def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: + hidden_states = self.dense(inputs=hidden_states) + hidden_states = self.dropout(inputs=hidden_states, training=training) + hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor) + + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + if getattr(self, "LayerNorm", None) is not None: + with tf.name_scope(self.LayerNorm.name): + self.LayerNorm.build([None, None, self.config.hidden_size]) + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertAttention with Bert->Roberta +class TFRobertaAttention(keras.layers.Layer): + def __init__(self, config: RobertaConfig, **kwargs): + super().__init__(**kwargs) + + self.self_attention = TFRobertaSelfAttention(config, name="self") + self.dense_output = TFRobertaSelfOutput(config, name="output") + + def prune_heads(self, heads): + raise NotImplementedError + + def call( + self, + input_tensor: tf.Tensor, + attention_mask: tf.Tensor, + head_mask: tf.Tensor, + encoder_hidden_states: tf.Tensor, + encoder_attention_mask: tf.Tensor, + past_key_value: Tuple[tf.Tensor], + output_attentions: bool, + training: bool = False, + ) -> Tuple[tf.Tensor]: + self_outputs = self.self_attention( + hidden_states=input_tensor, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + training=training, + ) + attention_output = self.dense_output( + hidden_states=self_outputs[0], input_tensor=input_tensor, training=training + ) + # add attentions (possibly with past_key_value) if we output them + outputs = (attention_output,) + self_outputs[1:] + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "self_attention", None) is not None: + with tf.name_scope(self.self_attention.name): + self.self_attention.build(None) + if getattr(self, "dense_output", None) is not None: + with tf.name_scope(self.dense_output.name): + self.dense_output.build(None) + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertIntermediate with Bert->Roberta +class TFRobertaIntermediate(keras.layers.Layer): + def __init__(self, config: RobertaConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = keras.layers.Dense( + units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = get_tf_activation(config.hidden_act) + else: + self.intermediate_act_fn = config.hidden_act + self.config = config + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + hidden_states = self.dense(inputs=hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertOutput with Bert->Roberta +class TFRobertaOutput(keras.layers.Layer): + def __init__(self, config: RobertaConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = keras.layers.Dense( + units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") + self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) + self.config = config + + def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: + hidden_states = self.dense(inputs=hidden_states) + hidden_states = self.dropout(inputs=hidden_states, training=training) + hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor) + + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.intermediate_size]) + if getattr(self, "LayerNorm", None) is not None: + with tf.name_scope(self.LayerNorm.name): + self.LayerNorm.build([None, None, self.config.hidden_size]) + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertLayer with Bert->Roberta +class TFRobertaLayer(keras.layers.Layer): + def __init__(self, config: RobertaConfig, **kwargs): + super().__init__(**kwargs) + + self.attention = TFRobertaAttention(config, name="attention") + self.is_decoder = config.is_decoder + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + if not self.is_decoder: + raise ValueError(f"{self} should be used as a decoder model if cross attention is added") + self.crossattention = TFRobertaAttention(config, name="crossattention") + self.intermediate = TFRobertaIntermediate(config, name="intermediate") + self.bert_output = TFRobertaOutput(config, name="output") + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: tf.Tensor, + head_mask: tf.Tensor, + encoder_hidden_states: tf.Tensor | None, + encoder_attention_mask: tf.Tensor | None, + past_key_value: Tuple[tf.Tensor] | None, + output_attentions: bool, + training: bool = False, + ) -> Tuple[tf.Tensor]: + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attention_outputs = self.attention( + input_tensor=hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=self_attn_past_key_value, + output_attentions=output_attentions, + training=training, + ) + attention_output = self_attention_outputs[0] + + # if decoder, the last output is tuple of self-attn cache + if self.is_decoder: + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + else: + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + cross_attn_present_key_value = None + if self.is_decoder and encoder_hidden_states is not None: + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" + " by setting `config.add_cross_attention=True`" + ) + + # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + cross_attention_outputs = self.crossattention( + input_tensor=attention_output, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=cross_attn_past_key_value, + output_attentions=output_attentions, + training=training, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + + # add cross-attn cache to positions 3,4 of present_key_value tuple + cross_attn_present_key_value = cross_attention_outputs[-1] + present_key_value = present_key_value + cross_attn_present_key_value + + intermediate_output = self.intermediate(hidden_states=attention_output) + layer_output = self.bert_output( + hidden_states=intermediate_output, input_tensor=attention_output, training=training + ) + outputs = (layer_output,) + outputs # add attentions if we output them + + # if decoder, return the attn key/values as the last output + if self.is_decoder: + outputs = outputs + (present_key_value,) + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "attention", None) is not None: + with tf.name_scope(self.attention.name): + self.attention.build(None) + if getattr(self, "intermediate", None) is not None: + with tf.name_scope(self.intermediate.name): + self.intermediate.build(None) + if getattr(self, "bert_output", None) is not None: + with tf.name_scope(self.bert_output.name): + self.bert_output.build(None) + if getattr(self, "crossattention", None) is not None: + with tf.name_scope(self.crossattention.name): + self.crossattention.build(None) + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertEncoder with Bert->Roberta +class TFRobertaEncoder(keras.layers.Layer): + def __init__(self, config: RobertaConfig, **kwargs): + super().__init__(**kwargs) + self.config = config + self.layer = [TFRobertaLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)] + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: tf.Tensor, + head_mask: tf.Tensor, + encoder_hidden_states: tf.Tensor | None, + encoder_attention_mask: tf.Tensor | None, + past_key_values: Tuple[Tuple[tf.Tensor]] | None, + use_cache: Optional[bool], + output_attentions: bool, + output_hidden_states: bool, + return_dict: bool, + training: bool = False, + ) -> Union[TFBaseModelOutputWithPastAndCrossAttentions, Tuple[tf.Tensor]]: + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + next_decoder_cache = () if use_cache else None + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + past_key_value = past_key_values[i] if past_key_values is not None else None + + layer_outputs = layer_module( + hidden_states=hidden_states, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + training=training, + ) + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + if self.config.add_cross_attention and encoder_hidden_states is not None: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v for v in [hidden_states, all_hidden_states, all_attentions, all_cross_attentions] if v is not None + ) + + return TFBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "layer", None) is not None: + for layer in self.layer: + with tf.name_scope(layer.name): + layer.build(None) + + +@keras_serializable +class TFRobertaMainLayer(keras.layers.Layer): + config_class = RobertaConfig + + def __init__(self, config, add_pooling_layer=True, **kwargs): + super().__init__(**kwargs) + + self.config = config + self.is_decoder = config.is_decoder + + self.num_hidden_layers = config.num_hidden_layers + self.initializer_range = config.initializer_range + self.output_attentions = config.output_attentions + self.output_hidden_states = config.output_hidden_states + self.return_dict = config.use_return_dict + self.encoder = TFRobertaEncoder(config, name="encoder") + self.pooler = TFRobertaPooler(config, name="pooler") if add_pooling_layer else None + # The embeddings must be the last declaration in order to follow the weights order + self.embeddings = TFRobertaEmbeddings(config, name="embeddings") + + # Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.get_input_embeddings + def get_input_embeddings(self) -> keras.layers.Layer: + return self.embeddings + + # Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.set_input_embeddings + def set_input_embeddings(self, value: tf.Variable): + self.embeddings.weight = value + self.embeddings.vocab_size = shape_list(value)[0] + + # Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer._prune_heads + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + raise NotImplementedError + + @unpack_inputs + # Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.call + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + encoder_hidden_states: np.ndarray | tf.Tensor | None = None, + encoder_attention_mask: np.ndarray | tf.Tensor | None = None, + past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: bool = False, + ) -> Union[TFBaseModelOutputWithPoolingAndCrossAttentions, Tuple[tf.Tensor]]: + if not self.config.is_decoder: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = shape_list(input_ids) + elif inputs_embeds is not None: + input_shape = shape_list(inputs_embeds)[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + + if past_key_values is None: + past_key_values_length = 0 + past_key_values = [None] * len(self.encoder.layer) + else: + past_key_values_length = shape_list(past_key_values[0][0])[-2] + + if attention_mask is None: + attention_mask = tf.fill(dims=(batch_size, seq_length + past_key_values_length), value=1) + + if token_type_ids is None: + token_type_ids = tf.fill(dims=input_shape, value=0) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + training=training, + ) + + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask_shape = shape_list(attention_mask) + + mask_seq_length = seq_length + past_key_values_length + # Copied from `modeling_tf_t5.py` + # Provided a padding mask of dimensions [batch_size, mask_seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] + if self.is_decoder: + seq_ids = tf.range(mask_seq_length) + causal_mask = tf.less_equal( + tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)), + seq_ids[None, :, None], + ) + causal_mask = tf.cast(causal_mask, dtype=attention_mask.dtype) + extended_attention_mask = causal_mask * attention_mask[:, None, :] + attention_mask_shape = shape_list(extended_attention_mask) + extended_attention_mask = tf.reshape( + extended_attention_mask, (attention_mask_shape[0], 1, attention_mask_shape[1], attention_mask_shape[2]) + ) + if past_key_values[0] is not None: + # attention_mask needs to be sliced to the shape `[batch_size, 1, from_seq_length - cached_seq_length, to_seq_length] + extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :] + else: + extended_attention_mask = tf.reshape( + attention_mask, (attention_mask_shape[0], 1, 1, attention_mask_shape[1]) + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = tf.cast(extended_attention_mask, dtype=embedding_output.dtype) + one_cst = tf.constant(1.0, dtype=embedding_output.dtype) + ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype) + extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst) + + # Copied from `modeling_tf_t5.py` with -1e9 -> -10000 + if self.is_decoder and encoder_attention_mask is not None: + # If a 2D ou 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + encoder_attention_mask = tf.cast(encoder_attention_mask, dtype=extended_attention_mask.dtype) + num_dims_encoder_attention_mask = len(shape_list(encoder_attention_mask)) + if num_dims_encoder_attention_mask == 3: + encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] + if num_dims_encoder_attention_mask == 2: + encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] + + # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition + # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270 + # encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask, + # tf.transpose(encoder_extended_attention_mask, perm=(-1, -2))) + + encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0 + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + if head_mask is not None: + raise NotImplementedError + else: + head_mask = [None] * self.config.num_hidden_layers + + encoder_outputs = self.encoder( + hidden_states=embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(hidden_states=sequence_output) if self.pooler is not None else None + + if not return_dict: + return ( + sequence_output, + pooled_output, + ) + encoder_outputs[1:] + + return TFBaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "encoder", None) is not None: + with tf.name_scope(self.encoder.name): + self.encoder.build(None) + if getattr(self, "pooler", None) is not None: + with tf.name_scope(self.pooler.name): + self.pooler.build(None) + if getattr(self, "embeddings", None) is not None: + with tf.name_scope(self.embeddings.name): + self.embeddings.build(None) + + +class TFRobertaPreTrainedModel(TFPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = RobertaConfig + base_model_prefix = "roberta" + + +ROBERTA_START_DOCSTRING = r""" + + This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it + as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and + behavior. + + + + TensorFlow models and layers in `transformers` accept two formats as input: + + - having all inputs as keyword arguments (like PyTorch models), or + - having all inputs as a list, tuple or dict in the first positional argument. + + The reason the second format is supported is that Keras methods prefer this format when passing inputs to models + and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just + pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second + format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with + the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first + positional argument: + + - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` + - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: + `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` + - a dictionary with one or several input Tensors associated to the input names given in the docstring: + `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` + + Note that when creating models and layers with + [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry + about any of this, as you can just pass inputs like you would to any other Python function! + + + + Parameters: + config ([`RobertaConfig`]): Model configuration class with all the parameters of the + model. Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +ROBERTA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`Numpy array` or `tf.Tensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and + [`PreTrainedTokenizer.encode`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`Numpy array` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`tf.Tensor` of shape `({0}, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the + config will be used instead. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. This argument can be used only in eager mode, in graph mode the value in the config will be + used instead. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in + eager mode, in graph mode the value will always be set to True. + training (`bool`, *optional*, defaults to `False`): + Whether or not to use the model in training mode (some modules like dropout modules have different + behaviors between training and evaluation). +""" + + +@add_start_docstrings( + "The bare RoBERTa Model transformer outputting raw hidden-states without any specific head on top.", + ROBERTA_START_DOCSTRING, +) +class TFRobertaModel(TFRobertaPreTrainedModel): + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.roberta = TFRobertaMainLayer(config, name="roberta") + + @unpack_inputs + @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFBaseModelOutputWithPoolingAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + encoder_hidden_states: np.ndarray | tf.Tensor | None = None, + encoder_attention_mask: np.ndarray | tf.Tensor | None = None, + past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: Optional[bool] = False, + ) -> Union[Tuple, TFBaseModelOutputWithPoolingAndCrossAttentions]: + r""" + encoder_hidden_states (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + past_key_values (`Tuple[Tuple[tf.Tensor]]` of length `config.n_layers`) + contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*, defaults to `True`): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). Set to `False` during training, `True` during generation + """ + outputs = self.roberta( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "roberta", None) is not None: + with tf.name_scope(self.roberta.name): + self.roberta.build(None) + + +class TFRobertaLMHead(keras.layers.Layer): + """Roberta Head for masked language modeling.""" + + def __init__(self, config, input_embeddings, **kwargs): + super().__init__(**kwargs) + + self.config = config + self.hidden_size = config.hidden_size + self.dense = keras.layers.Dense( + config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm") + self.act = get_tf_activation("gelu") + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = input_embeddings + + def build(self, input_shape=None): + self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias") + + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + if getattr(self, "layer_norm", None) is not None: + with tf.name_scope(self.layer_norm.name): + self.layer_norm.build([None, None, self.config.hidden_size]) + + def get_output_embeddings(self): + return self.decoder + + def set_output_embeddings(self, value): + self.decoder.weight = value + self.decoder.vocab_size = shape_list(value)[0] + + def get_bias(self): + return {"bias": self.bias} + + def set_bias(self, value): + self.bias = value["bias"] + self.config.vocab_size = shape_list(value["bias"])[0] + + def call(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.layer_norm(hidden_states) + + # project back to size of vocabulary with bias + seq_length = shape_list(tensor=hidden_states)[1] + hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.hidden_size]) + hidden_states = tf.matmul(a=hidden_states, b=self.decoder.weight, transpose_b=True) + hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.config.vocab_size]) + hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.bias) + + return hidden_states + + +@add_start_docstrings("""RoBERTa Model with a `language modeling` head on top.""", ROBERTA_START_DOCSTRING) +class TFRobertaForMaskedLM(TFRobertaPreTrainedModel, TFMaskedLanguageModelingLoss): + # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model + _keys_to_ignore_on_load_unexpected = [r"pooler", r"lm_head.decoder.weight"] + + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.roberta = TFRobertaMainLayer(config, add_pooling_layer=False, name="roberta") + self.lm_head = TFRobertaLMHead(config, self.roberta.embeddings, name="lm_head") + + def get_lm_head(self): + return self.lm_head + + def get_prefix_bias_name(self): + warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) + return self.name + "/" + self.lm_head.name + + @unpack_inputs + @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFMaskedLMOutput, + config_class=_CONFIG_FOR_DOC, + mask="", + expected_output="' Paris'", + expected_loss=0.1, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: np.ndarray | tf.Tensor | None = None, + training: Optional[bool] = False, + ) -> Union[TFMaskedLMOutput, Tuple[tf.Tensor]]: + r""" + labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + """ + outputs = self.roberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + sequence_output = outputs[0] + prediction_scores = self.lm_head(sequence_output) + + loss = None if labels is None else self.hf_compute_loss(labels, prediction_scores) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFMaskedLMOutput( + loss=loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "roberta", None) is not None: + with tf.name_scope(self.roberta.name): + self.roberta.build(None) + if getattr(self, "lm_head", None) is not None: + with tf.name_scope(self.lm_head.name): + self.lm_head.build(None) + + +class TFRobertaForCausalLM(TFRobertaPreTrainedModel, TFCausalLanguageModelingLoss): + # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model + _keys_to_ignore_on_load_unexpected = [r"pooler", r"lm_head.decoder.weight"] + + def __init__(self, config: RobertaConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + if not config.is_decoder: + logger.warning("If you want to use `TFRobertaLMHeadModel` as a standalone, add `is_decoder=True.`") + + self.roberta = TFRobertaMainLayer(config, add_pooling_layer=False, name="roberta") + self.lm_head = TFRobertaLMHead(config, input_embeddings=self.roberta.embeddings, name="lm_head") + + def get_lm_head(self): + return self.lm_head + + def get_prefix_bias_name(self): + warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) + return self.name + "/" + self.lm_head.name + + # Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel.prepare_inputs_for_generation + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs): + input_shape = input_ids.shape + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = tf.ones(input_shape) + + # cut decoder_input_ids if past is used + if past_key_values is not None: + input_ids = input_ids[:, -1:] + + return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} + + @unpack_inputs + @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFCausalLMOutputWithCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + encoder_hidden_states: np.ndarray | tf.Tensor | None = None, + encoder_attention_mask: np.ndarray | tf.Tensor | None = None, + past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: np.ndarray | tf.Tensor | None = None, + training: Optional[bool] = False, + ) -> Union[TFCausalLMOutputWithCrossAttentions, Tuple[tf.Tensor]]: + r""" + encoder_hidden_states (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + past_key_values (`Tuple[Tuple[tf.Tensor]]` of length `config.n_layers`) + contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*, defaults to `True`): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). Set to `False` during training, `True` during generation + labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the cross entropy classification loss. Indices should be in `[0, ..., + config.vocab_size - 1]`. + """ + outputs = self.roberta( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + sequence_output = outputs[0] + logits = self.lm_head(hidden_states=sequence_output, training=training) + loss = None + + if labels is not None: + # shift labels to the left and cut last logit token + shifted_logits = logits[:, :-1] + labels = labels[:, 1:] + loss = self.hf_compute_loss(labels=labels, logits=shifted_logits) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFCausalLMOutputWithCrossAttentions( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "roberta", None) is not None: + with tf.name_scope(self.roberta.name): + self.roberta.build(None) + if getattr(self, "lm_head", None) is not None: + with tf.name_scope(self.lm_head.name): + self.lm_head.build(None) + + +class TFRobertaClassificationHead(keras.layers.Layer): + """Head for sentence-level classification tasks.""" + + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + self.dense = keras.layers.Dense( + config.hidden_size, + kernel_initializer=get_initializer(config.initializer_range), + activation="tanh", + name="dense", + ) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = keras.layers.Dropout(classifier_dropout) + self.out_proj = keras.layers.Dense( + config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="out_proj" + ) + self.config = config + + def call(self, features, training=False): + x = features[:, 0, :] # take token (equiv. to [CLS]) + x = self.dropout(x, training=training) + x = self.dense(x) + x = self.dropout(x, training=training) + x = self.out_proj(x) + return x + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + if getattr(self, "out_proj", None) is not None: + with tf.name_scope(self.out_proj.name): + self.out_proj.build([None, None, self.config.hidden_size]) + + +@add_start_docstrings( + """ + RoBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the + pooled output) e.g. for GLUE tasks. + """, + ROBERTA_START_DOCSTRING, +) +class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceClassificationLoss): + # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model + _keys_to_ignore_on_load_unexpected = [r"pooler", r"lm_head"] + + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.num_labels = config.num_labels + + self.roberta = TFRobertaMainLayer(config, add_pooling_layer=False, name="roberta") + self.classifier = TFRobertaClassificationHead(config, name="classifier") + + @unpack_inputs + @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint="cardiffnlp/twitter-roberta-base-emotion", + output_type=TFSequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_output="'optimism'", + expected_loss=0.08, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: np.ndarray | tf.Tensor | None = None, + training: Optional[bool] = False, + ) -> Union[TFSequenceClassifierOutput, Tuple[tf.Tensor]]: + r""" + labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + outputs = self.roberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + sequence_output = outputs[0] + logits = self.classifier(sequence_output, training=training) + + loss = None if labels is None else self.hf_compute_loss(labels, logits) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFSequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "roberta", None) is not None: + with tf.name_scope(self.roberta.name): + self.roberta.build(None) + if getattr(self, "classifier", None) is not None: + with tf.name_scope(self.classifier.name): + self.classifier.build(None) + + +@add_start_docstrings( + """ + Roberta Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a + softmax) e.g. for RocStories/SWAG tasks. + """, + ROBERTA_START_DOCSTRING, +) +class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss): + # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model + _keys_to_ignore_on_load_unexpected = [r"lm_head"] + _keys_to_ignore_on_load_missing = [r"dropout"] + + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.roberta = TFRobertaMainLayer(config, name="roberta") + self.dropout = keras.layers.Dropout(config.hidden_dropout_prob) + self.classifier = keras.layers.Dense( + 1, kernel_initializer=get_initializer(config.initializer_range), name="classifier" + ) + self.config = config + + @unpack_inputs + @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFMultipleChoiceModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: np.ndarray | tf.Tensor | None = None, + training: Optional[bool] = False, + ) -> Union[TFMultipleChoiceModelOutput, Tuple[tf.Tensor]]: + r""" + labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` + where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above) + """ + + if input_ids is not None: + num_choices = shape_list(input_ids)[1] + seq_length = shape_list(input_ids)[2] + else: + num_choices = shape_list(inputs_embeds)[1] + seq_length = shape_list(inputs_embeds)[2] + + flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None + flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None + flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None + flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None + outputs = self.roberta( + flat_input_ids, + flat_attention_mask, + flat_token_type_ids, + flat_position_ids, + head_mask, + inputs_embeds, + output_attentions, + output_hidden_states, + return_dict=return_dict, + training=training, + ) + pooled_output = outputs[1] + pooled_output = self.dropout(pooled_output, training=training) + logits = self.classifier(pooled_output) + reshaped_logits = tf.reshape(logits, (-1, num_choices)) + + loss = None if labels is None else self.hf_compute_loss(labels, reshaped_logits) + + if not return_dict: + output = (reshaped_logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFMultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "roberta", None) is not None: + with tf.name_scope(self.roberta.name): + self.roberta.build(None) + if getattr(self, "classifier", None) is not None: + with tf.name_scope(self.classifier.name): + self.classifier.build([None, None, self.config.hidden_size]) + + +@add_start_docstrings( + """ + RoBERTa Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + ROBERTA_START_DOCSTRING, +) +class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassificationLoss): + # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model + _keys_to_ignore_on_load_unexpected = [r"pooler", r"lm_head"] + _keys_to_ignore_on_load_missing = [r"dropout"] + + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.num_labels = config.num_labels + + self.roberta = TFRobertaMainLayer(config, add_pooling_layer=False, name="roberta") + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = keras.layers.Dropout(classifier_dropout) + self.classifier = keras.layers.Dense( + config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" + ) + self.config = config + + @unpack_inputs + @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint="ydshieh/roberta-large-ner-english", + output_type=TFTokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_output="['O', 'ORG', 'ORG', 'O', 'O', 'O', 'O', 'O', 'LOC', 'O', 'LOC', 'LOC']", + expected_loss=0.01, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: np.ndarray | tf.Tensor | None = None, + training: Optional[bool] = False, + ) -> Union[TFTokenClassifierOutput, Tuple[tf.Tensor]]: + r""" + labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + outputs = self.roberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output, training=training) + logits = self.classifier(sequence_output) + + loss = None if labels is None else self.hf_compute_loss(labels, logits) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFTokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "roberta", None) is not None: + with tf.name_scope(self.roberta.name): + self.roberta.build(None) + if getattr(self, "classifier", None) is not None: + with tf.name_scope(self.classifier.name): + self.classifier.build([None, None, self.config.hidden_size]) + + +@add_start_docstrings( + """ + RoBERTa Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear + layers on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + ROBERTA_START_DOCSTRING, +) +class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnsweringLoss): + # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model + _keys_to_ignore_on_load_unexpected = [r"pooler", r"lm_head"] + + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.num_labels = config.num_labels + + self.roberta = TFRobertaMainLayer(config, add_pooling_layer=False, name="roberta") + self.qa_outputs = keras.layers.Dense( + config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs" + ) + self.config = config + + @unpack_inputs + @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint="ydshieh/roberta-base-squad2", + output_type=TFQuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + expected_output="' puppet'", + expected_loss=0.86, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + start_positions: np.ndarray | tf.Tensor | None = None, + end_positions: np.ndarray | tf.Tensor | None = None, + training: Optional[bool] = False, + ) -> Union[TFQuestionAnsweringModelOutput, Tuple[tf.Tensor]]: + r""" + start_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + outputs = self.roberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = tf.split(logits, 2, axis=-1) + start_logits = tf.squeeze(start_logits, axis=-1) + end_logits = tf.squeeze(end_logits, axis=-1) + + loss = None + if start_positions is not None and end_positions is not None: + labels = {"start_position": start_positions} + labels["end_position"] = end_positions + loss = self.hf_compute_loss(labels, (start_logits, end_logits)) + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFQuestionAnsweringModelOutput( + loss=loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "roberta", None) is not None: + with tf.name_scope(self.roberta.name): + self.roberta.build(None) + if getattr(self, "qa_outputs", None) is not None: + with tf.name_scope(self.qa_outputs.name): + self.qa_outputs.build([None, None, self.config.hidden_size]) + + +__all__ = [ + "TFRobertaForCausalLM", + "TFRobertaForMaskedLM", + "TFRobertaForMultipleChoice", + "TFRobertaForQuestionAnswering", + "TFRobertaForSequenceClassification", + "TFRobertaForTokenClassification", + "TFRobertaMainLayer", + "TFRobertaModel", + "TFRobertaPreTrainedModel", +] diff --git a/docs/transformers/src/transformers/models/roberta/tokenization_roberta.py b/docs/transformers/src/transformers/models/roberta/tokenization_roberta.py new file mode 100644 index 0000000000000000000000000000000000000000..6ec6cbbf9866ea3a5a7bbf89a5606e91317f7374 --- /dev/null +++ b/docs/transformers/src/transformers/models/roberta/tokenization_roberta.py @@ -0,0 +1,402 @@ +# coding=utf-8 +# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for RoBERTa.""" + +import json +import os +from functools import lru_cache +from typing import List, Optional, Tuple + +import regex as re + +from ...tokenization_utils import AddedToken, PreTrainedTokenizer +from ...utils import logging + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = { + "vocab_file": "vocab.json", + "merges_file": "merges.txt", +} + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control + characters the bpe code barfs on. + + The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab + if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for + decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup + tables between utf-8 bytes and unicode strings. + """ + bs = ( + list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """ + Return set of symbol pairs in a word. + + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +class RobertaTokenizer(PreTrainedTokenizer): + """ + Constructs a RoBERTa tokenizer, derived from the GPT-2 tokenizer, using byte-level Byte-Pair-Encoding. + + This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will + be encoded differently whether it is at the beginning of the sentence (without space) or not: + + ```python + >>> from transformers import RobertaTokenizer + + >>> tokenizer = RobertaTokenizer.from_pretrained("FacebookAI/roberta-base") + >>> tokenizer("Hello world")["input_ids"] + [0, 31414, 232, 2] + + >>> tokenizer(" Hello world")["input_ids"] + [0, 20920, 232, 2] + ``` + + You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer or when you + call it on some text, but since the model was not pretrained this way, it might yield a decrease in performance. + + + + When used with `is_split_into_words=True`, this tokenizer will add a space before each word (even the first one). + + + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + merges_file (`str`): + Path to the merges file. + errors (`str`, *optional*, defaults to `"replace"`): + Paradigm to follow when decoding bytes to UTF-8. See + [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information. + bos_token (`str`, *optional*, defaults to `""`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + + + + When building a sequence using special tokens, this is not the token that is used for the beginning of + sequence. The token used is the `cls_token`. + + + + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + sep_token (`str`, *optional*, defaults to `""`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + cls_token (`str`, *optional*, defaults to `""`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + mask_token (`str`, *optional*, defaults to `""`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + add_prefix_space (`bool`, *optional*, defaults to `False`): + Whether or not to add an initial space to the input. This allows to treat the leading word just as any + other word. (RoBERTa tokenizer detect beginning of words by the preceding space). + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + merges_file, + errors="replace", + bos_token="", + eos_token="", + sep_token="", + cls_token="", + unk_token="", + pad_token="", + mask_token="", + add_prefix_space=False, + **kwargs, + ): + bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token + pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token + eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token + unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token + sep_token = AddedToken(sep_token, lstrip=False, rstrip=False) if isinstance(sep_token, str) else sep_token + cls_token = AddedToken(cls_token, lstrip=False, rstrip=False) if isinstance(cls_token, str) else cls_token + + # Mask token behave like a normal word, i.e. include the space before it + mask_token = ( + AddedToken(mask_token, lstrip=True, rstrip=False, normalized=False) + if isinstance(mask_token, str) + else mask_token + ) + + # these special tokens are not part of the vocab.json, let's add them in the correct order + + with open(vocab_file, encoding="utf-8") as vocab_handle: + self.encoder = json.load(vocab_handle) + self.decoder = {v: k for k, v in self.encoder.items()} + self.errors = errors # how to handle errors in decoding + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + with open(merges_file, encoding="utf-8") as merges_handle: + bpe_merges = merges_handle.read().split("\n")[1:-1] + bpe_merges = [tuple(merge.split()) for merge in bpe_merges] + self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) + self.cache = {} + self.add_prefix_space = add_prefix_space + + # Should have added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions + self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") + + super().__init__( + errors=errors, + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + sep_token=sep_token, + cls_token=cls_token, + pad_token=pad_token, + mask_token=mask_token, + add_prefix_space=add_prefix_space, + **kwargs, + ) + + @property + def vocab_size(self): + return len(self.encoder) + + def get_vocab(self): + vocab = dict(self.encoder).copy() + vocab.update(self.added_tokens_encoder) + return vocab + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token) + pairs = get_pairs(word) + + if not pairs: + return token + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + except ValueError: + new_word.extend(word[i:]) + break + else: + new_word.extend(word[i:j]) + i = j + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = " ".join(word) + self.cache[token] = word + return word + + def _tokenize(self, text): + """Tokenize a string.""" + bpe_tokens = [] + for token in re.findall(self.pat, text): + token = "".join( + self.byte_encoder[b] for b in token.encode("utf-8") + ) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case) + bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" ")) + return bpe_tokens + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.encoder.get(token, self.encoder.get(self.unk_token)) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.decoder.get(index) + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + text = "".join(tokens) + text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors) + return text + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + merge_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"] + ) + + with open(vocab_file, "w", encoding="utf-8") as f: + f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n") + + index = 0 + with open(merge_file, "w", encoding="utf-8") as writer: + writer.write("#version: 0.2\n") + for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning( + f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive." + " Please check that the tokenizer is not corrupted!" + ) + index = token_index + writer.write(" ".join(bpe_tokens) + "\n") + index += 1 + + return vocab_file, merge_file + + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A RoBERTa sequence has the following format: + + - single sequence: ` X ` + - pair of sequences: ` A B ` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + if token_ids_1 is None: + return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + cls = [self.cls_token_id] + sep = [self.sep_token_id] + return cls + token_ids_0 + sep + sep + token_ids_1 + sep + + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + if token_ids_1 is None: + return [1] + ([0] * len(token_ids_0)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1] + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. RoBERTa does not + make use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of zeros. + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0] + + def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs): + add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space) + if (is_split_into_words or add_prefix_space) and (len(text) > 0 and not text[0].isspace()): + text = " " + text + return (text, kwargs) + + +__all__ = ["RobertaTokenizer"] diff --git a/docs/transformers/src/transformers/models/roberta/tokenization_roberta_fast.py b/docs/transformers/src/transformers/models/roberta/tokenization_roberta_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..cf288f4d8c7c11d7482ea1e38da101f588215894 --- /dev/null +++ b/docs/transformers/src/transformers/models/roberta/tokenization_roberta_fast.py @@ -0,0 +1,264 @@ +# coding=utf-8 +# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Tokenization classes for RoBERTa.""" + +import json +from typing import List, Optional, Tuple + +from tokenizers import processors + +from ...tokenization_utils_base import AddedToken, BatchEncoding +from ...tokenization_utils_fast import PreTrainedTokenizerFast +from ...utils import logging +from .tokenization_roberta import RobertaTokenizer + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"} + + +class RobertaTokenizerFast(PreTrainedTokenizerFast): + """ + Construct a "fast" RoBERTa tokenizer (backed by HuggingFace's *tokenizers* library), derived from the GPT-2 + tokenizer, using byte-level Byte-Pair-Encoding. + + This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will + be encoded differently whether it is at the beginning of the sentence (without space) or not: + + ```python + >>> from transformers import RobertaTokenizerFast + + >>> tokenizer = RobertaTokenizerFast.from_pretrained("FacebookAI/roberta-base") + >>> tokenizer("Hello world")["input_ids"] + [0, 31414, 232, 2] + + >>> tokenizer(" Hello world")["input_ids"] + [0, 20920, 232, 2] + ``` + + You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer or when you + call it on some text, but since the model was not pretrained this way, it might yield a decrease in performance. + + + + When used with `is_split_into_words=True`, this tokenizer needs to be instantiated with `add_prefix_space=True`. + + + + This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should + refer to this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + merges_file (`str`): + Path to the merges file. + errors (`str`, *optional*, defaults to `"replace"`): + Paradigm to follow when decoding bytes to UTF-8. See + [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information. + bos_token (`str`, *optional*, defaults to `""`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + + + + When building a sequence using special tokens, this is not the token that is used for the beginning of + sequence. The token used is the `cls_token`. + + + + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + sep_token (`str`, *optional*, defaults to `""`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + cls_token (`str`, *optional*, defaults to `""`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + mask_token (`str`, *optional*, defaults to `""`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + add_prefix_space (`bool`, *optional*, defaults to `False`): + Whether or not to add an initial space to the input. This allows to treat the leading word just as any + other word. (RoBERTa tokenizer detect beginning of words by the preceding space). + trim_offsets (`bool`, *optional*, defaults to `True`): + Whether the post processing step should trim offsets to avoid including whitespaces. + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + slow_tokenizer_class = RobertaTokenizer + + def __init__( + self, + vocab_file=None, + merges_file=None, + tokenizer_file=None, + errors="replace", + bos_token="", + eos_token="", + sep_token="", + cls_token="", + unk_token="", + pad_token="", + mask_token="", + add_prefix_space=False, + trim_offsets=True, + **kwargs, + ): + mask_token = ( + AddedToken(mask_token, lstrip=True, rstrip=False, normalized=False) + if isinstance(mask_token, str) + else mask_token + ) + super().__init__( + vocab_file, + merges_file, + tokenizer_file=tokenizer_file, + errors=errors, + bos_token=bos_token, + eos_token=eos_token, + sep_token=sep_token, + cls_token=cls_token, + unk_token=unk_token, + pad_token=pad_token, + mask_token=mask_token, + add_prefix_space=add_prefix_space, + trim_offsets=trim_offsets, + **kwargs, + ) + + tokenizer_component = "post_processor" + tokenizer_component_instance = getattr(self.backend_tokenizer, tokenizer_component, None) + if tokenizer_component_instance: + state = json.loads(tokenizer_component_instance.__getstate__()) + + # The lists 'sep' and 'cls' must be cased in tuples for the object `post_processor_class` + if "sep" in state: + state["sep"] = tuple(state["sep"]) + if "cls" in state: + state["cls"] = tuple(state["cls"]) + + changes_to_apply = False + + if state.get("add_prefix_space", add_prefix_space) != add_prefix_space: + state["add_prefix_space"] = add_prefix_space + changes_to_apply = True + + if state.get("trim_offsets", trim_offsets) != trim_offsets: + state["trim_offsets"] = trim_offsets + changes_to_apply = True + + if changes_to_apply: + component_class = getattr(processors, state.pop("type")) + new_value = component_class(**state) + setattr(self.backend_tokenizer, tokenizer_component, new_value) + + @property + def mask_token(self) -> str: + """ + `str`: Mask token, to use when training a model with masked-language modeling. Log an error if used while not + having been set. + + Roberta tokenizer has a special mask token to be usable in the fill-mask pipeline. The mask token will greedily + comprise the space before the **. + """ + if self._mask_token is None: + if self.verbose: + logger.error("Using mask_token, but it is not set yet.") + return None + return str(self._mask_token) + + @mask_token.setter + def mask_token(self, value): + """ + Overriding the default behavior of the mask token to have it eat the space before it. + + This is needed to preserve backward compatibility with all the previously used models based on Roberta. + """ + # Mask token behave like a normal word, i.e. include the space before it + # So we set lstrip to True + value = AddedToken(value, lstrip=True, rstrip=False) if isinstance(value, str) else value + self._mask_token = value + + def _batch_encode_plus(self, *args, **kwargs) -> BatchEncoding: + is_split_into_words = kwargs.get("is_split_into_words", False) + assert self.add_prefix_space or not is_split_into_words, ( + f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True " + "to use it with pretokenized inputs." + ) + + return super()._batch_encode_plus(*args, **kwargs) + + def _encode_plus(self, *args, **kwargs) -> BatchEncoding: + is_split_into_words = kwargs.get("is_split_into_words", False) + + assert self.add_prefix_space or not is_split_into_words, ( + f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True " + "to use it with pretokenized inputs." + ) + + return super()._encode_plus(*args, **kwargs) + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + files = self._tokenizer.model.save(save_directory, name=filename_prefix) + return tuple(files) + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + output = [self.bos_token_id] + token_ids_0 + [self.eos_token_id] + if token_ids_1 is None: + return output + + return output + [self.eos_token_id] + token_ids_1 + [self.eos_token_id] + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. RoBERTa does not + make use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of zeros. + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0] + + +__all__ = ["RobertaTokenizerFast"] diff --git a/docs/transformers/src/transformers/models/roberta_prelayernorm/__init__.py b/docs/transformers/src/transformers/models/roberta_prelayernorm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..208878343d24b3cec254a918f59d69841d1110c7 --- /dev/null +++ b/docs/transformers/src/transformers/models/roberta_prelayernorm/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_roberta_prelayernorm import * + from .modeling_flax_roberta_prelayernorm import * + from .modeling_roberta_prelayernorm import * + from .modeling_tf_roberta_prelayernorm import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/roberta_prelayernorm/configuration_roberta_prelayernorm.py b/docs/transformers/src/transformers/models/roberta_prelayernorm/configuration_roberta_prelayernorm.py new file mode 100644 index 0000000000000000000000000000000000000000..71ecbd4474d4f00495b73fe8dfd051f376ddb56c --- /dev/null +++ b/docs/transformers/src/transformers/models/roberta_prelayernorm/configuration_roberta_prelayernorm.py @@ -0,0 +1,157 @@ +# coding=utf-8 +# Copyright 2022 The Google AI Language Team Authors and The HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""RoBERTa-PreLayerNorm configuration""" + +from collections import OrderedDict +from typing import Mapping + +from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +# Copied from transformers.models.roberta.configuration_roberta.RobertaConfig with FacebookAI/roberta-base->andreasmadsen/efficient_mlm_m0.40,RoBERTa->RoBERTa-PreLayerNorm,Roberta->RobertaPreLayerNorm,roberta->roberta-prelayernorm +class RobertaPreLayerNormConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`RobertaPreLayerNormModel`] or a [`TFRobertaPreLayerNormModel`]. It is + used to instantiate a RoBERTa-PreLayerNorm model according to the specified arguments, defining the model architecture. + Instantiating a configuration with the defaults will yield a similar configuration to that of the RoBERTa-PreLayerNorm + [andreasmadsen/efficient_mlm_m0.40](https://huggingface.co/andreasmadsen/efficient_mlm_m0.40) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50265): + Vocabulary size of the RoBERTa-PreLayerNorm model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`RobertaPreLayerNormModel`] or [`TFRobertaPreLayerNormModel`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + type_vocab_size (`int`, *optional*, defaults to 2): + The vocabulary size of the `token_type_ids` passed when calling [`RobertaPreLayerNormModel`] or [`TFRobertaPreLayerNormModel`]. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + position_embedding_type (`str`, *optional*, defaults to `"absolute"`): + Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For + positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to + [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155). + For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models + with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658). + is_decoder (`bool`, *optional*, defaults to `False`): + Whether the model is used as a decoder or not. If `False`, the model is used as an encoder. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + classifier_dropout (`float`, *optional*): + The dropout ratio for the classification head. + + Examples: + + ```python + >>> from transformers import RobertaPreLayerNormConfig, RobertaPreLayerNormModel + + >>> # Initializing a RoBERTa-PreLayerNorm configuration + >>> configuration = RobertaPreLayerNormConfig() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = RobertaPreLayerNormModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "roberta-prelayernorm" + + def __init__( + self, + vocab_size=50265, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + layer_norm_eps=1e-12, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + position_embedding_type="absolute", + use_cache=True, + classifier_dropout=None, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.position_embedding_type = position_embedding_type + self.use_cache = use_cache + self.classifier_dropout = classifier_dropout + + +# Copied from transformers.models.roberta.configuration_roberta.RobertaOnnxConfig with Roberta->RobertaPreLayerNorm +class RobertaPreLayerNormOnnxConfig(OnnxConfig): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + if self.task == "multiple-choice": + dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"} + else: + dynamic_axis = {0: "batch", 1: "sequence"} + return OrderedDict( + [ + ("input_ids", dynamic_axis), + ("attention_mask", dynamic_axis), + ] + ) + + +__all__ = ["RobertaPreLayerNormConfig", "RobertaPreLayerNormOnnxConfig"] diff --git a/docs/transformers/src/transformers/models/roberta_prelayernorm/convert_roberta_prelayernorm_original_pytorch_checkpoint_to_pytorch.py b/docs/transformers/src/transformers/models/roberta_prelayernorm/convert_roberta_prelayernorm_original_pytorch_checkpoint_to_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..c4a6b03162f60bafc2231df547f1c5ee77714fe5 --- /dev/null +++ b/docs/transformers/src/transformers/models/roberta_prelayernorm/convert_roberta_prelayernorm_original_pytorch_checkpoint_to_pytorch.py @@ -0,0 +1,79 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert RoBERTa-PreLayerNorm checkpoint.""" + +import argparse + +import torch +from huggingface_hub import hf_hub_download + +from transformers import AutoTokenizer, RobertaPreLayerNormConfig, RobertaPreLayerNormForMaskedLM +from transformers.utils import logging + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + + +def convert_roberta_prelayernorm_checkpoint_to_pytorch(checkpoint_repo: str, pytorch_dump_folder_path: str): + """ + Copy/paste/tweak roberta_prelayernorm's weights to our BERT structure. + """ + # convert configuration + config = RobertaPreLayerNormConfig.from_pretrained( + checkpoint_repo, architectures=["RobertaPreLayerNormForMaskedLM"] + ) + + # convert state_dict + original_state_dict = torch.load( + hf_hub_download(repo_id=checkpoint_repo, filename="pytorch_model.bin"), weights_only=True + ) + state_dict = {} + for tensor_key, tensor_value in original_state_dict.items(): + # The transformer implementation gives the model a unique name, rather than overwiriting 'roberta' + if tensor_key.startswith("roberta."): + tensor_key = "roberta_prelayernorm." + tensor_key[len("roberta.") :] + + # The original implementation contains weights which are not used, remove them from the state_dict + if tensor_key.endswith(".self.LayerNorm.weight") or tensor_key.endswith(".self.LayerNorm.bias"): + continue + + state_dict[tensor_key] = tensor_value + + model = RobertaPreLayerNormForMaskedLM.from_pretrained( + pretrained_model_name_or_path=None, config=config, state_dict=state_dict + ) + model.save_pretrained(pytorch_dump_folder_path) + + # convert tokenizer + tokenizer = AutoTokenizer.from_pretrained(checkpoint_repo) + tokenizer.save_pretrained(pytorch_dump_folder_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--checkpoint-repo", + default=None, + type=str, + required=True, + help="Path the official PyTorch dump, e.g. 'andreasmadsen/efficient_mlm_m0.40'.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model." + ) + args = parser.parse_args() + convert_roberta_prelayernorm_checkpoint_to_pytorch(args.checkpoint_repo, args.pytorch_dump_folder_path) diff --git a/docs/transformers/src/transformers/models/roberta_prelayernorm/modeling_flax_roberta_prelayernorm.py b/docs/transformers/src/transformers/models/roberta_prelayernorm/modeling_flax_roberta_prelayernorm.py new file mode 100644 index 0000000000000000000000000000000000000000..1e691c047bd4fc23b7ef83a58f37407a21aa088f --- /dev/null +++ b/docs/transformers/src/transformers/models/roberta_prelayernorm/modeling_flax_roberta_prelayernorm.py @@ -0,0 +1,1527 @@ +# coding=utf-8 +# Copyright 2022 The Google Flax Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Flax RoBERTa-PreLayerNorm model.""" + +from typing import Callable, Optional, Tuple + +import flax.linen as nn +import jax +import jax.numpy as jnp +import numpy as np +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze +from flax.linen import combine_masks, make_causal_mask +from flax.linen import partitioning as nn_partitioning +from flax.linen.attention import dot_product_attention_weights +from flax.traverse_util import flatten_dict, unflatten_dict +from jax import lax + +from ...modeling_flax_outputs import ( + FlaxBaseModelOutputWithPastAndCrossAttentions, + FlaxBaseModelOutputWithPooling, + FlaxBaseModelOutputWithPoolingAndCrossAttentions, + FlaxCausalLMOutputWithCrossAttentions, + FlaxMaskedLMOutput, + FlaxMultipleChoiceModelOutput, + FlaxQuestionAnsweringModelOutput, + FlaxSequenceClassifierOutput, + FlaxTokenClassifierOutput, +) +from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring, overwrite_call_docstring +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging +from .configuration_roberta_prelayernorm import RobertaPreLayerNormConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "andreasmadsen/efficient_mlm_m0.40" +_CONFIG_FOR_DOC = "RobertaPreLayerNormConfig" + +remat = nn_partitioning.remat + + +# Copied from transformers.models.roberta.modeling_flax_roberta.create_position_ids_from_input_ids +def create_position_ids_from_input_ids(input_ids, padding_idx): + """ + Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols + are ignored. This is modified from fairseq's `utils.make_positions`. + + Args: + input_ids: jnp.ndarray + padding_idx: int + + Returns: jnp.ndarray + """ + # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. + mask = (input_ids != padding_idx).astype("i4") + + if mask.ndim > 2: + mask = mask.reshape((-1, mask.shape[-1])) + incremental_indices = jnp.cumsum(mask, axis=1).astype("i4") * mask + incremental_indices = incremental_indices.reshape(input_ids.shape) + else: + incremental_indices = jnp.cumsum(mask, axis=1).astype("i4") * mask + + return incremental_indices.astype("i4") + padding_idx + + +ROBERTA_PRELAYERNORM_START_DOCSTRING = r""" + + This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading, saving and converting weights from PyTorch models) + + This model is also a + [flax.linen.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) subclass. Use it as + a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and + behavior. + + Finally, this model supports inherent JAX features such as: + + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + + Parameters: + config ([`RobertaPreLayerNormConfig`]): Model configuration class with all the parameters of the + model. Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. +""" + +ROBERTA_PRELAYERNORM_INPUTS_DOCSTRING = r""" + Args: + input_ids (`numpy.ndarray` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`numpy.ndarray` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`numpy.ndarray` of shape `({0})`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`numpy.ndarray` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + head_mask (`numpy.ndarray` of shape `({0})`, `optional): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEmbeddings with Bert->RobertaPreLayerNorm +class FlaxRobertaPreLayerNormEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + config: RobertaPreLayerNormConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.word_embeddings = nn.Embed( + self.config.vocab_size, + self.config.hidden_size, + embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), + dtype=self.dtype, + ) + self.position_embeddings = nn.Embed( + self.config.max_position_embeddings, + self.config.hidden_size, + embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), + dtype=self.dtype, + ) + self.token_type_embeddings = nn.Embed( + self.config.type_vocab_size, + self.config.hidden_size, + embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), + dtype=self.dtype, + ) + self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + + def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True): + # Embed + inputs_embeds = self.word_embeddings(input_ids.astype("i4")) + position_embeds = self.position_embeddings(position_ids.astype("i4")) + token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4")) + + # Sum all embeddings + hidden_states = inputs_embeds + token_type_embeddings + position_embeds + + # Layer Norm + hidden_states = self.LayerNorm(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + return hidden_states + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfAttention with Bert->RobertaPreLayerNorm +class FlaxRobertaPreLayerNormSelfAttention(nn.Module): + config: RobertaPreLayerNormConfig + causal: bool = False + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.head_dim = self.config.hidden_size // self.config.num_attention_heads + if self.config.hidden_size % self.config.num_attention_heads != 0: + raise ValueError( + "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads` " + " : {self.config.num_attention_heads}" + ) + + self.query = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + ) + self.key = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + ) + self.value = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + ) + + if self.causal: + self.causal_mask = make_causal_mask( + jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool" + ) + + def _split_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.config.num_attention_heads, self.head_dim)) + + def _merge_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.config.hidden_size,)) + + @nn.compact + # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention._concatenate_to_cache + def _concatenate_to_cache(self, key, value, query, attention_mask): + """ + This function takes projected key, value states from a single input token and concatenates the states to cached + states from previous steps. This function is slightly adapted from the official Flax repository: + https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 + """ + # detect if we're initializing by absence of existing cache data. + is_initialized = self.has_variable("cache", "cached_key") + cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) + cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) + cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) + + if is_initialized: + *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape + # update key, value caches with our new 1d spatial slices + cur_index = cache_index.value + indices = (0,) * len(batch_dims) + (cur_index, 0, 0) + key = lax.dynamic_update_slice(cached_key.value, key, indices) + value = lax.dynamic_update_slice(cached_value.value, value, indices) + cached_key.value = key + cached_value.value = value + num_updated_cache_vectors = query.shape[1] + cache_index.value = cache_index.value + num_updated_cache_vectors + # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements. + pad_mask = jnp.broadcast_to( + jnp.arange(max_length) < cur_index + num_updated_cache_vectors, + tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), + ) + attention_mask = combine_masks(pad_mask, attention_mask) + return key, value, attention_mask + + def __call__( + self, + hidden_states, + attention_mask, + layer_head_mask, + key_value_states: Optional[jnp.ndarray] = None, + init_cache: bool = False, + deterministic=True, + output_attentions: bool = False, + ): + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + batch_size = hidden_states.shape[0] + + # get query proj + query_states = self.query(hidden_states) + # get key, value proj + if is_cross_attention: + # cross_attentions + key_states = self.key(key_value_states) + value_states = self.value(key_value_states) + else: + # self_attention + key_states = self.key(hidden_states) + value_states = self.value(hidden_states) + + query_states = self._split_heads(query_states) + key_states = self._split_heads(key_states) + value_states = self._split_heads(value_states) + + # handle cache prepare causal attention mask + if self.causal: + query_length, key_length = query_states.shape[1], key_states.shape[1] + if self.has_variable("cache", "cached_key"): + mask_shift = self.variables["cache"]["cache_index"] + max_decoder_length = self.variables["cache"]["cached_key"].shape[1] + causal_mask = lax.dynamic_slice( + self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length) + ) + else: + causal_mask = self.causal_mask[:, :, :query_length, :key_length] + causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) + + # combine masks if needed + if attention_mask is not None and self.causal: + attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) + attention_mask = combine_masks(attention_mask, causal_mask) + elif self.causal: + attention_mask = causal_mask + elif attention_mask is not None: + attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) + + # During fast autoregressive decoding, we feed one position at a time, + # and cache the keys and values step by step. + if self.causal and (self.has_variable("cache", "cached_key") or init_cache): + key_states, value_states, attention_mask = self._concatenate_to_cache( + key_states, value_states, query_states, attention_mask + ) + + # Convert the boolean attention mask to an attention bias. + if attention_mask is not None: + # attention mask in the form of attention bias + attention_bias = lax.select( + attention_mask > 0, + jnp.full(attention_mask.shape, 0.0).astype(self.dtype), + jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), + ) + else: + attention_bias = None + + dropout_rng = None + if not deterministic and self.config.attention_probs_dropout_prob > 0.0: + dropout_rng = self.make_rng("dropout") + + attn_weights = dot_product_attention_weights( + query_states, + key_states, + bias=attention_bias, + dropout_rng=dropout_rng, + dropout_rate=self.config.attention_probs_dropout_prob, + broadcast_dropout=True, + deterministic=deterministic, + dtype=self.dtype, + precision=None, + ) + + # Mask heads if we want to + if layer_head_mask is not None: + attn_weights = jnp.einsum("...hqk,h->...hqk", attn_weights, layer_head_mask) + + attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) + attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,)) + + outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) + return outputs + + +class FlaxRobertaPreLayerNormSelfOutput(nn.Module): + config: RobertaPreLayerNormConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dense = nn.Dense( + self.config.hidden_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + ) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + + def __call__(self, hidden_states, input_tensor, deterministic: bool = True): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + hidden_states = hidden_states + input_tensor + return hidden_states + + +class FlaxRobertaPreLayerNormAttention(nn.Module): + config: RobertaPreLayerNormConfig + causal: bool = False + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.self = FlaxRobertaPreLayerNormSelfAttention(self.config, causal=self.causal, dtype=self.dtype) + self.output = FlaxRobertaPreLayerNormSelfOutput(self.config, dtype=self.dtype) + self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + + def __call__( + self, + hidden_states, + attention_mask, + layer_head_mask, + key_value_states=None, + init_cache=False, + deterministic=True, + output_attentions: bool = False, + ): + hidden_states_pre_layer_norm = self.LayerNorm(hidden_states) + # Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length) + # FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable + # with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length) + attn_outputs = self.self( + hidden_states_pre_layer_norm, + attention_mask, + layer_head_mask=layer_head_mask, + key_value_states=key_value_states, + init_cache=init_cache, + deterministic=deterministic, + output_attentions=output_attentions, + ) + attn_output = attn_outputs[0] + hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_outputs[1],) + + return outputs + + +class FlaxRobertaPreLayerNormIntermediate(nn.Module): + config: RobertaPreLayerNormConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + self.dense = nn.Dense( + self.config.intermediate_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + ) + self.activation = ACT2FN[self.config.hidden_act] + + def __call__(self, hidden_states): + hidden_states = self.LayerNorm(hidden_states) + hidden_states = self.dense(hidden_states) + hidden_states = self.activation(hidden_states) + return hidden_states + + +class FlaxRobertaPreLayerNormOutput(nn.Module): + config: RobertaPreLayerNormConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dense = nn.Dense( + self.config.hidden_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + ) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + + def __call__(self, hidden_states, attention_output, deterministic: bool = True): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + hidden_states = hidden_states + attention_output + return hidden_states + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayer with Bert->RobertaPreLayerNorm +class FlaxRobertaPreLayerNormLayer(nn.Module): + config: RobertaPreLayerNormConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.attention = FlaxRobertaPreLayerNormAttention(self.config, causal=self.config.is_decoder, dtype=self.dtype) + self.intermediate = FlaxRobertaPreLayerNormIntermediate(self.config, dtype=self.dtype) + self.output = FlaxRobertaPreLayerNormOutput(self.config, dtype=self.dtype) + if self.config.add_cross_attention: + self.crossattention = FlaxRobertaPreLayerNormAttention(self.config, causal=False, dtype=self.dtype) + + def __call__( + self, + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + deterministic: bool = True, + output_attentions: bool = False, + ): + # Self Attention + attention_outputs = self.attention( + hidden_states, + attention_mask, + layer_head_mask=layer_head_mask, + init_cache=init_cache, + deterministic=deterministic, + output_attentions=output_attentions, + ) + attention_output = attention_outputs[0] + + # Cross-Attention Block + if encoder_hidden_states is not None: + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask=encoder_attention_mask, + layer_head_mask=layer_head_mask, + key_value_states=encoder_hidden_states, + deterministic=deterministic, + output_attentions=output_attentions, + ) + attention_output = cross_attention_outputs[0] + + hidden_states = self.intermediate(attention_output) + hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attention_outputs[1],) + if encoder_hidden_states is not None: + outputs += (cross_attention_outputs[1],) + return outputs + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayerCollection with Bert->RobertaPreLayerNorm +class FlaxRobertaPreLayerNormLayerCollection(nn.Module): + config: RobertaPreLayerNormConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False + + def setup(self): + if self.gradient_checkpointing: + FlaxRobertaPreLayerNormCheckpointLayer = remat(FlaxRobertaPreLayerNormLayer, static_argnums=(5, 6, 7)) + self.layers = [ + FlaxRobertaPreLayerNormCheckpointLayer(self.config, name=str(i), dtype=self.dtype) + for i in range(self.config.num_hidden_layers) + ] + else: + self.layers = [ + FlaxRobertaPreLayerNormLayer(self.config, name=str(i), dtype=self.dtype) + for i in range(self.config.num_hidden_layers) + ] + + def __call__( + self, + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + all_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + + # Check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + if head_mask.shape[0] != (len(self.layers)): + raise ValueError( + f"The head_mask should be specified for {len(self.layers)} layers, but it is for " + f" {head_mask.shape[0]}." + ) + + for i, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = layer( + hidden_states, + attention_mask, + head_mask[i] if head_mask is not None else None, + encoder_hidden_states, + encoder_attention_mask, + init_cache, + deterministic, + output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + outputs = (hidden_states, all_hidden_states, all_attentions, all_cross_attentions) + + if not return_dict: + return tuple(v for v in outputs if v is not None) + + return FlaxBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ) + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEncoder with Bert->RobertaPreLayerNorm +class FlaxRobertaPreLayerNormEncoder(nn.Module): + config: RobertaPreLayerNormConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False + + def setup(self): + self.layer = FlaxRobertaPreLayerNormLayerCollection( + self.config, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) + + def __call__( + self, + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + return self.layer( + hidden_states, + attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + init_cache=init_cache, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPooler with Bert->RobertaPreLayerNorm +class FlaxRobertaPreLayerNormPooler(nn.Module): + config: RobertaPreLayerNormConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dense = nn.Dense( + self.config.hidden_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + ) + + def __call__(self, hidden_states): + cls_hidden_state = hidden_states[:, 0] + cls_hidden_state = self.dense(cls_hidden_state) + return nn.tanh(cls_hidden_state) + + +# Copied from transformers.models.roberta.modeling_flax_roberta.FlaxRobertaLMHead with Roberta->RobertaPreLayerNorm +class FlaxRobertaPreLayerNormLMHead(nn.Module): + config: RobertaPreLayerNormConfig + dtype: jnp.dtype = jnp.float32 + bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros + + def setup(self): + self.dense = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + ) + self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + self.decoder = nn.Dense( + self.config.vocab_size, + dtype=self.dtype, + use_bias=False, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + ) + self.bias = self.param("bias", self.bias_init, (self.config.vocab_size,)) + + def __call__(self, hidden_states, shared_embedding=None): + hidden_states = self.dense(hidden_states) + hidden_states = ACT2FN["gelu"](hidden_states) + hidden_states = self.layer_norm(hidden_states) + + if shared_embedding is not None: + hidden_states = self.decoder.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) + else: + hidden_states = self.decoder(hidden_states) + + bias = jnp.asarray(self.bias, self.dtype) + hidden_states += bias + return hidden_states + + +# Copied from transformers.models.roberta.modeling_flax_roberta.FlaxRobertaClassificationHead with Roberta->RobertaPreLayerNorm +class FlaxRobertaPreLayerNormClassificationHead(nn.Module): + config: RobertaPreLayerNormConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.dense = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + ) + classifier_dropout = ( + self.config.classifier_dropout + if self.config.classifier_dropout is not None + else self.config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(rate=classifier_dropout) + self.out_proj = nn.Dense( + self.config.num_labels, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + ) + + def __call__(self, hidden_states, deterministic=True): + hidden_states = hidden_states[:, 0, :] # take token (equiv. to [CLS]) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + hidden_states = self.dense(hidden_states) + hidden_states = nn.tanh(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + hidden_states = self.out_proj(hidden_states) + return hidden_states + + +# Copied from transformers.models.roberta.modeling_flax_roberta.FlaxRobertaPreTrainedModel with ROBERTA->ROBERTA_PRELAYERNORM,Roberta->RobertaPreLayerNorm,roberta->roberta_prelayernorm +class FlaxRobertaPreLayerNormPreTrainedModel(FlaxPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = RobertaPreLayerNormConfig + base_model_prefix = "roberta_prelayernorm" + + module_class: nn.Module = None + + def __init__( + self, + config: RobertaPreLayerNormConfig, + input_shape: Tuple = (1, 1), + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, + gradient_checkpointing: bool = False, + **kwargs, + ): + module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + + # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.enable_gradient_checkpointing + def enable_gradient_checkpointing(self): + self._module = self.module_class( + config=self.config, + dtype=self.dtype, + gradient_checkpointing=True, + ) + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: + # init input tensors + input_ids = jnp.zeros(input_shape, dtype="i4") + token_type_ids = jnp.ones_like(input_ids) + position_ids = create_position_ids_from_input_ids(input_ids, self.config.pad_token_id) + attention_mask = jnp.ones_like(input_ids) + head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads)) + + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + if self.config.add_cross_attention: + encoder_hidden_states = jnp.zeros(input_shape + (self.config.hidden_size,)) + encoder_attention_mask = attention_mask + module_init_outputs = self.module.init( + rngs, + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + return_dict=False, + ) + else: + module_init_outputs = self.module.init( + rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False + ) + + random_params = module_init_outputs["params"] + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + + # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartDecoderPreTrainedModel.init_cache + def init_cache(self, batch_size, max_length): + r""" + Args: + batch_size (`int`): + batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. + max_length (`int`): + maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized + cache. + """ + # init input variables to retrieve cache + input_ids = jnp.ones((batch_size, max_length), dtype="i4") + attention_mask = jnp.ones_like(input_ids, dtype="i4") + position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) + + init_variables = self.module.init( + jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True + ) + return unfreeze(init_variables["cache"]) + + @add_start_docstrings_to_model_forward(ROBERTA_PRELAYERNORM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + def __call__( + self, + input_ids, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + params: dict = None, + dropout_rng: jax.random.PRNGKey = None, + train: bool = False, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + past_key_values: dict = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + # init input tensors if not passed + if token_type_ids is None: + token_type_ids = jnp.zeros_like(input_ids) + + if position_ids is None: + position_ids = create_position_ids_from_input_ids(input_ids, self.config.pad_token_id) + + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + + if head_mask is None: + head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads)) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + inputs = {"params": params or self.params} + + if self.config.add_cross_attention: + # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed + # down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be + # changed by FlaxRobertaPreLayerNormAttention module + if past_key_values: + inputs["cache"] = past_key_values + mutable = ["cache"] + else: + mutable = False + + outputs = self.module.apply( + inputs, + jnp.array(input_ids, dtype="i4"), + jnp.array(attention_mask, dtype="i4"), + token_type_ids=jnp.array(token_type_ids, dtype="i4"), + position_ids=jnp.array(position_ids, dtype="i4"), + head_mask=jnp.array(head_mask, dtype="i4"), + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + deterministic=not train, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + rngs=rngs, + mutable=mutable, + ) + + # add updated cache to model output + if past_key_values is not None and return_dict: + outputs, past_key_values = outputs + outputs["past_key_values"] = unfreeze(past_key_values["cache"]) + return outputs + elif past_key_values is not None and not return_dict: + outputs, past_key_values = outputs + outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:] + + else: + outputs = self.module.apply( + inputs, + jnp.array(input_ids, dtype="i4"), + jnp.array(attention_mask, dtype="i4"), + token_type_ids=jnp.array(token_type_ids, dtype="i4"), + position_ids=jnp.array(position_ids, dtype="i4"), + head_mask=jnp.array(head_mask, dtype="i4"), + deterministic=not train, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + rngs=rngs, + ) + + return outputs + + +class FlaxRobertaPreLayerNormModule(nn.Module): + config: RobertaPreLayerNormConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + add_pooling_layer: bool = True + gradient_checkpointing: bool = False + + def setup(self): + self.embeddings = FlaxRobertaPreLayerNormEmbeddings(self.config, dtype=self.dtype) + self.encoder = FlaxRobertaPreLayerNormEncoder( + self.config, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) + self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + self.pooler = FlaxRobertaPreLayerNormPooler(self.config, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids: Optional[jnp.ndarray] = None, + position_ids: Optional[jnp.ndarray] = None, + head_mask: Optional[jnp.ndarray] = None, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # make sure `token_type_ids` is correctly initialized when not passed + if token_type_ids is None: + token_type_ids = jnp.zeros_like(input_ids) + + # make sure `position_ids` is correctly initialized when not passed + if position_ids is None: + position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) + + hidden_states = self.embeddings( + input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic + ) + outputs = self.encoder( + hidden_states, + attention_mask, + head_mask=head_mask, + deterministic=deterministic, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + init_cache=init_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = outputs[0] + hidden_states = self.LayerNorm(hidden_states) + pooled = self.pooler(hidden_states) if self.add_pooling_layer else None + + if not return_dict: + # if pooled is None, don't return it + if pooled is None: + return (hidden_states,) + outputs[1:] + return (hidden_states, pooled) + outputs[1:] + + return FlaxBaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=hidden_states, + pooler_output=pooled, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +@add_start_docstrings( + "The bare RoBERTa-PreLayerNorm Model transformer outputting raw hidden-states without any specific head on top.", + ROBERTA_PRELAYERNORM_START_DOCSTRING, +) +# Copied from transformers.models.roberta.modeling_flax_roberta.FlaxRobertaModel with Roberta->RobertaPreLayerNorm +class FlaxRobertaPreLayerNormModel(FlaxRobertaPreLayerNormPreTrainedModel): + module_class = FlaxRobertaPreLayerNormModule + + +append_call_sample_docstring( + FlaxRobertaPreLayerNormModel, + _CHECKPOINT_FOR_DOC, + FlaxBaseModelOutputWithPooling, + _CONFIG_FOR_DOC, +) + + +# Copied from transformers.models.roberta.modeling_flax_roberta.FlaxRobertaForMaskedLMModule with Roberta->RobertaPreLayerNorm,roberta->roberta_prelayernorm +class FlaxRobertaPreLayerNormForMaskedLMModule(nn.Module): + config: RobertaPreLayerNormConfig + dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False + + def setup(self): + self.roberta_prelayernorm = FlaxRobertaPreLayerNormModule( + config=self.config, + add_pooling_layer=False, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) + self.lm_head = FlaxRobertaPreLayerNormLMHead(config=self.config, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # Model + outputs = self.roberta_prelayernorm( + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + if self.config.tie_word_embeddings: + shared_embedding = self.roberta_prelayernorm.variables["params"]["embeddings"]["word_embeddings"][ + "embedding" + ] + else: + shared_embedding = None + + # Compute the prediction scores + logits = self.lm_head(hidden_states, shared_embedding=shared_embedding) + + if not return_dict: + return (logits,) + outputs[1:] + + return FlaxMaskedLMOutput( + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """RoBERTa-PreLayerNorm Model with a `language modeling` head on top.""", ROBERTA_PRELAYERNORM_START_DOCSTRING +) +# Copied from transformers.models.roberta.modeling_flax_roberta.FlaxRobertaForMaskedLM with Roberta->RobertaPreLayerNorm +class FlaxRobertaPreLayerNormForMaskedLM(FlaxRobertaPreLayerNormPreTrainedModel): + module_class = FlaxRobertaPreLayerNormForMaskedLMModule + + +append_call_sample_docstring( + FlaxRobertaPreLayerNormForMaskedLM, + _CHECKPOINT_FOR_DOC, + FlaxBaseModelOutputWithPooling, + _CONFIG_FOR_DOC, + mask="", +) + + +# Copied from transformers.models.roberta.modeling_flax_roberta.FlaxRobertaForSequenceClassificationModule with Roberta->RobertaPreLayerNorm,roberta->roberta_prelayernorm +class FlaxRobertaPreLayerNormForSequenceClassificationModule(nn.Module): + config: RobertaPreLayerNormConfig + dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False + + def setup(self): + self.roberta_prelayernorm = FlaxRobertaPreLayerNormModule( + config=self.config, + dtype=self.dtype, + add_pooling_layer=False, + gradient_checkpointing=self.gradient_checkpointing, + ) + self.classifier = FlaxRobertaPreLayerNormClassificationHead(config=self.config, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # Model + outputs = self.roberta_prelayernorm( + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + logits = self.classifier(sequence_output, deterministic=deterministic) + + if not return_dict: + return (logits,) + outputs[1:] + + return FlaxSequenceClassifierOutput( + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + RobertaPreLayerNorm Model transformer with a sequence classification/regression head on top (a linear layer on top + of the pooled output) e.g. for GLUE tasks. + """, + ROBERTA_PRELAYERNORM_START_DOCSTRING, +) +# Copied from transformers.models.roberta.modeling_flax_roberta.FlaxRobertaForSequenceClassification with Roberta->RobertaPreLayerNorm +class FlaxRobertaPreLayerNormForSequenceClassification(FlaxRobertaPreLayerNormPreTrainedModel): + module_class = FlaxRobertaPreLayerNormForSequenceClassificationModule + + +append_call_sample_docstring( + FlaxRobertaPreLayerNormForSequenceClassification, + _CHECKPOINT_FOR_DOC, + FlaxSequenceClassifierOutput, + _CONFIG_FOR_DOC, +) + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForMultipleChoiceModule with Bert->RobertaPreLayerNorm, with self.bert->self.roberta_prelayernorm +class FlaxRobertaPreLayerNormForMultipleChoiceModule(nn.Module): + config: RobertaPreLayerNormConfig + dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False + + def setup(self): + self.roberta_prelayernorm = FlaxRobertaPreLayerNormModule( + config=self.config, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + self.classifier = nn.Dense(1, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + num_choices = input_ids.shape[1] + input_ids = input_ids.reshape(-1, input_ids.shape[-1]) if input_ids is not None else None + attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1]) if attention_mask is not None else None + token_type_ids = token_type_ids.reshape(-1, token_type_ids.shape[-1]) if token_type_ids is not None else None + position_ids = position_ids.reshape(-1, position_ids.shape[-1]) if position_ids is not None else None + + # Model + outputs = self.roberta_prelayernorm( + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + pooled_output = self.dropout(pooled_output, deterministic=deterministic) + logits = self.classifier(pooled_output) + + reshaped_logits = logits.reshape(-1, num_choices) + + if not return_dict: + return (reshaped_logits,) + outputs[2:] + + return FlaxMultipleChoiceModelOutput( + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + RobertaPreLayerNorm Model with a multiple choice classification head on top (a linear layer on top of the pooled + output and a softmax) e.g. for RocStories/SWAG tasks. + """, + ROBERTA_PRELAYERNORM_START_DOCSTRING, +) +# Copied from transformers.models.roberta.modeling_flax_roberta.FlaxRobertaForMultipleChoice with Roberta->RobertaPreLayerNorm +class FlaxRobertaPreLayerNormForMultipleChoice(FlaxRobertaPreLayerNormPreTrainedModel): + module_class = FlaxRobertaPreLayerNormForMultipleChoiceModule + + +overwrite_call_docstring( + FlaxRobertaPreLayerNormForMultipleChoice, + ROBERTA_PRELAYERNORM_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"), +) +append_call_sample_docstring( + FlaxRobertaPreLayerNormForMultipleChoice, + _CHECKPOINT_FOR_DOC, + FlaxMultipleChoiceModelOutput, + _CONFIG_FOR_DOC, +) + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForTokenClassificationModule with Bert->RobertaPreLayerNorm, with self.bert->self.roberta_prelayernorm +class FlaxRobertaPreLayerNormForTokenClassificationModule(nn.Module): + config: RobertaPreLayerNormConfig + dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False + + def setup(self): + self.roberta_prelayernorm = FlaxRobertaPreLayerNormModule( + config=self.config, + dtype=self.dtype, + add_pooling_layer=False, + gradient_checkpointing=self.gradient_checkpointing, + ) + classifier_dropout = ( + self.config.classifier_dropout + if self.config.classifier_dropout is not None + else self.config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(rate=classifier_dropout) + self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # Model + outputs = self.roberta_prelayernorm( + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + logits = self.classifier(hidden_states) + + if not return_dict: + return (logits,) + outputs[1:] + + return FlaxTokenClassifierOutput( + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + RobertaPreLayerNorm Model with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + ROBERTA_PRELAYERNORM_START_DOCSTRING, +) +# Copied from transformers.models.roberta.modeling_flax_roberta.FlaxRobertaForTokenClassification with Roberta->RobertaPreLayerNorm +class FlaxRobertaPreLayerNormForTokenClassification(FlaxRobertaPreLayerNormPreTrainedModel): + module_class = FlaxRobertaPreLayerNormForTokenClassificationModule + + +append_call_sample_docstring( + FlaxRobertaPreLayerNormForTokenClassification, + _CHECKPOINT_FOR_DOC, + FlaxTokenClassifierOutput, + _CONFIG_FOR_DOC, +) + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForQuestionAnsweringModule with Bert->RobertaPreLayerNorm, with self.bert->self.roberta_prelayernorm +class FlaxRobertaPreLayerNormForQuestionAnsweringModule(nn.Module): + config: RobertaPreLayerNormConfig + dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False + + def setup(self): + self.roberta_prelayernorm = FlaxRobertaPreLayerNormModule( + config=self.config, + dtype=self.dtype, + add_pooling_layer=False, + gradient_checkpointing=self.gradient_checkpointing, + ) + self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # Model + outputs = self.roberta_prelayernorm( + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + + logits = self.qa_outputs(hidden_states) + start_logits, end_logits = jnp.split(logits, self.config.num_labels, axis=-1) + start_logits = start_logits.squeeze(-1) + end_logits = end_logits.squeeze(-1) + + if not return_dict: + return (start_logits, end_logits) + outputs[1:] + + return FlaxQuestionAnsweringModelOutput( + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + RobertaPreLayerNorm Model with a span classification head on top for extractive question-answering tasks like SQuAD + (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + ROBERTA_PRELAYERNORM_START_DOCSTRING, +) +# Copied from transformers.models.roberta.modeling_flax_roberta.FlaxRobertaForQuestionAnswering with Roberta->RobertaPreLayerNorm +class FlaxRobertaPreLayerNormForQuestionAnswering(FlaxRobertaPreLayerNormPreTrainedModel): + module_class = FlaxRobertaPreLayerNormForQuestionAnsweringModule + + +append_call_sample_docstring( + FlaxRobertaPreLayerNormForQuestionAnswering, + _CHECKPOINT_FOR_DOC, + FlaxQuestionAnsweringModelOutput, + _CONFIG_FOR_DOC, +) + + +# Copied from transformers.models.roberta.modeling_flax_roberta.FlaxRobertaForCausalLMModule with Roberta->RobertaPreLayerNorm,roberta->roberta_prelayernorm +class FlaxRobertaPreLayerNormForCausalLMModule(nn.Module): + config: RobertaPreLayerNormConfig + dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False + + def setup(self): + self.roberta_prelayernorm = FlaxRobertaPreLayerNormModule( + config=self.config, + add_pooling_layer=False, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) + self.lm_head = FlaxRobertaPreLayerNormLMHead(config=self.config, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + position_ids, + token_type_ids: Optional[jnp.ndarray] = None, + head_mask: Optional[jnp.ndarray] = None, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # Model + outputs = self.roberta_prelayernorm( + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + init_cache=init_cache, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + if self.config.tie_word_embeddings: + shared_embedding = self.roberta_prelayernorm.variables["params"]["embeddings"]["word_embeddings"][ + "embedding" + ] + else: + shared_embedding = None + + # Compute the prediction scores + logits = self.lm_head(hidden_states, shared_embedding=shared_embedding) + + if not return_dict: + return (logits,) + outputs[1:] + + return FlaxCausalLMOutputWithCrossAttentions( + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +@add_start_docstrings( + """ + RobertaPreLayerNorm Model with a language modeling head on top (a linear layer on top of the hidden-states output) + e.g for autoregressive tasks. + """, + ROBERTA_PRELAYERNORM_START_DOCSTRING, +) +# Copied from transformers.models.roberta.modeling_flax_roberta.FlaxRobertaForCausalLM with Roberta->RobertaPreLayerNorm +class FlaxRobertaPreLayerNormForCausalLM(FlaxRobertaPreLayerNormPreTrainedModel): + module_class = FlaxRobertaPreLayerNormForCausalLMModule + + def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None): + # initializing the cache + batch_size, seq_length = input_ids.shape + + past_key_values = self.init_cache(batch_size, max_length) + # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. + # But since the decoder uses a causal mask, those positions are masked anyway. + # Thus, we can create a single static attention_mask here, which is more efficient for compilation + extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") + if attention_mask is not None: + position_ids = attention_mask.cumsum(axis=-1) - 1 + extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0)) + else: + position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) + + return { + "past_key_values": past_key_values, + "attention_mask": extended_attention_mask, + "position_ids": position_ids, + } + + def update_inputs_for_generation(self, model_outputs, model_kwargs): + model_kwargs["past_key_values"] = model_outputs.past_key_values + model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1 + return model_kwargs + + +append_call_sample_docstring( + FlaxRobertaPreLayerNormForCausalLM, + _CHECKPOINT_FOR_DOC, + FlaxCausalLMOutputWithCrossAttentions, + _CONFIG_FOR_DOC, +) + + +__all__ = [ + "FlaxRobertaPreLayerNormForCausalLM", + "FlaxRobertaPreLayerNormForMaskedLM", + "FlaxRobertaPreLayerNormForMultipleChoice", + "FlaxRobertaPreLayerNormForQuestionAnswering", + "FlaxRobertaPreLayerNormForSequenceClassification", + "FlaxRobertaPreLayerNormForTokenClassification", + "FlaxRobertaPreLayerNormModel", + "FlaxRobertaPreLayerNormPreTrainedModel", +] diff --git a/docs/transformers/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py b/docs/transformers/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py new file mode 100644 index 0000000000000000000000000000000000000000..6b0c40b222c19685f269609d44cc5a39fb553696 --- /dev/null +++ b/docs/transformers/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py @@ -0,0 +1,1558 @@ +# coding=utf-8 +# Copyright 2022 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch RoBERTa-PreLayerNorm model.""" + +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN, gelu +from ...generation import GenerationMixin +from ...modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + MaskedLMOutput, + MultipleChoiceModelOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from .configuration_roberta_prelayernorm import RobertaPreLayerNormConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "andreasmadsen/efficient_mlm_m0.40" +_CONFIG_FOR_DOC = "RobertaPreLayerNormConfig" + + +# Copied from transformers.models.roberta.modeling_roberta.RobertaEmbeddings with Roberta->RobertaPreLayerNorm +class RobertaPreLayerNormEmbeddings(nn.Module): + """ + Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. + """ + + # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.__init__ + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + self.register_buffer( + "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False + ) + + # End copy + self.padding_idx = config.pad_token_id + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx + ) + + def forward( + self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 + ): + if position_ids is None: + if input_ids is not None: + # Create the position ids from the input token ids. Any padded tokens remain padded. + position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length) + else: + position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds) + + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs + # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves + # issue #5664 + if token_type_ids is None: + if hasattr(self, "token_type_ids"): + buffered_token_type_ids = self.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + def create_position_ids_from_inputs_embeds(self, inputs_embeds): + """ + We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids. + + Args: + inputs_embeds: torch.Tensor + + Returns: torch.Tensor + """ + input_shape = inputs_embeds.size()[:-1] + sequence_length = input_shape[1] + + position_ids = torch.arange( + self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device + ) + return position_ids.unsqueeze(0).expand(input_shape) + + +# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->RobertaPreLayerNorm +class RobertaPreLayerNormSelfAttention(nn.Module): + def __init__(self, config, position_embedding_type=None): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = position_embedding_type or getattr( + config, "position_embedding_type", "absolute" + ) + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) + + self.is_decoder = config.is_decoder + + def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_layer = past_key_value[0] + value_layer = past_key_value[1] + attention_mask = encoder_attention_mask + elif is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + use_cache = past_key_value is not None + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + query_length, key_length = query_layer.shape[2], key_layer.shape[2] + if use_cache: + position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( + -1, 1 + ) + else: + position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in RobertaPreLayerNormModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + +class RobertaPreLayerNormSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = hidden_states + input_tensor + return hidden_states + + +class RobertaPreLayerNormAttention(nn.Module): + def __init__(self, config, position_embedding_type=None): + super().__init__() + self.self = RobertaPreLayerNormSelfAttention(config, position_embedding_type=position_embedding_type) + self.output = RobertaPreLayerNormSelfOutput(config) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.pruned_heads = set() + + # Copied from transformers.models.bert.modeling_bert.BertAttention.prune_heads + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + hidden_states_pre_layer_norm = self.LayerNorm(hidden_states) + self_outputs = self.self( + hidden_states_pre_layer_norm, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class RobertaPreLayerNormIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.LayerNorm(hidden_states) + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class RobertaPreLayerNormOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = hidden_states + input_tensor + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->RobertaPreLayerNorm +class RobertaPreLayerNormLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = RobertaPreLayerNormAttention(config) + self.is_decoder = config.is_decoder + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + if not self.is_decoder: + raise ValueError(f"{self} should be used as a decoder model if cross attention is added") + self.crossattention = RobertaPreLayerNormAttention(config, position_embedding_type="absolute") + self.intermediate = RobertaPreLayerNormIntermediate(config) + self.output = RobertaPreLayerNormOutput(config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + + # if decoder, the last output is tuple of self-attn cache + if self.is_decoder: + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + else: + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + cross_attn_present_key_value = None + if self.is_decoder and encoder_hidden_states is not None: + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" + " by setting `config.add_cross_attention=True`" + ) + + # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + cross_attn_past_key_value, + output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + + # add cross-attn cache to positions 3,4 of present_key_value tuple + cross_attn_present_key_value = cross_attention_outputs[-1] + present_key_value = present_key_value + cross_attn_present_key_value + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output + ) + outputs = (layer_output,) + outputs + + # if decoder, return the attn key/values as the last output + if self.is_decoder: + outputs = outputs + (present_key_value,) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +# Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->RobertaPreLayerNorm +class RobertaPreLayerNormEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList([RobertaPreLayerNormLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + next_decoder_cache = () if use_cache else None + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.__call__, + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +# Copied from transformers.models.bert.modeling_bert.BertPooler +class RobertaPreLayerNormPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class RobertaPreLayerNormPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = RobertaPreLayerNormConfig + base_model_prefix = "roberta_prelayernorm" + supports_gradient_checkpointing = True + _no_split_modules = ["RobertaPreLayerNormEmbeddings", "RobertaPreLayerNormSelfAttention"] + + # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights with BertLMPredictionHead->RobertaPreLayerNormLMHead + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, RobertaPreLayerNormLMHead): + module.bias.data.zero_() + + +ROBERTA_PRELAYERNORM_START_DOCSTRING = r""" + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`RobertaPreLayerNormConfig`]): Model configuration class with all the parameters of the + model. Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +ROBERTA_PRELAYERNORM_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + This parameter can only be used when the model is initialized with `type_vocab_size` parameter with value + >= 2. All the value in this tensor should be always < type_vocab_size. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare RoBERTa-PreLayerNorm Model transformer outputting raw hidden-states without any specific head on top.", + ROBERTA_PRELAYERNORM_START_DOCSTRING, +) +class RobertaPreLayerNormModel(RobertaPreLayerNormPreTrainedModel): + """ + + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in *Attention is + all you need*_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz + Kaiser and Illia Polosukhin. + + To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set + to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and + `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass. + + .. _*Attention is all you need*: https://arxiv.org/abs/1706.03762 + + """ + + def __init__(self, config, add_pooling_layer=True): + super().__init__(config) + self.config = config + + self.embeddings = RobertaPreLayerNormEmbeddings(config) + self.encoder = RobertaPreLayerNormEncoder(config) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.pooler = RobertaPreLayerNormPooler(config) if add_pooling_layer else None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(ROBERTA_PRELAYERNORM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPoolingAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.config.is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + + if token_type_ids is None: + if hasattr(self.embeddings, "token_type_ids"): + buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + sequence_output = self.LayerNorm(sequence_output) + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +@add_start_docstrings( + """RoBERTa-PreLayerNorm Model with a `language modeling` head on top for CLM fine-tuning.""", + ROBERTA_PRELAYERNORM_START_DOCSTRING, +) +# Copied from transformers.models.roberta.modeling_roberta.RobertaForCausalLM with FacebookAI/roberta-base->andreasmadsen/efficient_mlm_m0.40,ROBERTA->ROBERTA_PRELAYERNORM,Roberta->RobertaPreLayerNorm,roberta->roberta_prelayernorm, RobertaPreLayerNormTokenizer->RobertaTokenizer +class RobertaPreLayerNormForCausalLM(RobertaPreLayerNormPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + if not config.is_decoder: + logger.warning( + "If you want to use `RobertaPreLayerNormLMHeadModel` as a standalone, add `is_decoder=True.`" + ) + + self.roberta_prelayernorm = RobertaPreLayerNormModel(config, add_pooling_layer=False) + self.lm_head = RobertaPreLayerNormLMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.lm_head.decoder + + def set_output_embeddings(self, new_embeddings): + self.lm_head.decoder = new_embeddings + + @add_start_docstrings_to_model_forward(ROBERTA_PRELAYERNORM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + past_key_values: Tuple[Tuple[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are + ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, RobertaPreLayerNormForCausalLM, AutoConfig + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("andreasmadsen/efficient_mlm_m0.40") + >>> config = AutoConfig.from_pretrained("andreasmadsen/efficient_mlm_m0.40") + >>> config.is_decoder = True + >>> model = RobertaPreLayerNormForCausalLM.from_pretrained("andreasmadsen/efficient_mlm_m0.40", config=config) + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> prediction_logits = outputs.logits + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if labels is not None: + use_cache = False + + outputs = self.roberta_prelayernorm( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + prediction_scores = self.lm_head(sequence_output) + + lm_loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(prediction_scores.device) + lm_loss = self.loss_function( + prediction_scores, + labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def _reorder_cache(self, past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + + +@add_start_docstrings( + """RoBERTa-PreLayerNorm Model with a `language modeling` head on top.""", ROBERTA_PRELAYERNORM_START_DOCSTRING +) +class RobertaPreLayerNormForMaskedLM(RobertaPreLayerNormPreTrainedModel): + _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] + + # Copied from transformers.models.roberta.modeling_roberta.RobertaForMaskedLM.__init__ with ROBERTA->ROBERTA_PRELAYERNORM,Roberta->RobertaPreLayerNorm,roberta->roberta_prelayernorm + def __init__(self, config): + super().__init__(config) + + if config.is_decoder: + logger.warning( + "If you want to use `RobertaPreLayerNormForMaskedLM` make sure `config.is_decoder=False` for " + "bi-directional self-attention." + ) + + self.roberta_prelayernorm = RobertaPreLayerNormModel(config, add_pooling_layer=False) + self.lm_head = RobertaPreLayerNormLMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.lm_head.decoder + + def set_output_embeddings(self, new_embeddings): + self.lm_head.decoder = new_embeddings + + @add_start_docstrings_to_model_forward(ROBERTA_PRELAYERNORM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MaskedLMOutput, + config_class=_CONFIG_FOR_DOC, + mask="", + expected_output="' Paris'", + expected_loss=0.69, + ) + # Copied from transformers.models.roberta.modeling_roberta.RobertaForMaskedLM.forward with ROBERTA->ROBERTA_PRELAYERNORM,Roberta->RobertaPreLayerNorm,roberta->roberta_prelayernorm + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + kwargs (`Dict[str, any]`, *optional*, defaults to `{}`): + Used to hide legacy arguments that have been deprecated. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.roberta_prelayernorm( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + prediction_scores = self.lm_head(sequence_output) + + masked_lm_loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(prediction_scores.device) + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +# Copied from transformers.models.roberta.modeling_roberta.RobertaLMHead with Roberta->RobertaPreLayerNorm +class RobertaPreLayerNormLMHead(nn.Module): + """RobertaPreLayerNorm Head for masked language modeling.""" + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.decoder = nn.Linear(config.hidden_size, config.vocab_size) + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + self.decoder.bias = self.bias + + def forward(self, features, **kwargs): + x = self.dense(features) + x = gelu(x) + x = self.layer_norm(x) + + # project back to size of vocabulary with bias + x = self.decoder(x) + + return x + + def _tie_weights(self): + # To tie those two weights if they get disconnected (on TPU or when the bias is resized) + # For accelerate compatibility and to not break backward compatibility + if self.decoder.bias.device.type == "meta": + self.decoder.bias = self.bias + else: + self.bias = self.decoder.bias + + +@add_start_docstrings( + """ + RoBERTa-PreLayerNorm Model transformer with a sequence classification/regression head on top (a linear layer on top + of the pooled output) e.g. for GLUE tasks. + """, + ROBERTA_PRELAYERNORM_START_DOCSTRING, +) +class RobertaPreLayerNormForSequenceClassification(RobertaPreLayerNormPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + + self.roberta_prelayernorm = RobertaPreLayerNormModel(config, add_pooling_layer=False) + self.classifier = RobertaPreLayerNormClassificationHead(config) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(ROBERTA_PRELAYERNORM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=SequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + # Copied from transformers.models.roberta.modeling_roberta.RobertaForSequenceClassification.forward with roberta->roberta_prelayernorm + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.roberta_prelayernorm( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + RobertaPreLayerNorm Model with a multiple choice classification head on top (a linear layer on top of the pooled + output and a softmax) e.g. for RocStories/SWAG tasks. + """, + ROBERTA_PRELAYERNORM_START_DOCSTRING, +) +# Copied from transformers.models.roberta.modeling_roberta.RobertaForMultipleChoice with ROBERTA->ROBERTA_PRELAYERNORM,Roberta->RobertaPreLayerNorm,roberta->roberta_prelayernorm +class RobertaPreLayerNormForMultipleChoice(RobertaPreLayerNormPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.roberta_prelayernorm = RobertaPreLayerNormModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, 1) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward( + ROBERTA_PRELAYERNORM_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length") + ) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MultipleChoiceModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., + num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See + `input_ids` above) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] + + flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None + flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None + flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + flat_inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) + + outputs = self.roberta_prelayernorm( + flat_input_ids, + position_ids=flat_position_ids, + token_type_ids=flat_token_type_ids, + attention_mask=flat_attention_mask, + head_mask=head_mask, + inputs_embeds=flat_inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + reshaped_logits = logits.view(-1, num_choices) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(reshaped_logits.device) + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + + if not return_dict: + output = (reshaped_logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return MultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + RobertaPreLayerNorm Model with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + ROBERTA_PRELAYERNORM_START_DOCSTRING, +) +class RobertaPreLayerNormForTokenClassification(RobertaPreLayerNormPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.roberta_prelayernorm = RobertaPreLayerNormModel(config, add_pooling_layer=False) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(ROBERTA_PRELAYERNORM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + # Copied from transformers.models.roberta.modeling_roberta.RobertaForTokenClassification.forward with roberta->roberta_prelayernorm + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.roberta_prelayernorm( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +# Copied from transformers.models.roberta.modeling_roberta.RobertaClassificationHead with Roberta->RobertaPreLayerNorm +class RobertaPreLayerNormClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) + self.out_proj = nn.Linear(config.hidden_size, config.num_labels) + + def forward(self, features, **kwargs): + x = features[:, 0, :] # take token (equiv. to [CLS]) + x = self.dropout(x) + x = self.dense(x) + x = torch.tanh(x) + x = self.dropout(x) + x = self.out_proj(x) + return x + + +@add_start_docstrings( + """ + RobertaPreLayerNorm Model with a span classification head on top for extractive question-answering tasks like SQuAD + (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + ROBERTA_PRELAYERNORM_START_DOCSTRING, +) +class RobertaPreLayerNormForQuestionAnswering(RobertaPreLayerNormPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.roberta_prelayernorm = RobertaPreLayerNormModel(config, add_pooling_layer=False) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(ROBERTA_PRELAYERNORM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=QuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + # Copied from transformers.models.roberta.modeling_roberta.RobertaForQuestionAnswering.forward with roberta->roberta_prelayernorm + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.roberta_prelayernorm( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0): + """ + Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols + are ignored. This is modified from fairseq's `utils.make_positions`. + + Args: + x: torch.Tensor x: + + Returns: torch.Tensor + """ + # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. + mask = input_ids.ne(padding_idx).int() + incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask + return incremental_indices.long() + padding_idx + + +__all__ = [ + "RobertaPreLayerNormForCausalLM", + "RobertaPreLayerNormForMaskedLM", + "RobertaPreLayerNormForMultipleChoice", + "RobertaPreLayerNormForQuestionAnswering", + "RobertaPreLayerNormForSequenceClassification", + "RobertaPreLayerNormForTokenClassification", + "RobertaPreLayerNormModel", + "RobertaPreLayerNormPreTrainedModel", +] diff --git a/docs/transformers/src/transformers/models/roberta_prelayernorm/modeling_tf_roberta_prelayernorm.py b/docs/transformers/src/transformers/models/roberta_prelayernorm/modeling_tf_roberta_prelayernorm.py new file mode 100644 index 0000000000000000000000000000000000000000..a11f6151c4dba1dc801931da94f2c9e126a4bb1d --- /dev/null +++ b/docs/transformers/src/transformers/models/roberta_prelayernorm/modeling_tf_roberta_prelayernorm.py @@ -0,0 +1,1808 @@ +# coding=utf-8 +# Copyright 2022 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TF 2.0 RoBERTa-PreLayerNorm model.""" + +from __future__ import annotations + +import math +import warnings +from typing import Optional, Tuple, Union + +import numpy as np +import tensorflow as tf + +from ...activations_tf import get_tf_activation +from ...modeling_tf_outputs import ( + TFBaseModelOutputWithPastAndCrossAttentions, + TFBaseModelOutputWithPoolingAndCrossAttentions, + TFCausalLMOutputWithCrossAttentions, + TFMaskedLMOutput, + TFMultipleChoiceModelOutput, + TFQuestionAnsweringModelOutput, + TFSequenceClassifierOutput, + TFTokenClassifierOutput, +) +from ...modeling_tf_utils import ( + TFCausalLanguageModelingLoss, + TFMaskedLanguageModelingLoss, + TFModelInputType, + TFMultipleChoiceLoss, + TFPreTrainedModel, + TFQuestionAnsweringLoss, + TFSequenceClassificationLoss, + TFTokenClassificationLoss, + get_initializer, + keras, + keras_serializable, + unpack_inputs, +) +from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax +from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, +) +from .configuration_roberta_prelayernorm import RobertaPreLayerNormConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "andreasmadsen/efficient_mlm_m0.40" +_CONFIG_FOR_DOC = "RobertaPreLayerNormConfig" + + +# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaEmbeddings with Roberta->RobertaPreLayerNorm +class TFRobertaPreLayerNormEmbeddings(keras.layers.Layer): + """ + Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. + """ + + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + + self.padding_idx = 1 + self.config = config + self.hidden_size = config.hidden_size + self.max_position_embeddings = config.max_position_embeddings + self.initializer_range = config.initializer_range + self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") + self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) + + def build(self, input_shape=None): + with tf.name_scope("word_embeddings"): + self.weight = self.add_weight( + name="weight", + shape=[self.config.vocab_size, self.hidden_size], + initializer=get_initializer(self.initializer_range), + ) + + with tf.name_scope("token_type_embeddings"): + self.token_type_embeddings = self.add_weight( + name="embeddings", + shape=[self.config.type_vocab_size, self.hidden_size], + initializer=get_initializer(self.initializer_range), + ) + + with tf.name_scope("position_embeddings"): + self.position_embeddings = self.add_weight( + name="embeddings", + shape=[self.max_position_embeddings, self.hidden_size], + initializer=get_initializer(self.initializer_range), + ) + + if self.built: + return + self.built = True + if getattr(self, "LayerNorm", None) is not None: + with tf.name_scope(self.LayerNorm.name): + self.LayerNorm.build([None, None, self.config.hidden_size]) + + def create_position_ids_from_input_ids(self, input_ids, past_key_values_length=0): + """ + Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding + symbols are ignored. This is modified from fairseq's `utils.make_positions`. + + Args: + input_ids: tf.Tensor + Returns: tf.Tensor + """ + mask = tf.cast(tf.math.not_equal(input_ids, self.padding_idx), dtype=input_ids.dtype) + incremental_indices = (tf.math.cumsum(mask, axis=1) + past_key_values_length) * mask + + return incremental_indices + self.padding_idx + + def call( + self, + input_ids=None, + position_ids=None, + token_type_ids=None, + inputs_embeds=None, + past_key_values_length=0, + training=False, + ): + """ + Applies embedding based on inputs tensor. + + Returns: + final_embeddings (`tf.Tensor`): output embedding tensor. + """ + assert not (input_ids is None and inputs_embeds is None) + + if input_ids is not None: + check_embeddings_within_bounds(input_ids, self.config.vocab_size) + inputs_embeds = tf.gather(params=self.weight, indices=input_ids) + + input_shape = shape_list(inputs_embeds)[:-1] + + if token_type_ids is None: + token_type_ids = tf.fill(dims=input_shape, value=0) + + if position_ids is None: + if input_ids is not None: + # Create the position ids from the input token ids. Any padded tokens remain padded. + position_ids = self.create_position_ids_from_input_ids( + input_ids=input_ids, past_key_values_length=past_key_values_length + ) + else: + position_ids = tf.expand_dims( + tf.range(start=self.padding_idx + 1, limit=input_shape[-1] + self.padding_idx + 1), axis=0 + ) + + position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids) + token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids) + final_embeddings = inputs_embeds + position_embeds + token_type_embeds + final_embeddings = self.LayerNorm(inputs=final_embeddings) + final_embeddings = self.dropout(inputs=final_embeddings, training=training) + + return final_embeddings + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertPooler with Bert->RobertaPreLayerNorm +class TFRobertaPreLayerNormPooler(keras.layers.Layer): + def __init__(self, config: RobertaPreLayerNormConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = keras.layers.Dense( + units=config.hidden_size, + kernel_initializer=get_initializer(config.initializer_range), + activation="tanh", + name="dense", + ) + self.config = config + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(inputs=first_token_tensor) + + return pooled_output + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfAttention with Bert->RobertaPreLayerNorm +class TFRobertaPreLayerNormSelfAttention(keras.layers.Layer): + def __init__(self, config: RobertaPreLayerNormConfig, **kwargs): + super().__init__(**kwargs) + + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number " + f"of attention heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.sqrt_att_head_size = math.sqrt(self.attention_head_size) + + self.query = keras.layers.Dense( + units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query" + ) + self.key = keras.layers.Dense( + units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key" + ) + self.value = keras.layers.Dense( + units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value" + ) + self.dropout = keras.layers.Dropout(rate=config.attention_probs_dropout_prob) + + self.is_decoder = config.is_decoder + self.config = config + + def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor: + # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size] + tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size)) + + # Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size] + return tf.transpose(tensor, perm=[0, 2, 1, 3]) + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: tf.Tensor, + head_mask: tf.Tensor, + encoder_hidden_states: tf.Tensor, + encoder_attention_mask: tf.Tensor, + past_key_value: Tuple[tf.Tensor], + output_attentions: bool, + training: bool = False, + ) -> Tuple[tf.Tensor]: + batch_size = shape_list(hidden_states)[0] + mixed_query_layer = self.query(inputs=hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_layer = past_key_value[0] + value_layer = past_key_value[1] + attention_mask = encoder_attention_mask + elif is_cross_attention: + key_layer = self.transpose_for_scores(self.key(inputs=encoder_hidden_states), batch_size) + value_layer = self.transpose_for_scores(self.value(inputs=encoder_hidden_states), batch_size) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size) + value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size) + key_layer = tf.concat([past_key_value[0], key_layer], axis=2) + value_layer = tf.concat([past_key_value[1], value_layer], axis=2) + else: + key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size) + value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size) + + query_layer = self.transpose_for_scores(mixed_query_layer, batch_size) + + if self.is_decoder: + # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + # (batch size, num_heads, seq_len_q, seq_len_k) + attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) + dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype) + attention_scores = tf.divide(attention_scores, dk) + + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in TFRobertaPreLayerNormModel call() function) + attention_scores = tf.add(attention_scores, attention_mask) + + # Normalize the attention scores to probabilities. + attention_probs = stable_softmax(logits=attention_scores, axis=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(inputs=attention_probs, training=training) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = tf.multiply(attention_probs, head_mask) + + attention_output = tf.matmul(attention_probs, value_layer) + attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3]) + + # (batch_size, seq_len_q, all_head_size) + attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size)) + outputs = (attention_output, attention_probs) if output_attentions else (attention_output,) + + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "query", None) is not None: + with tf.name_scope(self.query.name): + self.query.build([None, None, self.config.hidden_size]) + if getattr(self, "key", None) is not None: + with tf.name_scope(self.key.name): + self.key.build([None, None, self.config.hidden_size]) + if getattr(self, "value", None) is not None: + with tf.name_scope(self.value.name): + self.value.build([None, None, self.config.hidden_size]) + + +class TFRobertaPreLayerNormSelfOutput(keras.layers.Layer): + def __init__(self, config: RobertaPreLayerNormConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = keras.layers.Dense( + units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) + self.config = config + + def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: + hidden_states = self.dense(inputs=hidden_states) + hidden_states = self.dropout(inputs=hidden_states, training=training) + hidden_states = hidden_states + input_tensor + + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + + +class TFRobertaPreLayerNormAttention(keras.layers.Layer): + def __init__(self, config: RobertaPreLayerNormConfig, **kwargs): + super().__init__(**kwargs) + + self.self_attention = TFRobertaPreLayerNormSelfAttention(config, name="self") + self.dense_output = TFRobertaPreLayerNormSelfOutput(config, name="output") + self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") + self.config = config + + # Copied from transformers.models.bert.modeling_tf_bert.TFBertAttention.prune_heads + def prune_heads(self, heads): + raise NotImplementedError + + def call( + self, + input_tensor: tf.Tensor, + attention_mask: tf.Tensor, + head_mask: tf.Tensor, + encoder_hidden_states: tf.Tensor, + encoder_attention_mask: tf.Tensor, + past_key_value: Tuple[tf.Tensor], + output_attentions: bool, + training: bool = False, + ) -> Tuple[tf.Tensor]: + hidden_states_pre_layer_norm = self.LayerNorm(inputs=input_tensor) + self_outputs = self.self_attention( + hidden_states=hidden_states_pre_layer_norm, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + training=training, + ) + attention_output = self.dense_output( + hidden_states=self_outputs[0], input_tensor=input_tensor, training=training + ) + # add attentions (possibly with past_key_value) if we output them + outputs = (attention_output,) + self_outputs[1:] + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "self_attention", None) is not None: + with tf.name_scope(self.self_attention.name): + self.self_attention.build(None) + if getattr(self, "dense_output", None) is not None: + with tf.name_scope(self.dense_output.name): + self.dense_output.build(None) + if getattr(self, "LayerNorm", None) is not None: + with tf.name_scope(self.LayerNorm.name): + self.LayerNorm.build([None, None, self.config.hidden_size]) + + +class TFRobertaPreLayerNormIntermediate(keras.layers.Layer): + def __init__(self, config: RobertaPreLayerNormConfig, **kwargs): + super().__init__(**kwargs) + + self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") + self.dense = keras.layers.Dense( + units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = get_tf_activation(config.hidden_act) + else: + self.intermediate_act_fn = config.hidden_act + self.config = config + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + hidden_states = self.LayerNorm(inputs=hidden_states) + hidden_states = self.dense(inputs=hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "LayerNorm", None) is not None: + with tf.name_scope(self.LayerNorm.name): + self.LayerNorm.build([None, None, self.config.hidden_size]) + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + + +class TFRobertaPreLayerNormOutput(keras.layers.Layer): + def __init__(self, config: RobertaPreLayerNormConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = keras.layers.Dense( + units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) + self.config = config + + def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: + hidden_states = self.dense(inputs=hidden_states) + hidden_states = self.dropout(inputs=hidden_states, training=training) + hidden_states = hidden_states + input_tensor + + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.intermediate_size]) + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertLayer with Bert->RobertaPreLayerNorm +class TFRobertaPreLayerNormLayer(keras.layers.Layer): + def __init__(self, config: RobertaPreLayerNormConfig, **kwargs): + super().__init__(**kwargs) + + self.attention = TFRobertaPreLayerNormAttention(config, name="attention") + self.is_decoder = config.is_decoder + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + if not self.is_decoder: + raise ValueError(f"{self} should be used as a decoder model if cross attention is added") + self.crossattention = TFRobertaPreLayerNormAttention(config, name="crossattention") + self.intermediate = TFRobertaPreLayerNormIntermediate(config, name="intermediate") + self.bert_output = TFRobertaPreLayerNormOutput(config, name="output") + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: tf.Tensor, + head_mask: tf.Tensor, + encoder_hidden_states: tf.Tensor | None, + encoder_attention_mask: tf.Tensor | None, + past_key_value: Tuple[tf.Tensor] | None, + output_attentions: bool, + training: bool = False, + ) -> Tuple[tf.Tensor]: + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attention_outputs = self.attention( + input_tensor=hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=self_attn_past_key_value, + output_attentions=output_attentions, + training=training, + ) + attention_output = self_attention_outputs[0] + + # if decoder, the last output is tuple of self-attn cache + if self.is_decoder: + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + else: + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + cross_attn_present_key_value = None + if self.is_decoder and encoder_hidden_states is not None: + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" + " by setting `config.add_cross_attention=True`" + ) + + # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + cross_attention_outputs = self.crossattention( + input_tensor=attention_output, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=cross_attn_past_key_value, + output_attentions=output_attentions, + training=training, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + + # add cross-attn cache to positions 3,4 of present_key_value tuple + cross_attn_present_key_value = cross_attention_outputs[-1] + present_key_value = present_key_value + cross_attn_present_key_value + + intermediate_output = self.intermediate(hidden_states=attention_output) + layer_output = self.bert_output( + hidden_states=intermediate_output, input_tensor=attention_output, training=training + ) + outputs = (layer_output,) + outputs # add attentions if we output them + + # if decoder, return the attn key/values as the last output + if self.is_decoder: + outputs = outputs + (present_key_value,) + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "attention", None) is not None: + with tf.name_scope(self.attention.name): + self.attention.build(None) + if getattr(self, "intermediate", None) is not None: + with tf.name_scope(self.intermediate.name): + self.intermediate.build(None) + if getattr(self, "bert_output", None) is not None: + with tf.name_scope(self.bert_output.name): + self.bert_output.build(None) + if getattr(self, "crossattention", None) is not None: + with tf.name_scope(self.crossattention.name): + self.crossattention.build(None) + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertEncoder with Bert->RobertaPreLayerNorm +class TFRobertaPreLayerNormEncoder(keras.layers.Layer): + def __init__(self, config: RobertaPreLayerNormConfig, **kwargs): + super().__init__(**kwargs) + self.config = config + self.layer = [TFRobertaPreLayerNormLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)] + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: tf.Tensor, + head_mask: tf.Tensor, + encoder_hidden_states: tf.Tensor | None, + encoder_attention_mask: tf.Tensor | None, + past_key_values: Tuple[Tuple[tf.Tensor]] | None, + use_cache: Optional[bool], + output_attentions: bool, + output_hidden_states: bool, + return_dict: bool, + training: bool = False, + ) -> Union[TFBaseModelOutputWithPastAndCrossAttentions, Tuple[tf.Tensor]]: + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + next_decoder_cache = () if use_cache else None + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + past_key_value = past_key_values[i] if past_key_values is not None else None + + layer_outputs = layer_module( + hidden_states=hidden_states, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + training=training, + ) + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + if self.config.add_cross_attention and encoder_hidden_states is not None: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v for v in [hidden_states, all_hidden_states, all_attentions, all_cross_attentions] if v is not None + ) + + return TFBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "layer", None) is not None: + for layer in self.layer: + with tf.name_scope(layer.name): + layer.build(None) + + +@keras_serializable +class TFRobertaPreLayerNormMainLayer(keras.layers.Layer): + config_class = RobertaPreLayerNormConfig + + def __init__(self, config, add_pooling_layer=True, **kwargs): + super().__init__(**kwargs) + + self.config = config + self.is_decoder = config.is_decoder + + self.num_hidden_layers = config.num_hidden_layers + self.initializer_range = config.initializer_range + self.output_attentions = config.output_attentions + self.output_hidden_states = config.output_hidden_states + self.return_dict = config.use_return_dict + self.encoder = TFRobertaPreLayerNormEncoder(config, name="encoder") + self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") + self.pooler = TFRobertaPreLayerNormPooler(config, name="pooler") if add_pooling_layer else None + # The embeddings must be the last declaration in order to follow the weights order + self.embeddings = TFRobertaPreLayerNormEmbeddings(config, name="embeddings") + + def get_input_embeddings(self) -> keras.layers.Layer: + return self.embeddings + + def set_input_embeddings(self, value: tf.Variable): + self.embeddings.weight = value + self.embeddings.vocab_size = shape_list(value)[0] + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + raise NotImplementedError + + @unpack_inputs + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + encoder_hidden_states: np.ndarray | tf.Tensor | None = None, + encoder_attention_mask: np.ndarray | tf.Tensor | None = None, + past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: bool = False, + ) -> Union[TFBaseModelOutputWithPoolingAndCrossAttentions, Tuple[tf.Tensor]]: + if not self.config.is_decoder: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = shape_list(input_ids) + elif inputs_embeds is not None: + input_shape = shape_list(inputs_embeds)[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + + if past_key_values is None: + past_key_values_length = 0 + past_key_values = [None] * len(self.encoder.layer) + else: + past_key_values_length = shape_list(past_key_values[0][0])[-2] + + if attention_mask is None: + attention_mask = tf.fill(dims=(batch_size, seq_length + past_key_values_length), value=1) + + if token_type_ids is None: + token_type_ids = tf.fill(dims=input_shape, value=0) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + training=training, + ) + + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask_shape = shape_list(attention_mask) + + mask_seq_length = seq_length + past_key_values_length + # Provided a padding mask of dimensions [batch_size, mask_seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] + if self.is_decoder: + seq_ids = tf.range(mask_seq_length) + causal_mask = tf.less_equal( + tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)), + seq_ids[None, :, None], + ) + causal_mask = tf.cast(causal_mask, dtype=attention_mask.dtype) + extended_attention_mask = causal_mask * attention_mask[:, None, :] + attention_mask_shape = shape_list(extended_attention_mask) + extended_attention_mask = tf.reshape( + extended_attention_mask, (attention_mask_shape[0], 1, attention_mask_shape[1], attention_mask_shape[2]) + ) + if past_key_values[0] is not None: + # attention_mask needs to be sliced to the shape `[batch_size, 1, from_seq_length - cached_seq_length, to_seq_length] + extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :] + else: + extended_attention_mask = tf.reshape( + attention_mask, (attention_mask_shape[0], 1, 1, attention_mask_shape[1]) + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = tf.cast(extended_attention_mask, dtype=embedding_output.dtype) + one_cst = tf.constant(1.0, dtype=embedding_output.dtype) + ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype) + extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst) + + if self.is_decoder and encoder_attention_mask is not None: + # If a 2D ou 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + encoder_attention_mask = tf.cast(encoder_attention_mask, dtype=extended_attention_mask.dtype) + num_dims_encoder_attention_mask = len(shape_list(encoder_attention_mask)) + if num_dims_encoder_attention_mask == 3: + encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] + if num_dims_encoder_attention_mask == 2: + encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] + + # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition + # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270 + # encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask, + # tf.transpose(encoder_extended_attention_mask, perm=(-1, -2))) + + encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0 + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + if head_mask is not None: + raise NotImplementedError + else: + head_mask = [None] * self.config.num_hidden_layers + + encoder_outputs = self.encoder( + hidden_states=embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + sequence_output = encoder_outputs[0] + sequence_output = self.LayerNorm(inputs=sequence_output) + pooled_output = self.pooler(hidden_states=sequence_output) if self.pooler is not None else None + + if not return_dict: + return ( + sequence_output, + pooled_output, + ) + encoder_outputs[1:] + + return TFBaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "encoder", None) is not None: + with tf.name_scope(self.encoder.name): + self.encoder.build(None) + if getattr(self, "LayerNorm", None) is not None: + with tf.name_scope(self.LayerNorm.name): + self.LayerNorm.build([None, None, self.config.hidden_size]) + if getattr(self, "pooler", None) is not None: + with tf.name_scope(self.pooler.name): + self.pooler.build(None) + if getattr(self, "embeddings", None) is not None: + with tf.name_scope(self.embeddings.name): + self.embeddings.build(None) + + +# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaPreTrainedModel with Roberta->RobertaPreLayerNorm,roberta->roberta_prelayernorm +class TFRobertaPreLayerNormPreTrainedModel(TFPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = RobertaPreLayerNormConfig + base_model_prefix = "roberta_prelayernorm" + + +ROBERTA_PRELAYERNORM_START_DOCSTRING = r""" + + This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it + as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and + behavior. + + + + TensorFlow models and layers in `transformers` accept two formats as input: + + - having all inputs as keyword arguments (like PyTorch models), or + - having all inputs as a list, tuple or dict in the first positional argument. + + The reason the second format is supported is that Keras methods prefer this format when passing inputs to models + and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just + pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second + format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with + the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first + positional argument: + + - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` + - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: + `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` + - a dictionary with one or several input Tensors associated to the input names given in the docstring: + `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` + + Note that when creating models and layers with + [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry + about any of this, as you can just pass inputs like you would to any other Python function! + + + + Parameters: + config ([`RobertaPreLayerNormConfig`]): Model configuration class with all the parameters of the + model. Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +ROBERTA_PRELAYERNORM_INPUTS_DOCSTRING = r""" + Args: + input_ids (`Numpy array` or `tf.Tensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and + [`PreTrainedTokenizer.encode`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`Numpy array` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`tf.Tensor` of shape `({0}, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the + config will be used instead. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. This argument can be used only in eager mode, in graph mode the value in the config will be + used instead. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in + eager mode, in graph mode the value will always be set to True. + training (`bool`, *optional*, defaults to `False`): + Whether or not to use the model in training mode (some modules like dropout modules have different + behaviors between training and evaluation). +""" + + +@add_start_docstrings( + "The bare RoBERTa-PreLayerNorm Model transformer outputting raw hidden-states without any specific head on top.", + ROBERTA_PRELAYERNORM_START_DOCSTRING, +) +# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaModel with ROBERTA->ROBERTA_PRELAYERNORM,Roberta->RobertaPreLayerNorm,roberta->roberta_prelayernorm +class TFRobertaPreLayerNormModel(TFRobertaPreLayerNormPreTrainedModel): + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.roberta_prelayernorm = TFRobertaPreLayerNormMainLayer(config, name="roberta_prelayernorm") + + @unpack_inputs + @add_start_docstrings_to_model_forward(ROBERTA_PRELAYERNORM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFBaseModelOutputWithPoolingAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + encoder_hidden_states: np.ndarray | tf.Tensor | None = None, + encoder_attention_mask: np.ndarray | tf.Tensor | None = None, + past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: Optional[bool] = False, + ) -> Union[Tuple, TFBaseModelOutputWithPoolingAndCrossAttentions]: + r""" + encoder_hidden_states (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + past_key_values (`Tuple[Tuple[tf.Tensor]]` of length `config.n_layers`) + contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*, defaults to `True`): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). Set to `False` during training, `True` during generation + """ + outputs = self.roberta_prelayernorm( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "roberta_prelayernorm", None) is not None: + with tf.name_scope(self.roberta_prelayernorm.name): + self.roberta_prelayernorm.build(None) + + +# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaLMHead with Roberta->RobertaPreLayerNorm +class TFRobertaPreLayerNormLMHead(keras.layers.Layer): + """RobertaPreLayerNorm Head for masked language modeling.""" + + def __init__(self, config, input_embeddings, **kwargs): + super().__init__(**kwargs) + + self.config = config + self.hidden_size = config.hidden_size + self.dense = keras.layers.Dense( + config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm") + self.act = get_tf_activation("gelu") + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = input_embeddings + + def build(self, input_shape=None): + self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias") + + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + if getattr(self, "layer_norm", None) is not None: + with tf.name_scope(self.layer_norm.name): + self.layer_norm.build([None, None, self.config.hidden_size]) + + def get_output_embeddings(self): + return self.decoder + + def set_output_embeddings(self, value): + self.decoder.weight = value + self.decoder.vocab_size = shape_list(value)[0] + + def get_bias(self): + return {"bias": self.bias} + + def set_bias(self, value): + self.bias = value["bias"] + self.config.vocab_size = shape_list(value["bias"])[0] + + def call(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.layer_norm(hidden_states) + + # project back to size of vocabulary with bias + seq_length = shape_list(tensor=hidden_states)[1] + hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.hidden_size]) + hidden_states = tf.matmul(a=hidden_states, b=self.decoder.weight, transpose_b=True) + hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.config.vocab_size]) + hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.bias) + + return hidden_states + + +@add_start_docstrings( + """RoBERTa-PreLayerNorm Model with a `language modeling` head on top.""", ROBERTA_PRELAYERNORM_START_DOCSTRING +) +class TFRobertaPreLayerNormForMaskedLM(TFRobertaPreLayerNormPreTrainedModel, TFMaskedLanguageModelingLoss): + # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model + _keys_to_ignore_on_load_unexpected = [r"pooler", r"lm_head.decoder.weight"] + + # Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaForMaskedLM.__init__ with ROBERTA->ROBERTA_PRELAYERNORM,Roberta->RobertaPreLayerNorm,roberta->roberta_prelayernorm + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.roberta_prelayernorm = TFRobertaPreLayerNormMainLayer( + config, add_pooling_layer=False, name="roberta_prelayernorm" + ) + self.lm_head = TFRobertaPreLayerNormLMHead(config, self.roberta_prelayernorm.embeddings, name="lm_head") + + def get_lm_head(self): + return self.lm_head + + def get_prefix_bias_name(self): + warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) + return self.name + "/" + self.lm_head.name + + @unpack_inputs + @add_start_docstrings_to_model_forward(ROBERTA_PRELAYERNORM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFMaskedLMOutput, + config_class=_CONFIG_FOR_DOC, + mask="", + expected_output="' Paris'", + expected_loss=0.69, + ) + # Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaForMaskedLM.call with ROBERTA->ROBERTA_PRELAYERNORM,Roberta->RobertaPreLayerNorm,roberta->roberta_prelayernorm + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: np.ndarray | tf.Tensor | None = None, + training: Optional[bool] = False, + ) -> Union[TFMaskedLMOutput, Tuple[tf.Tensor]]: + r""" + labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + """ + outputs = self.roberta_prelayernorm( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + sequence_output = outputs[0] + prediction_scores = self.lm_head(sequence_output) + + loss = None if labels is None else self.hf_compute_loss(labels, prediction_scores) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFMaskedLMOutput( + loss=loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "roberta_prelayernorm", None) is not None: + with tf.name_scope(self.roberta_prelayernorm.name): + self.roberta_prelayernorm.build(None) + if getattr(self, "lm_head", None) is not None: + with tf.name_scope(self.lm_head.name): + self.lm_head.build(None) + + +# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaForCausalLM with ROBERTA->ROBERTA_PRELAYERNORM,Roberta->RobertaPreLayerNorm,roberta->roberta_prelayernorm +class TFRobertaPreLayerNormForCausalLM(TFRobertaPreLayerNormPreTrainedModel, TFCausalLanguageModelingLoss): + # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model + _keys_to_ignore_on_load_unexpected = [r"pooler", r"lm_head.decoder.weight"] + + def __init__(self, config: RobertaPreLayerNormConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + if not config.is_decoder: + logger.warning( + "If you want to use `TFRobertaPreLayerNormLMHeadModel` as a standalone, add `is_decoder=True.`" + ) + + self.roberta_prelayernorm = TFRobertaPreLayerNormMainLayer( + config, add_pooling_layer=False, name="roberta_prelayernorm" + ) + self.lm_head = TFRobertaPreLayerNormLMHead( + config, input_embeddings=self.roberta_prelayernorm.embeddings, name="lm_head" + ) + + def get_lm_head(self): + return self.lm_head + + def get_prefix_bias_name(self): + warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) + return self.name + "/" + self.lm_head.name + + # Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel.prepare_inputs_for_generation + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs): + input_shape = input_ids.shape + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = tf.ones(input_shape) + + # cut decoder_input_ids if past is used + if past_key_values is not None: + input_ids = input_ids[:, -1:] + + return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} + + @unpack_inputs + @add_start_docstrings_to_model_forward(ROBERTA_PRELAYERNORM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFCausalLMOutputWithCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + encoder_hidden_states: np.ndarray | tf.Tensor | None = None, + encoder_attention_mask: np.ndarray | tf.Tensor | None = None, + past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: np.ndarray | tf.Tensor | None = None, + training: Optional[bool] = False, + ) -> Union[TFCausalLMOutputWithCrossAttentions, Tuple[tf.Tensor]]: + r""" + encoder_hidden_states (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + past_key_values (`Tuple[Tuple[tf.Tensor]]` of length `config.n_layers`) + contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*, defaults to `True`): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). Set to `False` during training, `True` during generation + labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the cross entropy classification loss. Indices should be in `[0, ..., + config.vocab_size - 1]`. + """ + outputs = self.roberta_prelayernorm( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + sequence_output = outputs[0] + logits = self.lm_head(hidden_states=sequence_output, training=training) + loss = None + + if labels is not None: + # shift labels to the left and cut last logit token + shifted_logits = logits[:, :-1] + labels = labels[:, 1:] + loss = self.hf_compute_loss(labels=labels, logits=shifted_logits) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFCausalLMOutputWithCrossAttentions( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "roberta_prelayernorm", None) is not None: + with tf.name_scope(self.roberta_prelayernorm.name): + self.roberta_prelayernorm.build(None) + if getattr(self, "lm_head", None) is not None: + with tf.name_scope(self.lm_head.name): + self.lm_head.build(None) + + +# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaClassificationHead with Roberta->RobertaPreLayerNorm +class TFRobertaPreLayerNormClassificationHead(keras.layers.Layer): + """Head for sentence-level classification tasks.""" + + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + self.dense = keras.layers.Dense( + config.hidden_size, + kernel_initializer=get_initializer(config.initializer_range), + activation="tanh", + name="dense", + ) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = keras.layers.Dropout(classifier_dropout) + self.out_proj = keras.layers.Dense( + config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="out_proj" + ) + self.config = config + + def call(self, features, training=False): + x = features[:, 0, :] # take token (equiv. to [CLS]) + x = self.dropout(x, training=training) + x = self.dense(x) + x = self.dropout(x, training=training) + x = self.out_proj(x) + return x + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + if getattr(self, "out_proj", None) is not None: + with tf.name_scope(self.out_proj.name): + self.out_proj.build([None, None, self.config.hidden_size]) + + +@add_start_docstrings( + """ + RoBERTa-PreLayerNorm Model transformer with a sequence classification/regression head on top (a linear layer on top + of the pooled output) e.g. for GLUE tasks. + """, + ROBERTA_PRELAYERNORM_START_DOCSTRING, +) +class TFRobertaPreLayerNormForSequenceClassification( + TFRobertaPreLayerNormPreTrainedModel, TFSequenceClassificationLoss +): + # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model + _keys_to_ignore_on_load_unexpected = [r"pooler", r"lm_head"] + + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.num_labels = config.num_labels + + self.roberta_prelayernorm = TFRobertaPreLayerNormMainLayer( + config, add_pooling_layer=False, name="roberta_prelayernorm" + ) + self.classifier = TFRobertaPreLayerNormClassificationHead(config, name="classifier") + + @unpack_inputs + @add_start_docstrings_to_model_forward(ROBERTA_PRELAYERNORM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFSequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + # Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaForSequenceClassification.call with roberta->roberta_prelayernorm + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: np.ndarray | tf.Tensor | None = None, + training: Optional[bool] = False, + ) -> Union[TFSequenceClassifierOutput, Tuple[tf.Tensor]]: + r""" + labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + outputs = self.roberta_prelayernorm( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + sequence_output = outputs[0] + logits = self.classifier(sequence_output, training=training) + + loss = None if labels is None else self.hf_compute_loss(labels, logits) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFSequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "roberta_prelayernorm", None) is not None: + with tf.name_scope(self.roberta_prelayernorm.name): + self.roberta_prelayernorm.build(None) + if getattr(self, "classifier", None) is not None: + with tf.name_scope(self.classifier.name): + self.classifier.build(None) + + +@add_start_docstrings( + """ + RobertaPreLayerNorm Model with a multiple choice classification head on top (a linear layer on top of the pooled + output and a softmax) e.g. for RocStories/SWAG tasks. + """, + ROBERTA_PRELAYERNORM_START_DOCSTRING, +) +# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaForMultipleChoice with ROBERTA->ROBERTA_PRELAYERNORM,Roberta->RobertaPreLayerNorm,roberta->roberta_prelayernorm +class TFRobertaPreLayerNormForMultipleChoice(TFRobertaPreLayerNormPreTrainedModel, TFMultipleChoiceLoss): + # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model + _keys_to_ignore_on_load_unexpected = [r"lm_head"] + _keys_to_ignore_on_load_missing = [r"dropout"] + + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.roberta_prelayernorm = TFRobertaPreLayerNormMainLayer(config, name="roberta_prelayernorm") + self.dropout = keras.layers.Dropout(config.hidden_dropout_prob) + self.classifier = keras.layers.Dense( + 1, kernel_initializer=get_initializer(config.initializer_range), name="classifier" + ) + self.config = config + + @unpack_inputs + @add_start_docstrings_to_model_forward( + ROBERTA_PRELAYERNORM_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length") + ) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFMultipleChoiceModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: np.ndarray | tf.Tensor | None = None, + training: Optional[bool] = False, + ) -> Union[TFMultipleChoiceModelOutput, Tuple[tf.Tensor]]: + r""" + labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` + where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above) + """ + + if input_ids is not None: + num_choices = shape_list(input_ids)[1] + seq_length = shape_list(input_ids)[2] + else: + num_choices = shape_list(inputs_embeds)[1] + seq_length = shape_list(inputs_embeds)[2] + + flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None + flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None + flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None + flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None + outputs = self.roberta_prelayernorm( + flat_input_ids, + flat_attention_mask, + flat_token_type_ids, + flat_position_ids, + head_mask, + inputs_embeds, + output_attentions, + output_hidden_states, + return_dict=return_dict, + training=training, + ) + pooled_output = outputs[1] + pooled_output = self.dropout(pooled_output, training=training) + logits = self.classifier(pooled_output) + reshaped_logits = tf.reshape(logits, (-1, num_choices)) + + loss = None if labels is None else self.hf_compute_loss(labels, reshaped_logits) + + if not return_dict: + output = (reshaped_logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFMultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "roberta_prelayernorm", None) is not None: + with tf.name_scope(self.roberta_prelayernorm.name): + self.roberta_prelayernorm.build(None) + if getattr(self, "classifier", None) is not None: + with tf.name_scope(self.classifier.name): + self.classifier.build([None, None, self.config.hidden_size]) + + +@add_start_docstrings( + """ + RoBERTa-PreLayerNorm Model with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + ROBERTA_PRELAYERNORM_START_DOCSTRING, +) +class TFRobertaPreLayerNormForTokenClassification(TFRobertaPreLayerNormPreTrainedModel, TFTokenClassificationLoss): + # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model + _keys_to_ignore_on_load_unexpected = [r"pooler", r"lm_head"] + _keys_to_ignore_on_load_missing = [r"dropout"] + + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.num_labels = config.num_labels + + self.roberta_prelayernorm = TFRobertaPreLayerNormMainLayer( + config, add_pooling_layer=False, name="roberta_prelayernorm" + ) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = keras.layers.Dropout(classifier_dropout) + self.classifier = keras.layers.Dense( + config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" + ) + self.config = config + + @unpack_inputs + @add_start_docstrings_to_model_forward(ROBERTA_PRELAYERNORM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFTokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + # Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaForTokenClassification.call with roberta->roberta_prelayernorm + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: np.ndarray | tf.Tensor | None = None, + training: Optional[bool] = False, + ) -> Union[TFTokenClassifierOutput, Tuple[tf.Tensor]]: + r""" + labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + outputs = self.roberta_prelayernorm( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output, training=training) + logits = self.classifier(sequence_output) + + loss = None if labels is None else self.hf_compute_loss(labels, logits) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFTokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "roberta_prelayernorm", None) is not None: + with tf.name_scope(self.roberta_prelayernorm.name): + self.roberta_prelayernorm.build(None) + if getattr(self, "classifier", None) is not None: + with tf.name_scope(self.classifier.name): + self.classifier.build([None, None, self.config.hidden_size]) + + +@add_start_docstrings( + """ + RoBERTa-PreLayerNorm Model with a span classification head on top for extractive question-answering tasks like + SQuAD (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + ROBERTA_PRELAYERNORM_START_DOCSTRING, +) +class TFRobertaPreLayerNormForQuestionAnswering(TFRobertaPreLayerNormPreTrainedModel, TFQuestionAnsweringLoss): + # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model + _keys_to_ignore_on_load_unexpected = [r"pooler", r"lm_head"] + + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.num_labels = config.num_labels + + self.roberta_prelayernorm = TFRobertaPreLayerNormMainLayer( + config, add_pooling_layer=False, name="roberta_prelayernorm" + ) + self.qa_outputs = keras.layers.Dense( + config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs" + ) + self.config = config + + @unpack_inputs + @add_start_docstrings_to_model_forward(ROBERTA_PRELAYERNORM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFQuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + # Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaForQuestionAnswering.call with roberta->roberta_prelayernorm + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + start_positions: np.ndarray | tf.Tensor | None = None, + end_positions: np.ndarray | tf.Tensor | None = None, + training: Optional[bool] = False, + ) -> Union[TFQuestionAnsweringModelOutput, Tuple[tf.Tensor]]: + r""" + start_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + outputs = self.roberta_prelayernorm( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = tf.split(logits, 2, axis=-1) + start_logits = tf.squeeze(start_logits, axis=-1) + end_logits = tf.squeeze(end_logits, axis=-1) + + loss = None + if start_positions is not None and end_positions is not None: + labels = {"start_position": start_positions} + labels["end_position"] = end_positions + loss = self.hf_compute_loss(labels, (start_logits, end_logits)) + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFQuestionAnsweringModelOutput( + loss=loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "roberta_prelayernorm", None) is not None: + with tf.name_scope(self.roberta_prelayernorm.name): + self.roberta_prelayernorm.build(None) + if getattr(self, "qa_outputs", None) is not None: + with tf.name_scope(self.qa_outputs.name): + self.qa_outputs.build([None, None, self.config.hidden_size]) + + +__all__ = [ + "TFRobertaPreLayerNormForCausalLM", + "TFRobertaPreLayerNormForMaskedLM", + "TFRobertaPreLayerNormForMultipleChoice", + "TFRobertaPreLayerNormForQuestionAnswering", + "TFRobertaPreLayerNormForSequenceClassification", + "TFRobertaPreLayerNormForTokenClassification", + "TFRobertaPreLayerNormMainLayer", + "TFRobertaPreLayerNormModel", + "TFRobertaPreLayerNormPreTrainedModel", +] diff --git a/docs/transformers/src/transformers/models/roc_bert/__init__.py b/docs/transformers/src/transformers/models/roc_bert/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..eb847b03b75587df50ebb17140401da0be4dea22 --- /dev/null +++ b/docs/transformers/src/transformers/models/roc_bert/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_roc_bert import * + from .modeling_roc_bert import * + from .tokenization_roc_bert import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/roc_bert/configuration_roc_bert.py b/docs/transformers/src/transformers/models/roc_bert/configuration_roc_bert.py new file mode 100644 index 0000000000000000000000000000000000000000..4bf53bf3384268523e518128e2ee4945d6715328 --- /dev/null +++ b/docs/transformers/src/transformers/models/roc_bert/configuration_roc_bert.py @@ -0,0 +1,163 @@ +# coding=utf-8 +# Copyright 2022 WeChatAI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""RoCBert model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class RoCBertConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`RoCBertModel`]. It is used to instantiate a + RoCBert model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the RoCBert + [weiweishi/roc-bert-base-zh](https://huggingface.co/weiweishi/roc-bert-base-zh) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 30522): + Vocabulary size of the RoCBert model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`RoCBertModel`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimension of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + type_vocab_size (`int`, *optional*, defaults to 2): + The vocabulary size of the `token_type_ids` passed when calling [`RoCBertModel`]. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + is_decoder (`bool`, *optional*, defaults to `False`): + Whether the model is used as a decoder or not. If `False`, the model is used as an encoder. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + position_embedding_type (`str`, *optional*, defaults to `"absolute"`): + Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For + positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to + [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155). + For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models + with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658). + classifier_dropout (`float`, *optional*): + The dropout ratio for the classification head. + enable_pronunciation (`bool`, *optional*, defaults to `True`): + Whether or not the model use pronunciation embed when training. + enable_shape (`bool`, *optional*, defaults to `True`): + Whether or not the model use shape embed when training. + pronunciation_embed_dim (`int`, *optional*, defaults to 768): + Dimension of the pronunciation_embed. + pronunciation_vocab_size (`int`, *optional*, defaults to 910): + Pronunciation Vocabulary size of the RoCBert model. Defines the number of different tokens that can be + represented by the `input_pronunciation_ids` passed when calling [`RoCBertModel`]. + shape_embed_dim (`int`, *optional*, defaults to 512): + Dimension of the shape_embed. + shape_vocab_size (`int`, *optional*, defaults to 24858): + Shape Vocabulary size of the RoCBert model. Defines the number of different tokens that can be represented + by the `input_shape_ids` passed when calling [`RoCBertModel`]. + concat_input (`bool`, *optional*, defaults to `True`): + Defines the way of merging the shape_embed, pronunciation_embed and word_embed, if the value is true, + output_embed = torch.cat((word_embed, shape_embed, pronunciation_embed), -1), else output_embed = + (word_embed + shape_embed + pronunciation_embed) / 3 + Example: + + ```python + >>> from transformers import RoCBertModel, RoCBertConfig + + >>> # Initializing a RoCBert weiweishi/roc-bert-base-zh style configuration + >>> configuration = RoCBertConfig() + + >>> # Initializing a model from the weiweishi/roc-bert-base-zh style configuration + >>> model = RoCBertModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "roc_bert" + + def __init__( + self, + vocab_size=30522, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + layer_norm_eps=1e-12, + use_cache=True, + pad_token_id=0, + position_embedding_type="absolute", + classifier_dropout=None, + enable_pronunciation=True, + enable_shape=True, + pronunciation_embed_dim=768, + pronunciation_vocab_size=910, + shape_embed_dim=512, + shape_vocab_size=24858, + concat_input=True, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.initializer_range = initializer_range + self.type_vocab_size = type_vocab_size + self.layer_norm_eps = layer_norm_eps + self.use_cache = use_cache + self.enable_pronunciation = enable_pronunciation + self.enable_shape = enable_shape + self.pronunciation_embed_dim = pronunciation_embed_dim + self.pronunciation_vocab_size = pronunciation_vocab_size + self.shape_embed_dim = shape_embed_dim + self.shape_vocab_size = shape_vocab_size + self.concat_input = concat_input + self.position_embedding_type = position_embedding_type + self.classifier_dropout = classifier_dropout + super().__init__(pad_token_id=pad_token_id, **kwargs) + + +__all__ = ["RoCBertConfig"] diff --git a/docs/transformers/src/transformers/models/roc_bert/modeling_roc_bert.py b/docs/transformers/src/transformers/models/roc_bert/modeling_roc_bert.py new file mode 100644 index 0000000000000000000000000000000000000000..b5ca264fb73d3e28b576aa519a0c0d330baed152 --- /dev/null +++ b/docs/transformers/src/transformers/models/roc_bert/modeling_roc_bert.py @@ -0,0 +1,2017 @@ +# coding=utf-8 +# Copyright 2022 WeChatAI The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch RoCBert model.""" + +import math +import os +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...generation import GenerationMixin +from ...modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + MaskedLMOutput, + MultipleChoiceModelOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from .configuration_roc_bert import RoCBertConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "weiweishi/roc-bert-base-zh" +_CONFIG_FOR_DOC = "RoCBertConfig" + +# Base model docstring +_EXPECTED_OUTPUT_SHAPE = [1, 8, 768] + +# Token Classification output +_CHECKPOINT_FOR_TOKEN_CLASSIFICATION = "ArthurZ/dummy-rocbert-ner" +_TOKEN_CLASS_EXPECTED_OUTPUT = ["S-EVENT", "S-FAC", "I-ORDINAL", "I-ORDINAL", "E-ORG", "E-LANGUAGE", "E-ORG", "E-ORG", "E-ORG", "E-ORG", "I-EVENT", "S-TIME", "S-TIME", "E-LANGUAGE", "S-TIME", "E-DATE", "I-ORDINAL", "E-QUANTITY", "E-LANGUAGE", "S-TIME", "B-ORDINAL", "S-PRODUCT", "E-LANGUAGE", "E-LANGUAGE", "E-ORG", "E-LOC", "S-TIME", "I-ORDINAL", "S-FAC", "O", "S-GPE", "I-EVENT", "S-GPE", "E-LANGUAGE", "E-ORG", "S-EVENT", "S-FAC", "S-FAC", "S-FAC", "E-ORG", "S-FAC", "E-ORG", "S-GPE"] # fmt: skip +_TOKEN_CLASS_EXPECTED_LOSS = 3.62 + +# SequenceClassification docstring +_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "ArthurZ/dummy-rocbert-seq" +_SEQ_CLASS_EXPECTED_OUTPUT = "'financial news'" +_SEQ_CLASS_EXPECTED_LOSS = 2.31 + +# QuestionAsnwering docstring +_CHECKPOINT_FOR_QA = "ArthurZ/dummy-rocbert-qa" +_QA_EXPECTED_OUTPUT = "''" +_QA_EXPECTED_LOSS = 3.75 +_QA_TARGET_START_INDEX = 14 +_QA_TARGET_END_INDEX = 15 + +# Maske language modeling + + +# Copied from transformers.models.bert.modeling_bert.load_tf_weights_in_bert with bert->roc_bert +def load_tf_weights_in_roc_bert(model, config, tf_checkpoint_path): + """Load tf checkpoints in a pytorch model.""" + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array) + + for name, array in zip(names, arrays): + name = name.split("/") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any( + n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] + for n in name + ): + logger.info(f"Skipping {'/'.join(name)}") + continue + pointer = model + for m_name in name: + if re.fullmatch(r"[A-Za-z]+_\d+", m_name): + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "kernel" or scope_names[0] == "gamma": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "output_bias" or scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "output_weights": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "squad": + pointer = getattr(pointer, "classifier") + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info(f"Skipping {'/'.join(name)}") + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if m_name[-11:] == "_embeddings": + pointer = getattr(pointer, "weight") + elif m_name == "kernel": + array = np.transpose(array) + try: + if pointer.shape != array.shape: + raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") + except ValueError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array) + return model + + +class RoCBertEmbeddings(nn.Module): + """Construct the embeddings from word, position, shape, pronunciation and token_type embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.pronunciation_embed = nn.Embedding( + config.pronunciation_vocab_size, config.pronunciation_embed_dim, padding_idx=config.pad_token_id + ) + self.shape_embed = nn.Embedding( + config.shape_vocab_size, config.shape_embed_dim, padding_idx=config.pad_token_id + ) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + + self.enable_pronunciation = config.enable_pronunciation + self.enable_shape = config.enable_shape + + if config.concat_input: + input_dim = config.hidden_size + if self.enable_pronunciation: + pronunciation_dim = config.pronunciation_embed_dim + input_dim += pronunciation_dim + if self.enable_shape: + shape_dim = config.shape_embed_dim + input_dim += shape_dim + self.map_inputs_layer = torch.nn.Linear(input_dim, config.hidden_size) + else: + self.map_inputs_layer = None + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + self.register_buffer( + "token_type_ids", + torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device), + persistent=False, + ) + + def forward( + self, + input_ids=None, + input_shape_ids=None, + input_pronunciation_ids=None, + token_type_ids=None, + position_ids=None, + inputs_embeds=None, + past_key_values_length=0, + ): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] + + # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs + # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves + # issue #5664 + if token_type_ids is None: + if hasattr(self, "token_type_ids"): + buffered_token_type_ids = self.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + + if self.map_inputs_layer is None: + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + embeddings = inputs_embeds + token_type_embeddings + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + + denominator = 1 + embedding_in = torch.clone(embeddings) + if self.enable_shape and input_shape_ids is not None: + embedding_shape = self.shape_embed(input_shape_ids) + embedding_in += embedding_shape + denominator += 1 + if self.enable_pronunciation and input_pronunciation_ids is not None: + embedding_pronunciation = self.pronunciation_embed(input_pronunciation_ids) + embedding_in += embedding_pronunciation + denominator += 1 + + embedding_in /= denominator + return embedding_in + else: + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) # embedding_word + device = inputs_embeds.device + + embedding_in = torch.clone(inputs_embeds) + if self.enable_shape: + if input_shape_ids is None: + input_shape_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + embedding_shape = self.shape_embed(input_shape_ids) + embedding_in = torch.cat((embedding_in, embedding_shape), -1) + if self.enable_pronunciation: + if input_pronunciation_ids is None: + input_pronunciation_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + embedding_pronunciation = self.pronunciation_embed(input_pronunciation_ids) + embedding_in = torch.cat((embedding_in, embedding_pronunciation), -1) + + embedding_in = self.map_inputs_layer(embedding_in) # batch_size * seq_len * hidden_dim + + token_type_embeddings = self.token_type_embeddings(token_type_ids) + embedding_in += token_type_embeddings + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embedding_in += position_embeddings + + embedding_in = self.LayerNorm(embedding_in) + embedding_in = self.dropout(embedding_in) + return embedding_in + + +# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->RoCBert +class RoCBertSelfAttention(nn.Module): + def __init__(self, config, position_embedding_type=None): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = position_embedding_type or getattr( + config, "position_embedding_type", "absolute" + ) + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) + + self.is_decoder = config.is_decoder + + def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_layer = past_key_value[0] + value_layer = past_key_value[1] + attention_mask = encoder_attention_mask + elif is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + use_cache = past_key_value is not None + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + query_length, key_length = query_layer.shape[2], key_layer.shape[2] + if use_cache: + position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( + -1, 1 + ) + else: + position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in RoCBertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + +# Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->RoCBert +class RoCBertSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +ROC_BERT_SELF_ATTENTION_CLASSES = { + "eager": RoCBertSelfAttention, +} + + +# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->RoCBert,BERT->ROC_BERT +class RoCBertAttention(nn.Module): + def __init__(self, config, position_embedding_type=None): + super().__init__() + self.self = ROC_BERT_SELF_ATTENTION_CLASSES[config._attn_implementation]( + config, position_embedding_type=position_embedding_type + ) + self.output = RoCBertSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->RoCBert +class RoCBertIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->RoCBert +class RoCBertOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->RoCBert +class RoCBertLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = RoCBertAttention(config) + self.is_decoder = config.is_decoder + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + if not self.is_decoder: + raise ValueError(f"{self} should be used as a decoder model if cross attention is added") + self.crossattention = RoCBertAttention(config, position_embedding_type="absolute") + self.intermediate = RoCBertIntermediate(config) + self.output = RoCBertOutput(config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + + # if decoder, the last output is tuple of self-attn cache + if self.is_decoder: + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + else: + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + cross_attn_present_key_value = None + if self.is_decoder and encoder_hidden_states is not None: + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" + " by setting `config.add_cross_attention=True`" + ) + + # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + cross_attn_past_key_value, + output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + + # add cross-attn cache to positions 3,4 of present_key_value tuple + cross_attn_present_key_value = cross_attention_outputs[-1] + present_key_value = present_key_value + cross_attn_present_key_value + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output + ) + outputs = (layer_output,) + outputs + + # if decoder, return the attn key/values as the last output + if self.is_decoder: + outputs = outputs + (present_key_value,) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +# Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->RoCBert +class RoCBertEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList([RoCBertLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + next_decoder_cache = () if use_cache else None + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.__call__, + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +# Copied from transformers.models.bert.modeling_bert.BertPooler with Bert->RoCBert +class RoCBertPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +# Copied from transformers.models.bert.modeling_bert.BertPredictionHeadTransform with Bert->RoCBert +class RoCBertPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead with Bert->RoCBert +class RoCBertLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = RoCBertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def _tie_weights(self): + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertOnlyMLMHead with Bert->RoCBert +class RoCBertOnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = RoCBertLMPredictionHead(config) + + def forward(self, sequence_output: torch.Tensor) -> torch.Tensor: + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class RoCBertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = RoCBertConfig + load_tf_weights = load_tf_weights_in_roc_bert + base_model_prefix = "roc_bert" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, RoCBertLMPredictionHead): + module.bias.data.zero_() + + +ROC_BERT_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use + it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`RoCBertConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +ROC_BERT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + input_shape_ids (`torch.LongTensor` of shape `({0})`): + Indices of input sequence tokens in the shape vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input_shape_ids) + input_pronunciation_ids (`torch.LongTensor` of shape `({0})`): + Indices of input sequence tokens in the pronunciation vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input_pronunciation_ids) + attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert *input_ids* indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare RoCBert Model transformer outputting raw hidden-states without any specific head on top.", + ROC_BERT_START_DOCSTRING, +) +class RoCBertModel(RoCBertPreTrainedModel): + """ + + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in [Attention is + all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + + To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set + to `True`. To be used in a Seq2Seq model, the model needs to be initialized with both `is_decoder` argument and + `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass. + """ + + # Copied from transformers.models.clap.modeling_clap.ClapTextModel.__init__ with ClapText->RoCBert + def __init__(self, config, add_pooling_layer=True): + super().__init__(config) + self.config = config + + self.embeddings = RoCBertEmbeddings(config) + self.encoder = RoCBertEncoder(config) + + self.pooler = RoCBertPooler(config) if add_pooling_layer else None + + # Initialize weights and apply final processing + self.post_init() + + # Copied from transformers.models.bert.modeling_bert.BertModel.get_input_embeddings + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + # Copied from transformers.models.bert.modeling_bert.BertModel.set_input_embeddings + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def get_pronunciation_embeddings(self): + return self.embeddings.pronunciation_embed + + def set_pronunciation_embeddings(self, value): + self.embeddings.pronunciation_embed = value + + def get_shape_embeddings(self): + return self.embeddings.shape_embed + + def set_shape_embeddings(self, value): + self.embeddings.shape_embed = value + + # Copied from transformers.models.bert.modeling_bert.BertModel._prune_heads + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(ROC_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPoolingAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + input_shape_ids: Optional[torch.Tensor] = None, + input_pronunciation_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.config.is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + + if token_type_ids is None: + if hasattr(self.embeddings, "token_type_ids"): + buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings( + input_ids=input_ids, + input_shape_ids=input_shape_ids, + input_pronunciation_ids=input_pronunciation_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +@add_start_docstrings( + """ + RoCBert Model with contrastive loss and masked_lm_loss during the pretraining. + """, + ROC_BERT_START_DOCSTRING, +) +class RoCBertForPreTraining(RoCBertPreTrainedModel): + _tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.roc_bert = RoCBertModel(config) + self.cls = RoCBertOnlyMLMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + # Copied from transformers.models.bert.modeling_bert.BertForPreTraining.get_output_embeddings + def get_output_embeddings(self): + return self.cls.predictions.decoder + + # Copied from transformers.models.bert.modeling_bert.BertForPreTraining.set_output_embeddings + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias + + @add_start_docstrings_to_model_forward(ROC_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + input_shape_ids: Optional[torch.Tensor] = None, + input_pronunciation_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + attack_input_ids: Optional[torch.Tensor] = None, + attack_input_shape_ids: Optional[torch.Tensor] = None, + attack_input_pronunciation_ids: Optional[torch.Tensor] = None, + attack_attention_mask: Optional[torch.Tensor] = None, + attack_token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels_input_ids: Optional[torch.Tensor] = None, + labels_input_shape_ids: Optional[torch.Tensor] = None, + labels_input_pronunciation_ids: Optional[torch.Tensor] = None, + labels_attention_mask: Optional[torch.Tensor] = None, + labels_token_type_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]: + r""" + attack_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + attack sample ids for computing the contrastive loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), + the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + attack_input_shape_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + attack sample shape ids for computing the contrastive loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), + the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + attack_input_pronunciation_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + attack sample pronunciation ids for computing the contrastive loss. Indices should be in `[-100, 0, + ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + labels_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + target ids for computing the contrastive loss and masked_lm_loss . Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), + the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + labels_input_shape_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + target shape ids for computing the contrastive loss and masked_lm_loss . Indices should be in `[-100, + 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + labels_input_pronunciation_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + target pronunciation ids for computing the contrastive loss and masked_lm_loss . Indices should be in + `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are + ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., + config.vocab_size]` + + kwargs (`Dict[str, any]`, *optional*, defaults to *{}*): + Used to hide legacy arguments that have been deprecated. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, RoCBertForPreTraining + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("weiweishi/roc-bert-base-zh") + >>> model = RoCBertForPreTraining.from_pretrained("weiweishi/roc-bert-base-zh") + + >>> inputs = tokenizer("你好,很高兴认识你", return_tensors="pt") + >>> attack_inputs = {} + >>> for key in list(inputs.keys()): + ... attack_inputs[f"attack_{key}"] = inputs[key] + >>> label_inputs = {} + >>> for key in list(inputs.keys()): + ... label_inputs[f"labels_{key}"] = inputs[key] + + >>> inputs.update(label_inputs) + >>> inputs.update(attack_inputs) + >>> outputs = model(**inputs) + + >>> logits = outputs.logits + >>> logits.shape + torch.Size([1, 11, 21128]) + ``` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.roc_bert( + input_ids, + input_shape_ids=input_shape_ids, + input_pronunciation_ids=input_pronunciation_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output, pooled_output = outputs[:2] + prediction_scores = self.cls(sequence_output) + + loss = None + if labels_input_ids is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels_input_ids.view(-1)) + + if attack_input_ids is not None: + batch_size, _ = labels_input_ids.shape + device = labels_input_ids.device + + target_inputs = torch.clone(labels_input_ids) + target_inputs[target_inputs == -100] = self.config.pad_token_id + + labels_output = self.roc_bert( + target_inputs, + input_shape_ids=labels_input_shape_ids, + input_pronunciation_ids=labels_input_pronunciation_ids, + attention_mask=labels_attention_mask, + token_type_ids=labels_token_type_ids, + return_dict=return_dict, + ) + attack_output = self.roc_bert( + attack_input_ids, + input_shape_ids=attack_input_shape_ids, + input_pronunciation_ids=attack_input_pronunciation_ids, + attention_mask=attack_attention_mask, + token_type_ids=attack_token_type_ids, + return_dict=return_dict, + ) + + labels_pooled_output = labels_output[1] + attack_pooled_output = attack_output[1] + + pooled_output_norm = torch.nn.functional.normalize(pooled_output, dim=-1) + labels_pooled_output_norm = torch.nn.functional.normalize(labels_pooled_output, dim=-1) + attack_pooled_output_norm = torch.nn.functional.normalize(attack_pooled_output, dim=-1) + + sim_matrix = torch.matmul(pooled_output_norm, attack_pooled_output_norm.T) # batch_size * hidden_dim + sim_matrix_target = torch.matmul(labels_pooled_output_norm, attack_pooled_output_norm.T) + batch_labels = torch.tensor(list(range(batch_size)), device=device) + contrastive_loss = ( + loss_fct(100 * sim_matrix.view(batch_size, -1), batch_labels.view(-1)) + + loss_fct(100 * sim_matrix_target.view(batch_size, -1), batch_labels.view(-1)) + ) / 2 + + loss = contrastive_loss + masked_lm_loss + else: + loss = masked_lm_loss + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return MaskedLMOutput( + loss=loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings("""RoCBert Model with a `language modeling` head on top.""", ROC_BERT_START_DOCSTRING) +class RoCBertForMaskedLM(RoCBertPreTrainedModel): + _tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"] + + # Copied from transformers.models.bert.modeling_bert.BertForMaskedLM.__init__ with Bert->RoCBert,bert->roc_bert + def __init__(self, config): + super().__init__(config) + + if config.is_decoder: + logger.warning( + "If you want to use `RoCBertForMaskedLM` make sure `config.is_decoder=False` for " + "bi-directional self-attention." + ) + + self.roc_bert = RoCBertModel(config, add_pooling_layer=False) + self.cls = RoCBertOnlyMLMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + # Copied from transformers.models.bert.modeling_bert.BertForMaskedLM.get_output_embeddings + def get_output_embeddings(self): + return self.cls.predictions.decoder + + # Copied from transformers.models.bert.modeling_bert.BertForMaskedLM.set_output_embeddings + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias + + @add_start_docstrings_to_model_forward(ROC_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + input_shape_ids: Optional[torch.Tensor] = None, + input_pronunciation_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + ```python + >>> from transformers import AutoTokenizer, RoCBertForMaskedLM + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("weiweishi/roc-bert-base-zh") + >>> model = RoCBertForMaskedLM.from_pretrained("weiweishi/roc-bert-base-zh") + + >>> inputs = tokenizer("法国是首都[MASK].", return_tensors="pt") + + >>> with torch.no_grad(): + ... logits = model(**inputs).logits + + >>> # retrieve index of {mask} + >>> mask_token_index = (inputs.input_ids == tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0] + + >>> predicted_token_id = logits[0, mask_token_index].argmax(axis=-1) + >>> tokenizer.decode(predicted_token_id) + '.' + ``` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.roc_bert( + input_ids, + input_shape_ids=input_shape_ids, + input_pronunciation_ids=input_pronunciation_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, input_shape_ids=None, input_pronunciation_ids=None, attention_mask=None, **model_kwargs + ): + input_shape = input_ids.shape + effective_batch_size = input_shape[0] + + # add a dummy token + if self.config.pad_token_id is None: + raise ValueError("The PAD token should be defined for generation") + + attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1) + dummy_token = torch.full( + (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device + ) + input_ids = torch.cat([input_ids, dummy_token], dim=1) + if input_shape_ids is not None: + input_shape_ids = torch.cat([input_shape_ids, dummy_token], dim=1) + if input_pronunciation_ids is not None: + input_pronunciation_ids = torch.cat([input_pronunciation_ids, dummy_token], dim=1) + + return { + "input_ids": input_ids, + "input_shape_ids": input_shape_ids, + "input_pronunciation_ids": input_pronunciation_ids, + "attention_mask": attention_mask, + } + + +@add_start_docstrings( + """RoCBert Model with a `language modeling` head on top for CLM fine-tuning.""", ROC_BERT_START_DOCSTRING +) +class RoCBertForCausalLM(RoCBertPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"] + + # Copied from transformers.models.bert.modeling_bert.BertLMHeadModel.__init__ with BertLMHeadModel->RoCBertForCausalLM,Bert->RoCBert,bert->roc_bert + def __init__(self, config): + super().__init__(config) + + if not config.is_decoder: + logger.warning("If you want to use `RoCRoCBertForCausalLM` as a standalone, add `is_decoder=True.`") + + self.roc_bert = RoCBertModel(config, add_pooling_layer=False) + self.cls = RoCBertOnlyMLMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + # Copied from transformers.models.bert.modeling_bert.BertLMHeadModel.get_output_embeddings + def get_output_embeddings(self): + return self.cls.predictions.decoder + + # Copied from transformers.models.bert.modeling_bert.BertLMHeadModel.set_output_embeddings + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias + + @add_start_docstrings_to_model_forward(ROC_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + input_shape_ids: Optional[torch.Tensor] = None, + input_pronunciation_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.Tensor]] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional tensors are + only required when the model is used as a decoder in a Sequence to Sequence model. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are + ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, RoCBertForCausalLM, RoCBertConfig + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("weiweishi/roc-bert-base-zh") + >>> config = RoCBertConfig.from_pretrained("weiweishi/roc-bert-base-zh") + >>> config.is_decoder = True + >>> model = RoCBertForCausalLM.from_pretrained("weiweishi/roc-bert-base-zh", config=config) + + >>> inputs = tokenizer("你好,很高兴认识你", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> prediction_logits = outputs.logits + ``` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.roc_bert( + input_ids, + input_shape_ids=input_shape_ids, + input_pronunciation_ids=input_pronunciation_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + lm_loss = None + if labels is not None: + lm_loss = self.loss_function( + prediction_scores, + labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + input_shape_ids=None, + input_pronunciation_ids=None, + past_key_values=None, + attention_mask=None, + **model_kwargs, + ): + # Overwritten -- `input_pronunciation_ids` + + input_shape = input_ids.shape + + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_shape) + + # cut decoder_input_ids if past_key_values is used + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] + if input_shape_ids is not None: + input_shape_ids = input_shape_ids[:, -1:] + if input_pronunciation_ids is not None: + input_pronunciation_ids = input_pronunciation_ids[:, -1:] + + return { + "input_ids": input_ids, + "input_shape_ids": input_shape_ids, + "input_pronunciation_ids": input_pronunciation_ids, + "attention_mask": attention_mask, + "past_key_values": past_key_values, + } + + # Copied from transformers.models.bert.modeling_bert.BertLMHeadModel._reorder_cache + def _reorder_cache(self, past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + + +@add_start_docstrings( + """RoCBert Model transformer with a sequence classification/regression head on top (a linear layer on top of + the pooled output) e.g. for GLUE tasks.""", + ROC_BERT_START_DOCSTRING, +) +class RoCBertForSequenceClassification(RoCBertPreTrainedModel): + # Copied from transformers.models.bert.modeling_bert.BertForSequenceClassification.__init__ with Bert->RoCBert,bert->roc_bert + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + + self.roc_bert = RoCBertModel(config) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(ROC_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION, + output_type=SequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_SEQ_CLASS_EXPECTED_OUTPUT, + expected_loss=_SEQ_CLASS_EXPECTED_LOSS, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + input_shape_ids: Optional[torch.Tensor] = None, + input_pronunciation_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.roc_bert( + input_ids, + input_shape_ids=input_shape_ids, + input_pronunciation_ids=input_pronunciation_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """RoCBert Model with a multiple choice classification head on top (a linear layer on top of + the pooled output and a softmax) e.g. for RocStories/SWAG tasks.""", + ROC_BERT_START_DOCSTRING, +) +class RoCBertForMultipleChoice(RoCBertPreTrainedModel): + # Copied from transformers.models.bert.modeling_bert.BertForMultipleChoice.__init__ with Bert->RoCBert,bert->roc_bert + def __init__(self, config): + super().__init__(config) + + self.roc_bert = RoCBertModel(config) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, 1) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward( + ROC_BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length") + ) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MultipleChoiceModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + input_shape_ids: Optional[torch.Tensor] = None, + input_pronunciation_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., + num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See + `input_ids` above) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] + + input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None + input_shape_ids = input_shape_ids.view(-1, input_shape_ids.size(-1)) if input_shape_ids is not None else None + input_pronunciation_ids = ( + input_pronunciation_ids.view(-1, input_pronunciation_ids.size(-1)) + if input_pronunciation_ids is not None + else None + ) + attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None + position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) + + outputs = self.roc_bert( + input_ids, + input_shape_ids=input_shape_ids, + input_pronunciation_ids=input_pronunciation_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + reshaped_logits = logits.view(-1, num_choices) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + + if not return_dict: + output = (reshaped_logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return MultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """RoCBert Model with a token classification head on top (a linear layer on top of + the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks.""", + ROC_BERT_START_DOCSTRING, +) +class RoCBertForTokenClassification(RoCBertPreTrainedModel): + # Copied from transformers.models.bert.modeling_bert.BertForTokenClassification.__init__ with Bert->RoCBert,bert->roc_bert + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.roc_bert = RoCBertModel(config, add_pooling_layer=False) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(ROC_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_TOKEN_CLASSIFICATION, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_TOKEN_CLASS_EXPECTED_OUTPUT, + expected_loss=_TOKEN_CLASS_EXPECTED_LOSS, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + input_shape_ids: Optional[torch.Tensor] = None, + input_pronunciation_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.roc_bert( + input_ids, + input_shape_ids=input_shape_ids, + input_pronunciation_ids=input_pronunciation_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """RoCBert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear + layers on top of the hidden-states output to compute `span start logits` and `span end logits`).""", + ROC_BERT_START_DOCSTRING, +) +class RoCBertForQuestionAnswering(RoCBertPreTrainedModel): + # Copied from transformers.models.bert.modeling_bert.BertForQuestionAnswering.__init__ with Bert->RoCBert,bert->roc_bert + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.roc_bert = RoCBertModel(config, add_pooling_layer=False) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(ROC_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_QA, + output_type=QuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + qa_target_start_index=_QA_TARGET_START_INDEX, + qa_target_end_index=_QA_TARGET_END_INDEX, + expected_output=_QA_EXPECTED_OUTPUT, + expected_loss=_QA_EXPECTED_LOSS, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + input_shape_ids: Optional[torch.Tensor] = None, + input_pronunciation_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + start_positions: Optional[torch.Tensor] = None, + end_positions: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.roc_bert( + input_ids, + input_shape_ids=input_shape_ids, + input_pronunciation_ids=input_pronunciation_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1) + end_logits = end_logits.squeeze(-1) + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "RoCBertForCausalLM", + "RoCBertForMaskedLM", + "RoCBertForMultipleChoice", + "RoCBertForPreTraining", + "RoCBertForQuestionAnswering", + "RoCBertForSequenceClassification", + "RoCBertForTokenClassification", + "RoCBertLayer", + "RoCBertModel", + "RoCBertPreTrainedModel", + "load_tf_weights_in_roc_bert", +] diff --git a/docs/transformers/src/transformers/models/roc_bert/tokenization_roc_bert.py b/docs/transformers/src/transformers/models/roc_bert/tokenization_roc_bert.py new file mode 100644 index 0000000000000000000000000000000000000000..f0f38a48e51afc21f621fad4b85dfc1de75eb57d --- /dev/null +++ b/docs/transformers/src/transformers/models/roc_bert/tokenization_roc_bert.py @@ -0,0 +1,1122 @@ +# coding=utf-8 +# Copyright 2022 WeChatAI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for RoCBert.""" + +import collections +import itertools +import json +import os +import unicodedata +from typing import Dict, List, Optional, Tuple, Union + +from ...tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace +from ...tokenization_utils_base import ( + ENCODE_KWARGS_DOCSTRING, + ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING, + BatchEncoding, + EncodedInput, + EncodedInputPair, + PaddingStrategy, + PreTokenizedInput, + PreTokenizedInputPair, + TensorType, + TextInput, + TextInputPair, + TruncationStrategy, +) +from ...utils import add_end_docstrings, logging + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = { + "vocab_file": "vocab.txt", + "word_shape_file": "word_shape.json", + "word_pronunciation_file": "word_pronunciation.json", +} + + +# Copied from transformers.models.bert.tokenization_bert.load_vocab +def load_vocab(vocab_file): + """Loads a vocabulary file into a dictionary.""" + vocab = collections.OrderedDict() + with open(vocab_file, "r", encoding="utf-8") as reader: + tokens = reader.readlines() + for index, token in enumerate(tokens): + token = token.rstrip("\n") + vocab[token] = index + return vocab + + +# Copied from transformers.models.bert.tokenization_bert.whitespace_tokenize +def whitespace_tokenize(text): + """Runs basic whitespace cleaning and splitting on a piece of text.""" + text = text.strip() + if not text: + return [] + tokens = text.split() + return tokens + + +class RoCBertTokenizer(PreTrainedTokenizer): + r""" + Args: + Construct a RoCBert tokenizer. Based on WordPiece. This tokenizer inherits from [`PreTrainedTokenizer`] which + contains most of the main methods. Users should refer to this superclass for more information regarding those + methods. + vocab_file (`str`): + File containing the vocabulary. + word_shape_file (`str`): + File containing the word => shape info. + word_pronunciation_file (`str`): + File containing the word => pronunciation info. + do_lower_case (`bool`, *optional*, defaults to `True`): + Whether or not to lowercase the input when tokenizing. + do_basic_tokenize (`bool`, *optional*, defaults to `True`): + Whether or not to do basic tokenization before WordPiece. + never_split (`Iterable`, *optional*): + Collection of tokens which will never be split during tokenization. Only has an effect when + `do_basic_tokenize=True` + unk_token (`str`, *optional*, defaults to `"[UNK]"`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + sep_token (`str`, *optional*, defaults to `"[SEP]"`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + pad_token (`str`, *optional*, defaults to `"[PAD]"`): + The token used for padding, for example when batching sequences of different lengths. + cls_token (`str`, *optional*, defaults to `"[CLS]"`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + mask_token (`str`, *optional*, defaults to `"[MASK]"`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + tokenize_chinese_chars (`bool`, *optional*, defaults to `True`): + Whether or not to tokenize Chinese characters. This should likely be deactivated for Japanese (see this + [issue](https://github.com/huggingface/transformers/issues/328)). + strip_accents (`bool`, *optional*): + Whether or not to strip all accents. If this option is not specified, then it will be determined by the + value for `lowercase` (as in the original BERT). + """ + + vocab_files_names = VOCAB_FILES_NAMES + + def __init__( + self, + vocab_file, + word_shape_file, + word_pronunciation_file, + do_lower_case=True, + do_basic_tokenize=True, + never_split=None, + unk_token="[UNK]", + sep_token="[SEP]", + pad_token="[PAD]", + cls_token="[CLS]", + mask_token="[MASK]", + tokenize_chinese_chars=True, + strip_accents=None, + **kwargs, + ): + for cur_file in [vocab_file, word_shape_file, word_pronunciation_file]: + if cur_file is None or not os.path.isfile(cur_file): + raise ValueError( + f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google " + "pretrained model use `tokenizer = RoCBertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`" + ) + + self.vocab = load_vocab(vocab_file) + + with open(word_shape_file, "r", encoding="utf8") as in_file: + self.word_shape = json.load(in_file) + + with open(word_pronunciation_file, "r", encoding="utf8") as in_file: + self.word_pronunciation = json.load(in_file) + + self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()]) + + self.do_basic_tokenize = do_basic_tokenize + if do_basic_tokenize: + self.basic_tokenizer = RoCBertBasicTokenizer( + do_lower_case=do_lower_case, + never_split=never_split, + tokenize_chinese_chars=tokenize_chinese_chars, + strip_accents=strip_accents, + ) + self.wordpiece_tokenizer = RoCBertWordpieceTokenizer(vocab=self.vocab, unk_token=str(unk_token)) + super().__init__( + do_lower_case=do_lower_case, + do_basic_tokenize=do_basic_tokenize, + never_split=never_split, + unk_token=unk_token, + sep_token=sep_token, + pad_token=pad_token, + cls_token=cls_token, + mask_token=mask_token, + tokenize_chinese_chars=tokenize_chinese_chars, + strip_accents=strip_accents, + **kwargs, + ) + + @property + def do_lower_case(self): + return self.basic_tokenizer.do_lower_case + + @property + def vocab_size(self): + return len(self.vocab) + + # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.get_vocab + def get_vocab(self): + return dict(self.vocab, **self.added_tokens_encoder) + + # Copied from transformers.models.bert.tokenization_bert.BertTokenizer._tokenize + def _tokenize(self, text, split_special_tokens=False): + split_tokens = [] + if self.do_basic_tokenize: + for token in self.basic_tokenizer.tokenize( + text, never_split=self.all_special_tokens if not split_special_tokens else None + ): + # If the token is part of the never_split set + if token in self.basic_tokenizer.never_split: + split_tokens.append(token) + else: + split_tokens += self.wordpiece_tokenizer.tokenize(token) + else: + split_tokens = self.wordpiece_tokenizer.tokenize(text) + return split_tokens + + def _encode_plus( + self, + text: Union[TextInput, PreTokenizedInput, EncodedInput], + text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None, + add_special_tokens: bool = True, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + max_length: Optional[int] = None, + stride: int = 0, + is_split_into_words: bool = False, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + def get_input_ids(text): + if isinstance(text, str): + tokens = self.tokenize(text, **kwargs) + tokens_ids = self.convert_tokens_to_ids(tokens) + tokens_shape_ids = self.convert_tokens_to_shape_ids(tokens) + tokens_proun_ids = self.convert_tokens_to_pronunciation_ids(tokens) + return tokens_ids, tokens_shape_ids, tokens_proun_ids + elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], str): + if is_split_into_words: + tokens = list( + itertools.chain(*(self.tokenize(t, is_split_into_words=True, **kwargs) for t in text)) + ) + tokens_ids = self.convert_tokens_to_ids(tokens) + tokens_shape_ids = self.convert_tokens_to_shape_ids(tokens) + tokens_proun_ids = self.convert_tokens_to_pronunciation_ids(tokens) + return tokens_ids, tokens_shape_ids, tokens_proun_ids + else: + tokens_ids = self.convert_tokens_to_ids(text) + tokens_shape_ids = self.convert_tokens_to_shape_ids(text) + tokens_proun_ids = self.convert_tokens_to_pronunciation_ids(text) + return tokens_ids, tokens_shape_ids, tokens_proun_ids + elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int): + return text, [0] * len(text), [0] * len(text) # shape and proun id is pad_value + else: + if is_split_into_words: + raise ValueError( + f"Input {text} is not valid. Should be a string or a list/tuple of strings when" + " `is_split_into_words=True`." + ) + else: + raise ValueError( + f"Input {text} is not valid. Should be a string, a list/tuple of strings or a list/tuple of" + " integers." + ) + + if return_offsets_mapping: + raise NotImplementedError( + "return_offset_mapping is not available when using Python tokenizers. " + "To use this feature, change your tokenizer to one deriving from " + "transformers.PreTrainedTokenizerFast. " + "More information on available tokenizers at " + "https://github.com/huggingface/transformers/pull/2674" + ) + + first_ids, first_shape_ids, first_proun_ids = get_input_ids(text) + if text_pair is not None: + second_ids, second_shape_ids, second_proun_ids = get_input_ids(text_pair) + else: + second_ids, second_shape_ids, second_proun_ids = None, None, None + + return self.prepare_for_model( + first_ids, + first_shape_ids, + first_proun_ids, + pair_ids=second_ids, + pair_shape_ids=second_shape_ids, + pair_pronunciation_ids=second_proun_ids, + add_special_tokens=add_special_tokens, + padding=padding_strategy.value, + truncation=truncation_strategy.value, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_tensors=return_tensors, + prepend_batch_axis=True, + return_attention_mask=return_attention_mask, + return_token_type_ids=return_token_type_ids, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_length=return_length, + verbose=verbose, + ) + + @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) + def prepare_for_model( + self, + ids: List[int], + shape_ids: List[int], + pronunciation_ids: List[int], + pair_ids: Optional[List[int]] = None, + pair_shape_ids: Optional[List[int]] = None, + pair_pronunciation_ids: Optional[List[int]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + prepend_batch_axis: bool = False, + **kwargs, + ) -> BatchEncoding: + """ + Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. It + adds special tokens, truncates sequences if overflowing while taking into account the special tokens and + manages a moving window (with user defined stride) for overflowing tokens. Please Note, for *pair_ids* + different than `None` and *truncation_strategy = longest_first* or `True`, it is not possible to return + overflowing tokens. Such a combination of arguments will raise an error. + + Args: + ids (`List[int]`): + Tokenized input ids of the first sequence. Can be obtained from a string by chaining the `tokenize` and + `convert_tokens_to_id` methods. + shape_ids (`List[int]`): + Tokenized input ids of the first sequence. Can be obtained from a string by chaining the `tokenize` and + `convert_token_to_shape_id` methods. + pronunciation_ids (`List[int]`): + Tokenized input ids of the first sequence. Can be obtained from a string by chaining the `tokenize` and + `convert_token_to_pronunciation_id` methods. + pair_ids (`List[int]`, *optional*): + Tokenized input ids of the second sequence. Can be obtained from a string by chaining the `tokenize` + and `convert_tokens_to_id` methods. + pair_shape_ids (`List[int]`, *optional*): + Tokenized input ids of the second sequence. Can be obtained from a string by chaining the `tokenize` + and `convert_token_to_shape_id` methods. + pair_pronunciation_ids (`List[int]`, *optional*): + Tokenized input ids of the second sequence. Can be obtained from a string by chaining the `tokenize` + and `convert_token_to_pronunciation_id` methods. + """ + + # Backward compatibility for 'truncation_strategy', 'pad_to_max_length' + padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies( + padding=padding, + truncation=truncation, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + verbose=verbose, + **kwargs, + ) + + pair = bool(pair_ids is not None) + len_ids = len(ids) + len_pair_ids = len(pair_ids) if pair else 0 + + if return_token_type_ids and not add_special_tokens: + raise ValueError( + "Asking to return token_type_ids while setting add_special_tokens to False " + "results in an undefined behavior. Please set add_special_tokens to True or " + "set return_token_type_ids to None." + ) + + if ( + return_overflowing_tokens + and truncation_strategy == TruncationStrategy.LONGEST_FIRST + and pair_ids is not None + ): + raise ValueError( + "Not possible to return overflowing tokens for pair of sequences with the " + "`longest_first`. Please select another truncation strategy than `longest_first`, " + "for instance `only_second` or `only_first`." + ) + + # Load from model defaults + if return_token_type_ids is None: + return_token_type_ids = "token_type_ids" in self.model_input_names + if return_attention_mask is None: + return_attention_mask = "attention_mask" in self.model_input_names + + encoded_inputs = {} + + # Compute the total size of the returned encodings + total_len = len_ids + len_pair_ids + (self.num_special_tokens_to_add(pair=pair) if add_special_tokens else 0) + + # Truncation: Handle max sequence length + overflowing_tokens = [] + if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and max_length and total_len > max_length: + ids, pair_ids, overflowing_tokens = self.truncate_sequences( + ids, + pair_ids=pair_ids, + num_tokens_to_remove=total_len - max_length, + truncation_strategy=truncation_strategy, + stride=stride, + ) + shape_ids, pair_shape_ids, _ = self.truncate_sequences( + shape_ids, + pair_ids=pair_shape_ids, + num_tokens_to_remove=total_len - max_length, + truncation_strategy=truncation_strategy, + stride=stride, + ) + pronunciation_ids, pair_pronunciation_ids, _ = self.truncate_sequences( + pronunciation_ids, + pair_ids=pair_pronunciation_ids, + num_tokens_to_remove=total_len - max_length, + truncation_strategy=truncation_strategy, + stride=stride, + ) + + if return_overflowing_tokens: + encoded_inputs["overflowing_tokens"] = overflowing_tokens + encoded_inputs["num_truncated_tokens"] = total_len - max_length + + # Add special tokens + if add_special_tokens: + sequence = self.build_inputs_with_special_tokens(ids, pair_ids) + token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids) + input_shape_ids = self.build_inputs_with_special_tokens( + shape_ids, pair_shape_ids, self.word_shape["[UNK]"], self.word_shape["[UNK]"] + ) + input_pronunciation_ids = self.build_inputs_with_special_tokens( + pronunciation_ids, + pair_pronunciation_ids, + self.word_pronunciation["[UNK]"], + self.word_pronunciation["[UNK]"], + ) + else: + sequence = ids + pair_ids if pair_ids else ids + token_type_ids = [0] * len(ids) + ([0] * len(pair_ids) if pair_ids else []) + input_shape_ids = shape_ids + pair_shape_ids if pair_shape_ids else shape_ids + input_pronunciation_ids = ( + pronunciation_ids + pair_pronunciation_ids if pair_pronunciation_ids else pronunciation_ids + ) + + # Build output dictionary + encoded_inputs["input_ids"] = sequence + encoded_inputs["input_shape_ids"] = input_shape_ids + encoded_inputs["input_pronunciation_ids"] = input_pronunciation_ids + if return_token_type_ids: + encoded_inputs["token_type_ids"] = token_type_ids + if return_special_tokens_mask: + if add_special_tokens: + encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask(ids, pair_ids) + else: + encoded_inputs["special_tokens_mask"] = [0] * len(sequence) + + # Check lengths + self._eventual_warn_about_too_long_sequence(encoded_inputs["input_ids"], max_length, verbose) + + # Padding + if padding_strategy != PaddingStrategy.DO_NOT_PAD or return_attention_mask: + encoded_inputs = self.pad( + encoded_inputs, + max_length=max_length, + padding=padding_strategy.value, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_attention_mask=return_attention_mask, + ) + + if return_length: + encoded_inputs["length"] = len(encoded_inputs["input_ids"]) + + batch_outputs = BatchEncoding( + encoded_inputs, tensor_type=return_tensors, prepend_batch_axis=prepend_batch_axis + ) + + return batch_outputs + + def _pad( + self, + encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding], + max_length: Optional[int] = None, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_attention_mask: Optional[bool] = None, + ) -> dict: + # Load from model defaults + if return_attention_mask is None: + return_attention_mask = "attention_mask" in self.model_input_names + + required_input = encoded_inputs[self.model_input_names[0]] + + if padding_strategy == PaddingStrategy.LONGEST: + max_length = len(required_input) + + if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): + max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of + + needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length + + # Initialize attention mask if not present. + if return_attention_mask and "attention_mask" not in encoded_inputs: + encoded_inputs["attention_mask"] = [1] * len(required_input) + + if needs_to_be_padded: + difference = max_length - len(required_input) + padding_side = padding_side if padding_side is not None else self.padding_side + + if padding_side == "right": + if return_attention_mask: + encoded_inputs["attention_mask"] = encoded_inputs["attention_mask"] + [0] * difference + if "token_type_ids" in encoded_inputs: + encoded_inputs["token_type_ids"] = ( + encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference + ) + if "special_tokens_mask" in encoded_inputs: + encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference + for key in ["input_shape_ids", "input_pronunciation_ids"]: + if key in encoded_inputs: + encoded_inputs[key] = encoded_inputs[key] + [self.pad_token_id] * difference + encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference + elif padding_side == "left": + if return_attention_mask: + encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"] + if "token_type_ids" in encoded_inputs: + encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[ + "token_type_ids" + ] + if "special_tokens_mask" in encoded_inputs: + encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"] + for key in ["input_shape_ids", "input_pronunciation_ids"]: + if key in encoded_inputs: + encoded_inputs[key] = [self.pad_token_id] * difference + encoded_inputs[key] + encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input + else: + raise ValueError("Invalid padding strategy:" + str(padding_side)) + + return encoded_inputs + + def _batch_encode_plus( + self, + batch_text_or_text_pairs: Union[ + List[TextInput], + List[TextInputPair], + List[PreTokenizedInput], + List[PreTokenizedInputPair], + List[EncodedInput], + List[EncodedInputPair], + ], + add_special_tokens: bool = True, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + max_length: Optional[int] = None, + stride: int = 0, + is_split_into_words: bool = False, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + def get_input_ids(text): + if isinstance(text, str): + tokens = self.tokenize(text, **kwargs) + tokens_ids = self.convert_tokens_to_ids(tokens) + tokens_shape_ids = self.convert_tokens_to_shape_ids(tokens) + tokens_proun_ids = self.convert_tokens_to_pronunciation_ids(tokens) + return tokens_ids, tokens_shape_ids, tokens_proun_ids + elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], str): + if is_split_into_words: + tokens = list( + itertools.chain(*(self.tokenize(t, is_split_into_words=True, **kwargs) for t in text)) + ) + tokens_ids = self.convert_tokens_to_ids(tokens) + tokens_shape_ids = self.convert_tokens_to_shape_ids(tokens) + tokens_proun_ids = self.convert_tokens_to_pronunciation_ids(tokens) + return tokens_ids, tokens_shape_ids, tokens_proun_ids + else: + tokens_ids = self.convert_tokens_to_ids(text) + tokens_shape_ids = self.convert_tokens_to_shape_ids(text) + tokens_proun_ids = self.convert_tokens_to_pronunciation_ids(text) + return tokens_ids, tokens_shape_ids, tokens_proun_ids + elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int): + return text, [0] * len(text), [0] * len(text) # shape and proun id is pad_value + else: + raise ValueError( + "Input is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers." + ) + + if return_offsets_mapping: + raise NotImplementedError( + "return_offset_mapping is not available when using Python tokenizers. " + "To use this feature, change your tokenizer to one deriving from " + "transformers.PreTrainedTokenizerFast." + ) + + input_ids = [] + input_shape_ids = [] + input_pronunciation_ids = [] + for ids_or_pair_ids in batch_text_or_text_pairs: + if not isinstance(ids_or_pair_ids, (list, tuple)): + ids, pair_ids = ids_or_pair_ids, None + elif is_split_into_words and not isinstance(ids_or_pair_ids[0], (list, tuple)): + ids, pair_ids = ids_or_pair_ids, None + else: + ids, pair_ids = ids_or_pair_ids + + first_ids, first_shape_ids, first_proun_ids = get_input_ids(ids) + if pair_ids is not None: + second_ids, second_shape_ids, second_proun_ids = get_input_ids(pair_ids) + else: + second_ids, second_shape_ids, second_proun_ids = None, None, None + + input_ids.append((first_ids, second_ids)) + input_shape_ids.append((first_shape_ids, second_shape_ids)) + input_pronunciation_ids.append((first_proun_ids, second_proun_ids)) + + batch_outputs = self._batch_prepare_for_model( + input_ids, + batch_shape_ids_pairs=input_shape_ids, + batch_pronunciation_ids_pairs=input_pronunciation_ids, + add_special_tokens=add_special_tokens, + padding_strategy=padding_strategy, + truncation_strategy=truncation_strategy, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_attention_mask=return_attention_mask, + return_token_type_ids=return_token_type_ids, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_length=return_length, + return_tensors=return_tensors, + verbose=verbose, + ) + + return BatchEncoding(batch_outputs) + + @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) + def _batch_prepare_for_model( + self, + batch_ids_pairs: List[Union[PreTokenizedInputPair, Tuple[List[int], None]]], + batch_shape_ids_pairs: List[Union[PreTokenizedInputPair, Tuple[List[int], None]]], + batch_pronunciation_ids_pairs: List[Union[PreTokenizedInputPair, Tuple[List[int], None]]], + add_special_tokens: bool = True, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[str] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_length: bool = False, + verbose: bool = True, + ) -> BatchEncoding: + """ + Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. It + adds special tokens, truncates sequences if overflowing while taking into account the special tokens and + manages a moving window (with user defined stride) for overflowing tokens + + Args: + batch_ids_pairs: list of tokenized input ids or input ids pairs + batch_shape_ids_pairs: list of tokenized input shape ids or input shape ids pairs + batch_pronunciation_ids_pairs: list of tokenized input pronunciation ids or input pronunciation ids pairs + """ + + batch_outputs = {} + for i, (first_ids, second_ids) in enumerate(batch_ids_pairs): + first_shape_ids, second_shape_ids = batch_shape_ids_pairs[i] + first_pronunciation_ids, second_pronunciation_ids = batch_pronunciation_ids_pairs[i] + outputs = self.prepare_for_model( + first_ids, + first_shape_ids, + first_pronunciation_ids, + pair_ids=second_ids, + pair_shape_ids=second_shape_ids, + pair_pronunciation_ids=second_pronunciation_ids, + add_special_tokens=add_special_tokens, + padding=PaddingStrategy.DO_NOT_PAD.value, # we pad in batch afterward + truncation=truncation_strategy.value, + max_length=max_length, + stride=stride, + pad_to_multiple_of=None, # we pad in batch afterward + padding_side=None, # we pad in batch afterward + return_attention_mask=False, # we pad in batch afterward + return_token_type_ids=return_token_type_ids, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_length=return_length, + return_tensors=None, # We convert the whole batch to tensors at the end + prepend_batch_axis=False, + verbose=verbose, + ) + + for key, value in outputs.items(): + if key not in batch_outputs: + batch_outputs[key] = [] + batch_outputs[key].append(value) + + batch_outputs = self.pad( + batch_outputs, + padding=padding_strategy.value, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_attention_mask=return_attention_mask, + ) + + batch_outputs = BatchEncoding(batch_outputs, tensor_type=return_tensors) + + return batch_outputs + + # Copied from transformers.models.bert.tokenization_bert.BertTokenizer._convert_token_to_id + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.vocab.get(token, self.vocab.get(self.unk_token)) + + def _convert_token_to_shape_id(self, token): + """Converts a token (str) in an shape_id using the shape vocab.""" + return self.word_shape.get(token, self.word_shape.get(self.unk_token)) + + def convert_tokens_to_shape_ids(self, tokens: Union[str, List[str]]) -> Union[int, List[int]]: + if tokens is None: + return None + + ids = [] + for token in tokens: + ids.append(self._convert_token_to_shape_id(token)) + return ids + + def _convert_token_to_pronunciation_id(self, token): + """Converts a token (str) in an shape_id using the shape vocab.""" + return self.word_pronunciation.get(token, self.word_pronunciation.get(self.unk_token)) + + def convert_tokens_to_pronunciation_ids(self, tokens: Union[str, List[str]]) -> Union[int, List[int]]: + if tokens is None: + return None + + ids = [] + for token in tokens: + ids.append(self._convert_token_to_pronunciation_id(token)) + return ids + + # Copied from transformers.models.bert.tokenization_bert.BertTokenizer._convert_id_to_token + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.ids_to_tokens.get(index, self.unk_token) + + # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.convert_tokens_to_string + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + out_string = " ".join(tokens).replace(" ##", "").strip() + return out_string + + def build_inputs_with_special_tokens( + self, + token_ids_0: List[int], + token_ids_1: Optional[List[int]] = None, + cls_token_id: Optional[int] = None, + sep_token_id: Optional[int] = None, + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A BERT sequence has the following format: + + - single sequence: `[CLS] X [SEP]` + - pair of sequences: `[CLS] A [SEP] B [SEP]` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + cls = [self.cls_token_id] if cls_token_id is None else [cls_token_id] + sep = [self.sep_token_id] if sep_token_id is None else [sep_token_id] + if token_ids_1 is None: + return cls + token_ids_0 + sep + return cls + token_ids_0 + sep + token_ids_1 + sep + + # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.get_special_tokens_mask + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + if token_ids_1 is not None: + return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1] + + # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.create_token_type_ids_from_sequences + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence + pair mask has the following format: + + ``` + 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 + | first sequence | second sequence | + ``` + + If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s). + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s). + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str, str, str]: + index = 0 + if os.path.isdir(save_directory): + vocab_file = os.path.join( + save_directory, + (filename_prefix + "-" if filename_prefix else "") + self.vocab_files_names["vocab_file"], + ) + word_shape_file = os.path.join( + save_directory, + (filename_prefix + "-" if filename_prefix else "") + self.vocab_files_names["word_shape_file"], + ) + word_pronunciation_file = os.path.join( + save_directory, + (filename_prefix + "-" if filename_prefix else "") + self.vocab_files_names["word_pronunciation_file"], + ) + else: + raise ValueError( + f"Can't find a directory at path '{save_directory}'. To load the vocabulary from a Google " + "pretrained model use `tokenizer = RoCBertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`" + ) + + with open(vocab_file, "w", encoding="utf-8") as writer: + for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning( + f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive." + " Please check that the vocabulary is not corrupted!" + ) + index = token_index + writer.write(token + "\n") + index += 1 + + with open(word_shape_file, "w", encoding="utf8") as writer: + json.dump(self.word_shape, writer, ensure_ascii=False, indent=4, separators=(", ", ": ")) + + with open(word_pronunciation_file, "w", encoding="utf8") as writer: + json.dump(self.word_pronunciation, writer, ensure_ascii=False, indent=4, separators=(", ", ": ")) + + return ( + vocab_file, + word_shape_file, + word_pronunciation_file, + ) + + +# Copied from transformers.models.bert.tokenization_bert.BasicTokenizer with BasicTokenizer->RoCBertBasicTokenizer +class RoCBertBasicTokenizer: + """ + Constructs a RoCBertBasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.). + + Args: + do_lower_case (`bool`, *optional*, defaults to `True`): + Whether or not to lowercase the input when tokenizing. + never_split (`Iterable`, *optional*): + Collection of tokens which will never be split during tokenization. Only has an effect when + `do_basic_tokenize=True` + tokenize_chinese_chars (`bool`, *optional*, defaults to `True`): + Whether or not to tokenize Chinese characters. + + This should likely be deactivated for Japanese (see this + [issue](https://github.com/huggingface/transformers/issues/328)). + strip_accents (`bool`, *optional*): + Whether or not to strip all accents. If this option is not specified, then it will be determined by the + value for `lowercase` (as in the original BERT). + do_split_on_punc (`bool`, *optional*, defaults to `True`): + In some instances we want to skip the basic punctuation splitting so that later tokenization can capture + the full context of the words, such as contractions. + """ + + def __init__( + self, + do_lower_case=True, + never_split=None, + tokenize_chinese_chars=True, + strip_accents=None, + do_split_on_punc=True, + ): + if never_split is None: + never_split = [] + self.do_lower_case = do_lower_case + self.never_split = set(never_split) + self.tokenize_chinese_chars = tokenize_chinese_chars + self.strip_accents = strip_accents + self.do_split_on_punc = do_split_on_punc + + def tokenize(self, text, never_split=None): + """ + Basic Tokenization of a piece of text. For sub-word tokenization, see WordPieceTokenizer. + + Args: + never_split (`List[str]`, *optional*) + Kept for backward compatibility purposes. Now implemented directly at the base class level (see + [`PreTrainedTokenizer.tokenize`]) List of token not to split. + """ + # union() returns a new set by concatenating the two sets. + never_split = self.never_split.union(set(never_split)) if never_split else self.never_split + text = self._clean_text(text) + + # This was added on November 1st, 2018 for the multilingual and Chinese + # models. This is also applied to the English models now, but it doesn't + # matter since the English models were not trained on any Chinese data + # and generally don't have any Chinese data in them (there are Chinese + # characters in the vocabulary because Wikipedia does have some Chinese + # words in the English Wikipedia.). + if self.tokenize_chinese_chars: + text = self._tokenize_chinese_chars(text) + # prevents treating the same character with different unicode codepoints as different characters + unicode_normalized_text = unicodedata.normalize("NFC", text) + orig_tokens = whitespace_tokenize(unicode_normalized_text) + split_tokens = [] + for token in orig_tokens: + if token not in never_split: + if self.do_lower_case: + token = token.lower() + if self.strip_accents is not False: + token = self._run_strip_accents(token) + elif self.strip_accents: + token = self._run_strip_accents(token) + split_tokens.extend(self._run_split_on_punc(token, never_split)) + + output_tokens = whitespace_tokenize(" ".join(split_tokens)) + return output_tokens + + def _run_strip_accents(self, text): + """Strips accents from a piece of text.""" + text = unicodedata.normalize("NFD", text) + output = [] + for char in text: + cat = unicodedata.category(char) + if cat == "Mn": + continue + output.append(char) + return "".join(output) + + def _run_split_on_punc(self, text, never_split=None): + """Splits punctuation on a piece of text.""" + if not self.do_split_on_punc or (never_split is not None and text in never_split): + return [text] + chars = list(text) + i = 0 + start_new_word = True + output = [] + while i < len(chars): + char = chars[i] + if _is_punctuation(char): + output.append([char]) + start_new_word = True + else: + if start_new_word: + output.append([]) + start_new_word = False + output[-1].append(char) + i += 1 + + return ["".join(x) for x in output] + + def _tokenize_chinese_chars(self, text): + """Adds whitespace around any CJK character.""" + output = [] + for char in text: + cp = ord(char) + if self._is_chinese_char(cp): + output.append(" ") + output.append(char) + output.append(" ") + else: + output.append(char) + return "".join(output) + + def _is_chinese_char(self, cp): + """Checks whether CP is the codepoint of a CJK character.""" + # This defines a "chinese character" as anything in the CJK Unicode block: + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + # + # Note that the CJK Unicode block is NOT all Japanese and Korean characters, + # despite its name. The modern Korean Hangul alphabet is a different block, + # as is Japanese Hiragana and Katakana. Those alphabets are used to write + # space-separated words, so they are not treated specially and handled + # like the all of the other languages. + if ( + (cp >= 0x4E00 and cp <= 0x9FFF) + or (cp >= 0x3400 and cp <= 0x4DBF) # + or (cp >= 0x20000 and cp <= 0x2A6DF) # + or (cp >= 0x2A700 and cp <= 0x2B73F) # + or (cp >= 0x2B740 and cp <= 0x2B81F) # + or (cp >= 0x2B820 and cp <= 0x2CEAF) # + or (cp >= 0xF900 and cp <= 0xFAFF) + or (cp >= 0x2F800 and cp <= 0x2FA1F) # + ): # + return True + + return False + + def _clean_text(self, text): + """Performs invalid character removal and whitespace cleanup on text.""" + output = [] + for char in text: + cp = ord(char) + if cp == 0 or cp == 0xFFFD or _is_control(char): + continue + if _is_whitespace(char): + output.append(" ") + else: + output.append(char) + return "".join(output) + + +# Copied from transformers.models.bert.tokenization_bert.WordpieceTokenizer with WordpieceTokenizer->RoCBertWordpieceTokenizer +class RoCBertWordpieceTokenizer: + """Runs WordPiece tokenization.""" + + def __init__(self, vocab, unk_token, max_input_chars_per_word=100): + self.vocab = vocab + self.unk_token = unk_token + self.max_input_chars_per_word = max_input_chars_per_word + + def tokenize(self, text): + """ + Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform + tokenization using the given vocabulary. + + For example, `input = "unaffable"` wil return as output `["un", "##aff", "##able"]`. + + Args: + text: A single token or whitespace separated tokens. This should have + already been passed through *BasicTokenizer*. + + Returns: + A list of wordpiece tokens. + """ + + output_tokens = [] + for token in whitespace_tokenize(text): + chars = list(token) + if len(chars) > self.max_input_chars_per_word: + output_tokens.append(self.unk_token) + continue + + is_bad = False + start = 0 + sub_tokens = [] + while start < len(chars): + end = len(chars) + cur_substr = None + while start < end: + substr = "".join(chars[start:end]) + if start > 0: + substr = "##" + substr + if substr in self.vocab: + cur_substr = substr + break + end -= 1 + if cur_substr is None: + is_bad = True + break + sub_tokens.append(cur_substr) + start = end + + if is_bad: + output_tokens.append(self.unk_token) + else: + output_tokens.extend(sub_tokens) + return output_tokens + + +__all__ = ["RoCBertTokenizer"] diff --git a/docs/transformers/src/transformers/models/roformer/__init__.py b/docs/transformers/src/transformers/models/roformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..63c1c00e572391222d4f5b33d18401d0afd699ee --- /dev/null +++ b/docs/transformers/src/transformers/models/roformer/__init__.py @@ -0,0 +1,31 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_roformer import * + from .modeling_flax_roformer import * + from .modeling_roformer import * + from .modeling_tf_roformer import * + from .tokenization_roformer import * + from .tokenization_roformer_fast import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/roformer/configuration_roformer.py b/docs/transformers/src/transformers/models/roformer/configuration_roformer.py new file mode 100644 index 0000000000000000000000000000000000000000..1852509199e6d9e7967efa4f8741f50b05e70463 --- /dev/null +++ b/docs/transformers/src/transformers/models/roformer/configuration_roformer.py @@ -0,0 +1,150 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""RoFormer model configuration""" + +from collections import OrderedDict +from typing import Mapping + +from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class RoFormerConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`RoFormerModel`]. It is used to instantiate an + RoFormer model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the RoFormer + [junnyu/roformer_chinese_base](https://huggingface.co/junnyu/roformer_chinese_base) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50000): + Vocabulary size of the RoFormer model. Defines the number of different tokens that can be represented by + the `inputs_ids` passed when calling [`RoFormerModel`] or [`TFRoFormerModel`]. + embedding_size (`int`, *optional*, defaults to None): + Dimensionality of the encoder layers and the pooler layer. Defaults to the `hidden_size` if not provided. + hidden_size (`int`, *optional*, defaults to 768): + Dimension of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (`int`, *optional*, defaults to 1536): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 1536). + type_vocab_size (`int`, *optional*, defaults to 2): + The vocabulary size of the `token_type_ids` passed when calling [`RoFormerModel`] or [`TFRoFormerModel`]. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + is_decoder (`bool`, *optional*, defaults to `False`): + Whether the model is used as a decoder or not. If `False`, the model is used as an encoder. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + rotary_value (`bool`, *optional*, defaults to `False`): + Whether or not apply rotary position embeddings on value layer. + + Example: + + ```python + >>> from transformers import RoFormerModel, RoFormerConfig + + >>> # Initializing a RoFormer junnyu/roformer_chinese_base style configuration + >>> configuration = RoFormerConfig() + + >>> # Initializing a model (with random weights) from the junnyu/roformer_chinese_base style configuration + >>> model = RoFormerModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "roformer" + + def __init__( + self, + vocab_size=50000, + embedding_size=None, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=1536, + type_vocab_size=2, + initializer_range=0.02, + layer_norm_eps=1e-12, + pad_token_id=0, + rotary_value=False, + use_cache=True, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, **kwargs) + + self.vocab_size = vocab_size + self.embedding_size = hidden_size if embedding_size is None else embedding_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.rotary_value = rotary_value + self.use_cache = use_cache + + +class RoFormerOnnxConfig(OnnxConfig): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + if self.task == "multiple-choice": + dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"} + else: + dynamic_axis = {0: "batch", 1: "sequence"} + dynamic_axis = {0: "batch", 1: "sequence"} + return OrderedDict( + [ + ("input_ids", dynamic_axis), + ("attention_mask", dynamic_axis), + ("token_type_ids", dynamic_axis), + ] + ) + + +__all__ = ["RoFormerConfig", "RoFormerOnnxConfig"] diff --git a/docs/transformers/src/transformers/models/roformer/convert_roformer_original_tf_checkpoint_to_pytorch.py b/docs/transformers/src/transformers/models/roformer/convert_roformer_original_tf_checkpoint_to_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..d227948e0ee3ea8008f2893ac21afd846c79f036 --- /dev/null +++ b/docs/transformers/src/transformers/models/roformer/convert_roformer_original_tf_checkpoint_to_pytorch.py @@ -0,0 +1,62 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert RoFormer checkpoint.""" + +import argparse + +import torch + +from transformers import RoFormerConfig, RoFormerForMaskedLM, load_tf_weights_in_roformer +from transformers.utils import logging + + +logging.set_verbosity_info() + + +def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path): + # Initialise PyTorch model + config = RoFormerConfig.from_json_file(bert_config_file) + print(f"Building PyTorch model from configuration: {config}") + model = RoFormerForMaskedLM(config) + + # Load weights from tf checkpoint + load_tf_weights_in_roformer(model, config, tf_checkpoint_path) + + # Save pytorch-model + print(f"Save PyTorch model to {pytorch_dump_path}") + torch.save(model.state_dict(), pytorch_dump_path, _use_new_zipfile_serialization=False) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." + ) + parser.add_argument( + "--bert_config_file", + default=None, + type=str, + required=True, + help=( + "The config json file corresponding to the pre-trained BERT model. \n" + "This specifies the model architecture." + ), + ) + parser.add_argument( + "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." + ) + args = parser.parse_args() + convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path) diff --git a/docs/transformers/src/transformers/models/roformer/modeling_flax_roformer.py b/docs/transformers/src/transformers/models/roformer/modeling_flax_roformer.py new file mode 100644 index 0000000000000000000000000000000000000000..f47146f16bbd7bbbf466dc6f59e09f93247a85ef --- /dev/null +++ b/docs/transformers/src/transformers/models/roformer/modeling_flax_roformer.py @@ -0,0 +1,1091 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Flax RoFormer model.""" + +from typing import Callable, Optional, Tuple + +import flax.linen as nn +import jax +import jax.numpy as jnp +import numpy as np +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze +from flax.linen.attention import dot_product_attention_weights +from flax.traverse_util import flatten_dict, unflatten_dict +from jax import lax + +from ...modeling_flax_outputs import ( + FlaxBaseModelOutput, + FlaxMaskedLMOutput, + FlaxMultipleChoiceModelOutput, + FlaxQuestionAnsweringModelOutput, + FlaxSequenceClassifierOutput, + FlaxTokenClassifierOutput, +) +from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring, overwrite_call_docstring +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging +from .configuration_roformer import RoFormerConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "junnyu/roformer_chinese_base" +_CONFIG_FOR_DOC = "RoFormerConfig" + + +ROFORMER_START_DOCSTRING = r""" + + This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading, saving and converting weights from PyTorch models) + + This model is also a + [flax.linen.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) subclass. Use it as + a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and + behavior. + + Finally, this model supports inherent JAX features such as: + + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + + Parameters: + config ([`RoFormerConfig`]): Model configuration class with all the parameters of the + model. Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. + dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): + The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and + `jax.numpy.bfloat16` (on TPUs). + + This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If + specified all the computation will be performed with the given `dtype`. + + **Note that this only specifies the dtype of the computation and does not influence the dtype of model + parameters.** + + If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and + [`~FlaxPreTrainedModel.to_bf16`]. +""" + +ROFORMER_INPUTS_DOCSTRING = r""" + Args: + input_ids (`numpy.ndarray` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`numpy.ndarray` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`numpy.ndarray` of shape `({0})`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`numpy.ndarray` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + head_mask (`numpy.ndarray` of shape `({0})`, `optional): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +# Copied from transformers.models.marian.modeling_flax_marian.create_sinusoidal_positions +def create_sinusoidal_positions(n_pos, dim): + position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]) + sentinel = dim // 2 + dim % 2 + out = np.zeros_like(position_enc) + out[:, 0:sentinel] = np.sin(position_enc[:, 0::2]) + out[:, sentinel:] = np.cos(position_enc[:, 1::2]) + + return jnp.array(out) + + +class FlaxRoFormerEmbeddings(nn.Module): + """Construct the embeddings from word and token_type embeddings.""" + + config: RoFormerConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.word_embeddings = nn.Embed( + self.config.vocab_size, + self.config.hidden_size, + embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), + ) + self.token_type_embeddings = nn.Embed( + self.config.type_vocab_size, + self.config.hidden_size, + embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), + ) + self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + + def __call__(self, input_ids, token_type_ids, attention_mask, deterministic: bool = True): + # Embed + inputs_embeds = self.word_embeddings(input_ids.astype("i4")) + token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4")) + + # Sum all embeddings + hidden_states = inputs_embeds + token_type_embeddings + + # Layer Norm + hidden_states = self.LayerNorm(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + return hidden_states + + +class FlaxRoFormerSelfAttention(nn.Module): + config: RoFormerConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self) -> None: + if self.config.hidden_size % self.config.num_attention_heads != 0: + raise ValueError( + "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads` " + " : {self.config.num_attention_heads}" + ) + + self.query = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + ) + self.key = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + ) + self.value = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + ) + + self.rotary_value = self.config.rotary_value + + def __call__( + self, + hidden_states, + attention_mask, + sinusoidal_pos, + layer_head_mask, + deterministic=True, + output_attentions: bool = False, + ): + head_dim = self.config.hidden_size // self.config.num_attention_heads + + query_states = self.query(hidden_states).reshape( + hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) + ) + value_states = self.value(hidden_states).reshape( + hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) + ) + key_states = self.key(hidden_states).reshape( + hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) + ) + + if sinusoidal_pos is not None: + if self.rotary_value: + query_states, key_states, value_states = self.apply_rotary_position_embeddings( + sinusoidal_pos, query_states, key_states, value_states + ) + else: + query_states, key_states = self.apply_rotary_position_embeddings( + sinusoidal_pos, query_states, key_states + ) + + # Convert the boolean attention mask to an attention bias. + if attention_mask is not None: + # attention mask in the form of attention bias + attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) + attention_bias = lax.select( + attention_mask > 0, + jnp.full(attention_mask.shape, 0.0).astype(self.dtype), + jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), + ) + else: + attention_bias = None + + dropout_rng = None + if not deterministic and self.config.attention_probs_dropout_prob > 0.0: + dropout_rng = self.make_rng("dropout") + + attn_weights = dot_product_attention_weights( + query_states, + key_states, + bias=attention_bias, + dropout_rng=dropout_rng, + dropout_rate=self.config.attention_probs_dropout_prob, + broadcast_dropout=True, + deterministic=deterministic, + dtype=self.dtype, + precision=None, + ) + + # Mask heads if we want to + if layer_head_mask is not None: + attn_weights = jnp.einsum("...hqk,h->...hqk", attn_weights, layer_head_mask) + + attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) + attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,)) + + outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) + return outputs + + @staticmethod + def apply_rotary_position_embeddings(sinusoidal_pos, query_layer, key_layer, value_layer=None): + sin, cos = jnp.split(sinusoidal_pos, 2, axis=-1) + sin_pos = jnp.stack([sin, sin], axis=-1).reshape(sinusoidal_pos.shape) + cos_pos = jnp.stack([cos, cos], axis=-1).reshape(sinusoidal_pos.shape) + + def rotate_layer(layer, sin_pos, cos_pos): + rotate_half_layer = jnp.stack([-layer[..., 1::2], layer[..., ::2]], axis=-1).reshape(layer.shape) + rotary_matrix_cos = jnp.einsum("bslh,...sh->bslh", layer, cos_pos) + rotary_matrix_sin = jnp.einsum("bslh,...sh->bslh", rotate_half_layer, sin_pos) + return rotary_matrix_cos + rotary_matrix_sin + + query_layer = rotate_layer(query_layer, sin_pos, cos_pos) + key_layer = rotate_layer(key_layer, sin_pos, cos_pos) + if value_layer is not None: + value_layer = rotate_layer(value_layer, sin_pos, cos_pos) + return query_layer, key_layer, value_layer + return query_layer, key_layer + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfOutput with Bert->RoFormer +class FlaxRoFormerSelfOutput(nn.Module): + config: RoFormerConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dense = nn.Dense( + self.config.hidden_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + ) + self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + + def __call__(self, hidden_states, input_tensor, deterministic: bool = True): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class FlaxRoFormerAttention(nn.Module): + config: RoFormerConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.self = FlaxRoFormerSelfAttention(self.config, dtype=self.dtype) + self.output = FlaxRoFormerSelfOutput(self.config, dtype=self.dtype) + + def __call__( + self, + hidden_states, + attention_mask, + sinusoidal_pos, + layer_head_mask, + deterministic=True, + output_attentions: bool = False, + ): + # Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length) + # FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable + # with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length) + attn_outputs = self.self( + hidden_states, + attention_mask, + sinusoidal_pos, + layer_head_mask=layer_head_mask, + deterministic=deterministic, + output_attentions=output_attentions, + ) + attn_output = attn_outputs[0] + hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_outputs[1],) + + return outputs + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertIntermediate with Bert->RoFormer +class FlaxRoFormerIntermediate(nn.Module): + config: RoFormerConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dense = nn.Dense( + self.config.intermediate_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + ) + self.activation = ACT2FN[self.config.hidden_act] + + def __call__(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.activation(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertOutput with Bert->RoFormer +class FlaxRoFormerOutput(nn.Module): + config: RoFormerConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dense = nn.Dense( + self.config.hidden_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + ) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + + def __call__(self, hidden_states, attention_output, deterministic: bool = True): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + hidden_states = self.LayerNorm(hidden_states + attention_output) + return hidden_states + + +class FlaxRoFormerLayer(nn.Module): + config: RoFormerConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.attention = FlaxRoFormerAttention(self.config, dtype=self.dtype) + self.intermediate = FlaxRoFormerIntermediate(self.config, dtype=self.dtype) + self.output = FlaxRoFormerOutput(self.config, dtype=self.dtype) + + def __call__( + self, + hidden_states, + attention_mask, + sinusiodal_pos, + layer_head_mask, + deterministic: bool = True, + output_attentions: bool = False, + ): + attention_outputs = self.attention( + hidden_states, + attention_mask, + sinusiodal_pos, + layer_head_mask=layer_head_mask, + deterministic=deterministic, + output_attentions=output_attentions, + ) + attention_output = attention_outputs[0] + + hidden_states = self.intermediate(attention_output) + hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attention_outputs[1],) + return outputs + + +class FlaxRoFormerLayerCollection(nn.Module): + config: RoFormerConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.layers = [ + FlaxRoFormerLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers) + ] + + def __call__( + self, + hidden_states, + attention_mask, + sinusoidal_pos, + head_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + all_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + # Check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + if head_mask.shape[0] != (len(self.layers)): + raise ValueError( + f"The head_mask should be specified for {len(self.layers)} layers, but it is for " + f" {head_mask.shape[0]}." + ) + + for i, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = layer( + hidden_states, + attention_mask, + sinusoidal_pos, + layer_head_mask=head_mask[i] if head_mask is not None else None, + deterministic=deterministic, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions += (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + outputs = (hidden_states,) + + if not return_dict: + return tuple(v for v in outputs if v is not None) + + return FlaxBaseModelOutput( + last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions + ) + + +class FlaxRoFormerEncoder(nn.Module): + config: RoFormerConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.embed_positions = create_sinusoidal_positions( + self.config.max_position_embeddings, self.config.hidden_size // self.config.num_attention_heads + ) + self.layer = FlaxRoFormerLayerCollection(self.config, dtype=self.dtype) + + def __call__( + self, + hidden_states, + attention_mask, + head_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + sinusoidal_pos = self.embed_positions[: hidden_states.shape[1], :] + + return self.layer( + hidden_states, + attention_mask, + sinusoidal_pos, + head_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPredictionHeadTransform with Bert->RoFormer +class FlaxRoFormerPredictionHeadTransform(nn.Module): + config: RoFormerConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.dense = nn.Dense(self.config.hidden_size, dtype=self.dtype) + self.activation = ACT2FN[self.config.hidden_act] + self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + + def __call__(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.activation(hidden_states) + return self.LayerNorm(hidden_states) + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLMPredictionHead with Bert->RoFormer +class FlaxRoFormerLMPredictionHead(nn.Module): + config: RoFormerConfig + dtype: jnp.dtype = jnp.float32 + bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros + + def setup(self): + self.transform = FlaxRoFormerPredictionHeadTransform(self.config, dtype=self.dtype) + self.decoder = nn.Dense(self.config.vocab_size, dtype=self.dtype, use_bias=False) + self.bias = self.param("bias", self.bias_init, (self.config.vocab_size,)) + + def __call__(self, hidden_states, shared_embedding=None): + hidden_states = self.transform(hidden_states) + + if shared_embedding is not None: + hidden_states = self.decoder.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) + else: + hidden_states = self.decoder(hidden_states) + + bias = jnp.asarray(self.bias, self.dtype) + hidden_states += bias + return hidden_states + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertOnlyMLMHead with Bert->RoFormer +class FlaxRoFormerOnlyMLMHead(nn.Module): + config: RoFormerConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.predictions = FlaxRoFormerLMPredictionHead(self.config, dtype=self.dtype) + + def __call__(self, hidden_states, shared_embedding=None): + hidden_states = self.predictions(hidden_states, shared_embedding=shared_embedding) + return hidden_states + + +class FlaxRoFormerClassificationHead(nn.Module): + config: RoFormerConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.dense = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + ) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + self.out_proj = nn.Dense( + self.config.num_labels, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + ) + self.activation = ACT2FN[self.config.hidden_act] + + def __call__(self, hidden_states, deterministic=True): + hidden_states = hidden_states[:, 0, :] # take token (equiv. to [CLS]) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + hidden_states = self.dense(hidden_states) + hidden_states = self.activation(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + hidden_states = self.out_proj(hidden_states) + return hidden_states + + +class FlaxRoFormerPreTrainedModel(FlaxPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = RoFormerConfig + base_model_prefix = "roformer" + module_class: nn.Module = None + + def __init__( + self, + config: RoFormerConfig, + input_shape: Tuple = (1, 1), + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, + **kwargs, + ): + module = self.module_class(config=config, dtype=dtype, **kwargs) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: + # init input tensors + input_ids = jnp.zeros(input_shape, dtype="i4") + token_type_ids = jnp.zeros_like(input_ids) + attention_mask = jnp.ones_like(input_ids) + head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads)) + + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + random_params = self.module.init( + rngs, input_ids, attention_mask, token_type_ids, head_mask, return_dict=False + )["params"] + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + + @add_start_docstrings_to_model_forward(ROFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + def __call__( + self, + input_ids, + attention_mask=None, + token_type_ids=None, + head_mask=None, + params: dict = None, + dropout_rng: jax.random.PRNGKey = None, + train: bool = False, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + # init input tensors if not passed + if token_type_ids is None: + token_type_ids = jnp.zeros_like(input_ids) + + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + + if head_mask is None: + head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads)) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + return self.module.apply( + {"params": params or self.params}, + jnp.array(input_ids, dtype="i4"), + jnp.array(attention_mask, dtype="i4"), + jnp.array(token_type_ids, dtype="i4"), + jnp.array(head_mask, dtype="i4"), + not train, + output_attentions, + output_hidden_states, + return_dict, + rngs=rngs, + ) + + +class FlaxRoFormerModule(nn.Module): + config: RoFormerConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.embeddings = FlaxRoFormerEmbeddings(self.config, dtype=self.dtype) + self.encoder = FlaxRoFormerEncoder(self.config, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids, + head_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + hidden_states = self.embeddings(input_ids, token_type_ids, attention_mask, deterministic=deterministic) + outputs = self.encoder( + hidden_states, + attention_mask, + head_mask=head_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = outputs[0] + + if not return_dict: + return (hidden_states,) + outputs[1:] + + return FlaxBaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + "The bare RoFormer Model transformer outputting raw hidden-states without any specific head on top.", + ROFORMER_START_DOCSTRING, +) +class FlaxRoFormerModel(FlaxRoFormerPreTrainedModel): + module_class = FlaxRoFormerModule + + +append_call_sample_docstring(FlaxRoFormerModel, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutput, _CONFIG_FOR_DOC) + + +class FlaxRoFormerForMaskedLMModule(nn.Module): + config: RoFormerConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.roformer = FlaxRoFormerModule(config=self.config, dtype=self.dtype) + self.cls = FlaxRoFormerOnlyMLMHead(config=self.config, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids, + head_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # Model + outputs = self.roformer( + input_ids, + attention_mask, + token_type_ids, + head_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + if self.config.tie_word_embeddings: + shared_embedding = self.roformer.variables["params"]["embeddings"]["word_embeddings"]["embedding"] + else: + shared_embedding = None + + # Compute the prediction scores + logits = self.cls(hidden_states, shared_embedding=shared_embedding) + + if not return_dict: + return (logits,) + outputs[1:] + + return FlaxMaskedLMOutput( + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings("""RoFormer Model with a `language modeling` head on top.""", ROFORMER_START_DOCSTRING) +class FlaxRoFormerForMaskedLM(FlaxRoFormerPreTrainedModel): + module_class = FlaxRoFormerForMaskedLMModule + + +append_call_sample_docstring( + FlaxRoFormerForMaskedLM, + _CHECKPOINT_FOR_DOC, + FlaxMaskedLMOutput, + _CONFIG_FOR_DOC, + mask="", +) + + +class FlaxRoFormerForSequenceClassificationModule(nn.Module): + config: RoFormerConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.roformer = FlaxRoFormerModule(config=self.config, dtype=self.dtype) + self.classifier = FlaxRoFormerClassificationHead(config=self.config, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids, + head_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # Model + outputs = self.roformer( + input_ids, + attention_mask, + token_type_ids, + head_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + logits = self.classifier(sequence_output, deterministic=deterministic) + + if not return_dict: + return (logits,) + outputs[1:] + + return FlaxSequenceClassifierOutput( + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + RoFormer Model transformer with a sequence classification/regression head on top (a linear layer on top of the + pooled output) e.g. for GLUE tasks. + """, + ROFORMER_START_DOCSTRING, +) +class FlaxRoFormerForSequenceClassification(FlaxRoFormerPreTrainedModel): + module_class = FlaxRoFormerForSequenceClassificationModule + + +append_call_sample_docstring( + FlaxRoFormerForSequenceClassification, + _CHECKPOINT_FOR_DOC, + FlaxSequenceClassifierOutput, + _CONFIG_FOR_DOC, +) + + +class FlaxRoFormerForMultipleChoiceModule(nn.Module): + config: RoFormerConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.roformer = FlaxRoFormerModule(config=self.config, dtype=self.dtype) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + self.classifier = nn.Dense(1, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids, + head_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + num_choices = input_ids.shape[1] + input_ids = input_ids.reshape(-1, input_ids.shape[-1]) + attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1]) + token_type_ids = token_type_ids.reshape(-1, token_type_ids.shape[-1]) + + # Model + outputs = self.roformer( + input_ids, + attention_mask, + token_type_ids, + head_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + # Equivalent to sequence_summary call in the PyTorch implementation + hidden_states = outputs[0] + pooled_output = hidden_states[:, -1] + pooled_output = self.dropout(pooled_output, deterministic=deterministic) + + logits = self.classifier(pooled_output) + + reshaped_logits = logits.reshape(-1, num_choices) + + if not return_dict: + return (reshaped_logits,) + outputs[2:] + + return FlaxMultipleChoiceModelOutput( + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + RoFormer Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a + softmax) e.g. for RocStories/SWAG tasks. + """, + ROFORMER_START_DOCSTRING, +) +class FlaxRoFormerForMultipleChoice(FlaxRoFormerPreTrainedModel): + module_class = FlaxRoFormerForMultipleChoiceModule + + +overwrite_call_docstring( + FlaxRoFormerForMultipleChoice, ROFORMER_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length") +) +append_call_sample_docstring( + FlaxRoFormerForMultipleChoice, + _CHECKPOINT_FOR_DOC, + FlaxMultipleChoiceModelOutput, + _CONFIG_FOR_DOC, +) + + +class FlaxRoFormerForTokenClassificationModule(nn.Module): + config: RoFormerConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.roformer = FlaxRoFormerModule(config=self.config, dtype=self.dtype) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids, + head_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # Model + outputs = self.roformer( + input_ids, + attention_mask, + token_type_ids, + head_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + logits = self.classifier(hidden_states) + + if not return_dict: + return (logits,) + outputs[1:] + + return FlaxTokenClassifierOutput( + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + RoFormer Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + ROFORMER_START_DOCSTRING, +) +class FlaxRoFormerForTokenClassification(FlaxRoFormerPreTrainedModel): + module_class = FlaxRoFormerForTokenClassificationModule + + +append_call_sample_docstring( + FlaxRoFormerForTokenClassification, + _CHECKPOINT_FOR_DOC, + FlaxTokenClassifierOutput, + _CONFIG_FOR_DOC, +) + + +class FlaxRoFormerForQuestionAnsweringModule(nn.Module): + config: RoFormerConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.roformer = FlaxRoFormerModule(config=self.config, dtype=self.dtype) + self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids, + head_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # Model + outputs = self.roformer( + input_ids, + attention_mask, + token_type_ids, + head_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + + logits = self.qa_outputs(hidden_states) + start_logits, end_logits = jnp.split(logits, self.config.num_labels, axis=-1) + start_logits = start_logits.squeeze(-1) + end_logits = end_logits.squeeze(-1) + + if not return_dict: + return (start_logits, end_logits) + outputs[1:] + + return FlaxQuestionAnsweringModelOutput( + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + RoFormer Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear + layers on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + ROFORMER_START_DOCSTRING, +) +class FlaxRoFormerForQuestionAnswering(FlaxRoFormerPreTrainedModel): + module_class = FlaxRoFormerForQuestionAnsweringModule + + +append_call_sample_docstring( + FlaxRoFormerForQuestionAnswering, + _CHECKPOINT_FOR_DOC, + FlaxQuestionAnsweringModelOutput, + _CONFIG_FOR_DOC, +) + + +__all__ = [ + "FlaxRoFormerForMaskedLM", + "FlaxRoFormerForMultipleChoice", + "FlaxRoFormerForQuestionAnswering", + "FlaxRoFormerForSequenceClassification", + "FlaxRoFormerForTokenClassification", + "FlaxRoFormerModel", + "FlaxRoFormerPreTrainedModel", +] diff --git a/docs/transformers/src/transformers/models/roformer/modeling_roformer.py b/docs/transformers/src/transformers/models/roformer/modeling_roformer.py new file mode 100644 index 0000000000000000000000000000000000000000..86465153b10e980490939bf454103362418f379a --- /dev/null +++ b/docs/transformers/src/transformers/models/roformer/modeling_roformer.py @@ -0,0 +1,1660 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch RoFormer model.""" + +import math +import os +from typing import Callable, Optional, Tuple, Union + +import numpy as np +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN, get_activation +from ...generation import GenerationMixin +from ...modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + MaskedLMOutput, + MultipleChoiceModelOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from .configuration_roformer import RoFormerConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "junnyu/roformer_chinese_base" +_CONFIG_FOR_DOC = "RoFormerConfig" + + +# Copied from transformers.models.marian.modeling_marian.MarianSinusoidalPositionalEmbedding with Marian->RoFormer +class RoFormerSinusoidalPositionalEmbedding(nn.Embedding): + """This module produces sinusoidal positional embeddings of any length.""" + + def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None) -> None: + super().__init__(num_positions, embedding_dim) + + def _init_weight(self): + """ + Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in + the 2nd half of the vector. [dim // 2:] + """ + n_pos, dim = self.weight.shape + position_enc = np.array( + [[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)] + ) + out = torch.empty(n_pos, dim, dtype=self.weight.dtype, requires_grad=False) + sentinel = dim // 2 if dim % 2 == 0 else (dim // 2) + 1 + out[:, 0:sentinel] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) + out[:, sentinel:] = torch.FloatTensor(np.cos(position_enc[:, 1::2])) + self.weight = nn.Parameter(out, requires_grad=False) + + @torch.no_grad() + def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0) -> torch.Tensor: + """`input_ids_shape` is expected to be [bsz x seqlen].""" + bsz, seq_len = input_ids_shape[:2] + positions = torch.arange( + past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device + ) + return super().forward(positions) + + +def load_tf_weights_in_roformer(model, config, tf_checkpoint_path): + """Load tf checkpoints in a pytorch model.""" + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name.replace("bert", "roformer")) + arrays.append(array) + + for name, array in zip(names, arrays): + name = name.split("/") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any( + n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] + for n in name + ): + logger.info(f"Skipping {'/'.join(name)}") + continue + pointer = model + for m_name in name: + if re.fullmatch(r"[A-Za-z]+_\d+", m_name): + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "kernel" or scope_names[0] == "gamma": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "output_bias" or scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "output_weights": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "squad": + pointer = getattr(pointer, "classifier") + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info(f"Skipping {'/'.join(name)}") + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if m_name[-11:] == "_embeddings": + pointer = getattr(pointer, "weight") + elif m_name == "kernel": + array = np.transpose(array) + try: + if not pointer.shape == array.shape: + raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array) + return model + + +class RoFormerEmbeddings(nn.Module): + """Construct the embeddings from word and token_type embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.embedding_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, input_ids=None, token_type_ids=None, inputs_embeds=None): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=inputs_embeds.device) + + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class RoFormerSelfAttention(nn.Module): + def __init__(self, config): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + self.is_decoder = config.is_decoder + self.rotary_value = config.rotary_value + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + sinusoidal_pos=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + mixed_query_layer = self.query(hidden_states) + query_layer = self.transpose_for_scores(mixed_query_layer) + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_layer = past_key_value[0] + value_layer = past_key_value[1] + attention_mask = encoder_attention_mask + elif is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + if sinusoidal_pos is not None: + if self.rotary_value: + query_layer, key_layer, value_layer = self.apply_rotary_position_embeddings( + sinusoidal_pos, query_layer, key_layer, value_layer + ) + else: + query_layer, key_layer = self.apply_rotary_position_embeddings( + sinusoidal_pos, query_layer, key_layer + ) + if past_key_value is not None: + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in RoFormerModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + @staticmethod + def apply_rotary_position_embeddings(sinusoidal_pos, query_layer, key_layer, value_layer=None): + # https://kexue.fm/archives/8265 + # sin [batch_size, num_heads, sequence_length, embed_size_per_head//2] + # cos [batch_size, num_heads, sequence_length, embed_size_per_head//2] + sin, cos = sinusoidal_pos.chunk(2, dim=-1) + # sin [θ0,θ1,θ2......θd/2-1] -> sin_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1] + sin_pos = torch.stack([sin, sin], dim=-1).reshape_as(sinusoidal_pos) + # cos [θ0,θ1,θ2......θd/2-1] -> cos_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1] + cos_pos = torch.stack([cos, cos], dim=-1).reshape_as(sinusoidal_pos) + # rotate_half_query_layer [-q1,q0,-q3,q2......,-qd-1,qd-2] + rotate_half_query_layer = torch.stack([-query_layer[..., 1::2], query_layer[..., ::2]], dim=-1).reshape_as( + query_layer + ) + query_layer = query_layer * cos_pos + rotate_half_query_layer * sin_pos + # rotate_half_key_layer [-k1,k0,-k3,k2......,-kd-1,kd-2] + rotate_half_key_layer = torch.stack([-key_layer[..., 1::2], key_layer[..., ::2]], dim=-1).reshape_as(key_layer) + key_layer = key_layer * cos_pos + rotate_half_key_layer * sin_pos + if value_layer is not None: + # rotate_half_value_layer [-v1,v0,-v3,v2......,-vd-1,vd-2] + rotate_half_value_layer = torch.stack([-value_layer[..., 1::2], value_layer[..., ::2]], dim=-1).reshape_as( + value_layer + ) + value_layer = value_layer * cos_pos + rotate_half_value_layer * sin_pos + return query_layer, key_layer, value_layer + return query_layer, key_layer + + +# Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->RoFormer +class RoFormerSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class RoFormerAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.self = RoFormerSelfAttention(config) + self.output = RoFormerSelfOutput(config) + self.pruned_heads = set() + + # Copied from transformers.models.bert.modeling_bert.BertAttention.prune_heads + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + # End Copy + def forward( + self, + hidden_states, + attention_mask=None, + sinusoidal_pos=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + self_outputs = self.self( + hidden_states, + attention_mask, + sinusoidal_pos, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->RoFormer +class RoFormerIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->RoFormer +class RoFormerOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class RoFormerLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = RoFormerAttention(config) + self.is_decoder = config.is_decoder + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + if not self.is_decoder: + raise ValueError(f"{self} should be used as a decoder model if cross attention is added") + self.crossattention = RoFormerAttention(config) + self.intermediate = RoFormerIntermediate(config) + self.output = RoFormerOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + sinusoidal_pos=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + sinusoidal_pos, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + + # if decoder, the last output is tuple of self-attn cache + if self.is_decoder: + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + else: + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + cross_attn_present_key_value = None + if self.is_decoder and encoder_hidden_states is not None: + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention " + "layers by setting `config.add_cross_attention=True`" + ) + + # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + sinusoidal_pos, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + cross_attn_past_key_value, + output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + + # add cross-attn cache to positions 3,4 of present_key_value tuple + cross_attn_present_key_value = cross_attention_outputs[-1] + present_key_value = present_key_value + cross_attn_present_key_value + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output + ) + outputs = (layer_output,) + outputs + + # if decoder, return the attn key/values as the last output + if self.is_decoder: + outputs = outputs + (present_key_value,) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class RoFormerEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.embed_positions = RoFormerSinusoidalPositionalEmbedding( + config.max_position_embeddings, config.hidden_size // config.num_attention_heads + ) + self.layer = nn.ModuleList([RoFormerLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + ): + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + # [sequence_length, embed_size_per_head] -> [batch_size, num_heads, sequence_length, embed_size_per_head] + sinusoidal_pos = self.embed_positions(hidden_states.shape[:-1], past_key_values_length)[None, None, :, :] + + next_decoder_cache = () if use_cache else None + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.__call__, + hidden_states, + attention_mask, + sinusoidal_pos, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + sinusoidal_pos, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +# Copied from transformers.models.xlm.modeling_xlm.XLMSequenceSummary with XLM->RoFormer +class RoFormerSequenceSummary(nn.Module): + r""" + Compute a single vector summary of a sequence hidden states. + + Args: + config ([`RoFormerConfig`]): + The config used by the model. Relevant arguments in the config class of the model are (refer to the actual + config class of your model for the default values it uses): + + - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are: + + - `"last"` -- Take the last token hidden state (like XLNet) + - `"first"` -- Take the first token hidden state (like Bert) + - `"mean"` -- Take the mean of all tokens hidden states + - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2) + - `"attn"` -- Not implemented now, use multi-head attention + + - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction. + - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes + (otherwise to `config.hidden_size`). + - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output, + another string or `None` will add no activation. + - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation. + - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation. + """ + + def __init__(self, config: RoFormerConfig): + super().__init__() + + self.summary_type = getattr(config, "summary_type", "last") + if self.summary_type == "attn": + # We should use a standard multi-head attention module with absolute positional embedding for that. + # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276 + # We can probably just use the multi-head attention module of PyTorch >=1.1.0 + raise NotImplementedError + + self.summary = nn.Identity() + if hasattr(config, "summary_use_proj") and config.summary_use_proj: + if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0: + num_classes = config.num_labels + else: + num_classes = config.hidden_size + self.summary = nn.Linear(config.hidden_size, num_classes) + + activation_string = getattr(config, "summary_activation", None) + self.activation: Callable = get_activation(activation_string) if activation_string else nn.Identity() + + self.first_dropout = nn.Identity() + if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0: + self.first_dropout = nn.Dropout(config.summary_first_dropout) + + self.last_dropout = nn.Identity() + if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0: + self.last_dropout = nn.Dropout(config.summary_last_dropout) + + def forward( + self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None + ) -> torch.FloatTensor: + """ + Compute a single vector summary of a sequence hidden states. + + Args: + hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`): + The hidden states of the last layer. + cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*): + Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token. + + Returns: + `torch.FloatTensor`: The summary of the sequence hidden states. + """ + if self.summary_type == "last": + output = hidden_states[:, -1] + elif self.summary_type == "first": + output = hidden_states[:, 0] + elif self.summary_type == "mean": + output = hidden_states.mean(dim=1) + elif self.summary_type == "cls_index": + if cls_index is None: + cls_index = torch.full_like( + hidden_states[..., :1, :], + hidden_states.shape[-2] - 1, + dtype=torch.long, + ) + else: + cls_index = cls_index.unsqueeze(-1).unsqueeze(-1) + cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),)) + # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states + output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size) + elif self.summary_type == "attn": + raise NotImplementedError + + output = self.first_dropout(output) + output = self.summary(output) + output = self.activation(output) + output = self.last_dropout(output) + + return output + + +class RoFormerPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.embedding_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class RoFormerLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = RoFormerPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.embedding_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def _tie_weights(self) -> None: + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertOnlyMLMHead with Bert->RoFormer +class RoFormerOnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = RoFormerLMPredictionHead(config) + + def forward(self, sequence_output: torch.Tensor) -> torch.Tensor: + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class RoFormerPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = RoFormerConfig + load_tf_weights = load_tf_weights_in_roformer + base_model_prefix = "roformer" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, RoFormerSinusoidalPositionalEmbedding): + module._init_weight() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, RoFormerLMPredictionHead): + module.bias.data.zero_() + + +ROFORMER_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use + it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`RoFormerConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +ROFORMER_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert *input_ids* indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare RoFormer Model transformer outputting raw hidden-states without any specific head on top.", + ROFORMER_START_DOCSTRING, +) +class RoFormerModel(RoFormerPreTrainedModel): + """ + + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in [Attention is + all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + + To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set + to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and + `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass. + """ + + def __init__(self, config): + super().__init__(config) + self.config = config + self.embeddings = RoFormerEmbeddings(config) + + if config.embedding_size != config.hidden_size: + self.embeddings_project = nn.Linear(config.embedding_size, config.hidden_size) + + self.encoder = RoFormerEncoder(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(ROFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPastAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[BaseModelOutputWithPastAndCrossAttentions, Tuple[torch.Tensor]]: + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.config.is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings( + input_ids=input_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds + ) + if hasattr(self, "embeddings_project"): + embedding_output = self.embeddings_project(embedding_output) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + + if not return_dict: + return (sequence_output,) + encoder_outputs[1:] + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=sequence_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +@add_start_docstrings("""RoFormer Model with a `language modeling` head on top.""", ROFORMER_START_DOCSTRING) +class RoFormerForMaskedLM(RoFormerPreTrainedModel): + _tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"] + + def __init__(self, config): + super().__init__(config) + + if config.is_decoder: + logger.warning( + "If you want to use `RoFormerForMaskedLM` make sure `config.is_decoder=False` for " + "bi-directional self-attention." + ) + + self.roformer = RoFormerModel(config) + self.cls = RoFormerOnlyMLMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias + + @add_start_docstrings_to_model_forward(ROFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MaskedLMOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[MaskedLMOutput, Tuple[torch.Tensor]]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.roformer( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs): + input_shape = input_ids.shape + effective_batch_size = input_shape[0] + + # add a dummy token + assert self.config.pad_token_id is not None, "The PAD token should be defined for generation" + attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1) + dummy_token = torch.full( + (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device + ) + input_ids = torch.cat([input_ids, dummy_token], dim=1) + + return {"input_ids": input_ids, "attention_mask": attention_mask} + + +@add_start_docstrings( + """RoFormer Model with a `language modeling` head on top for CLM fine-tuning.""", ROFORMER_START_DOCSTRING +) +class RoFormerForCausalLM(RoFormerPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"] + + def __init__(self, config): + super().__init__(config) + + if not config.is_decoder: + logger.warning("If you want to use `RoFormerForCausalLM` as a standalone, add `is_decoder=True.`") + + self.roformer = RoFormerModel(config) + self.cls = RoFormerOnlyMLMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias + + @add_start_docstrings_to_model_forward(ROFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[CausalLMOutputWithCrossAttentions, Tuple[torch.Tensor]]: + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are + ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, RoFormerForCausalLM, RoFormerConfig + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("junnyu/roformer_chinese_base") + >>> config = RoFormerConfig.from_pretrained("junnyu/roformer_chinese_base") + >>> config.is_decoder = True + >>> model = RoFormerForCausalLM.from_pretrained("junnyu/roformer_chinese_base", config=config) + + >>> inputs = tokenizer("今天天气非常好。", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> prediction_logits = outputs.logits + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.roformer( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + lm_loss = None + if labels is not None: + lm_loss = self.loss_function( + prediction_scores, + labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + if not return_dict: + output = (prediction_scores,) + outputs[1:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def _reorder_cache(self, past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) + + layer_past[2:], + ) + return reordered_past + + +class RoFormerClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.out_proj = nn.Linear(config.hidden_size, config.num_labels) + + self.config = config + + def forward(self, features, **kwargs): + x = features[:, 0, :] # take token (equiv. to [CLS]) + x = self.dropout(x) + x = self.dense(x) + x = ACT2FN[self.config.hidden_act](x) + x = self.dropout(x) + x = self.out_proj(x) + return x + + +@add_start_docstrings( + """ + RoFormer Model transformer with a sequence classification/regression head on top (a linear layer on top of the + pooled output) e.g. for GLUE tasks. + """, + ROFORMER_START_DOCSTRING, +) +class RoFormerForSequenceClassification(RoFormerPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.roformer = RoFormerModel(config) + self.classifier = RoFormerClassificationHead(config) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(ROFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=SequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[SequenceClassifierOutput, Tuple[torch.Tensor]]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.roformer( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + RoFormer Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a + softmax) e.g. for RocStories/SWAG tasks. + """, + ROFORMER_START_DOCSTRING, +) +class RoFormerForMultipleChoice(RoFormerPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.roformer = RoFormerModel(config) + self.sequence_summary = RoFormerSequenceSummary(config) + self.classifier = nn.Linear(config.hidden_size, 1) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward( + ROFORMER_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length") + ) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MultipleChoiceModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[MultipleChoiceModelOutput, Tuple[torch.Tensor]]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., + num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See + `input_ids` above) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] + + input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None + attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None + + inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) + + outputs = self.roformer( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + pooled_output = self.sequence_summary(sequence_output) + logits = self.classifier(pooled_output) + reshaped_logits = logits.view(-1, num_choices) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + + if not return_dict: + output = (reshaped_logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return MultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + RoFormer Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + ROFORMER_START_DOCSTRING, +) +class RoFormerForTokenClassification(RoFormerPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.roformer = RoFormerModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(ROFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[TokenClassifierOutput, Tuple[torch.Tensor]]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.roformer( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + RoFormer Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear + layers on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + ROFORMER_START_DOCSTRING, +) +class RoFormerForQuestionAnswering(RoFormerPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + config.num_labels = 2 + self.num_labels = config.num_labels + + self.roformer = RoFormerModel(config) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(ROFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=QuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[QuestionAnsweringModelOutput, Tuple[torch.Tensor]]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.roformer( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1) + end_logits = end_logits.squeeze(-1) + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[1:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "RoFormerForCausalLM", + "RoFormerForMaskedLM", + "RoFormerForMultipleChoice", + "RoFormerForQuestionAnswering", + "RoFormerForSequenceClassification", + "RoFormerForTokenClassification", + "RoFormerLayer", + "RoFormerModel", + "RoFormerPreTrainedModel", + "load_tf_weights_in_roformer", +] diff --git a/docs/transformers/src/transformers/models/roformer/modeling_tf_roformer.py b/docs/transformers/src/transformers/models/roformer/modeling_tf_roformer.py new file mode 100644 index 0000000000000000000000000000000000000000..738f8e67e9b9886fc56c71f128ec38c5affc5d1b --- /dev/null +++ b/docs/transformers/src/transformers/models/roformer/modeling_tf_roformer.py @@ -0,0 +1,1547 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TF 2.0 RoFormer model.""" + +from __future__ import annotations + +import math +from typing import Dict, Optional, Tuple, Union + +import numpy as np +import tensorflow as tf + +from ...activations_tf import get_tf_activation +from ...modeling_tf_outputs import ( + TFBaseModelOutput, + TFBaseModelOutputWithPooling, + TFCausalLMOutput, + TFMaskedLMOutput, + TFMultipleChoiceModelOutput, + TFQuestionAnsweringModelOutput, + TFSequenceClassifierOutput, + TFTokenClassifierOutput, +) +from ...modeling_tf_utils import ( + TFCausalLanguageModelingLoss, + TFMaskedLanguageModelingLoss, + TFModelInputType, + TFMultipleChoiceLoss, + TFPreTrainedModel, + TFQuestionAnsweringLoss, + TFSequenceClassificationLoss, + TFSequenceSummary, + TFTokenClassificationLoss, + get_initializer, + keras, + keras_serializable, + unpack_inputs, +) +from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax +from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, +) +from .configuration_roformer import RoFormerConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "junnyu/roformer_chinese_base" +_CONFIG_FOR_DOC = "RoFormerConfig" + + +class TFRoFormerSinusoidalPositionalEmbedding(keras.layers.Layer): + """This module produces sinusoidal positional embeddings of any length.""" + + def __init__(self, num_positions: int, embedding_dim: int, **kwargs): + super().__init__(**kwargs) + + if embedding_dim % 2 != 0: + raise NotImplementedError(f"odd embedding_dim {embedding_dim} not supported") + + self.embedding_dim = embedding_dim + self.num_positions = num_positions + + def build(self, input_shape: tf.TensorShape): + """ + Build shared token embedding layer Shared weights logic adapted from + https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24 + """ + + weight = self._init_weight(self.num_positions, self.embedding_dim) + + self.weight = self.add_weight( + name="embeddings", + shape=[self.num_positions, self.embedding_dim], + ) + weight = tf.cast(weight, dtype=self.weight.dtype) + + self.weight.assign(weight) + + super().build(input_shape) + + @staticmethod + def _init_weight(n_pos: int, dim: int): + """ + Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in + the 2nd half of the vector. [dim // 2:] + """ + position_enc = np.array( + [[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)] + ) + table = np.zeros_like(position_enc) + # index 0 is all zero + table[:, 0 : dim // 2] = np.sin(position_enc[:, 0::2]) + table[:, dim // 2 :] = np.cos(position_enc[:, 1::2]) + # convert to tensor + table = tf.convert_to_tensor(table) + tf.stop_gradient(table) + return table + + def call(self, input_shape: tf.TensorShape, past_key_values_length: int = 0): + """Input is expected to be of size [bsz x seqlen].""" + bsz, seq_len = input_shape[:2] + + positions = tf.range(past_key_values_length, seq_len + past_key_values_length, delta=1, name="range") + return tf.gather(self.weight, positions) + + +class TFRoFormerEmbeddings(keras.layers.Layer): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config: RoFormerConfig, **kwargs): + super().__init__(**kwargs) + + self.config = config + self.embedding_size = config.embedding_size + self.initializer_range = config.initializer_range + self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") + self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) + + def build(self, input_shape=None): + with tf.name_scope("word_embeddings"): + self.weight = self.add_weight( + name="weight", + shape=[self.config.vocab_size, self.embedding_size], + initializer=get_initializer(self.initializer_range), + ) + + with tf.name_scope("token_type_embeddings"): + self.token_type_embeddings = self.add_weight( + name="embeddings", + shape=[self.config.type_vocab_size, self.embedding_size], + initializer=get_initializer(self.initializer_range), + ) + + if self.built: + return + self.built = True + if getattr(self, "LayerNorm", None) is not None: + with tf.name_scope(self.LayerNorm.name): + self.LayerNorm.build([None, None, self.config.embedding_size]) + + def call( + self, + input_ids: Optional[tf.Tensor] = None, + token_type_ids: Optional[tf.Tensor] = None, + inputs_embeds: Optional[tf.Tensor] = None, + training: bool = False, + ) -> tf.Tensor: + """ + Applies embedding based on inputs tensor. + + + Returns: + final_embeddings (`tf.Tensor`): output embedding tensor. + """ + assert not (input_ids is None and inputs_embeds is None) + + if input_ids is not None: + check_embeddings_within_bounds(input_ids, self.config.vocab_size) + inputs_embeds = tf.gather(params=self.weight, indices=input_ids) + + input_shape = shape_list(inputs_embeds)[:-1] + + if token_type_ids is None: + token_type_ids = tf.fill(dims=input_shape, value=0) + + token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids) + final_embeddings = inputs_embeds + token_type_embeds + final_embeddings = self.LayerNorm(inputs=final_embeddings) + final_embeddings = self.dropout(inputs=final_embeddings, training=training) + + return final_embeddings + + +class TFRoFormerSelfAttention(keras.layers.Layer): + def __init__(self, config: RoFormerConfig, **kwargs): + super().__init__(**kwargs) + + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number " + f"of attention heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.sqrt_att_head_size = math.sqrt(self.attention_head_size) + + self.query = keras.layers.Dense( + units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query" + ) + self.key = keras.layers.Dense( + units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key" + ) + self.value = keras.layers.Dense( + units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value" + ) + self.dropout = keras.layers.Dropout(rate=config.attention_probs_dropout_prob) + self.rotary_value = config.rotary_value + self.config = config + + def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor: + # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size] + tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size)) + + # Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size] + return tf.transpose(tensor, perm=[0, 2, 1, 3]) + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: tf.Tensor, + sinusoidal_pos: tf.Tensor, + head_mask: tf.Tensor, + output_attentions: bool, + training: bool = False, + ) -> Tuple[tf.Tensor]: + batch_size = shape_list(hidden_states)[0] + mixed_query_layer = self.query(inputs=hidden_states) + mixed_key_layer = self.key(inputs=hidden_states) + mixed_value_layer = self.value(inputs=hidden_states) + query_layer = self.transpose_for_scores(mixed_query_layer, batch_size) + key_layer = self.transpose_for_scores(mixed_key_layer, batch_size) + value_layer = self.transpose_for_scores(mixed_value_layer, batch_size) + + if sinusoidal_pos is not None: + if self.rotary_value: + query_layer, key_layer, value_layer = self.apply_rotary_position_embeddings( + sinusoidal_pos, query_layer, key_layer, value_layer + ) + else: + query_layer, key_layer = self.apply_rotary_position_embeddings(sinusoidal_pos, query_layer, key_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + # (batch size, num_heads, seq_len_q, seq_len_k) + attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) + dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype) + attention_scores = tf.divide(attention_scores, dk) + + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in TFRoFormerModel call() function) + attention_scores = tf.add(attention_scores, attention_mask) + + # Normalize the attention scores to probabilities. + attention_probs = stable_softmax(logits=attention_scores, axis=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(inputs=attention_probs, training=training) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = tf.multiply(attention_probs, head_mask) + + attention_output = tf.matmul(attention_probs, value_layer) + attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3]) + + # (batch_size, seq_len_q, all_head_size) + attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size)) + outputs = (attention_output, attention_probs) if output_attentions else (attention_output,) + + return outputs + + @staticmethod + def apply_rotary_position_embeddings(sinusoidal_pos, query_layer, key_layer, value_layer=None): + # https://kexue.fm/archives/8265 + # sin [batch_size, num_heads, sequence_length, embed_size_per_head//2] + # cos [batch_size, num_heads, sequence_length, embed_size_per_head//2] + sin, cos = tf.split(sinusoidal_pos, num_or_size_splits=2, axis=-1) + # sin [θ0,θ1,θ2......θd/2-1]-> sin_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1] + # cos [θ0,θ1,θ2......θd/2-1]-> cos_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1] + sin_pos = tf.repeat(sin, 2, axis=-1) + cos_pos = tf.repeat(cos, 2, axis=-1) + # rotate_half_query_layer [-q1,q0,-q3,q2......,-qd-1,qd-2] + rotate_half_query_layer = tf.stack([-query_layer[..., 1::2], query_layer[..., ::2]], axis=-1) + rotate_half_query_layer = tf.reshape(rotate_half_query_layer, shape_list(query_layer)) + query_layer = query_layer * cos_pos + rotate_half_query_layer * sin_pos + # rotate_half_key_layer [-k1,k0,-k3,k2......,-kd-1,kd-2] + rotate_half_key_layer = tf.stack([-key_layer[..., 1::2], key_layer[..., ::2]], axis=-1) + rotate_half_key_layer = tf.reshape(rotate_half_key_layer, shape_list(key_layer)) + key_layer = key_layer * cos_pos + rotate_half_key_layer * sin_pos + if value_layer is not None: + # rotate_half_value_layer [-v1,v0,-v3,v2......,-vd-1,vd-2] + rotate_half_value_layer = tf.stack([-value_layer[..., 1::2], value_layer[..., ::2]], axis=-1) + rotate_half_value_layer = tf.reshape(rotate_half_value_layer, shape_list(value_layer)) + value_layer = value_layer * cos_pos + rotate_half_value_layer * sin_pos + return query_layer, key_layer, value_layer + return query_layer, key_layer + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "query", None) is not None: + with tf.name_scope(self.query.name): + self.query.build([None, None, self.config.hidden_size]) + if getattr(self, "key", None) is not None: + with tf.name_scope(self.key.name): + self.key.build([None, None, self.config.hidden_size]) + if getattr(self, "value", None) is not None: + with tf.name_scope(self.value.name): + self.value.build([None, None, self.config.hidden_size]) + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfOutput with Bert->RoFormer +class TFRoFormerSelfOutput(keras.layers.Layer): + def __init__(self, config: RoFormerConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = keras.layers.Dense( + units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") + self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) + self.config = config + + def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: + hidden_states = self.dense(inputs=hidden_states) + hidden_states = self.dropout(inputs=hidden_states, training=training) + hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor) + + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + if getattr(self, "LayerNorm", None) is not None: + with tf.name_scope(self.LayerNorm.name): + self.LayerNorm.build([None, None, self.config.hidden_size]) + + +class TFRoFormerAttention(keras.layers.Layer): + def __init__(self, config: RoFormerConfig, **kwargs): + super().__init__(**kwargs) + + self.self_attention = TFRoFormerSelfAttention(config, name="self") + self.dense_output = TFRoFormerSelfOutput(config, name="output") + + def prune_heads(self, heads): + raise NotImplementedError + + def call( + self, + input_tensor: tf.Tensor, + attention_mask: tf.Tensor, + sinusoidal_pos: tf.Tensor, + head_mask: tf.Tensor, + output_attentions: bool, + training: bool = False, + ) -> Tuple[tf.Tensor]: + self_outputs = self.self_attention( + hidden_states=input_tensor, + attention_mask=attention_mask, + sinusoidal_pos=sinusoidal_pos, + head_mask=head_mask, + output_attentions=output_attentions, + training=training, + ) + attention_output = self.dense_output( + hidden_states=self_outputs[0], input_tensor=input_tensor, training=training + ) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "self_attention", None) is not None: + with tf.name_scope(self.self_attention.name): + self.self_attention.build(None) + if getattr(self, "dense_output", None) is not None: + with tf.name_scope(self.dense_output.name): + self.dense_output.build(None) + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertIntermediate with Bert->RoFormer +class TFRoFormerIntermediate(keras.layers.Layer): + def __init__(self, config: RoFormerConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = keras.layers.Dense( + units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = get_tf_activation(config.hidden_act) + else: + self.intermediate_act_fn = config.hidden_act + self.config = config + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + hidden_states = self.dense(inputs=hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertOutput with Bert->RoFormer +class TFRoFormerOutput(keras.layers.Layer): + def __init__(self, config: RoFormerConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = keras.layers.Dense( + units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") + self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) + self.config = config + + def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: + hidden_states = self.dense(inputs=hidden_states) + hidden_states = self.dropout(inputs=hidden_states, training=training) + hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor) + + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.intermediate_size]) + if getattr(self, "LayerNorm", None) is not None: + with tf.name_scope(self.LayerNorm.name): + self.LayerNorm.build([None, None, self.config.hidden_size]) + + +class TFRoFormerLayer(keras.layers.Layer): + def __init__(self, config: RoFormerConfig, **kwargs): + super().__init__(**kwargs) + + self.attention = TFRoFormerAttention(config, name="attention") + self.intermediate = TFRoFormerIntermediate(config, name="intermediate") + self.roformer_output = TFRoFormerOutput(config, name="output") + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: tf.Tensor, + sinusoidal_pos: tf.Tensor, + head_mask: tf.Tensor, + output_attentions: bool, + training: bool = False, + ) -> Tuple[tf.Tensor]: + attention_outputs = self.attention( + input_tensor=hidden_states, + attention_mask=attention_mask, + sinusoidal_pos=sinusoidal_pos, + head_mask=head_mask, + output_attentions=output_attentions, + training=training, + ) + attention_output = attention_outputs[0] + intermediate_output = self.intermediate(hidden_states=attention_output) + layer_output = self.roformer_output( + hidden_states=intermediate_output, input_tensor=attention_output, training=training + ) + outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "attention", None) is not None: + with tf.name_scope(self.attention.name): + self.attention.build(None) + if getattr(self, "intermediate", None) is not None: + with tf.name_scope(self.intermediate.name): + self.intermediate.build(None) + if getattr(self, "roformer_output", None) is not None: + with tf.name_scope(self.roformer_output.name): + self.roformer_output.build(None) + + +class TFRoFormerEncoder(keras.layers.Layer): + def __init__(self, config: RoFormerConfig, **kwargs): + super().__init__(**kwargs) + self.embed_positions = TFRoFormerSinusoidalPositionalEmbedding( + config.max_position_embeddings, + config.hidden_size // config.num_attention_heads, + name="embed_positions", + ) + self.layer = [TFRoFormerLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)] + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: tf.Tensor, + head_mask: tf.Tensor, + output_attentions: bool, + output_hidden_states: bool, + return_dict: bool, + training: bool = False, + ) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]: + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + # [sequence_length, embed_size_per_head] -> [batch_size, num_heads, sequence_length, embed_size_per_head] + sinusoidal_pos = self.embed_positions(shape_list(hidden_states)[:-1])[None, None, :, :] + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_outputs = layer_module( + hidden_states=hidden_states, + attention_mask=attention_mask, + sinusoidal_pos=sinusoidal_pos, + head_mask=head_mask[i], + output_attentions=output_attentions, + training=training, + ) + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) + + return TFBaseModelOutput( + last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "embed_positions", None) is not None: + with tf.name_scope(self.embed_positions.name): + self.embed_positions.build(None) + if getattr(self, "layer", None) is not None: + for layer in self.layer: + with tf.name_scope(layer.name): + layer.build(None) + + +class TFRoFormerPredictionHeadTransform(keras.layers.Layer): + def __init__(self, config: RoFormerConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = keras.layers.Dense( + units=config.embedding_size, + kernel_initializer=get_initializer(config.initializer_range), + name="dense", + ) + + if isinstance(config.hidden_act, str): + self.transform_act_fn = get_tf_activation(config.hidden_act) + else: + self.transform_act_fn = config.hidden_act + + self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") + self.config = config + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + hidden_states = self.dense(inputs=hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(inputs=hidden_states) + + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + if getattr(self, "LayerNorm", None) is not None: + with tf.name_scope(self.LayerNorm.name): + self.LayerNorm.build([None, None, self.config.embedding_size]) + + +class TFRoFormerLMPredictionHead(keras.layers.Layer): + def __init__(self, config: RoFormerConfig, input_embeddings: keras.layers.Layer, **kwargs): + super().__init__(**kwargs) + + self.config = config + self.embedding_size = config.embedding_size + + self.transform = TFRoFormerPredictionHeadTransform(config, name="transform") + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.input_embeddings = input_embeddings + + def build(self, input_shape=None): + self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias") + + if self.built: + return + self.built = True + if getattr(self, "transform", None) is not None: + with tf.name_scope(self.transform.name): + self.transform.build(None) + + def get_output_embeddings(self) -> keras.layers.Layer: + return self.input_embeddings + + def set_output_embeddings(self, value: tf.Variable): + self.input_embeddings.weight = value + self.input_embeddings.vocab_size = shape_list(value)[0] + + def get_bias(self) -> Dict[str, tf.Variable]: + return {"bias": self.bias} + + def set_bias(self, value: tf.Variable): + self.bias = value["bias"] + self.config.vocab_size = shape_list(value["bias"])[0] + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + hidden_states = self.transform(hidden_states=hidden_states) + seq_length = shape_list(hidden_states)[1] + hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.embedding_size]) + hidden_states = tf.matmul(a=hidden_states, b=self.input_embeddings.weight, transpose_b=True) + hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.config.vocab_size]) + hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.bias) + + return hidden_states + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertMLMHead with Bert->RoFormer +class TFRoFormerMLMHead(keras.layers.Layer): + def __init__(self, config: RoFormerConfig, input_embeddings: keras.layers.Layer, **kwargs): + super().__init__(**kwargs) + + self.predictions = TFRoFormerLMPredictionHead(config, input_embeddings, name="predictions") + + def call(self, sequence_output: tf.Tensor) -> tf.Tensor: + prediction_scores = self.predictions(hidden_states=sequence_output) + + return prediction_scores + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "predictions", None) is not None: + with tf.name_scope(self.predictions.name): + self.predictions.build(None) + + +@keras_serializable +class TFRoFormerMainLayer(keras.layers.Layer): + config_class = RoFormerConfig + + def __init__(self, config: RoFormerConfig, add_pooling_layer: bool = True, **kwargs): + super().__init__(**kwargs) + + self.config = config + + self.embeddings = TFRoFormerEmbeddings(config, name="embeddings") + if config.embedding_size != config.hidden_size: + self.embeddings_project = keras.layers.Dense(config.hidden_size, name="embeddings_project") + + self.encoder = TFRoFormerEncoder(config, name="encoder") + + def get_input_embeddings(self) -> keras.layers.Layer: + return self.embeddings + + def set_input_embeddings(self, value: tf.Variable): + self.embeddings.weight = value + self.embeddings.vocab_size = shape_list(value)[0] + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + raise NotImplementedError + + @unpack_inputs + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: bool = False, + ) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]: + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = shape_list(input_ids) + elif inputs_embeds is not None: + input_shape = shape_list(inputs_embeds)[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if attention_mask is None: + attention_mask = tf.fill(dims=input_shape, value=1) + + if token_type_ids is None: + token_type_ids = tf.fill(dims=input_shape, value=0) + + embedding_output = self.embeddings( + input_ids=input_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + training=training, + ) + if hasattr(self, "embeddings_project"): + embedding_output = self.embeddings_project(embedding_output, training=training) + + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + extended_attention_mask = tf.reshape(attention_mask, (input_shape[0], 1, 1, input_shape[1])) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = tf.cast(extended_attention_mask, dtype=embedding_output.dtype) + one_cst = tf.constant(1.0, dtype=embedding_output.dtype) + ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype) + extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + if head_mask is not None: + raise NotImplementedError + else: + head_mask = [None] * self.config.num_hidden_layers + + encoder_outputs = self.encoder( + hidden_states=embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + sequence_output = encoder_outputs[0] + + if not return_dict: + return (sequence_output,) + encoder_outputs[1:] + + return TFBaseModelOutput( + last_hidden_state=sequence_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "embeddings", None) is not None: + with tf.name_scope(self.embeddings.name): + self.embeddings.build(None) + if getattr(self, "encoder", None) is not None: + with tf.name_scope(self.encoder.name): + self.encoder.build(None) + if getattr(self, "embeddings_project", None) is not None: + with tf.name_scope(self.embeddings_project.name): + self.embeddings_project.build([None, None, self.config.embedding_size]) + + +class TFRoFormerPreTrainedModel(TFPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = RoFormerConfig + base_model_prefix = "roformer" + + +ROFORMER_START_DOCSTRING = r""" + + This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it + as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and + behavior. + + + + TensorFlow models and layers in `transformers` accept two formats as input: + + - having all inputs as keyword arguments (like PyTorch models), or + - having all inputs as a list, tuple or dict in the first positional argument. + + The reason the second format is supported is that Keras methods prefer this format when passing inputs to models + and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just + pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second + format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with + the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first + positional argument: + + - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` + - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: + `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` + - a dictionary with one or several input Tensors associated to the input names given in the docstring: + `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` + + Note that when creating models and layers with + [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry + about any of this, as you can just pass inputs like you would to any other Python function! + + + + Args: + config ([`RoFormerConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +ROFORMER_INPUTS_DOCSTRING = r""" + Args: + input_ids (`np.ndarray`, `tf.Tensor`, `List[tf.Tensor]` ``Dict[str, tf.Tensor]` or `Dict[str, np.ndarray]` and each example must have the shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and + [`PreTrainedTokenizer.encode`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + head_mask (`np.ndarray` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`np.ndarray` or `tf.Tensor` of shape `({0}, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the + config will be used instead. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. This argument can be used only in eager mode, in graph mode the value in the config will be + used instead. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in + eager mode, in graph mode the value will always be set to True. + training (`bool`, *optional*, defaults to `False``): + Whether or not to use the model in training mode (some modules like dropout modules have different + behaviors between training and evaluation). +""" + + +@add_start_docstrings( + "The bare RoFormer Model transformer outputing raw hidden-states without any specific head on top.", + ROFORMER_START_DOCSTRING, +) +class TFRoFormerModel(TFRoFormerPreTrainedModel): + def __init__(self, config: RoFormerConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.roformer = TFRoFormerMainLayer(config, name="roformer") + + @unpack_inputs + @add_start_docstrings_to_model_forward(ROFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFBaseModelOutputWithPooling, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: Optional[bool] = False, + ) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]: + outputs = self.roformer( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "roformer", None) is not None: + with tf.name_scope(self.roformer.name): + self.roformer.build(None) + + +@add_start_docstrings("""RoFormer Model with a `language modeling` head on top.""", ROFORMER_START_DOCSTRING) +class TFRoFormerForMaskedLM(TFRoFormerPreTrainedModel, TFMaskedLanguageModelingLoss): + def __init__(self, config: RoFormerConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + if config.is_decoder: + logger.warning( + "If you want to use `TFRoFormerForMaskedLM` make sure `config.is_decoder=False` for " + "bi-directional self-attention." + ) + + self.roformer = TFRoFormerMainLayer(config, name="roformer") + self.mlm = TFRoFormerMLMHead(config, input_embeddings=self.roformer.embeddings, name="mlm___cls") + + def get_lm_head(self) -> keras.layers.Layer: + return self.mlm.predictions + + @unpack_inputs + @add_start_docstrings_to_model_forward(ROFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFMaskedLMOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: np.ndarray | tf.Tensor | None = None, + training: Optional[bool] = False, + ) -> Union[TFMaskedLMOutput, Tuple[tf.Tensor]]: + r""" + labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + """ + outputs = self.roformer( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + sequence_output = outputs[0] + prediction_scores = self.mlm(sequence_output=sequence_output, training=training) + loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=prediction_scores) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFMaskedLMOutput( + loss=loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "roformer", None) is not None: + with tf.name_scope(self.roformer.name): + self.roformer.build(None) + if getattr(self, "mlm", None) is not None: + with tf.name_scope(self.mlm.name): + self.mlm.build(None) + + +@add_start_docstrings( + """RoFormer Model with a `language modeling` head on top for CLM fine-tuning.""", ROFORMER_START_DOCSTRING +) +class TFRoFormerForCausalLM(TFRoFormerPreTrainedModel, TFCausalLanguageModelingLoss): + def __init__(self, config: RoFormerConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + if not config.is_decoder: + logger.warning("If you want to use `TFRoFormerForCausalLM` as a standalone, add `is_decoder=True.`") + + self.roformer = TFRoFormerMainLayer(config, name="roformer") + self.mlm = TFRoFormerMLMHead(config, input_embeddings=self.roformer.embeddings, name="mlm___cls") + + def get_lm_head(self) -> keras.layers.Layer: + return self.mlm.predictions + + @unpack_inputs + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFCausalLMOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: np.ndarray | tf.Tensor | None = None, + training: Optional[bool] = False, + ) -> Union[TFCausalLMOutput, Tuple[tf.Tensor]]: + r""" + labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the cross entropy classification loss. Indices should be in `[0, ..., + config.vocab_size - 1]`. + """ + outputs = self.roformer( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + sequence_output = outputs[0] + logits = self.mlm(sequence_output=sequence_output, training=training) + loss = None + + if labels is not None: + # shift labels to the left and cut last logit token + shifted_logits = logits[:, :-1] + labels = labels[:, 1:] + loss = self.hf_compute_loss(labels=labels, logits=shifted_logits) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFCausalLMOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "roformer", None) is not None: + with tf.name_scope(self.roformer.name): + self.roformer.build(None) + if getattr(self, "mlm", None) is not None: + with tf.name_scope(self.mlm.name): + self.mlm.build(None) + + +class TFRoFormerClassificationHead(keras.layers.Layer): + """Head for sentence-level classification tasks.""" + + def __init__(self, config: RoFormerConfig, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + self.dense = keras.layers.Dense( + units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) + self.out_proj = keras.layers.Dense( + units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="out_proj" + ) + + if isinstance(config.hidden_act, str): + self.classifier_act_fn = get_tf_activation(config.hidden_act) + else: + self.classifier_act_fn = config.hidden_act + self.config = config + + def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor: + hidden_states = hidden_states[:, 0, :] # take token (equiv. to [CLS]) + hidden_states = self.dropout(inputs=hidden_states, training=training) + hidden_states = self.dense(inputs=hidden_states) + hidden_states = self.classifier_act_fn(hidden_states) + hidden_states = self.dropout(inputs=hidden_states, training=training) + hidden_states = self.out_proj(hidden_states) + + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + if getattr(self, "out_proj", None) is not None: + with tf.name_scope(self.out_proj.name): + self.out_proj.build([None, None, self.config.hidden_size]) + + +@add_start_docstrings( + """ + RoFormer Model transformer with a sequence classification/regression head on top e.g., for GLUE tasks. + """, + ROFORMER_START_DOCSTRING, +) +class TFRoFormerForSequenceClassification(TFRoFormerPreTrainedModel, TFSequenceClassificationLoss): + def __init__(self, config: RoFormerConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.num_labels = config.num_labels + + self.roformer = TFRoFormerMainLayer(config, name="roformer") + self.classifier = TFRoFormerClassificationHead(config, name="classifier") + + @unpack_inputs + @add_start_docstrings_to_model_forward(ROFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFSequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: np.ndarray | tf.Tensor | None = None, + training: Optional[bool] = False, + ) -> Union[TFSequenceClassifierOutput, Tuple[tf.Tensor]]: + r""" + labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + outputs = self.roformer( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + logits = self.classifier(hidden_states=outputs[0], training=training) + loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits) + + if not return_dict: + output = (logits,) + outputs[1:] + + return ((loss,) + output) if loss is not None else output + + return TFSequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "roformer", None) is not None: + with tf.name_scope(self.roformer.name): + self.roformer.build(None) + if getattr(self, "classifier", None) is not None: + with tf.name_scope(self.classifier.name): + self.classifier.build(None) + + +@add_start_docstrings( + """ + RoFormer Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a + softmax) e.g. for RocStories/SWAG tasks. + """, + ROFORMER_START_DOCSTRING, +) +class TFRoFormerForMultipleChoice(TFRoFormerPreTrainedModel, TFMultipleChoiceLoss): + def __init__(self, config: RoFormerConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.roformer = TFRoFormerMainLayer(config, name="roformer") + self.sequence_summary = TFSequenceSummary(config, config.initializer_range, name="sequence_summary") + self.classifier = keras.layers.Dense( + units=1, kernel_initializer=get_initializer(config.initializer_range), name="classifier" + ) + self.config = config + + @unpack_inputs + @add_start_docstrings_to_model_forward( + ROFORMER_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length") + ) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFMultipleChoiceModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: np.ndarray | tf.Tensor | None = None, + training: Optional[bool] = False, + ) -> Union[TFMultipleChoiceModelOutput, Tuple[tf.Tensor]]: + r""" + labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` + where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above) + """ + if input_ids is not None: + num_choices = shape_list(input_ids)[1] + seq_length = shape_list(input_ids)[2] + else: + num_choices = shape_list(inputs_embeds)[1] + seq_length = shape_list(inputs_embeds)[2] + + flat_input_ids = tf.reshape(tensor=input_ids, shape=(-1, seq_length)) if input_ids is not None else None + flat_attention_mask = ( + tf.reshape(tensor=attention_mask, shape=(-1, seq_length)) if attention_mask is not None else None + ) + flat_token_type_ids = ( + tf.reshape(tensor=token_type_ids, shape=(-1, seq_length)) if token_type_ids is not None else None + ) + flat_inputs_embeds = ( + tf.reshape(tensor=inputs_embeds, shape=(-1, seq_length, shape_list(inputs_embeds)[3])) + if inputs_embeds is not None + else None + ) + outputs = self.roformer( + input_ids=flat_input_ids, + attention_mask=flat_attention_mask, + token_type_ids=flat_token_type_ids, + head_mask=head_mask, + inputs_embeds=flat_inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + logits = self.sequence_summary(inputs=outputs[0], training=training) + logits = self.classifier(inputs=logits) + reshaped_logits = tf.reshape(tensor=logits, shape=(-1, num_choices)) + loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=reshaped_logits) + + if not return_dict: + output = (reshaped_logits,) + outputs[1:] + + return ((loss,) + output) if loss is not None else output + + return TFMultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "roformer", None) is not None: + with tf.name_scope(self.roformer.name): + self.roformer.build(None) + if getattr(self, "sequence_summary", None) is not None: + with tf.name_scope(self.sequence_summary.name): + self.sequence_summary.build(None) + if getattr(self, "classifier", None) is not None: + with tf.name_scope(self.classifier.name): + self.classifier.build([None, None, self.config.hidden_size]) + + +@add_start_docstrings( + """ + RoFormer Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + ROFORMER_START_DOCSTRING, +) +class TFRoFormerForTokenClassification(TFRoFormerPreTrainedModel, TFTokenClassificationLoss): + def __init__(self, config: RoFormerConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.num_labels = config.num_labels + + self.roformer = TFRoFormerMainLayer(config, name="roformer") + self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) + self.classifier = keras.layers.Dense( + units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" + ) + self.config = config + + @unpack_inputs + @add_start_docstrings_to_model_forward(ROFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFTokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: np.ndarray | tf.Tensor | None = None, + training: Optional[bool] = False, + ) -> Union[TFTokenClassifierOutput, Tuple[tf.Tensor]]: + r""" + labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + outputs = self.roformer( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(inputs=sequence_output, training=training) + logits = self.classifier(inputs=sequence_output) + loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return TFTokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "roformer", None) is not None: + with tf.name_scope(self.roformer.name): + self.roformer.build(None) + if getattr(self, "classifier", None) is not None: + with tf.name_scope(self.classifier.name): + self.classifier.build([None, None, self.config.hidden_size]) + + +@add_start_docstrings( + """ + RoFormer Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear + layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + ROFORMER_START_DOCSTRING, +) +class TFRoFormerForQuestionAnswering(TFRoFormerPreTrainedModel, TFQuestionAnsweringLoss): + def __init__(self, config: RoFormerConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.num_labels = config.num_labels + + self.roformer = TFRoFormerMainLayer(config, name="roformer") + self.qa_outputs = keras.layers.Dense( + units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs" + ) + self.config = config + + @unpack_inputs + @add_start_docstrings_to_model_forward(ROFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFQuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + start_positions: np.ndarray | tf.Tensor | None = None, + end_positions: np.ndarray | tf.Tensor | None = None, + training: Optional[bool] = False, + ) -> Union[TFQuestionAnsweringModelOutput, Tuple[tf.Tensor]]: + r""" + start_positions (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + outputs = self.roformer( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + sequence_output = outputs[0] + logits = self.qa_outputs(inputs=sequence_output) + start_logits, end_logits = tf.split(value=logits, num_or_size_splits=2, axis=-1) + start_logits = tf.squeeze(input=start_logits, axis=-1) + end_logits = tf.squeeze(input=end_logits, axis=-1) + loss = None + + if start_positions is not None and end_positions is not None: + labels = {"start_position": start_positions, "end_position": end_positions} + loss = self.hf_compute_loss(labels=labels, logits=(start_logits, end_logits)) + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFQuestionAnsweringModelOutput( + loss=loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "roformer", None) is not None: + with tf.name_scope(self.roformer.name): + self.roformer.build(None) + if getattr(self, "qa_outputs", None) is not None: + with tf.name_scope(self.qa_outputs.name): + self.qa_outputs.build([None, None, self.config.hidden_size]) + + +__all__ = [ + "TFRoFormerForCausalLM", + "TFRoFormerForMaskedLM", + "TFRoFormerForMultipleChoice", + "TFRoFormerForQuestionAnswering", + "TFRoFormerForSequenceClassification", + "TFRoFormerForTokenClassification", + "TFRoFormerLayer", + "TFRoFormerModel", + "TFRoFormerPreTrainedModel", +] diff --git a/docs/transformers/src/transformers/models/roformer/tokenization_roformer.py b/docs/transformers/src/transformers/models/roformer/tokenization_roformer.py new file mode 100644 index 0000000000000000000000000000000000000000..312cecca193b7988138cc7d2e3ad7472488a5ae8 --- /dev/null +++ b/docs/transformers/src/transformers/models/roformer/tokenization_roformer.py @@ -0,0 +1,540 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for RoFormer.""" + +import collections +import os +import unicodedata +from typing import List, Optional, Tuple + +from ...tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace +from ...utils import logging + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} + + +# Copied from transformers.models.bert.tokenization_bert.load_vocab +def load_vocab(vocab_file): + """Loads a vocabulary file into a dictionary.""" + vocab = collections.OrderedDict() + with open(vocab_file, "r", encoding="utf-8") as reader: + tokens = reader.readlines() + for index, token in enumerate(tokens): + token = token.rstrip("\n") + vocab[token] = index + return vocab + + +# Copied from transformers.models.bert.tokenization_bert.whitespace_tokenize +def whitespace_tokenize(text): + """Runs basic whitespace cleaning and splitting on a piece of text.""" + text = text.strip() + if not text: + return [] + tokens = text.split() + return tokens + + +# Copied from transformers.models.bert.tokenization_bert.BasicTokenizer +class BasicTokenizer: + """ + Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.). + + Args: + do_lower_case (`bool`, *optional*, defaults to `True`): + Whether or not to lowercase the input when tokenizing. + never_split (`Iterable`, *optional*): + Collection of tokens which will never be split during tokenization. Only has an effect when + `do_basic_tokenize=True` + tokenize_chinese_chars (`bool`, *optional*, defaults to `True`): + Whether or not to tokenize Chinese characters. + + This should likely be deactivated for Japanese (see this + [issue](https://github.com/huggingface/transformers/issues/328)). + strip_accents (`bool`, *optional*): + Whether or not to strip all accents. If this option is not specified, then it will be determined by the + value for `lowercase` (as in the original BERT). + do_split_on_punc (`bool`, *optional*, defaults to `True`): + In some instances we want to skip the basic punctuation splitting so that later tokenization can capture + the full context of the words, such as contractions. + """ + + def __init__( + self, + do_lower_case=True, + never_split=None, + tokenize_chinese_chars=True, + strip_accents=None, + do_split_on_punc=True, + ): + if never_split is None: + never_split = [] + self.do_lower_case = do_lower_case + self.never_split = set(never_split) + self.tokenize_chinese_chars = tokenize_chinese_chars + self.strip_accents = strip_accents + self.do_split_on_punc = do_split_on_punc + + def tokenize(self, text, never_split=None): + """ + Basic Tokenization of a piece of text. For sub-word tokenization, see WordPieceTokenizer. + + Args: + never_split (`List[str]`, *optional*) + Kept for backward compatibility purposes. Now implemented directly at the base class level (see + [`PreTrainedTokenizer.tokenize`]) List of token not to split. + """ + # union() returns a new set by concatenating the two sets. + never_split = self.never_split.union(set(never_split)) if never_split else self.never_split + text = self._clean_text(text) + + # This was added on November 1st, 2018 for the multilingual and Chinese + # models. This is also applied to the English models now, but it doesn't + # matter since the English models were not trained on any Chinese data + # and generally don't have any Chinese data in them (there are Chinese + # characters in the vocabulary because Wikipedia does have some Chinese + # words in the English Wikipedia.). + if self.tokenize_chinese_chars: + text = self._tokenize_chinese_chars(text) + # prevents treating the same character with different unicode codepoints as different characters + unicode_normalized_text = unicodedata.normalize("NFC", text) + orig_tokens = whitespace_tokenize(unicode_normalized_text) + split_tokens = [] + for token in orig_tokens: + if token not in never_split: + if self.do_lower_case: + token = token.lower() + if self.strip_accents is not False: + token = self._run_strip_accents(token) + elif self.strip_accents: + token = self._run_strip_accents(token) + split_tokens.extend(self._run_split_on_punc(token, never_split)) + + output_tokens = whitespace_tokenize(" ".join(split_tokens)) + return output_tokens + + def _run_strip_accents(self, text): + """Strips accents from a piece of text.""" + text = unicodedata.normalize("NFD", text) + output = [] + for char in text: + cat = unicodedata.category(char) + if cat == "Mn": + continue + output.append(char) + return "".join(output) + + def _run_split_on_punc(self, text, never_split=None): + """Splits punctuation on a piece of text.""" + if not self.do_split_on_punc or (never_split is not None and text in never_split): + return [text] + chars = list(text) + i = 0 + start_new_word = True + output = [] + while i < len(chars): + char = chars[i] + if _is_punctuation(char): + output.append([char]) + start_new_word = True + else: + if start_new_word: + output.append([]) + start_new_word = False + output[-1].append(char) + i += 1 + + return ["".join(x) for x in output] + + def _tokenize_chinese_chars(self, text): + """Adds whitespace around any CJK character.""" + output = [] + for char in text: + cp = ord(char) + if self._is_chinese_char(cp): + output.append(" ") + output.append(char) + output.append(" ") + else: + output.append(char) + return "".join(output) + + def _is_chinese_char(self, cp): + """Checks whether CP is the codepoint of a CJK character.""" + # This defines a "chinese character" as anything in the CJK Unicode block: + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + # + # Note that the CJK Unicode block is NOT all Japanese and Korean characters, + # despite its name. The modern Korean Hangul alphabet is a different block, + # as is Japanese Hiragana and Katakana. Those alphabets are used to write + # space-separated words, so they are not treated specially and handled + # like the all of the other languages. + if ( + (cp >= 0x4E00 and cp <= 0x9FFF) + or (cp >= 0x3400 and cp <= 0x4DBF) # + or (cp >= 0x20000 and cp <= 0x2A6DF) # + or (cp >= 0x2A700 and cp <= 0x2B73F) # + or (cp >= 0x2B740 and cp <= 0x2B81F) # + or (cp >= 0x2B820 and cp <= 0x2CEAF) # + or (cp >= 0xF900 and cp <= 0xFAFF) + or (cp >= 0x2F800 and cp <= 0x2FA1F) # + ): # + return True + + return False + + def _clean_text(self, text): + """Performs invalid character removal and whitespace cleanup on text.""" + output = [] + for char in text: + cp = ord(char) + if cp == 0 or cp == 0xFFFD or _is_control(char): + continue + if _is_whitespace(char): + output.append(" ") + else: + output.append(char) + return "".join(output) + + +# Copied from transformers.models.bert.tokenization_bert.WordpieceTokenizer +class WordpieceTokenizer: + """Runs WordPiece tokenization.""" + + def __init__(self, vocab, unk_token, max_input_chars_per_word=100): + self.vocab = vocab + self.unk_token = unk_token + self.max_input_chars_per_word = max_input_chars_per_word + + def tokenize(self, text): + """ + Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform + tokenization using the given vocabulary. + + For example, `input = "unaffable"` wil return as output `["un", "##aff", "##able"]`. + + Args: + text: A single token or whitespace separated tokens. This should have + already been passed through *BasicTokenizer*. + + Returns: + A list of wordpiece tokens. + """ + + output_tokens = [] + for token in whitespace_tokenize(text): + chars = list(token) + if len(chars) > self.max_input_chars_per_word: + output_tokens.append(self.unk_token) + continue + + is_bad = False + start = 0 + sub_tokens = [] + while start < len(chars): + end = len(chars) + cur_substr = None + while start < end: + substr = "".join(chars[start:end]) + if start > 0: + substr = "##" + substr + if substr in self.vocab: + cur_substr = substr + break + end -= 1 + if cur_substr is None: + is_bad = True + break + sub_tokens.append(cur_substr) + start = end + + if is_bad: + output_tokens.append(self.unk_token) + else: + output_tokens.extend(sub_tokens) + return output_tokens + + +class RoFormerTokenizer(PreTrainedTokenizer): + r""" + Construct a RoFormer tokenizer. Based on [Rust Jieba](https://pypi.org/project/rjieba/). + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + File containing the vocabulary. + do_lower_case (`bool`, *optional*, defaults to `True`): + Whether or not to lowercase the input when tokenizing. + do_basic_tokenize (`bool`, *optional*, defaults to `True`): + Whether or not to do basic tokenization before WordPiece. + never_split (`Iterable`, *optional*): + Collection of tokens which will never be split during tokenization. Only has an effect when + `do_basic_tokenize=True` + unk_token (`str`, *optional*, defaults to `"[UNK]"`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + sep_token (`str`, *optional*, defaults to `"[SEP]"`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + pad_token (`str`, *optional*, defaults to `"[PAD]"`): + The token used for padding, for example when batching sequences of different lengths. + cls_token (`str`, *optional*, defaults to `"[CLS]"`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + mask_token (`str`, *optional*, defaults to `"[MASK]"`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + tokenize_chinese_chars (`bool`, *optional*, defaults to `True`): + Whether or not to tokenize Chinese characters. + + This should likely be deactivated for Japanese (see this + [issue](https://github.com/huggingface/transformers/issues/328)). + strip_accents (`bool`, *optional*): + Whether or not to strip all accents. If this option is not specified, then it will be determined by the + value for `lowercase` (as in the original BERT). + + Example: + + ```python + >>> from transformers import RoFormerTokenizer + + >>> tokenizer = RoFormerTokenizer.from_pretrained("junnyu/roformer_chinese_base") + >>> tokenizer.tokenize("今天天气非常好。") + ['今', '天', '天', '气', '非常', '好', '。'] + ```""" + + vocab_files_names = VOCAB_FILES_NAMES + + def __init__( + self, + vocab_file, + do_lower_case=True, + do_basic_tokenize=True, + never_split=None, + unk_token="[UNK]", + sep_token="[SEP]", + pad_token="[PAD]", + cls_token="[CLS]", + mask_token="[MASK]", + tokenize_chinese_chars=True, + strip_accents=None, + **kwargs, + ): + if not os.path.isfile(vocab_file): + raise ValueError( + f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained" + " model use `tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`" + ) + self.vocab = load_vocab(vocab_file) + self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()]) + self.do_basic_tokenize = do_basic_tokenize + if do_basic_tokenize: + self.basic_tokenizer = BasicTokenizer( + do_lower_case=do_lower_case, + never_split=never_split, + tokenize_chinese_chars=tokenize_chinese_chars, + strip_accents=strip_accents, + ) + self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=str(unk_token)) + try: + import rjieba + except ImportError: + raise ImportError( + "You need to install rjieba to use RoFormerTokenizer. " + "See https://pypi.org/project/rjieba/ for installation." + ) + self.jieba = rjieba + + super().__init__( + do_lower_case=do_lower_case, + do_basic_tokenize=do_basic_tokenize, + never_split=never_split, + unk_token=unk_token, + sep_token=sep_token, + pad_token=pad_token, + cls_token=cls_token, + mask_token=mask_token, + tokenize_chinese_chars=tokenize_chinese_chars, + strip_accents=strip_accents, + **kwargs, + ) + + @property + def do_lower_case(self): + return self.basic_tokenizer.do_lower_case + + @property + def vocab_size(self): + return len(self.vocab) + + def __getstate__(self): + state = self.__dict__.copy() + state["jieba"] = None + return state + + def __setstate__(self, d): + self.__dict__ = d + import rjieba + + self.jieba = rjieba + + def get_vocab(self): + return dict(self.vocab, **self.added_tokens_encoder) + + def _tokenize(self, text, use_jieba=True): + split_tokens = [] + if use_jieba: + for wholword in self.jieba.cut(text, False): + if wholword in self.vocab: + split_tokens.append(wholword) + else: + # use bert tokenizer to _tokenize + char_list = self._tokenize(wholword, use_jieba=False) + split_tokens.extend(char_list) + else: + if self.do_basic_tokenize: + for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens): + # If the token is part of the never_split set + if token in self.basic_tokenizer.never_split: + split_tokens.append(token) + else: + split_tokens += self.wordpiece_tokenizer.tokenize(token) + else: + split_tokens = self.wordpiece_tokenizer.tokenize(text) + return split_tokens + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.vocab.get(token, self.vocab.get(self.unk_token)) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.ids_to_tokens.get(index, self.unk_token) + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + out_string = " ".join(tokens).replace(" ##", "").strip() + return out_string + + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A RoFormer sequence has the following format: + + - single sequence: `[CLS] X [SEP]` + - pair of sequences: `[CLS] A [SEP] B [SEP]` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + if token_ids_1 is None: + return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + cls = [self.cls_token_id] + sep = [self.sep_token_id] + return cls + token_ids_0 + sep + token_ids_1 + sep + + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + if token_ids_1 is not None: + return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1] + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. A RoFormer + sequence pair mask has the following format: + + ``` + 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 + | first sequence | second sequence | + ``` + + If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s). + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s). + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + index = 0 + if os.path.isdir(save_directory): + vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + else: + vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory + with open(vocab_file, "w", encoding="utf-8") as writer: + for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning( + f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive." + " Please check that the vocabulary is not corrupted!" + ) + index = token_index + writer.write(token + "\n") + index += 1 + return (vocab_file,) + + +__all__ = ["RoFormerTokenizer"] diff --git a/docs/transformers/src/transformers/models/roformer/tokenization_roformer_fast.py b/docs/transformers/src/transformers/models/roformer/tokenization_roformer_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..7f75e5a2ff12b8d0bce8f80bec2b2cbbfd31e4fb --- /dev/null +++ b/docs/transformers/src/transformers/models/roformer/tokenization_roformer_fast.py @@ -0,0 +1,180 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for RoFormer.""" + +import json +from typing import List, Optional, Tuple + +from tokenizers import normalizers +from tokenizers.pre_tokenizers import BertPreTokenizer, PreTokenizer + +from ...tokenization_utils_fast import PreTrainedTokenizerFast +from ...utils import logging +from .tokenization_roformer import RoFormerTokenizer +from .tokenization_utils import JiebaPreTokenizer + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer.json"} + + +class RoFormerTokenizerFast(PreTrainedTokenizerFast): + r""" + Construct a "fast" RoFormer tokenizer (backed by HuggingFace's *tokenizers* library). + + [`RoFormerTokenizerFast`] is almost identical to [`BertTokenizerFast`] and runs end-to-end tokenization: + punctuation splitting and wordpiece. There are some difference between them when tokenizing Chinese. + + This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should + refer to this superclass for more information regarding those methods. + + Example: + + ```python + >>> from transformers import RoFormerTokenizerFast + + >>> tokenizer = RoFormerTokenizerFast.from_pretrained("junnyu/roformer_chinese_base") + >>> tokenizer.tokenize("今天天气非常好。") + ['今', '天', '天', '气', '非常', '好', '。'] + ```""" + + vocab_files_names = VOCAB_FILES_NAMES + slow_tokenizer_class = RoFormerTokenizer + + def __init__( + self, + vocab_file=None, + tokenizer_file=None, + do_lower_case=True, + unk_token="[UNK]", + sep_token="[SEP]", + pad_token="[PAD]", + cls_token="[CLS]", + mask_token="[MASK]", + tokenize_chinese_chars=True, + strip_accents=None, + **kwargs, + ): + super().__init__( + vocab_file, + tokenizer_file=tokenizer_file, + do_lower_case=do_lower_case, + unk_token=unk_token, + sep_token=sep_token, + pad_token=pad_token, + cls_token=cls_token, + mask_token=mask_token, + tokenize_chinese_chars=tokenize_chinese_chars, + strip_accents=strip_accents, + **kwargs, + ) + + normalizer_state = json.loads(self.backend_tokenizer.normalizer.__getstate__()) + if ( + normalizer_state.get("lowercase", do_lower_case) != do_lower_case + or normalizer_state.get("strip_accents", strip_accents) != strip_accents + ): + normalizer_class = getattr(normalizers, normalizer_state.pop("type")) + normalizer_state["lowercase"] = do_lower_case + normalizer_state["strip_accents"] = strip_accents + self.backend_tokenizer.normalizer = normalizer_class(**normalizer_state) + + # Make sure we correctly set the custom PreTokenizer + vocab = self.backend_tokenizer.get_vocab() + self.backend_tokenizer.pre_tokenizer = PreTokenizer.custom(JiebaPreTokenizer(vocab)) + + self.do_lower_case = do_lower_case + + def __getstate__(self): + state = self.__dict__.copy() + state["_tokenizer"].pre_tokenizer = BertPreTokenizer() + return state + + def __setstate__(self, d): + self.__dict__ = d + vocab = self.__dict__["_tokenizer"].get_vocab() + self.__dict__["_tokenizer"].pre_tokenizer = PreTokenizer.custom(JiebaPreTokenizer(vocab)) + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A RoFormer sequence has the following format: + + - single sequence: `[CLS] X [SEP]` + - pair of sequences: `[CLS] A [SEP] B [SEP]` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + + if token_ids_1 is not None: + output += token_ids_1 + [self.sep_token_id] + + return output + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. A RoFormer + sequence pair mask has the following format: + + ``` + 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 + | first sequence | second sequence | + ``` + + If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s). + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s). + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + files = self._tokenizer.model.save(save_directory, name=filename_prefix) + return tuple(files) + + def save_pretrained( + self, + save_directory, + legacy_format=None, + filename_prefix=None, + push_to_hub=False, + **kwargs, + ): + self.backend_tokenizer.pre_tokenizer = BertPreTokenizer() + return super().save_pretrained(save_directory, legacy_format, filename_prefix, push_to_hub, **kwargs) + + +__all__ = ["RoFormerTokenizerFast"] diff --git a/docs/transformers/src/transformers/models/roformer/tokenization_utils.py b/docs/transformers/src/transformers/models/roformer/tokenization_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4c9cf6cb0a0b7f08ba43e8761b02862a5c368951 --- /dev/null +++ b/docs/transformers/src/transformers/models/roformer/tokenization_utils.py @@ -0,0 +1,68 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization utils for RoFormer.""" + +from typing import List + +from tokenizers import NormalizedString, PreTokenizedString, normalizers + + +class JiebaPreTokenizer: + def __init__(self, vocab) -> None: + self.vocab = vocab + self.normalizers = normalizers.BertNormalizer( + clean_text=False, + handle_chinese_chars=True, + strip_accents=False, + lowercase=False, + ) + try: + import rjieba + except ImportError: + raise ImportError( + "You need to install rjieba to use RoFormerTokenizer. " + "See https://pypi.org/project/rjieba/ for installation." + ) + self.jieba = rjieba + + def jieba_split(self, i: int, normalized_string: NormalizedString) -> List[NormalizedString]: + splits = [] + + # this code slice normalized_string is too slow (6s) but test_alignment_methods can pass + for token, start, end in self.jieba.tokenize(str(normalized_string), hmm=False): + if token in self.vocab: + splits.append(normalized_string[start:end]) + else: + token_list = self.normalizers.normalize_str(token).split() + for token in token_list: + if token: + end = start + len(token) + splits.append(normalized_string[start:end]) + start = end + + # this code test_alignment_methods can't pass but fast (300ms) + # for token in self.jieba.cut(str(normalized_string), False): + # if token in self.vocab: + # splits.append(NormalizedString(token)) + # else: + # token_list = self.normalizers.normalize_str(token).split() + # for token in token_list: + # if token: + # splits.append(NormalizedString(token)) + + return splits + + def pre_tokenize(self, pretok: PreTokenizedString): + pretok.split(self.jieba_split) diff --git a/docs/transformers/src/transformers/models/rt_detr/__init__.py b/docs/transformers/src/transformers/models/rt_detr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cd1fac5b679dc5ca527dcc0a6a45b9a873daa3dc --- /dev/null +++ b/docs/transformers/src/transformers/models/rt_detr/__init__.py @@ -0,0 +1,33 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_rt_detr import * + from .configuration_rt_detr_resnet import * + from .image_processing_rt_detr import * + from .image_processing_rt_detr_fast import * + from .modeling_rt_detr import * + from .modeling_rt_detr_resnet import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/rt_detr/configuration_rt_detr.py b/docs/transformers/src/transformers/models/rt_detr/configuration_rt_detr.py new file mode 100644 index 0000000000000000000000000000000000000000..3dc8ea550ad9e40d70da5049071d2f5dbcac2ab0 --- /dev/null +++ b/docs/transformers/src/transformers/models/rt_detr/configuration_rt_detr.py @@ -0,0 +1,364 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""RT-DETR model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ...utils.backbone_utils import verify_backbone_config_arguments +from ..auto import CONFIG_MAPPING +from .configuration_rt_detr_resnet import RTDetrResNetConfig + + +logger = logging.get_logger(__name__) + + +class RTDetrConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`RTDetrModel`]. It is used to instantiate a + RT-DETR model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the RT-DETR + [checkpoing/todo](https://huggingface.co/checkpoing/todo) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + initializer_range (`float`, *optional*, defaults to 0.01): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + initializer_bias_prior_prob (`float`, *optional*): + The prior probability used by the bias initializer to initialize biases for `enc_score_head` and `class_embed`. + If `None`, `prior_prob` computed as `prior_prob = 1 / (num_labels + 1)` while initializing model weights. + layer_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the layer normalization layers. + batch_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the batch normalization layers. + backbone_config (`Dict`, *optional*, defaults to `RTDetrResNetConfig()`): + The configuration of the backbone model. + backbone (`str`, *optional*): + Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this + will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone` + is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights. + use_pretrained_backbone (`bool`, *optional*, defaults to `False`): + Whether to use pretrained weights for the backbone. + use_timm_backbone (`bool`, *optional*, defaults to `False`): + Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers + library. + freeze_backbone_batch_norms (`bool`, *optional*, defaults to `True`): + Whether to freeze the batch normalization layers in the backbone. + backbone_kwargs (`dict`, *optional*): + Keyword arguments to be passed to AutoBackbone when loading from a checkpoint + e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set. + encoder_hidden_dim (`int`, *optional*, defaults to 256): + Dimension of the layers in hybrid encoder. + encoder_in_channels (`list`, *optional*, defaults to `[512, 1024, 2048]`): + Multi level features input for encoder. + feat_strides (`List[int]`, *optional*, defaults to `[8, 16, 32]`): + Strides used in each feature map. + encoder_layers (`int`, *optional*, defaults to 1): + Total of layers to be used by the encoder. + encoder_ffn_dim (`int`, *optional*, defaults to 1024): + Dimension of the "intermediate" (often named feed-forward) layer in decoder. + encoder_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer encoder. + dropout (`float`, *optional*, defaults to 0.0): + The ratio for all dropout layers. + activation_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for activations inside the fully connected layer. + encode_proj_layers (`List[int]`, *optional*, defaults to `[2]`): + Indexes of the projected layers to be used in the encoder. + positional_encoding_temperature (`int`, *optional*, defaults to 10000): + The temperature parameter used to create the positional encodings. + encoder_activation_function (`str`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + activation_function (`str`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the general layer. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + eval_size (`Tuple[int, int]`, *optional*): + Height and width used to computes the effective height and width of the position embeddings after taking + into account the stride. + normalize_before (`bool`, *optional*, defaults to `False`): + Determine whether to apply layer normalization in the transformer encoder layer before self-attention and + feed-forward modules. + hidden_expansion (`float`, *optional*, defaults to 1.0): + Expansion ratio to enlarge the dimension size of RepVGGBlock and CSPRepLayer. + d_model (`int`, *optional*, defaults to 256): + Dimension of the layers exclude hybrid encoder. + num_queries (`int`, *optional*, defaults to 300): + Number of object queries. + decoder_in_channels (`list`, *optional*, defaults to `[256, 256, 256]`): + Multi level features dimension for decoder + decoder_ffn_dim (`int`, *optional*, defaults to 1024): + Dimension of the "intermediate" (often named feed-forward) layer in decoder. + num_feature_levels (`int`, *optional*, defaults to 3): + The number of input feature levels. + decoder_n_points (`int`, *optional*, defaults to 4): + The number of sampled keys in each feature level for each attention head in the decoder. + decoder_layers (`int`, *optional*, defaults to 6): + Number of decoder layers. + decoder_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer decoder. + decoder_activation_function (`str`, *optional*, defaults to `"relu"`): + The non-linear activation function (function or string) in the decoder. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + num_denoising (`int`, *optional*, defaults to 100): + The total number of denoising tasks or queries to be used for contrastive denoising. + label_noise_ratio (`float`, *optional*, defaults to 0.5): + The fraction of denoising labels to which random noise should be added. + box_noise_scale (`float`, *optional*, defaults to 1.0): + Scale or magnitude of noise to be added to the bounding boxes. + learn_initial_query (`bool`, *optional*, defaults to `False`): + Indicates whether the initial query embeddings for the decoder should be learned during training + anchor_image_size (`Tuple[int, int]`, *optional*): + Height and width of the input image used during evaluation to generate the bounding box anchors. If None, automatic generate anchor is applied. + disable_custom_kernels (`bool`, *optional*, defaults to `True`): + Whether to disable custom kernels. + with_box_refine (`bool`, *optional*, defaults to `True`): + Whether to apply iterative bounding box refinement, where each decoder layer refines the bounding boxes + based on the predictions from the previous layer. + is_encoder_decoder (`bool`, *optional*, defaults to `True`): + Whether the architecture has an encoder decoder structure. + matcher_alpha (`float`, *optional*, defaults to 0.25): + Parameter alpha used by the Hungarian Matcher. + matcher_gamma (`float`, *optional*, defaults to 2.0): + Parameter gamma used by the Hungarian Matcher. + matcher_class_cost (`float`, *optional*, defaults to 2.0): + The relative weight of the class loss used by the Hungarian Matcher. + matcher_bbox_cost (`float`, *optional*, defaults to 5.0): + The relative weight of the bounding box loss used by the Hungarian Matcher. + matcher_giou_cost (`float`, *optional*, defaults to 2.0): + The relative weight of the giou loss of used by the Hungarian Matcher. + use_focal_loss (`bool`, *optional*, defaults to `True`): + Parameter informing if focal focal should be used. + auxiliary_loss (`bool`, *optional*, defaults to `True`): + Whether auxiliary decoding losses (loss at each decoder layer) are to be used. + focal_loss_alpha (`float`, *optional*, defaults to 0.75): + Parameter alpha used to compute the focal loss. + focal_loss_gamma (`float`, *optional*, defaults to 2.0): + Parameter gamma used to compute the focal loss. + weight_loss_vfl (`float`, *optional*, defaults to 1.0): + Relative weight of the varifocal loss in the object detection loss. + weight_loss_bbox (`float`, *optional*, defaults to 5.0): + Relative weight of the L1 bounding box loss in the object detection loss. + weight_loss_giou (`float`, *optional*, defaults to 2.0): + Relative weight of the generalized IoU loss in the object detection loss. + eos_coefficient (`float`, *optional*, defaults to 0.0001): + Relative classification weight of the 'no-object' class in the object detection loss. + + Examples: + + ```python + >>> from transformers import RTDetrConfig, RTDetrModel + + >>> # Initializing a RT-DETR configuration + >>> configuration = RTDetrConfig() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = RTDetrModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "rt_detr" + layer_types = ["basic", "bottleneck"] + attribute_map = { + "hidden_size": "d_model", + "num_attention_heads": "encoder_attention_heads", + } + + def __init__( + self, + initializer_range=0.01, + initializer_bias_prior_prob=None, + layer_norm_eps=1e-5, + batch_norm_eps=1e-5, + # backbone + backbone_config=None, + backbone=None, + use_pretrained_backbone=False, + use_timm_backbone=False, + freeze_backbone_batch_norms=True, + backbone_kwargs=None, + # encoder HybridEncoder + encoder_hidden_dim=256, + encoder_in_channels=[512, 1024, 2048], + feat_strides=[8, 16, 32], + encoder_layers=1, + encoder_ffn_dim=1024, + encoder_attention_heads=8, + dropout=0.0, + activation_dropout=0.0, + encode_proj_layers=[2], + positional_encoding_temperature=10000, + encoder_activation_function="gelu", + activation_function="silu", + eval_size=None, + normalize_before=False, + hidden_expansion=1.0, + # decoder RTDetrTransformer + d_model=256, + num_queries=300, + decoder_in_channels=[256, 256, 256], + decoder_ffn_dim=1024, + num_feature_levels=3, + decoder_n_points=4, + decoder_layers=6, + decoder_attention_heads=8, + decoder_activation_function="relu", + attention_dropout=0.0, + num_denoising=100, + label_noise_ratio=0.5, + box_noise_scale=1.0, + learn_initial_query=False, + anchor_image_size=None, + disable_custom_kernels=True, + with_box_refine=True, + is_encoder_decoder=True, + # Loss + matcher_alpha=0.25, + matcher_gamma=2.0, + matcher_class_cost=2.0, + matcher_bbox_cost=5.0, + matcher_giou_cost=2.0, + use_focal_loss=True, + auxiliary_loss=True, + focal_loss_alpha=0.75, + focal_loss_gamma=2.0, + weight_loss_vfl=1.0, + weight_loss_bbox=5.0, + weight_loss_giou=2.0, + eos_coefficient=1e-4, + **kwargs, + ): + self.initializer_range = initializer_range + self.initializer_bias_prior_prob = initializer_bias_prior_prob + self.layer_norm_eps = layer_norm_eps + self.batch_norm_eps = batch_norm_eps + # backbone + if backbone_config is None and backbone is None: + logger.info( + "`backbone_config` and `backbone` are `None`. Initializing the config with the default `RTDetr-ResNet` backbone." + ) + backbone_config = RTDetrResNetConfig( + num_channels=3, + embedding_size=64, + hidden_sizes=[256, 512, 1024, 2048], + depths=[3, 4, 6, 3], + layer_type="bottleneck", + hidden_act="relu", + downsample_in_first_stage=False, + downsample_in_bottleneck=False, + out_features=None, + out_indices=[2, 3, 4], + ) + elif isinstance(backbone_config, dict): + backbone_model_type = backbone_config.pop("model_type") + config_class = CONFIG_MAPPING[backbone_model_type] + backbone_config = config_class.from_dict(backbone_config) + + verify_backbone_config_arguments( + use_timm_backbone=use_timm_backbone, + use_pretrained_backbone=use_pretrained_backbone, + backbone=backbone, + backbone_config=backbone_config, + backbone_kwargs=backbone_kwargs, + ) + + self.backbone_config = backbone_config + self.backbone = backbone + self.use_pretrained_backbone = use_pretrained_backbone + self.use_timm_backbone = use_timm_backbone + self.freeze_backbone_batch_norms = freeze_backbone_batch_norms + self.backbone_kwargs = backbone_kwargs + # encoder + self.encoder_hidden_dim = encoder_hidden_dim + self.encoder_in_channels = encoder_in_channels + self.feat_strides = feat_strides + self.encoder_attention_heads = encoder_attention_heads + self.encoder_ffn_dim = encoder_ffn_dim + self.dropout = dropout + self.activation_dropout = activation_dropout + self.encode_proj_layers = encode_proj_layers + self.encoder_layers = encoder_layers + self.positional_encoding_temperature = positional_encoding_temperature + self.eval_size = eval_size + self.normalize_before = normalize_before + self.encoder_activation_function = encoder_activation_function + self.activation_function = activation_function + self.hidden_expansion = hidden_expansion + # decoder + self.d_model = d_model + self.num_queries = num_queries + self.decoder_ffn_dim = decoder_ffn_dim + self.decoder_in_channels = decoder_in_channels + self.num_feature_levels = num_feature_levels + self.decoder_n_points = decoder_n_points + self.decoder_layers = decoder_layers + self.decoder_attention_heads = decoder_attention_heads + self.decoder_activation_function = decoder_activation_function + self.attention_dropout = attention_dropout + self.num_denoising = num_denoising + self.label_noise_ratio = label_noise_ratio + self.box_noise_scale = box_noise_scale + self.learn_initial_query = learn_initial_query + self.anchor_image_size = anchor_image_size + self.auxiliary_loss = auxiliary_loss + self.disable_custom_kernels = disable_custom_kernels + self.with_box_refine = with_box_refine + # Loss + self.matcher_alpha = matcher_alpha + self.matcher_gamma = matcher_gamma + self.matcher_class_cost = matcher_class_cost + self.matcher_bbox_cost = matcher_bbox_cost + self.matcher_giou_cost = matcher_giou_cost + self.use_focal_loss = use_focal_loss + self.focal_loss_alpha = focal_loss_alpha + self.focal_loss_gamma = focal_loss_gamma + self.weight_loss_vfl = weight_loss_vfl + self.weight_loss_bbox = weight_loss_bbox + self.weight_loss_giou = weight_loss_giou + self.eos_coefficient = eos_coefficient + super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs) + + @property + def num_attention_heads(self) -> int: + return self.encoder_attention_heads + + @property + def hidden_size(self) -> int: + return self.d_model + + @classmethod + def from_backbone_configs(cls, backbone_config: PretrainedConfig, **kwargs): + """Instantiate a [`RTDetrConfig`] (or a derived class) from a pre-trained backbone model configuration and DETR model + configuration. + + Args: + backbone_config ([`PretrainedConfig`]): + The backbone configuration. + + Returns: + [`RTDetrConfig`]: An instance of a configuration object + """ + return cls( + backbone_config=backbone_config, + **kwargs, + ) + + +__all__ = ["RTDetrConfig"] diff --git a/docs/transformers/src/transformers/models/rt_detr/configuration_rt_detr_resnet.py b/docs/transformers/src/transformers/models/rt_detr/configuration_rt_detr_resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..35e74cd2f1819cf470f50b2c151e4e8406b7839c --- /dev/null +++ b/docs/transformers/src/transformers/models/rt_detr/configuration_rt_detr_resnet.py @@ -0,0 +1,114 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""RT-DETR ResNet model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices + + +logger = logging.get_logger(__name__) + + +class RTDetrResNetConfig(BackboneConfigMixin, PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`RTDetrResnetBackbone`]. It is used to instantiate an + ResNet model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the ResNet + [microsoft/resnet-50](https://huggingface.co/microsoft/resnet-50) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + embedding_size (`int`, *optional*, defaults to 64): + Dimensionality (hidden size) for the embedding layer. + hidden_sizes (`List[int]`, *optional*, defaults to `[256, 512, 1024, 2048]`): + Dimensionality (hidden size) at each stage. + depths (`List[int]`, *optional*, defaults to `[3, 4, 6, 3]`): + Depth (number of layers) for each stage. + layer_type (`str`, *optional*, defaults to `"bottleneck"`): + The layer to use, it can be either `"basic"` (used for smaller models, like resnet-18 or resnet-34) or + `"bottleneck"` (used for larger models like resnet-50 and above). + hidden_act (`str`, *optional*, defaults to `"relu"`): + The non-linear activation function in each block. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` + are supported. + downsample_in_first_stage (`bool`, *optional*, defaults to `False`): + If `True`, the first stage will downsample the inputs using a `stride` of 2. + downsample_in_bottleneck (`bool`, *optional*, defaults to `False`): + If `True`, the first conv 1x1 in ResNetBottleNeckLayer will downsample the inputs using a `stride` of 2. + out_features (`List[str]`, *optional*): + If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc. + (depending on how many stages the model has). If unset and `out_indices` is set, will default to the + corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the + same order as defined in the `stage_names` attribute. + out_indices (`List[int]`, *optional*): + If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how + many stages the model has). If unset and `out_features` is set, will default to the corresponding stages. + If unset and `out_features` is unset, will default to the last stage. Must be in the + same order as defined in the `stage_names` attribute. + + Example: + ```python + >>> from transformers import RTDetrResNetConfig, RTDetrResnetBackbone + + >>> # Initializing a ResNet resnet-50 style configuration + >>> configuration = RTDetrResNetConfig() + + >>> # Initializing a model (with random weights) from the resnet-50 style configuration + >>> model = RTDetrResnetBackbone(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "rt_detr_resnet" + layer_types = ["basic", "bottleneck"] + + def __init__( + self, + num_channels=3, + embedding_size=64, + hidden_sizes=[256, 512, 1024, 2048], + depths=[3, 4, 6, 3], + layer_type="bottleneck", + hidden_act="relu", + downsample_in_first_stage=False, + downsample_in_bottleneck=False, + out_features=None, + out_indices=None, + **kwargs, + ): + super().__init__(**kwargs) + if layer_type not in self.layer_types: + raise ValueError(f"layer_type={layer_type} is not one of {','.join(self.layer_types)}") + self.num_channels = num_channels + self.embedding_size = embedding_size + self.hidden_sizes = hidden_sizes + self.depths = depths + self.layer_type = layer_type + self.hidden_act = hidden_act + self.downsample_in_first_stage = downsample_in_first_stage + self.downsample_in_bottleneck = downsample_in_bottleneck + self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)] + self._out_features, self._out_indices = get_aligned_output_features_output_indices( + out_features=out_features, out_indices=out_indices, stage_names=self.stage_names + ) + + +__all__ = ["RTDetrResNetConfig"] diff --git a/docs/transformers/src/transformers/models/rt_detr/convert_rt_detr_original_pytorch_checkpoint_to_hf.py b/docs/transformers/src/transformers/models/rt_detr/convert_rt_detr_original_pytorch_checkpoint_to_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..9f2271930e13781175e2f7e8a71fc2cd53a9eda7 --- /dev/null +++ b/docs/transformers/src/transformers/models/rt_detr/convert_rt_detr_original_pytorch_checkpoint_to_hf.py @@ -0,0 +1,782 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert RT Detr checkpoints with Timm backbone""" + +import argparse +import json +from pathlib import Path + +import requests +import torch +from huggingface_hub import hf_hub_download +from PIL import Image +from torchvision import transforms + +from transformers import RTDetrConfig, RTDetrForObjectDetection, RTDetrImageProcessor +from transformers.utils import logging + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + + +def get_rt_detr_config(model_name: str) -> RTDetrConfig: + config = RTDetrConfig() + + config.num_labels = 80 + repo_id = "huggingface/label-files" + filename = "coco-detection-mmdet-id2label.json" + id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r")) + id2label = {int(k): v for k, v in id2label.items()} + config.id2label = id2label + config.label2id = {v: k for k, v in id2label.items()} + + if model_name == "rtdetr_r18vd": + config.backbone_config.hidden_sizes = [64, 128, 256, 512] + config.backbone_config.depths = [2, 2, 2, 2] + config.backbone_config.layer_type = "basic" + config.encoder_in_channels = [128, 256, 512] + config.hidden_expansion = 0.5 + config.decoder_layers = 3 + elif model_name == "rtdetr_r34vd": + config.backbone_config.hidden_sizes = [64, 128, 256, 512] + config.backbone_config.depths = [3, 4, 6, 3] + config.backbone_config.layer_type = "basic" + config.encoder_in_channels = [128, 256, 512] + config.hidden_expansion = 0.5 + config.decoder_layers = 4 + elif model_name == "rtdetr_r50vd_m": + pass + elif model_name == "rtdetr_r50vd": + pass + elif model_name == "rtdetr_r101vd": + config.backbone_config.depths = [3, 4, 23, 3] + config.encoder_ffn_dim = 2048 + config.encoder_hidden_dim = 384 + config.decoder_in_channels = [384, 384, 384] + elif model_name == "rtdetr_r18vd_coco_o365": + config.backbone_config.hidden_sizes = [64, 128, 256, 512] + config.backbone_config.depths = [2, 2, 2, 2] + config.backbone_config.layer_type = "basic" + config.encoder_in_channels = [128, 256, 512] + config.hidden_expansion = 0.5 + config.decoder_layers = 3 + elif model_name == "rtdetr_r50vd_coco_o365": + pass + elif model_name == "rtdetr_r101vd_coco_o365": + config.backbone_config.depths = [3, 4, 23, 3] + config.encoder_ffn_dim = 2048 + config.encoder_hidden_dim = 384 + config.decoder_in_channels = [384, 384, 384] + + return config + + +def create_rename_keys(config): + # here we list all keys to be renamed (original name on the left, our name on the right) + rename_keys = [] + + # stem + # fmt: off + last_key = ["weight", "bias", "running_mean", "running_var"] + + for level in range(3): + rename_keys.append((f"backbone.conv1.conv1_{level+1}.conv.weight", f"model.backbone.model.embedder.embedder.{level}.convolution.weight")) + for last in last_key: + rename_keys.append((f"backbone.conv1.conv1_{level+1}.norm.{last}", f"model.backbone.model.embedder.embedder.{level}.normalization.{last}")) + + for stage_idx in range(len(config.backbone_config.depths)): + for layer_idx in range(config.backbone_config.depths[stage_idx]): + # shortcut + if layer_idx == 0: + if stage_idx == 0: + rename_keys.append( + ( + f"backbone.res_layers.{stage_idx}.blocks.0.short.conv.weight", + f"model.backbone.model.encoder.stages.{stage_idx}.layers.0.shortcut.convolution.weight", + ) + ) + for last in last_key: + rename_keys.append( + ( + f"backbone.res_layers.{stage_idx}.blocks.0.short.norm.{last}", + f"model.backbone.model.encoder.stages.{stage_idx}.layers.0.shortcut.normalization.{last}", + ) + ) + else: + rename_keys.append( + ( + f"backbone.res_layers.{stage_idx}.blocks.0.short.conv.conv.weight", + f"model.backbone.model.encoder.stages.{stage_idx}.layers.0.shortcut.1.convolution.weight", + ) + ) + for last in last_key: + rename_keys.append( + ( + f"backbone.res_layers.{stage_idx}.blocks.0.short.conv.norm.{last}", + f"model.backbone.model.encoder.stages.{stage_idx}.layers.0.shortcut.1.normalization.{last}", + ) + ) + + rename_keys.append( + ( + f"backbone.res_layers.{stage_idx}.blocks.{layer_idx}.branch2a.conv.weight", + f"model.backbone.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.0.convolution.weight", + ) + ) + for last in last_key: + rename_keys.append(( + f"backbone.res_layers.{stage_idx}.blocks.{layer_idx}.branch2a.norm.{last}", + f"model.backbone.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.0.normalization.{last}", + )) + + rename_keys.append( + ( + f"backbone.res_layers.{stage_idx}.blocks.{layer_idx}.branch2b.conv.weight", + f"model.backbone.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.1.convolution.weight", + ) + ) + for last in last_key: + rename_keys.append(( + f"backbone.res_layers.{stage_idx}.blocks.{layer_idx}.branch2b.norm.{last}", + f"model.backbone.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.1.normalization.{last}", + )) + + # https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/rtdetr_pytorch/src/nn/backbone/presnet.py#L171 + if config.backbone_config.layer_type != "basic": + rename_keys.append( + ( + f"backbone.res_layers.{stage_idx}.blocks.{layer_idx}.branch2c.conv.weight", + f"model.backbone.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.2.convolution.weight", + ) + ) + for last in last_key: + rename_keys.append(( + f"backbone.res_layers.{stage_idx}.blocks.{layer_idx}.branch2c.norm.{last}", + f"model.backbone.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.2.normalization.{last}", + )) + # fmt: on + + for i in range(config.encoder_layers): + # encoder layers: output projection, 2 feedforward neural networks and 2 layernorms + rename_keys.append( + ( + f"encoder.encoder.{i}.layers.0.self_attn.out_proj.weight", + f"model.encoder.encoder.{i}.layers.0.self_attn.out_proj.weight", + ) + ) + rename_keys.append( + ( + f"encoder.encoder.{i}.layers.0.self_attn.out_proj.bias", + f"model.encoder.encoder.{i}.layers.0.self_attn.out_proj.bias", + ) + ) + rename_keys.append( + ( + f"encoder.encoder.{i}.layers.0.linear1.weight", + f"model.encoder.encoder.{i}.layers.0.fc1.weight", + ) + ) + rename_keys.append( + ( + f"encoder.encoder.{i}.layers.0.linear1.bias", + f"model.encoder.encoder.{i}.layers.0.fc1.bias", + ) + ) + rename_keys.append( + ( + f"encoder.encoder.{i}.layers.0.linear2.weight", + f"model.encoder.encoder.{i}.layers.0.fc2.weight", + ) + ) + rename_keys.append( + ( + f"encoder.encoder.{i}.layers.0.linear2.bias", + f"model.encoder.encoder.{i}.layers.0.fc2.bias", + ) + ) + rename_keys.append( + ( + f"encoder.encoder.{i}.layers.0.norm1.weight", + f"model.encoder.encoder.{i}.layers.0.self_attn_layer_norm.weight", + ) + ) + rename_keys.append( + ( + f"encoder.encoder.{i}.layers.0.norm1.bias", + f"model.encoder.encoder.{i}.layers.0.self_attn_layer_norm.bias", + ) + ) + rename_keys.append( + ( + f"encoder.encoder.{i}.layers.0.norm2.weight", + f"model.encoder.encoder.{i}.layers.0.final_layer_norm.weight", + ) + ) + rename_keys.append( + ( + f"encoder.encoder.{i}.layers.0.norm2.bias", + f"model.encoder.encoder.{i}.layers.0.final_layer_norm.bias", + ) + ) + + for j in range(0, 3): + rename_keys.append((f"encoder.input_proj.{j}.0.weight", f"model.encoder_input_proj.{j}.0.weight")) + for last in last_key: + rename_keys.append((f"encoder.input_proj.{j}.1.{last}", f"model.encoder_input_proj.{j}.1.{last}")) + + block_levels = 3 if config.backbone_config.layer_type != "basic" else 4 + + for i in range(len(config.encoder_in_channels) - 1): + # encoder layers: hybridencoder parts + for j in range(1, block_levels): + rename_keys.append( + (f"encoder.fpn_blocks.{i}.conv{j}.conv.weight", f"model.encoder.fpn_blocks.{i}.conv{j}.conv.weight") + ) + for last in last_key: + rename_keys.append( + ( + f"encoder.fpn_blocks.{i}.conv{j}.norm.{last}", + f"model.encoder.fpn_blocks.{i}.conv{j}.norm.{last}", + ) + ) + + rename_keys.append((f"encoder.lateral_convs.{i}.conv.weight", f"model.encoder.lateral_convs.{i}.conv.weight")) + for last in last_key: + rename_keys.append( + (f"encoder.lateral_convs.{i}.norm.{last}", f"model.encoder.lateral_convs.{i}.norm.{last}") + ) + + for j in range(3): + for k in range(1, 3): + rename_keys.append( + ( + f"encoder.fpn_blocks.{i}.bottlenecks.{j}.conv{k}.conv.weight", + f"model.encoder.fpn_blocks.{i}.bottlenecks.{j}.conv{k}.conv.weight", + ) + ) + for last in last_key: + rename_keys.append( + ( + f"encoder.fpn_blocks.{i}.bottlenecks.{j}.conv{k}.norm.{last}", + f"model.encoder.fpn_blocks.{i}.bottlenecks.{j}.conv{k}.norm.{last}", + ) + ) + + for j in range(1, block_levels): + rename_keys.append( + (f"encoder.pan_blocks.{i}.conv{j}.conv.weight", f"model.encoder.pan_blocks.{i}.conv{j}.conv.weight") + ) + for last in last_key: + rename_keys.append( + ( + f"encoder.pan_blocks.{i}.conv{j}.norm.{last}", + f"model.encoder.pan_blocks.{i}.conv{j}.norm.{last}", + ) + ) + + for j in range(3): + for k in range(1, 3): + rename_keys.append( + ( + f"encoder.pan_blocks.{i}.bottlenecks.{j}.conv{k}.conv.weight", + f"model.encoder.pan_blocks.{i}.bottlenecks.{j}.conv{k}.conv.weight", + ) + ) + for last in last_key: + rename_keys.append( + ( + f"encoder.pan_blocks.{i}.bottlenecks.{j}.conv{k}.norm.{last}", + f"model.encoder.pan_blocks.{i}.bottlenecks.{j}.conv{k}.norm.{last}", + ) + ) + + rename_keys.append( + (f"encoder.downsample_convs.{i}.conv.weight", f"model.encoder.downsample_convs.{i}.conv.weight") + ) + for last in last_key: + rename_keys.append( + (f"encoder.downsample_convs.{i}.norm.{last}", f"model.encoder.downsample_convs.{i}.norm.{last}") + ) + + for i in range(config.decoder_layers): + # decoder layers: 2 times output projection, 2 feedforward neural networks and 3 layernorms + rename_keys.append( + ( + f"decoder.decoder.layers.{i}.self_attn.out_proj.weight", + f"model.decoder.layers.{i}.self_attn.out_proj.weight", + ) + ) + rename_keys.append( + ( + f"decoder.decoder.layers.{i}.self_attn.out_proj.bias", + f"model.decoder.layers.{i}.self_attn.out_proj.bias", + ) + ) + rename_keys.append( + ( + f"decoder.decoder.layers.{i}.cross_attn.sampling_offsets.weight", + f"model.decoder.layers.{i}.encoder_attn.sampling_offsets.weight", + ) + ) + rename_keys.append( + ( + f"decoder.decoder.layers.{i}.cross_attn.sampling_offsets.bias", + f"model.decoder.layers.{i}.encoder_attn.sampling_offsets.bias", + ) + ) + rename_keys.append( + ( + f"decoder.decoder.layers.{i}.cross_attn.attention_weights.weight", + f"model.decoder.layers.{i}.encoder_attn.attention_weights.weight", + ) + ) + rename_keys.append( + ( + f"decoder.decoder.layers.{i}.cross_attn.attention_weights.bias", + f"model.decoder.layers.{i}.encoder_attn.attention_weights.bias", + ) + ) + rename_keys.append( + ( + f"decoder.decoder.layers.{i}.cross_attn.value_proj.weight", + f"model.decoder.layers.{i}.encoder_attn.value_proj.weight", + ) + ) + rename_keys.append( + ( + f"decoder.decoder.layers.{i}.cross_attn.value_proj.bias", + f"model.decoder.layers.{i}.encoder_attn.value_proj.bias", + ) + ) + rename_keys.append( + ( + f"decoder.decoder.layers.{i}.cross_attn.output_proj.weight", + f"model.decoder.layers.{i}.encoder_attn.output_proj.weight", + ) + ) + rename_keys.append( + ( + f"decoder.decoder.layers.{i}.cross_attn.output_proj.bias", + f"model.decoder.layers.{i}.encoder_attn.output_proj.bias", + ) + ) + rename_keys.append( + (f"decoder.decoder.layers.{i}.norm1.weight", f"model.decoder.layers.{i}.self_attn_layer_norm.weight") + ) + rename_keys.append( + (f"decoder.decoder.layers.{i}.norm1.bias", f"model.decoder.layers.{i}.self_attn_layer_norm.bias") + ) + rename_keys.append( + (f"decoder.decoder.layers.{i}.norm2.weight", f"model.decoder.layers.{i}.encoder_attn_layer_norm.weight") + ) + rename_keys.append( + (f"decoder.decoder.layers.{i}.norm2.bias", f"model.decoder.layers.{i}.encoder_attn_layer_norm.bias") + ) + rename_keys.append((f"decoder.decoder.layers.{i}.linear1.weight", f"model.decoder.layers.{i}.fc1.weight")) + rename_keys.append((f"decoder.decoder.layers.{i}.linear1.bias", f"model.decoder.layers.{i}.fc1.bias")) + rename_keys.append((f"decoder.decoder.layers.{i}.linear2.weight", f"model.decoder.layers.{i}.fc2.weight")) + rename_keys.append((f"decoder.decoder.layers.{i}.linear2.bias", f"model.decoder.layers.{i}.fc2.bias")) + rename_keys.append( + (f"decoder.decoder.layers.{i}.norm3.weight", f"model.decoder.layers.{i}.final_layer_norm.weight") + ) + rename_keys.append( + (f"decoder.decoder.layers.{i}.norm3.bias", f"model.decoder.layers.{i}.final_layer_norm.bias") + ) + + for i in range(config.decoder_layers): + # decoder + class and bounding box heads + rename_keys.append( + ( + f"decoder.dec_score_head.{i}.weight", + f"model.decoder.class_embed.{i}.weight", + ) + ) + rename_keys.append( + ( + f"decoder.dec_score_head.{i}.bias", + f"model.decoder.class_embed.{i}.bias", + ) + ) + rename_keys.append( + ( + f"decoder.dec_bbox_head.{i}.layers.0.weight", + f"model.decoder.bbox_embed.{i}.layers.0.weight", + ) + ) + rename_keys.append( + ( + f"decoder.dec_bbox_head.{i}.layers.0.bias", + f"model.decoder.bbox_embed.{i}.layers.0.bias", + ) + ) + rename_keys.append( + ( + f"decoder.dec_bbox_head.{i}.layers.1.weight", + f"model.decoder.bbox_embed.{i}.layers.1.weight", + ) + ) + rename_keys.append( + ( + f"decoder.dec_bbox_head.{i}.layers.1.bias", + f"model.decoder.bbox_embed.{i}.layers.1.bias", + ) + ) + rename_keys.append( + ( + f"decoder.dec_bbox_head.{i}.layers.2.weight", + f"model.decoder.bbox_embed.{i}.layers.2.weight", + ) + ) + rename_keys.append( + ( + f"decoder.dec_bbox_head.{i}.layers.2.bias", + f"model.decoder.bbox_embed.{i}.layers.2.bias", + ) + ) + + # decoder projection + for i in range(len(config.decoder_in_channels)): + rename_keys.append( + ( + f"decoder.input_proj.{i}.conv.weight", + f"model.decoder_input_proj.{i}.0.weight", + ) + ) + for last in last_key: + rename_keys.append( + ( + f"decoder.input_proj.{i}.norm.{last}", + f"model.decoder_input_proj.{i}.1.{last}", + ) + ) + + # convolutional projection + query embeddings + layernorm of decoder + class and bounding box heads + rename_keys.extend( + [ + ("decoder.denoising_class_embed.weight", "model.denoising_class_embed.weight"), + ("decoder.query_pos_head.layers.0.weight", "model.decoder.query_pos_head.layers.0.weight"), + ("decoder.query_pos_head.layers.0.bias", "model.decoder.query_pos_head.layers.0.bias"), + ("decoder.query_pos_head.layers.1.weight", "model.decoder.query_pos_head.layers.1.weight"), + ("decoder.query_pos_head.layers.1.bias", "model.decoder.query_pos_head.layers.1.bias"), + ("decoder.enc_output.0.weight", "model.enc_output.0.weight"), + ("decoder.enc_output.0.bias", "model.enc_output.0.bias"), + ("decoder.enc_output.1.weight", "model.enc_output.1.weight"), + ("decoder.enc_output.1.bias", "model.enc_output.1.bias"), + ("decoder.enc_score_head.weight", "model.enc_score_head.weight"), + ("decoder.enc_score_head.bias", "model.enc_score_head.bias"), + ("decoder.enc_bbox_head.layers.0.weight", "model.enc_bbox_head.layers.0.weight"), + ("decoder.enc_bbox_head.layers.0.bias", "model.enc_bbox_head.layers.0.bias"), + ("decoder.enc_bbox_head.layers.1.weight", "model.enc_bbox_head.layers.1.weight"), + ("decoder.enc_bbox_head.layers.1.bias", "model.enc_bbox_head.layers.1.bias"), + ("decoder.enc_bbox_head.layers.2.weight", "model.enc_bbox_head.layers.2.weight"), + ("decoder.enc_bbox_head.layers.2.bias", "model.enc_bbox_head.layers.2.bias"), + ] + ) + + return rename_keys + + +def rename_key(state_dict, old, new): + try: + val = state_dict.pop(old) + state_dict[new] = val + except Exception: + pass + + +def read_in_q_k_v(state_dict, config): + prefix = "" + encoder_hidden_dim = config.encoder_hidden_dim + + # first: transformer encoder + for i in range(config.encoder_layers): + # read in weights + bias of input projection layer (in PyTorch's MultiHeadAttention, this is a single matrix + bias) + in_proj_weight = state_dict.pop(f"{prefix}encoder.encoder.{i}.layers.0.self_attn.in_proj_weight") + in_proj_bias = state_dict.pop(f"{prefix}encoder.encoder.{i}.layers.0.self_attn.in_proj_bias") + # next, add query, keys and values (in that order) to the state dict + state_dict[f"model.encoder.encoder.{i}.layers.0.self_attn.q_proj.weight"] = in_proj_weight[ + :encoder_hidden_dim, : + ] + state_dict[f"model.encoder.encoder.{i}.layers.0.self_attn.q_proj.bias"] = in_proj_bias[:encoder_hidden_dim] + state_dict[f"model.encoder.encoder.{i}.layers.0.self_attn.k_proj.weight"] = in_proj_weight[ + encoder_hidden_dim : 2 * encoder_hidden_dim, : + ] + state_dict[f"model.encoder.encoder.{i}.layers.0.self_attn.k_proj.bias"] = in_proj_bias[ + encoder_hidden_dim : 2 * encoder_hidden_dim + ] + state_dict[f"model.encoder.encoder.{i}.layers.0.self_attn.v_proj.weight"] = in_proj_weight[ + -encoder_hidden_dim:, : + ] + state_dict[f"model.encoder.encoder.{i}.layers.0.self_attn.v_proj.bias"] = in_proj_bias[-encoder_hidden_dim:] + # next: transformer decoder (which is a bit more complex because it also includes cross-attention) + for i in range(config.decoder_layers): + # read in weights + bias of input projection layer of self-attention + in_proj_weight = state_dict.pop(f"{prefix}decoder.decoder.layers.{i}.self_attn.in_proj_weight") + in_proj_bias = state_dict.pop(f"{prefix}decoder.decoder.layers.{i}.self_attn.in_proj_bias") + # next, add query, keys and values (in that order) to the state dict + state_dict[f"model.decoder.layers.{i}.self_attn.q_proj.weight"] = in_proj_weight[:256, :] + state_dict[f"model.decoder.layers.{i}.self_attn.q_proj.bias"] = in_proj_bias[:256] + state_dict[f"model.decoder.layers.{i}.self_attn.k_proj.weight"] = in_proj_weight[256:512, :] + state_dict[f"model.decoder.layers.{i}.self_attn.k_proj.bias"] = in_proj_bias[256:512] + state_dict[f"model.decoder.layers.{i}.self_attn.v_proj.weight"] = in_proj_weight[-256:, :] + state_dict[f"model.decoder.layers.{i}.self_attn.v_proj.bias"] = in_proj_bias[-256:] + + +# We will verify our results on an image of cute cats +def prepare_img(): + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + im = Image.open(requests.get(url, stream=True).raw) + + return im + + +@torch.no_grad() +def convert_rt_detr_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub, repo_id): + """ + Copy/paste/tweak model's weights to our RTDETR structure. + """ + + # load default config + config = get_rt_detr_config(model_name) + + # load original model from torch hub + model_name_to_checkpoint_url = { + "rtdetr_r18vd": "https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetr_r18vd_dec3_6x_coco_from_paddle.pth", + "rtdetr_r34vd": "https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetr_r34vd_dec4_6x_coco_from_paddle.pth", + "rtdetr_r50vd_m": "https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetr_r50vd_m_6x_coco_from_paddle.pth", + "rtdetr_r50vd": "https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetr_r50vd_6x_coco_from_paddle.pth", + "rtdetr_r101vd": "https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetr_r101vd_6x_coco_from_paddle.pth", + "rtdetr_r18vd_coco_o365": "https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetr_r18vd_5x_coco_objects365_from_paddle.pth", + "rtdetr_r50vd_coco_o365": "https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetr_r50vd_2x_coco_objects365_from_paddle.pth", + "rtdetr_r101vd_coco_o365": "https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetr_r101vd_2x_coco_objects365_from_paddle.pth", + } + logger.info(f"Converting model {model_name}...") + state_dict = torch.hub.load_state_dict_from_url(model_name_to_checkpoint_url[model_name], map_location="cpu")[ + "ema" + ]["module"] + + # rename keys + for src, dest in create_rename_keys(config): + rename_key(state_dict, src, dest) + # query, key and value matrices need special treatment + read_in_q_k_v(state_dict, config) + # important: we need to prepend a prefix to each of the base model keys as the head models use different attributes for them + for key in state_dict.copy().keys(): + if key.endswith("num_batches_tracked"): + del state_dict[key] + # for two_stage + if "bbox_embed" in key or ("class_embed" in key and "denoising_" not in key): + state_dict[key.split("model.decoder.")[-1]] = state_dict[key] + + # finally, create HuggingFace model and load state dict + model = RTDetrForObjectDetection(config) + model.load_state_dict(state_dict) + model.eval() + + # load image processor + image_processor = RTDetrImageProcessor() + + # prepare image + img = prepare_img() + + # preprocess image + transformations = transforms.Compose( + [ + transforms.Resize([640, 640], interpolation=transforms.InterpolationMode.BILINEAR), + transforms.ToTensor(), + ] + ) + original_pixel_values = transformations(img).unsqueeze(0) # insert batch dimension + + encoding = image_processor(images=img, return_tensors="pt") + pixel_values = encoding["pixel_values"] + + assert torch.allclose(original_pixel_values, pixel_values) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model.to(device) + pixel_values = pixel_values.to(device) + + # Pass image by the model + outputs = model(pixel_values) + + if model_name == "rtdetr_r18vd": + expected_slice_logits = torch.tensor( + [ + [-4.3364253, -6.465683, -3.6130402], + [-4.083815, -6.4039373, -6.97881], + [-4.192215, -7.3410473, -6.9027247], + ] + ) + expected_slice_boxes = torch.tensor( + [ + [0.16868353, 0.19833282, 0.21182671], + [0.25559652, 0.55121744, 0.47988364], + [0.7698693, 0.4124569, 0.46036878], + ] + ) + elif model_name == "rtdetr_r34vd": + expected_slice_logits = torch.tensor( + [ + [-4.3727384, -4.7921476, -5.7299604], + [-4.840536, -8.455345, -4.1745796], + [-4.1277084, -5.2154565, -5.7852697], + ] + ) + expected_slice_boxes = torch.tensor( + [ + [0.258278, 0.5497808, 0.4732004], + [0.16889669, 0.19890057, 0.21138911], + [0.76632994, 0.4147879, 0.46851268], + ] + ) + elif model_name == "rtdetr_r50vd_m": + expected_slice_logits = torch.tensor( + [ + [-4.319764, -6.1349025, -6.094794], + [-5.1056995, -7.744766, -4.803956], + [-4.7685347, -7.9278393, -4.5751696], + ] + ) + expected_slice_boxes = torch.tensor( + [ + [0.2582739, 0.55071366, 0.47660282], + [0.16811174, 0.19954777, 0.21292639], + [0.54986024, 0.2752091, 0.0561416], + ] + ) + elif model_name == "rtdetr_r50vd": + expected_slice_logits = torch.tensor( + [ + [-4.6476398, -5.001154, -4.9785104], + [-4.1593494, -4.7038546, -5.946485], + [-4.4374595, -4.658361, -6.2352347], + ] + ) + expected_slice_boxes = torch.tensor( + [ + [0.16880608, 0.19992264, 0.21225442], + [0.76837635, 0.4122631, 0.46368608], + [0.2595386, 0.5483334, 0.4777486], + ] + ) + elif model_name == "rtdetr_r101vd": + expected_slice_logits = torch.tensor( + [ + [-4.6162, -4.9189, -4.6656], + [-4.4701, -4.4997, -4.9659], + [-5.6641, -7.9000, -5.0725], + ] + ) + expected_slice_boxes = torch.tensor( + [ + [0.7707, 0.4124, 0.4585], + [0.2589, 0.5492, 0.4735], + [0.1688, 0.1993, 0.2108], + ] + ) + elif model_name == "rtdetr_r18vd_coco_o365": + expected_slice_logits = torch.tensor( + [ + [-4.8726, -5.9066, -5.2450], + [-4.8157, -6.8764, -5.1656], + [-4.7492, -5.7006, -5.1333], + ] + ) + expected_slice_boxes = torch.tensor( + [ + [0.2552, 0.5501, 0.4773], + [0.1685, 0.1986, 0.2104], + [0.7692, 0.4141, 0.4620], + ] + ) + elif model_name == "rtdetr_r50vd_coco_o365": + expected_slice_logits = torch.tensor( + [ + [-4.6491, -3.9252, -5.3163], + [-4.1386, -5.0348, -3.9016], + [-4.4778, -4.5423, -5.7356], + ] + ) + expected_slice_boxes = torch.tensor( + [ + [0.2583, 0.5492, 0.4747], + [0.5501, 0.2754, 0.0574], + [0.7693, 0.4137, 0.4613], + ] + ) + elif model_name == "rtdetr_r101vd_coco_o365": + expected_slice_logits = torch.tensor( + [ + [-4.5152, -5.6811, -5.7311], + [-4.5358, -7.2422, -5.0941], + [-4.6919, -5.5834, -6.0145], + ] + ) + expected_slice_boxes = torch.tensor( + [ + [0.7703, 0.4140, 0.4583], + [0.1686, 0.1991, 0.2107], + [0.2570, 0.5496, 0.4750], + ] + ) + else: + raise ValueError(f"Unknown rt_detr_name: {model_name}") + + assert torch.allclose(outputs.logits[0, :3, :3], expected_slice_logits.to(outputs.logits.device), atol=1e-4) + assert torch.allclose(outputs.pred_boxes[0, :3, :3], expected_slice_boxes.to(outputs.pred_boxes.device), atol=1e-3) + + if pytorch_dump_folder_path is not None: + Path(pytorch_dump_folder_path).mkdir(exist_ok=True) + print(f"Saving model {model_name} to {pytorch_dump_folder_path}") + model.save_pretrained(pytorch_dump_folder_path) + print(f"Saving image processor to {pytorch_dump_folder_path}") + image_processor.save_pretrained(pytorch_dump_folder_path) + + if push_to_hub: + # Upload model, image processor and config to the hub + logger.info("Uploading PyTorch model and image processor to the hub...") + config.push_to_hub( + repo_id=repo_id, commit_message="Add config from convert_rt_detr_original_pytorch_checkpoint_to_pytorch.py" + ) + model.push_to_hub( + repo_id=repo_id, commit_message="Add model from convert_rt_detr_original_pytorch_checkpoint_to_pytorch.py" + ) + image_processor.push_to_hub( + repo_id=repo_id, + commit_message="Add image processor from convert_rt_detr_original_pytorch_checkpoint_to_pytorch.py", + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_name", + default="rtdetr_r50vd", + type=str, + help="model_name of the checkpoint you'd like to convert.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory." + ) + parser.add_argument("--push_to_hub", action="store_true", help="Whether to push the model to the hub or not.") + parser.add_argument( + "--repo_id", + type=str, + help="repo_id where the model will be pushed to.", + ) + args = parser.parse_args() + convert_rt_detr_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub, args.repo_id) diff --git a/docs/transformers/src/transformers/models/rt_detr/image_processing_rt_detr.py b/docs/transformers/src/transformers/models/rt_detr/image_processing_rt_detr.py new file mode 100644 index 0000000000000000000000000000000000000000..e458de37949b0f08ac390d2118167f4d15e4d06e --- /dev/null +++ b/docs/transformers/src/transformers/models/rt_detr/image_processing_rt_detr.py @@ -0,0 +1,1102 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for RT-DETR.""" + +import pathlib +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union + +import numpy as np + +from ...feature_extraction_utils import BatchFeature +from ...image_processing_utils import BaseImageProcessor, get_size_dict +from ...image_transforms import ( + PaddingMode, + center_to_corners_format, + corners_to_center_format, + pad, + rescale, + resize, + to_channel_dimension_format, +) +from ...image_utils import ( + IMAGENET_DEFAULT_MEAN, + IMAGENET_DEFAULT_STD, + AnnotationFormat, + AnnotationType, + ChannelDimension, + ImageInput, + PILImageResampling, + get_image_size, + infer_channel_dimension_format, + is_scaled_image, + make_list_of_images, + to_numpy_array, + valid_images, + validate_annotations, + validate_preprocess_arguments, +) +from ...utils import ( + filter_out_non_signature_kwargs, + is_flax_available, + is_jax_tensor, + is_tf_available, + is_tf_tensor, + is_torch_available, + is_torch_tensor, + logging, + requires_backends, +) +from ...utils.generic import TensorType + + +if is_torch_available(): + import torch + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +SUPPORTED_ANNOTATION_FORMATS = (AnnotationFormat.COCO_DETECTION,) + + +# Copied from transformers.models.detr.image_processing_detr.get_size_with_aspect_ratio +def get_size_with_aspect_ratio(image_size, size, max_size=None) -> Tuple[int, int]: + """ + Computes the output image size given the input image size and the desired output size. + + Args: + image_size (`Tuple[int, int]`): + The input image size. + size (`int`): + The desired output size. + max_size (`int`, *optional*): + The maximum allowed output size. + """ + height, width = image_size + raw_size = None + if max_size is not None: + min_original_size = float(min((height, width))) + max_original_size = float(max((height, width))) + if max_original_size / min_original_size * size > max_size: + raw_size = max_size * min_original_size / max_original_size + size = int(round(raw_size)) + + if (height <= width and height == size) or (width <= height and width == size): + oh, ow = height, width + elif width < height: + ow = size + if max_size is not None and raw_size is not None: + oh = int(raw_size * height / width) + else: + oh = int(size * height / width) + else: + oh = size + if max_size is not None and raw_size is not None: + ow = int(raw_size * width / height) + else: + ow = int(size * width / height) + + return (oh, ow) + + +# Copied from transformers.models.detr.image_processing_detr.get_resize_output_image_size +def get_resize_output_image_size( + input_image: np.ndarray, + size: Union[int, Tuple[int, int], List[int]], + max_size: Optional[int] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, +) -> Tuple[int, int]: + """ + Computes the output image size given the input image size and the desired output size. If the desired output size + is a tuple or list, the output image size is returned as is. If the desired output size is an integer, the output + image size is computed by keeping the aspect ratio of the input image size. + + Args: + input_image (`np.ndarray`): + The image to resize. + size (`int` or `Tuple[int, int]` or `List[int]`): + The desired output size. + max_size (`int`, *optional*): + The maximum allowed output size. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred from the input image. + """ + image_size = get_image_size(input_image, input_data_format) + if isinstance(size, (list, tuple)): + return size + + return get_size_with_aspect_ratio(image_size, size, max_size) + + +# Copied from transformers.models.detr.image_processing_detr.get_image_size_for_max_height_width +def get_image_size_for_max_height_width( + input_image: np.ndarray, + max_height: int, + max_width: int, + input_data_format: Optional[Union[str, ChannelDimension]] = None, +) -> Tuple[int, int]: + """ + Computes the output image size given the input image and the maximum allowed height and width. Keep aspect ratio. + Important, even if image_height < max_height and image_width < max_width, the image will be resized + to at least one of the edges be equal to max_height or max_width. + For example: + - input_size: (100, 200), max_height: 50, max_width: 50 -> output_size: (25, 50) + - input_size: (100, 200), max_height: 200, max_width: 500 -> output_size: (200, 400) + Args: + input_image (`np.ndarray`): + The image to resize. + max_height (`int`): + The maximum allowed height. + max_width (`int`): + The maximum allowed width. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred from the input image. + """ + image_size = get_image_size(input_image, input_data_format) + height, width = image_size + height_scale = max_height / height + width_scale = max_width / width + min_scale = min(height_scale, width_scale) + new_height = int(height * min_scale) + new_width = int(width * min_scale) + return new_height, new_width + + +# Copied from transformers.models.detr.image_processing_detr.get_numpy_to_framework_fn +def get_numpy_to_framework_fn(arr) -> Callable: + """ + Returns a function that converts a numpy array to the framework of the input array. + + Args: + arr (`np.ndarray`): The array to convert. + """ + if isinstance(arr, np.ndarray): + return np.array + if is_tf_available() and is_tf_tensor(arr): + import tensorflow as tf + + return tf.convert_to_tensor + if is_torch_available() and is_torch_tensor(arr): + import torch + + return torch.tensor + if is_flax_available() and is_jax_tensor(arr): + import jax.numpy as jnp + + return jnp.array + raise ValueError(f"Cannot convert arrays of type {type(arr)}") + + +# Copied from transformers.models.detr.image_processing_detr.safe_squeeze +def safe_squeeze(arr: np.ndarray, axis: Optional[int] = None) -> np.ndarray: + """ + Squeezes an array, but only if the axis specified has dim 1. + """ + if axis is None: + return arr.squeeze() + + try: + return arr.squeeze(axis=axis) + except ValueError: + return arr + + +# Copied from transformers.models.detr.image_processing_detr.normalize_annotation +def normalize_annotation(annotation: Dict, image_size: Tuple[int, int]) -> Dict: + image_height, image_width = image_size + norm_annotation = {} + for key, value in annotation.items(): + if key == "boxes": + boxes = value + boxes = corners_to_center_format(boxes) + boxes /= np.asarray([image_width, image_height, image_width, image_height], dtype=np.float32) + norm_annotation[key] = boxes + else: + norm_annotation[key] = value + return norm_annotation + + +# Copied from transformers.models.detr.image_processing_detr.max_across_indices +def max_across_indices(values: Iterable[Any]) -> List[Any]: + """ + Return the maximum value across all indices of an iterable of values. + """ + return [max(values_i) for values_i in zip(*values)] + + +# Copied from transformers.models.detr.image_processing_detr.get_max_height_width +def get_max_height_width( + images: List[np.ndarray], input_data_format: Optional[Union[str, ChannelDimension]] = None +) -> List[int]: + """ + Get the maximum height and width across all images in a batch. + """ + if input_data_format is None: + input_data_format = infer_channel_dimension_format(images[0]) + + if input_data_format == ChannelDimension.FIRST: + _, max_height, max_width = max_across_indices([img.shape for img in images]) + elif input_data_format == ChannelDimension.LAST: + max_height, max_width, _ = max_across_indices([img.shape for img in images]) + else: + raise ValueError(f"Invalid channel dimension format: {input_data_format}") + return (max_height, max_width) + + +# Copied from transformers.models.detr.image_processing_detr.make_pixel_mask +def make_pixel_mask( + image: np.ndarray, output_size: Tuple[int, int], input_data_format: Optional[Union[str, ChannelDimension]] = None +) -> np.ndarray: + """ + Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding. + + Args: + image (`np.ndarray`): + Image to make the pixel mask for. + output_size (`Tuple[int, int]`): + Output size of the mask. + """ + input_height, input_width = get_image_size(image, channel_dim=input_data_format) + mask = np.zeros(output_size, dtype=np.int64) + mask[:input_height, :input_width] = 1 + return mask + + +def prepare_coco_detection_annotation( + image, + target, + return_segmentation_masks: bool = False, + input_data_format: Optional[Union[ChannelDimension, str]] = None, +): + """ + Convert the target in COCO format into the format expected by RTDETR. + """ + image_height, image_width = get_image_size(image, channel_dim=input_data_format) + + image_id = target["image_id"] + image_id = np.asarray([image_id], dtype=np.int64) + + # Get all COCO annotations for the given image. + annotations = target["annotations"] + annotations = [obj for obj in annotations if "iscrowd" not in obj or obj["iscrowd"] == 0] + + classes = [obj["category_id"] for obj in annotations] + classes = np.asarray(classes, dtype=np.int64) + + # for conversion to coco api + area = np.asarray([obj["area"] for obj in annotations], dtype=np.float32) + iscrowd = np.asarray([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in annotations], dtype=np.int64) + + boxes = [obj["bbox"] for obj in annotations] + # guard against no boxes via resizing + boxes = np.asarray(boxes, dtype=np.float32).reshape(-1, 4) + boxes[:, 2:] += boxes[:, :2] + boxes[:, 0::2] = boxes[:, 0::2].clip(min=0, max=image_width) + boxes[:, 1::2] = boxes[:, 1::2].clip(min=0, max=image_height) + + keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) + + new_target = {} + new_target["image_id"] = image_id + new_target["class_labels"] = classes[keep] + new_target["boxes"] = boxes[keep] + new_target["area"] = area[keep] + new_target["iscrowd"] = iscrowd[keep] + new_target["orig_size"] = np.asarray([int(image_height), int(image_width)], dtype=np.int64) + + if annotations and "keypoints" in annotations[0]: + keypoints = [obj["keypoints"] for obj in annotations] + # Converting the filtered keypoints list to a numpy array + keypoints = np.asarray(keypoints, dtype=np.float32) + # Apply the keep mask here to filter the relevant annotations + keypoints = keypoints[keep] + num_keypoints = keypoints.shape[0] + keypoints = keypoints.reshape((-1, 3)) if num_keypoints else keypoints + new_target["keypoints"] = keypoints + + return new_target + + +# Copied from transformers.models.detr.image_processing_detr.resize_annotation +def resize_annotation( + annotation: Dict[str, Any], + orig_size: Tuple[int, int], + target_size: Tuple[int, int], + threshold: float = 0.5, + resample: PILImageResampling = PILImageResampling.NEAREST, +): + """ + Resizes an annotation to a target size. + + Args: + annotation (`Dict[str, Any]`): + The annotation dictionary. + orig_size (`Tuple[int, int]`): + The original size of the input image. + target_size (`Tuple[int, int]`): + The target size of the image, as returned by the preprocessing `resize` step. + threshold (`float`, *optional*, defaults to 0.5): + The threshold used to binarize the segmentation masks. + resample (`PILImageResampling`, defaults to `PILImageResampling.NEAREST`): + The resampling filter to use when resizing the masks. + """ + ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(target_size, orig_size)) + ratio_height, ratio_width = ratios + + new_annotation = {} + new_annotation["size"] = target_size + + for key, value in annotation.items(): + if key == "boxes": + boxes = value + scaled_boxes = boxes * np.asarray([ratio_width, ratio_height, ratio_width, ratio_height], dtype=np.float32) + new_annotation["boxes"] = scaled_boxes + elif key == "area": + area = value + scaled_area = area * (ratio_width * ratio_height) + new_annotation["area"] = scaled_area + elif key == "masks": + masks = value[:, None] + masks = np.array([resize(mask, target_size, resample=resample) for mask in masks]) + masks = masks.astype(np.float32) + masks = masks[:, 0] > threshold + new_annotation["masks"] = masks + elif key == "size": + new_annotation["size"] = target_size + else: + new_annotation[key] = value + + return new_annotation + + +class RTDetrImageProcessor(BaseImageProcessor): + r""" + Constructs a RT-DETR image processor. + + Args: + format (`str`, *optional*, defaults to `AnnotationFormat.COCO_DETECTION`): + Data format of the annotations. One of "coco_detection" or "coco_panoptic". + do_resize (`bool`, *optional*, defaults to `True`): + Controls whether to resize the image's (height, width) dimensions to the specified `size`. Can be + overridden by the `do_resize` parameter in the `preprocess` method. + size (`Dict[str, int]` *optional*, defaults to `{"height": 640, "width": 640}`): + Size of the image's `(height, width)` dimensions after resizing. Can be overridden by the `size` parameter + in the `preprocess` method. Available options are: + - `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`. + Do NOT keep the aspect ratio. + - `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting + the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge + less or equal to `longest_edge`. + - `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the + aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to + `max_width`. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`): + Resampling filter to use if resizing the image. + do_rescale (`bool`, *optional*, defaults to `True`): + Controls whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the + `do_rescale` parameter in the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the + `preprocess` method. + Controls whether to normalize the image. Can be overridden by the `do_normalize` parameter in the + `preprocess` method. + do_normalize (`bool`, *optional*, defaults to `False`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_DEFAULT_MEAN`): + Mean values to use when normalizing the image. Can be a single value or a list of values, one for each + channel. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_DEFAULT_STD`): + Standard deviation values to use when normalizing the image. Can be a single value or a list of values, one + for each channel. Can be overridden by the `image_std` parameter in the `preprocess` method. + do_convert_annotations (`bool`, *optional*, defaults to `True`): + Controls whether to convert the annotations to the format expected by the DETR model. Converts the + bounding boxes to the format `(center_x, center_y, width, height)` and in the range `[0, 1]`. + Can be overridden by the `do_convert_annotations` parameter in the `preprocess` method. + do_pad (`bool`, *optional*, defaults to `False`): + Controls whether to pad the image. Can be overridden by the `do_pad` parameter in the `preprocess` + method. If `True`, padding will be applied to the bottom and right of the image with zeros. + If `pad_size` is provided, the image will be padded to the specified dimensions. + Otherwise, the image will be padded to the maximum height and width of the batch. + pad_size (`Dict[str, int]`, *optional*): + The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size + provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest + height and width in the batch. + """ + + model_input_names = ["pixel_values", "pixel_mask"] + + def __init__( + self, + format: Union[str, AnnotationFormat] = AnnotationFormat.COCO_DETECTION, + do_resize: bool = True, + size: Dict[str, int] = None, + resample: PILImageResampling = PILImageResampling.BILINEAR, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = False, + image_mean: Union[float, List[float]] = None, + image_std: Union[float, List[float]] = None, + do_convert_annotations: bool = True, + do_pad: bool = False, + pad_size: Optional[Dict[str, int]] = None, + **kwargs, + ) -> None: + size = size if size is not None else {"height": 640, "width": 640} + size = get_size_dict(size, default_to_square=False) + + if do_convert_annotations is None: + do_convert_annotations = do_normalize + + super().__init__(**kwargs) + self.format = format + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.do_convert_annotations = do_convert_annotations + self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN + self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD + self.do_pad = do_pad + self.pad_size = pad_size + + def prepare_annotation( + self, + image: np.ndarray, + target: Dict, + format: Optional[AnnotationFormat] = None, + return_segmentation_masks: Optional[bool] = None, + masks_path: Optional[Union[str, pathlib.Path]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> Dict: + """ + Prepare an annotation for feeding into RTDETR model. + """ + format = format if format is not None else self.format + + if format == AnnotationFormat.COCO_DETECTION: + return_segmentation_masks = False if return_segmentation_masks is None else return_segmentation_masks + target = prepare_coco_detection_annotation( + image, target, return_segmentation_masks, input_data_format=input_data_format + ) + else: + raise ValueError(f"Format {format} is not supported.") + return target + + # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.resize + def resize( + self, + image: np.ndarray, + size: Dict[str, int], + resample: PILImageResampling = PILImageResampling.BILINEAR, + data_format: Optional[ChannelDimension] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize the image to the given size. Size can be `min_size` (scalar) or `(height, width)` tuple. If size is an + int, smaller edge of the image will be matched to this number. + + Args: + image (`np.ndarray`): + Image to resize. + size (`Dict[str, int]`): + Size of the image's `(height, width)` dimensions after resizing. Available options are: + - `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`. + Do NOT keep the aspect ratio. + - `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting + the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge + less or equal to `longest_edge`. + - `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the + aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to + `max_width`. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`): + Resampling filter to use if resizing the image. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the output image. If unset, the channel dimension format of the input + image is used. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + if "max_size" in kwargs: + logger.warning_once( + "The `max_size` parameter is deprecated and will be removed in v4.26. " + "Please specify in `size['longest_edge'] instead`.", + ) + max_size = kwargs.pop("max_size") + else: + max_size = None + size = get_size_dict(size, max_size=max_size, default_to_square=False) + if "shortest_edge" in size and "longest_edge" in size: + new_size = get_resize_output_image_size( + image, size["shortest_edge"], size["longest_edge"], input_data_format=input_data_format + ) + elif "max_height" in size and "max_width" in size: + new_size = get_image_size_for_max_height_width( + image, size["max_height"], size["max_width"], input_data_format=input_data_format + ) + elif "height" in size and "width" in size: + new_size = (size["height"], size["width"]) + else: + raise ValueError( + "Size must contain 'height' and 'width' keys or 'shortest_edge' and 'longest_edge' keys. Got" + f" {size.keys()}." + ) + image = resize( + image, + size=new_size, + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + return image + + # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.resize_annotation + def resize_annotation( + self, + annotation, + orig_size, + size, + resample: PILImageResampling = PILImageResampling.NEAREST, + ) -> Dict: + """ + Resize the annotation to match the resized image. If size is an int, smaller edge of the mask will be matched + to this number. + """ + return resize_annotation(annotation, orig_size=orig_size, target_size=size, resample=resample) + + # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.rescale + def rescale( + self, + image: np.ndarray, + rescale_factor: float, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> np.ndarray: + """ + Rescale the image by the given factor. image = image * rescale_factor. + + Args: + image (`np.ndarray`): + Image to rescale. + rescale_factor (`float`): + The value to use for rescaling. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the output image. If unset, the channel dimension format of the input + image is used. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + input_data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the input image. If unset, is inferred from the input image. Can be + one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + """ + return rescale(image, rescale_factor, data_format=data_format, input_data_format=input_data_format) + + # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.normalize_annotation + def normalize_annotation(self, annotation: Dict, image_size: Tuple[int, int]) -> Dict: + """ + Normalize the boxes in the annotation from `[top_left_x, top_left_y, bottom_right_x, bottom_right_y]` to + `[center_x, center_y, width, height]` format and from absolute to relative pixel values. + """ + return normalize_annotation(annotation, image_size=image_size) + + # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor._update_annotation_for_padded_image + def _update_annotation_for_padded_image( + self, + annotation: Dict, + input_image_size: Tuple[int, int], + output_image_size: Tuple[int, int], + padding, + update_bboxes, + ) -> Dict: + """ + Update the annotation for a padded image. + """ + new_annotation = {} + new_annotation["size"] = output_image_size + + for key, value in annotation.items(): + if key == "masks": + masks = value + masks = pad( + masks, + padding, + mode=PaddingMode.CONSTANT, + constant_values=0, + input_data_format=ChannelDimension.FIRST, + ) + masks = safe_squeeze(masks, 1) + new_annotation["masks"] = masks + elif key == "boxes" and update_bboxes: + boxes = value + boxes *= np.asarray( + [ + input_image_size[1] / output_image_size[1], + input_image_size[0] / output_image_size[0], + input_image_size[1] / output_image_size[1], + input_image_size[0] / output_image_size[0], + ] + ) + new_annotation["boxes"] = boxes + elif key == "size": + new_annotation["size"] = output_image_size + else: + new_annotation[key] = value + return new_annotation + + # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor._pad_image + def _pad_image( + self, + image: np.ndarray, + output_size: Tuple[int, int], + annotation: Optional[Dict[str, Any]] = None, + constant_values: Union[float, Iterable[float]] = 0, + data_format: Optional[ChannelDimension] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + update_bboxes: bool = True, + ) -> np.ndarray: + """ + Pad an image with zeros to the given size. + """ + input_height, input_width = get_image_size(image, channel_dim=input_data_format) + output_height, output_width = output_size + + pad_bottom = output_height - input_height + pad_right = output_width - input_width + padding = ((0, pad_bottom), (0, pad_right)) + padded_image = pad( + image, + padding, + mode=PaddingMode.CONSTANT, + constant_values=constant_values, + data_format=data_format, + input_data_format=input_data_format, + ) + if annotation is not None: + annotation = self._update_annotation_for_padded_image( + annotation, (input_height, input_width), (output_height, output_width), padding, update_bboxes + ) + return padded_image, annotation + + # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.pad + def pad( + self, + images: List[np.ndarray], + annotations: Optional[Union[AnnotationType, List[AnnotationType]]] = None, + constant_values: Union[float, Iterable[float]] = 0, + return_pixel_mask: bool = True, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Optional[ChannelDimension] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + update_bboxes: bool = True, + pad_size: Optional[Dict[str, int]] = None, + ) -> BatchFeature: + """ + Pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width + in the batch and optionally returns their corresponding pixel mask. + + Args: + images (List[`np.ndarray`]): + Images to pad. + annotations (`AnnotationType` or `List[AnnotationType]`, *optional*): + Annotations to transform according to the padding that is applied to the images. + constant_values (`float` or `Iterable[float]`, *optional*): + The value to use for the padding if `mode` is `"constant"`. + return_pixel_mask (`bool`, *optional*, defaults to `True`): + Whether to return a pixel mask. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be the same as the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + update_bboxes (`bool`, *optional*, defaults to `True`): + Whether to update the bounding boxes in the annotations to match the padded images. If the + bounding boxes have not been converted to relative coordinates and `(centre_x, centre_y, width, height)` + format, the bounding boxes will not be updated. + pad_size (`Dict[str, int]`, *optional*): + The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size + provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest + height and width in the batch. + """ + pad_size = pad_size if pad_size is not None else self.pad_size + if pad_size is not None: + padded_size = (pad_size["height"], pad_size["width"]) + else: + padded_size = get_max_height_width(images, input_data_format=input_data_format) + + annotation_list = annotations if annotations is not None else [None] * len(images) + padded_images = [] + padded_annotations = [] + for image, annotation in zip(images, annotation_list): + padded_image, padded_annotation = self._pad_image( + image, + padded_size, + annotation, + constant_values=constant_values, + data_format=data_format, + input_data_format=input_data_format, + update_bboxes=update_bboxes, + ) + padded_images.append(padded_image) + padded_annotations.append(padded_annotation) + + data = {"pixel_values": padded_images} + + if return_pixel_mask: + masks = [ + make_pixel_mask(image=image, output_size=padded_size, input_data_format=input_data_format) + for image in images + ] + data["pixel_mask"] = masks + + encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors) + + if annotations is not None: + encoded_inputs["labels"] = [ + BatchFeature(annotation, tensor_type=return_tensors) for annotation in padded_annotations + ] + + return encoded_inputs + + @filter_out_non_signature_kwargs() + def preprocess( + self, + images: ImageInput, + annotations: Optional[Union[AnnotationType, List[AnnotationType]]] = None, + return_segmentation_masks: Optional[bool] = None, + masks_path: Optional[Union[str, pathlib.Path]] = None, + do_resize: Optional[bool] = None, + size: Optional[Dict[str, int]] = None, + resample=None, # PILImageResampling + do_rescale: Optional[bool] = None, + rescale_factor: Optional[Union[int, float]] = None, + do_normalize: Optional[bool] = None, + do_convert_annotations: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_pad: Optional[bool] = None, + format: Optional[Union[str, AnnotationFormat]] = None, + return_tensors: Optional[Union[TensorType, str]] = None, + data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + pad_size: Optional[Dict[str, int]] = None, + ) -> BatchFeature: + """ + Preprocess an image or a batch of images so that it can be used by the model. + + Args: + images (`ImageInput`): + Image or batch of images to preprocess. Expects a single or batch of images with pixel values ranging + from 0 to 255. If passing in images with pixel values between 0 and 1, set `do_rescale=False`. + annotations (`AnnotationType` or `List[AnnotationType]`, *optional*): + List of annotations associated with the image or batch of images. If annotation is for object + detection, the annotations should be a dictionary with the following keys: + - "image_id" (`int`): The image id. + - "annotations" (`List[Dict]`): List of annotations for an image. Each annotation should be a + dictionary. An image can have no annotations, in which case the list should be empty. + If annotation is for segmentation, the annotations should be a dictionary with the following keys: + - "image_id" (`int`): The image id. + - "segments_info" (`List[Dict]`): List of segments for an image. Each segment should be a dictionary. + An image can have no segments, in which case the list should be empty. + - "file_name" (`str`): The file name of the image. + return_segmentation_masks (`bool`, *optional*, defaults to self.return_segmentation_masks): + Whether to return segmentation masks. + masks_path (`str` or `pathlib.Path`, *optional*): + Path to the directory containing the segmentation masks. + do_resize (`bool`, *optional*, defaults to self.do_resize): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to self.size): + Size of the image's `(height, width)` dimensions after resizing. Available options are: + - `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`. + Do NOT keep the aspect ratio. + - `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting + the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge + less or equal to `longest_edge`. + - `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the + aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to + `max_width`. + resample (`PILImageResampling`, *optional*, defaults to self.resample): + Resampling filter to use when resizing the image. + do_rescale (`bool`, *optional*, defaults to self.do_rescale): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to self.rescale_factor): + Rescale factor to use when rescaling the image. + do_normalize (`bool`, *optional*, defaults to self.do_normalize): + Whether to normalize the image. + do_convert_annotations (`bool`, *optional*, defaults to self.do_convert_annotations): + Whether to convert the annotations to the format expected by the model. Converts the bounding + boxes from the format `(top_left_x, top_left_y, width, height)` to `(center_x, center_y, width, height)` + and in relative coordinates. + image_mean (`float` or `List[float]`, *optional*, defaults to self.image_mean): + Mean to use when normalizing the image. + image_std (`float` or `List[float]`, *optional*, defaults to self.image_std): + Standard deviation to use when normalizing the image. + do_pad (`bool`, *optional*, defaults to self.do_pad): + Whether to pad the image. If `True`, padding will be applied to the bottom and right of + the image with zeros. If `pad_size` is provided, the image will be padded to the specified + dimensions. Otherwise, the image will be padded to the maximum height and width of the batch. + format (`str` or `AnnotationFormat`, *optional*, defaults to self.format): + Format of the annotations. + return_tensors (`str` or `TensorType`, *optional*, defaults to self.return_tensors): + Type of tensors to return. If `None`, will return the list of images. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + pad_size (`Dict[str, int]`, *optional*): + The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size + provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest + height and width in the batch. + """ + do_resize = self.do_resize if do_resize is None else do_resize + size = self.size if size is None else size + size = get_size_dict(size=size, default_to_square=True) + resample = self.resample if resample is None else resample + do_rescale = self.do_rescale if do_rescale is None else do_rescale + rescale_factor = self.rescale_factor if rescale_factor is None else rescale_factor + do_normalize = self.do_normalize if do_normalize is None else do_normalize + image_mean = self.image_mean if image_mean is None else image_mean + image_std = self.image_std if image_std is None else image_std + do_convert_annotations = ( + self.do_convert_annotations if do_convert_annotations is None else do_convert_annotations + ) + do_pad = self.do_pad if do_pad is None else do_pad + pad_size = self.pad_size if pad_size is None else pad_size + format = self.format if format is None else format + + images = make_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + # Here, the pad() method pads to the maximum of (width, height). It does not need to be validated. + + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_resize=do_resize, + size=size, + resample=resample, + ) + + if annotations is not None and isinstance(annotations, dict): + annotations = [annotations] + + if annotations is not None and len(images) != len(annotations): + raise ValueError( + f"The number of images ({len(images)}) and annotations ({len(annotations)}) do not match." + ) + + format = AnnotationFormat(format) + if annotations is not None: + validate_annotations(format, SUPPORTED_ANNOTATION_FORMATS, annotations) + + images = make_list_of_images(images) + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + # All transformations expect numpy arrays + images = [to_numpy_array(image) for image in images] + + if do_rescale and is_scaled_image(images[0]): + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + # prepare (COCO annotations as a list of Dict -> DETR target as a single Dict per image) + if annotations is not None: + prepared_images = [] + prepared_annotations = [] + for image, target in zip(images, annotations): + target = self.prepare_annotation( + image, + target, + format, + return_segmentation_masks=return_segmentation_masks, + masks_path=masks_path, + input_data_format=input_data_format, + ) + prepared_images.append(image) + prepared_annotations.append(target) + images = prepared_images + annotations = prepared_annotations + del prepared_images, prepared_annotations + + # transformations + if do_resize: + if annotations is not None: + resized_images, resized_annotations = [], [] + for image, target in zip(images, annotations): + orig_size = get_image_size(image, input_data_format) + resized_image = self.resize( + image, size=size, resample=resample, input_data_format=input_data_format + ) + resized_annotation = self.resize_annotation( + target, orig_size, get_image_size(resized_image, input_data_format) + ) + resized_images.append(resized_image) + resized_annotations.append(resized_annotation) + images = resized_images + annotations = resized_annotations + del resized_images, resized_annotations + else: + images = [ + self.resize(image, size=size, resample=resample, input_data_format=input_data_format) + for image in images + ] + + if do_rescale: + images = [self.rescale(image, rescale_factor, input_data_format=input_data_format) for image in images] + + if do_normalize: + images = [ + self.normalize(image, image_mean, image_std, input_data_format=input_data_format) for image in images + ] + + if do_convert_annotations and annotations is not None: + annotations = [ + self.normalize_annotation(annotation, get_image_size(image, input_data_format)) + for annotation, image in zip(annotations, images) + ] + + if do_pad: + # Pads images and returns their mask: {'pixel_values': ..., 'pixel_mask': ...} + encoded_inputs = self.pad( + images, + annotations=annotations, + return_pixel_mask=True, + data_format=data_format, + input_data_format=input_data_format, + update_bboxes=do_convert_annotations, + return_tensors=return_tensors, + pad_size=pad_size, + ) + else: + images = [ + to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + for image in images + ] + encoded_inputs = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors) + if annotations is not None: + encoded_inputs["labels"] = [ + BatchFeature(annotation, tensor_type=return_tensors) for annotation in annotations + ] + + return encoded_inputs + + def post_process_object_detection( + self, + outputs, + threshold: float = 0.5, + target_sizes: Union[TensorType, List[Tuple]] = None, + use_focal_loss: bool = True, + ): + """ + Converts the raw output of [`DetrForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y, + bottom_right_x, bottom_right_y) format. Only supports PyTorch. + + Args: + outputs ([`DetrObjectDetectionOutput`]): + Raw outputs of the model. + threshold (`float`, *optional*, defaults to 0.5): + Score threshold to keep object detection predictions. + target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*): + Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size + `(height, width)` of each image in the batch. If unset, predictions will not be resized. + use_focal_loss (`bool` defaults to `True`): + Variable informing if the focal loss was used to predict the outputs. If `True`, a sigmoid is applied + to compute the scores of each detection, otherwise, a softmax function is used. + + Returns: + `List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image + in the batch as predicted by the model. + """ + requires_backends(self, ["torch"]) + out_logits, out_bbox = outputs.logits, outputs.pred_boxes + # convert from relative cxcywh to absolute xyxy + boxes = center_to_corners_format(out_bbox) + if target_sizes is not None: + if len(out_logits) != len(target_sizes): + raise ValueError( + "Make sure that you pass in as many target sizes as the batch dimension of the logits" + ) + if isinstance(target_sizes, List): + img_h, img_w = torch.as_tensor(target_sizes).unbind(1) + else: + img_h, img_w = target_sizes.unbind(1) + scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device) + boxes = boxes * scale_fct[:, None, :] + + num_top_queries = out_logits.shape[1] + num_classes = out_logits.shape[2] + + if use_focal_loss: + scores = torch.nn.functional.sigmoid(out_logits) + scores, index = torch.topk(scores.flatten(1), num_top_queries, axis=-1) + labels = index % num_classes + index = index // num_classes + boxes = boxes.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, boxes.shape[-1])) + else: + scores = torch.nn.functional.softmax(out_logits)[:, :, :-1] + scores, labels = scores.max(dim=-1) + if scores.shape[1] > num_top_queries: + scores, index = torch.topk(scores, num_top_queries, dim=-1) + labels = torch.gather(labels, dim=1, index=index) + boxes = torch.gather(boxes, dim=1, index=index.unsqueeze(-1).tile(1, 1, boxes.shape[-1])) + + results = [] + for score, label, box in zip(scores, labels, boxes): + results.append( + { + "scores": score[score > threshold], + "labels": label[score > threshold], + "boxes": box[score > threshold], + } + ) + + return results + + +__all__ = ["RTDetrImageProcessor"] diff --git a/docs/transformers/src/transformers/models/rt_detr/image_processing_rt_detr_fast.py b/docs/transformers/src/transformers/models/rt_detr/image_processing_rt_detr_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..e9bd83f1b230131caca55c93da96a07b0c485570 --- /dev/null +++ b/docs/transformers/src/transformers/models/rt_detr/image_processing_rt_detr_fast.py @@ -0,0 +1,608 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/rt_detr/modular_rt_detr.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_rt_detr.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +import pathlib +from typing import Any, Dict, List, Optional, Tuple, Union + +from ...image_processing_utils import BatchFeature +from ...image_processing_utils_fast import ( + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS, + BaseImageProcessorFast, + DefaultFastImageProcessorKwargs, + SizeDict, + add_start_docstrings, + get_image_size_for_max_height_width, + get_max_height_width, + safe_squeeze, +) +from ...image_transforms import center_to_corners_format, corners_to_center_format +from ...image_utils import ( + IMAGENET_DEFAULT_MEAN, + IMAGENET_DEFAULT_STD, + AnnotationFormat, + AnnotationType, + ChannelDimension, + ImageInput, + PILImageResampling, + get_image_size, + validate_annotations, +) +from ...processing_utils import Unpack +from ...utils import ( + TensorType, + is_torch_available, + is_torchvision_available, + is_torchvision_v2_available, + requires_backends, +) +from ...utils.import_utils import requires +from .image_processing_rt_detr import get_size_with_aspect_ratio + + +if is_torch_available(): + import torch + + +if is_torchvision_v2_available(): + from torchvision.transforms.v2 import functional as F +elif is_torchvision_available(): + from torchvision.transforms import functional as F + + +class RTDetrFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): + format: Optional[Union[str, AnnotationFormat]] + do_convert_annotations: Optional[bool] + do_pad: Optional[bool] + pad_size: Optional[Dict[str, int]] + return_segmentation_masks: Optional[bool] + + +SUPPORTED_ANNOTATION_FORMATS = (AnnotationFormat.COCO_DETECTION, AnnotationFormat.COCO_PANOPTIC) + + +def prepare_coco_detection_annotation( + image, + target, + return_segmentation_masks: bool = False, + input_data_format: Optional[Union[ChannelDimension, str]] = None, +): + """ + Convert the target in COCO format into the format expected by RT-DETR. + """ + image_height, image_width = image.size()[-2:] + + image_id = target["image_id"] + image_id = torch.as_tensor([image_id], dtype=torch.int64, device=image.device) + + # Get all COCO annotations for the given image. + annotations = target["annotations"] + classes = [] + area = [] + boxes = [] + keypoints = [] + for obj in annotations: + if "iscrowd" not in obj or obj["iscrowd"] == 0: + classes.append(obj["category_id"]) + area.append(obj["area"]) + boxes.append(obj["bbox"]) + if "keypoints" in obj: + keypoints.append(obj["keypoints"]) + + classes = torch.as_tensor(classes, dtype=torch.int64, device=image.device) + area = torch.as_tensor(area, dtype=torch.float32, device=image.device) + iscrowd = torch.zeros_like(classes, dtype=torch.int64, device=image.device) + # guard against no boxes via resizing + boxes = torch.as_tensor(boxes, dtype=torch.float32, device=image.device).reshape(-1, 4) + boxes[:, 2:] += boxes[:, :2] + boxes[:, 0::2] = boxes[:, 0::2].clip(min=0, max=image_width) + boxes[:, 1::2] = boxes[:, 1::2].clip(min=0, max=image_height) + + keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) + + new_target = { + "image_id": image_id, + "class_labels": classes[keep], + "boxes": boxes[keep], + "area": area[keep], + "iscrowd": iscrowd[keep], + "orig_size": torch.as_tensor([int(image_height), int(image_width)], dtype=torch.int64, device=image.device), + } + + if keypoints: + keypoints = torch.as_tensor(keypoints, dtype=torch.float32, device=image.device) + # Apply the keep mask here to filter the relevant annotations + keypoints = keypoints[keep] + num_keypoints = keypoints.shape[0] + keypoints = keypoints.reshape((-1, 3)) if num_keypoints else keypoints + new_target["keypoints"] = keypoints + + return new_target + + +@add_start_docstrings( + "Constructs a fast RTDetr image processor.", + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, + """ + format (`str`, *optional*, defaults to `AnnotationFormat.COCO_DETECTION`): + Data format of the annotations. One of "coco_detection" or "coco_panoptic". + do_convert_annotations (`bool`, *optional*, defaults to `True`): + Controls whether to convert the annotations to the format expected by the RT_DETR model. Converts the + bounding boxes to the format `(center_x, center_y, width, height)` and in the range `[0, 1]`. + Can be overridden by the `do_convert_annotations` parameter in the `preprocess` method. + do_pad (`bool`, *optional*, defaults to `True`): + Controls whether to pad the image. Can be overridden by the `do_pad` parameter in the `preprocess` + method. If `True`, padding will be applied to the bottom and right of the image with zeros. + If `pad_size` is provided, the image will be padded to the specified dimensions. + Otherwise, the image will be padded to the maximum height and width of the batch. + pad_size (`Dict[str, int]`, *optional*): + The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size + provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest + height and width in the batch. + return_segmentation_masks (`bool`, *optional*, defaults to `False`): + Whether to return segmentation masks. + """, +) +@requires(backends=("torchvision", "torch")) +class RTDetrImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.BILINEAR + image_mean = IMAGENET_DEFAULT_MEAN + image_std = IMAGENET_DEFAULT_STD + format = AnnotationFormat.COCO_DETECTION + do_resize = True + do_rescale = True + do_normalize = False + do_pad = False + size = {"height": 640, "width": 640} + default_to_square = False + model_input_names = ["pixel_values", "pixel_mask"] + valid_kwargs = RTDetrFastImageProcessorKwargs + do_convert_annotations = True + + def __init__(self, **kwargs: Unpack[RTDetrFastImageProcessorKwargs]) -> None: + # Backwards compatibility + do_convert_annotations = kwargs.get("do_convert_annotations", None) + do_normalize = kwargs.get("do_normalize", None) + if do_convert_annotations is None and getattr(self, "do_convert_annotations", None) is None: + self.do_convert_annotations = do_normalize if do_normalize is not None else self.do_normalize + + super().__init__(**kwargs) + + def prepare_annotation( + self, + image: torch.Tensor, + target: Dict, + format: Optional[AnnotationFormat] = None, + return_segmentation_masks: Optional[bool] = None, + masks_path: Optional[Union[str, pathlib.Path]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> Dict: + """ + Prepare an annotation for feeding into RT_DETR model. + """ + format = format if format is not None else self.format + + if format == AnnotationFormat.COCO_DETECTION: + return_segmentation_masks = False if return_segmentation_masks is None else return_segmentation_masks + target = prepare_coco_detection_annotation( + image, target, return_segmentation_masks, input_data_format=input_data_format + ) + else: + raise ValueError(f"Format {format} is not supported.") + return target + + def resize( + self, + image: torch.Tensor, + size: SizeDict, + interpolation: "F.InterpolationMode" = None, + **kwargs, + ) -> torch.Tensor: + """ + Resize the image to the given size. Size can be `min_size` (scalar) or `(height, width)` tuple. If size is an + int, smaller edge of the image will be matched to this number. + + Args: + image (`torch.Tensor`): + Image to resize. + size (`SizeDict`): + Size of the image's `(height, width)` dimensions after resizing. Available options are: + - `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`. + Do NOT keep the aspect ratio. + - `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting + the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge + less or equal to `longest_edge`. + - `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the + aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to + `max_width`. + interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`): + Resampling filter to use if resizing the image. + """ + interpolation = interpolation if interpolation is not None else F.InterpolationMode.BILINEAR + if size.shortest_edge and size.longest_edge: + # Resize the image so that the shortest edge or the longest edge is of the given size + # while maintaining the aspect ratio of the original image. + new_size = get_size_with_aspect_ratio( + image.size()[-2:], + size["shortest_edge"], + size["longest_edge"], + ) + elif size.max_height and size.max_width: + new_size = get_image_size_for_max_height_width(image.size()[-2:], size["max_height"], size["max_width"]) + elif size.height and size.width: + new_size = (size["height"], size["width"]) + else: + raise ValueError( + "Size must contain 'height' and 'width' keys or 'shortest_edge' and 'longest_edge' keys. Got" + f" {size.keys()}." + ) + + image = F.resize( + image, + size=new_size, + interpolation=interpolation, + **kwargs, + ) + return image + + def resize_annotation( + self, + annotation: Dict[str, Any], + orig_size: Tuple[int, int], + target_size: Tuple[int, int], + threshold: float = 0.5, + interpolation: "F.InterpolationMode" = None, + ): + """ + Resizes an annotation to a target size. + + Args: + annotation (`Dict[str, Any]`): + The annotation dictionary. + orig_size (`Tuple[int, int]`): + The original size of the input image. + target_size (`Tuple[int, int]`): + The target size of the image, as returned by the preprocessing `resize` step. + threshold (`float`, *optional*, defaults to 0.5): + The threshold used to binarize the segmentation masks. + resample (`InterpolationMode`, defaults to `InterpolationMode.NEAREST`): + The resampling filter to use when resizing the masks. + """ + interpolation = interpolation if interpolation is not None else F.InterpolationMode.NEAREST + ratio_height, ratio_width = [target / orig for target, orig in zip(target_size, orig_size)] + + new_annotation = {} + new_annotation["size"] = target_size + + for key, value in annotation.items(): + if key == "boxes": + boxes = value + scaled_boxes = boxes * torch.as_tensor( + [ratio_width, ratio_height, ratio_width, ratio_height], dtype=torch.float32, device=boxes.device + ) + new_annotation["boxes"] = scaled_boxes + elif key == "area": + area = value + scaled_area = area * (ratio_width * ratio_height) + new_annotation["area"] = scaled_area + elif key == "masks": + masks = value[:, None] + masks = [F.resize(mask, target_size, interpolation=interpolation) for mask in masks] + masks = torch.stack(masks).to(torch.float32) + masks = masks[:, 0] > threshold + new_annotation["masks"] = masks + elif key == "size": + new_annotation["size"] = target_size + else: + new_annotation[key] = value + + return new_annotation + + def normalize_annotation(self, annotation: Dict, image_size: Tuple[int, int]) -> Dict: + image_height, image_width = image_size + norm_annotation = {} + for key, value in annotation.items(): + if key == "boxes": + boxes = value + boxes = corners_to_center_format(boxes) + boxes /= torch.as_tensor( + [image_width, image_height, image_width, image_height], dtype=torch.float32, device=boxes.device + ) + norm_annotation[key] = boxes + else: + norm_annotation[key] = value + return norm_annotation + + def _update_annotation_for_padded_image( + self, + annotation: Dict, + input_image_size: Tuple[int, int], + output_image_size: Tuple[int, int], + padding, + update_bboxes, + ) -> Dict: + """ + Update the annotation for a padded image. + """ + new_annotation = {} + new_annotation["size"] = output_image_size + ratio_height, ratio_width = (input / output for output, input in zip(output_image_size, input_image_size)) + + for key, value in annotation.items(): + if key == "masks": + masks = value + masks = F.pad( + masks, + padding, + fill=0, + ) + masks = safe_squeeze(masks, 1) + new_annotation["masks"] = masks + elif key == "boxes" and update_bboxes: + boxes = value + boxes *= torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height], device=boxes.device) + new_annotation["boxes"] = boxes + elif key == "size": + new_annotation["size"] = output_image_size + else: + new_annotation[key] = value + return new_annotation + + def pad( + self, + image: torch.Tensor, + padded_size: Tuple[int, int], + annotation: Optional[Dict[str, Any]] = None, + update_bboxes: bool = True, + fill: int = 0, + ): + original_size = image.size()[-2:] + padding_bottom = padded_size[0] - original_size[0] + padding_right = padded_size[1] - original_size[1] + if padding_bottom < 0 or padding_right < 0: + raise ValueError( + f"Padding dimensions are negative. Please make sure that the padded size is larger than the " + f"original size. Got padded size: {padded_size}, original size: {original_size}." + ) + if original_size != padded_size: + padding = [0, 0, padding_right, padding_bottom] + image = F.pad(image, padding, fill=fill) + if annotation is not None: + annotation = self._update_annotation_for_padded_image( + annotation, original_size, padded_size, padding, update_bboxes + ) + + # Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding. + pixel_mask = torch.zeros(padded_size, dtype=torch.int64, device=image.device) + pixel_mask[: original_size[0], : original_size[1]] = 1 + + return image, pixel_mask, annotation + + @add_start_docstrings( + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS, + """ + annotations (`AnnotationType` or `List[AnnotationType]`, *optional*): + List of annotations associated with the image or batch of images. If annotation is for object + detection, the annotations should be a dictionary with the following keys: + - "image_id" (`int`): The image id. + - "annotations" (`List[Dict]`): List of annotations for an image. Each annotation should be a + dictionary. An image can have no annotations, in which case the list should be empty. + If annotation is for segmentation, the annotations should be a dictionary with the following keys: + - "image_id" (`int`): The image id. + - "segments_info" (`List[Dict]`): List of segments for an image. Each segment should be a dictionary. + An image can have no segments, in which case the list should be empty. + - "file_name" (`str`): The file name of the image. + format (`str`, *optional*, defaults to `AnnotationFormat.COCO_DETECTION`): + Data format of the annotations. One of "coco_detection" or "coco_panoptic". + do_convert_annotations (`bool`, *optional*, defaults to `True`): + Controls whether to convert the annotations to the format expected by the DETR model. Converts the + bounding boxes to the format `(center_x, center_y, width, height)` and in the range `[0, 1]`. + Can be overridden by the `do_convert_annotations` parameter in the `preprocess` method. + do_pad (`bool`, *optional*, defaults to `True`): + Controls whether to pad the image. Can be overridden by the `do_pad` parameter in the `preprocess` + method. If `True`, padding will be applied to the bottom and right of the image with zeros. + If `pad_size` is provided, the image will be padded to the specified dimensions. + Otherwise, the image will be padded to the maximum height and width of the batch. + pad_size (`Dict[str, int]`, *optional*): + The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size + provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest + height and width in the batch. + return_segmentation_masks (`bool`, *optional*, defaults to `False`): + Whether to return segmentation masks. + masks_path (`str` or `pathlib.Path`, *optional*): + Path to the directory containing the segmentation masks. + """, + ) + def preprocess( + self, + images: ImageInput, + annotations: Optional[Union[AnnotationType, List[AnnotationType]]] = None, + masks_path: Optional[Union[str, pathlib.Path]] = None, + **kwargs: Unpack[RTDetrFastImageProcessorKwargs], + ) -> BatchFeature: + return super().preprocess(images, annotations=annotations, masks_path=masks_path, **kwargs) + + def _preprocess( + self, + images: List["torch.Tensor"], + annotations: Optional[Union[AnnotationType, List[AnnotationType]]], + return_segmentation_masks: bool, + masks_path: Optional[Union[str, pathlib.Path]], + do_resize: bool, + size: SizeDict, + interpolation: Optional["F.InterpolationMode"], + do_center_crop: bool, + crop_size: SizeDict, + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + do_convert_annotations: bool, + image_mean: Optional[Union[float, List[float]]], + image_std: Optional[Union[float, List[float]]], + do_pad: bool, + pad_size: Optional[Dict[str, int]], + format: Optional[Union[str, AnnotationFormat]], + return_tensors: Optional[Union[str, TensorType]], + ) -> BatchFeature: + """ + Preprocess an image or a batch of images so that it can be used by the model. + """ + + if annotations is not None and isinstance(annotations, dict): + annotations = [annotations] + + if annotations is not None and len(images) != len(annotations): + raise ValueError( + f"The number of images ({len(images)}) and annotations ({len(annotations)}) do not match." + ) + + format = AnnotationFormat(format) + if annotations is not None: + validate_annotations(format, SUPPORTED_ANNOTATION_FORMATS, annotations) + + data = {} + processed_images = [] + processed_annotations = [] + pixel_masks = [] # Initialize pixel_masks here + for image, annotation in zip(images, annotations if annotations is not None else [None] * len(images)): + # prepare (COCO annotations as a list of Dict -> DETR target as a single Dict per image) + if annotations is not None: + annotation = self.prepare_annotation( + image, + annotation, + format, + return_segmentation_masks=return_segmentation_masks, + masks_path=masks_path, + input_data_format=ChannelDimension.FIRST, + ) + + if do_resize: + resized_image = self.resize(image, size=size, interpolation=interpolation) + if annotations is not None: + annotation = self.resize_annotation( + annotation, + orig_size=image.size()[-2:], + target_size=resized_image.size()[-2:], + ) + image = resized_image + # Fused rescale and normalize + image = self.rescale_and_normalize(image, do_rescale, rescale_factor, do_normalize, image_mean, image_std) + if do_convert_annotations and annotations is not None: + annotation = self.normalize_annotation(annotation, get_image_size(image, ChannelDimension.FIRST)) + + processed_images.append(image) + processed_annotations.append(annotation) + images = processed_images + annotations = processed_annotations if annotations is not None else None + + if do_pad: + # depends on all resized image shapes so we need another loop + if pad_size is not None: + padded_size = (pad_size["height"], pad_size["width"]) + else: + padded_size = get_max_height_width(images) + + padded_images = [] + padded_annotations = [] + for image, annotation in zip(images, annotations if annotations is not None else [None] * len(images)): + # Pads images and returns their mask: {'pixel_values': ..., 'pixel_mask': ...} + if padded_size == image.size()[-2:]: + padded_images.append(image) + pixel_masks.append(torch.ones(padded_size, dtype=torch.int64, device=image.device)) + padded_annotations.append(annotation) + continue + image, pixel_mask, annotation = self.pad( + image, padded_size, annotation=annotation, update_bboxes=do_convert_annotations + ) + padded_images.append(image) + padded_annotations.append(annotation) + pixel_masks.append(pixel_mask) + images = padded_images + annotations = padded_annotations if annotations is not None else None + data.update({"pixel_mask": torch.stack(pixel_masks, dim=0)}) + + data.update({"pixel_values": torch.stack(images, dim=0)}) + encoded_inputs = BatchFeature(data, tensor_type=return_tensors) + if annotations is not None: + encoded_inputs["labels"] = [ + BatchFeature(annotation, tensor_type=return_tensors) for annotation in annotations + ] + return encoded_inputs + + def post_process_object_detection( + self, + outputs, + threshold: float = 0.5, + target_sizes: Union[TensorType, List[Tuple]] = None, + use_focal_loss: bool = True, + ): + """ + Converts the raw output of [`DetrForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y, + bottom_right_x, bottom_right_y) format. Only supports PyTorch. + + Args: + outputs ([`DetrObjectDetectionOutput`]): + Raw outputs of the model. + threshold (`float`, *optional*, defaults to 0.5): + Score threshold to keep object detection predictions. + target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*): + Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size + `(height, width)` of each image in the batch. If unset, predictions will not be resized. + use_focal_loss (`bool` defaults to `True`): + Variable informing if the focal loss was used to predict the outputs. If `True`, a sigmoid is applied + to compute the scores of each detection, otherwise, a softmax function is used. + + Returns: + `List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image + in the batch as predicted by the model. + """ + requires_backends(self, ["torch"]) + out_logits, out_bbox = outputs.logits, outputs.pred_boxes + # convert from relative cxcywh to absolute xyxy + boxes = center_to_corners_format(out_bbox) + if target_sizes is not None: + if len(out_logits) != len(target_sizes): + raise ValueError( + "Make sure that you pass in as many target sizes as the batch dimension of the logits" + ) + if isinstance(target_sizes, List): + img_h, img_w = torch.as_tensor(target_sizes).unbind(1) + else: + img_h, img_w = target_sizes.unbind(1) + scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device) + boxes = boxes * scale_fct[:, None, :] + + num_top_queries = out_logits.shape[1] + num_classes = out_logits.shape[2] + + if use_focal_loss: + scores = torch.nn.functional.sigmoid(out_logits) + scores, index = torch.topk(scores.flatten(1), num_top_queries, axis=-1) + labels = index % num_classes + index = index // num_classes + boxes = boxes.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, boxes.shape[-1])) + else: + scores = torch.nn.functional.softmax(out_logits)[:, :, :-1] + scores, labels = scores.max(dim=-1) + if scores.shape[1] > num_top_queries: + scores, index = torch.topk(scores, num_top_queries, dim=-1) + labels = torch.gather(labels, dim=1, index=index) + boxes = torch.gather(boxes, dim=1, index=index.unsqueeze(-1).tile(1, 1, boxes.shape[-1])) + + results = [] + for score, label, box in zip(scores, labels, boxes): + results.append( + { + "scores": score[score > threshold], + "labels": label[score > threshold], + "boxes": box[score > threshold], + } + ) + + return results + + +__all__ = ["RTDetrImageProcessorFast"] diff --git a/docs/transformers/src/transformers/models/rt_detr/modeling_rt_detr.py b/docs/transformers/src/transformers/models/rt_detr/modeling_rt_detr.py new file mode 100644 index 0000000000000000000000000000000000000000..6ecf72d0a12cea00a552443316e9f4b48d310590 --- /dev/null +++ b/docs/transformers/src/transformers/models/rt_detr/modeling_rt_detr.py @@ -0,0 +1,2081 @@ +# coding=utf-8 +# Copyright 2024 Baidu Inc and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch RT-DETR model.""" + +import math +import warnings +from dataclasses import dataclass +from functools import partial +from typing import Dict, List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from torch import Tensor, nn + +from ...activations import ACT2CLS, ACT2FN +from ...image_transforms import center_to_corners_format, corners_to_center_format +from ...integrations import use_kernel_forward_from_hub +from ...modeling_outputs import BaseModelOutput +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import compile_compatible_method_lru_cache +from ...utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, + torch_int, +) +from ...utils.backbone_utils import load_backbone +from .configuration_rt_detr import RTDetrConfig + + +logger = logging.get_logger(__name__) + + +_CONFIG_FOR_DOC = "RTDetrConfig" +# TODO: Replace all occurrences of the checkpoint with the final one +_CHECKPOINT_FOR_DOC = "PekingU/rtdetr_r50vd" + + +@use_kernel_forward_from_hub("MultiScaleDeformableAttention") +# Copied from transformers.models.deformable_detr.modeling_deformable_detr.MultiScaleDeformableAttention +class MultiScaleDeformableAttention(nn.Module): + def forward( + self, + value: Tensor, + value_spatial_shapes: Tensor, + value_spatial_shapes_list: List[Tuple], + level_start_index: Tensor, + sampling_locations: Tensor, + attention_weights: Tensor, + im2col_step: int, + ): + batch_size, _, num_heads, hidden_dim = value.shape + _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape + value_list = value.split([height * width for height, width in value_spatial_shapes_list], dim=1) + sampling_grids = 2 * sampling_locations - 1 + sampling_value_list = [] + for level_id, (height, width) in enumerate(value_spatial_shapes_list): + # batch_size, height*width, num_heads, hidden_dim + # -> batch_size, height*width, num_heads*hidden_dim + # -> batch_size, num_heads*hidden_dim, height*width + # -> batch_size*num_heads, hidden_dim, height, width + value_l_ = ( + value_list[level_id] + .flatten(2) + .transpose(1, 2) + .reshape(batch_size * num_heads, hidden_dim, height, width) + ) + # batch_size, num_queries, num_heads, num_points, 2 + # -> batch_size, num_heads, num_queries, num_points, 2 + # -> batch_size*num_heads, num_queries, num_points, 2 + sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1) + # batch_size*num_heads, hidden_dim, num_queries, num_points + sampling_value_l_ = nn.functional.grid_sample( + value_l_, + sampling_grid_l_, + mode="bilinear", + padding_mode="zeros", + align_corners=False, + ) + sampling_value_list.append(sampling_value_l_) + # (batch_size, num_queries, num_heads, num_levels, num_points) + # -> (batch_size, num_heads, num_queries, num_levels, num_points) + # -> (batch_size, num_heads, 1, num_queries, num_levels*num_points) + attention_weights = attention_weights.transpose(1, 2).reshape( + batch_size * num_heads, 1, num_queries, num_levels * num_points + ) + output = ( + (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights) + .sum(-1) + .view(batch_size, num_heads * hidden_dim, num_queries) + ) + return output.transpose(1, 2).contiguous() + + +@dataclass +class RTDetrDecoderOutput(ModelOutput): + """ + Base class for outputs of the RTDetrDecoder. This class adds two attributes to + BaseModelOutputWithCrossAttentions, namely: + - a stacked tensor of intermediate decoder hidden states (i.e. the output of each decoder layer) + - a stacked tensor of intermediate reference points. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`): + Stacked intermediate hidden states (output of each layer of the decoder). + intermediate_logits (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, config.num_labels)`): + Stacked intermediate logits (logits of each layer of the decoder). + intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, hidden_size)`): + Stacked intermediate reference points (reference points of each layer of the decoder). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer + plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in + the self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax, + used to compute the weighted average in the cross-attention heads. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + intermediate_hidden_states: Optional[torch.FloatTensor] = None + intermediate_logits: Optional[torch.FloatTensor] = None + intermediate_reference_points: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class RTDetrModelOutput(ModelOutput): + """ + Base class for outputs of the RT-DETR encoder-decoder model. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the decoder of the model. + intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`): + Stacked intermediate hidden states (output of each layer of the decoder). + intermediate_logits (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, config.num_labels)`): + Stacked intermediate logits (logits of each layer of the decoder). + intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`): + Stacked intermediate reference points (reference points of each layer of the decoder). + decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, num_queries, hidden_size)`. Hidden-states of the decoder at the output of each layer + plus the initial embedding outputs. + decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, num_queries, + num_queries)`. Attentions weights of the decoder, after the attention softmax, used to compute the weighted + average in the self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_queries, num_heads, 4, 4)`. + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the encoder at the output of each + layer plus the initial embedding outputs. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_queries, num_heads, 4, 4)`. + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): + Initial reference points sent through the Transformer decoder. + enc_topk_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`): + Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are + picked as region proposals in the encoder stage. Output of bounding box binary classification (i.e. + foreground and background). + enc_topk_bboxes (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`): + Logits of predicted bounding boxes coordinates in the encoder stage. + enc_outputs_class (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): + Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are + picked as region proposals in the first stage. Output of bounding box binary classification (i.e. + foreground and background). + enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): + Logits of predicted bounding boxes coordinates in the first stage. + denoising_meta_values (`dict`): + Extra dictionary for the denoising related values + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + intermediate_hidden_states: Optional[torch.FloatTensor] = None + intermediate_logits: Optional[torch.FloatTensor] = None + intermediate_reference_points: Optional[torch.FloatTensor] = None + decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + init_reference_points: Optional[torch.FloatTensor] = None + enc_topk_logits: Optional[torch.FloatTensor] = None + enc_topk_bboxes: Optional[torch.FloatTensor] = None + enc_outputs_class: Optional[torch.FloatTensor] = None + enc_outputs_coord_logits: Optional[torch.FloatTensor] = None + denoising_meta_values: Optional[Dict] = None + + +@dataclass +class RTDetrObjectDetectionOutput(ModelOutput): + """ + Output type of [`RTDetrForObjectDetection`]. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)): + Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a + bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized + scale-invariant IoU loss. + loss_dict (`Dict`, *optional*): + A dictionary containing the individual losses. Useful for logging. + logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes + 1)`): + Classification logits (including no-object) for all queries. + pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): + Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These + values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding + possible padding). You can use [`~RTDetrImageProcessor.post_process_object_detection`] to retrieve the + unnormalized (absolute) bounding boxes. + auxiliary_outputs (`list[Dict]`, *optional*): + Optional, only returned when auxiliary losses are activated (i.e. `config.auxiliary_loss` is set to `True`) + and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and + `pred_boxes`) for each decoder layer. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the decoder of the model. + intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`): + Stacked intermediate hidden states (output of each layer of the decoder). + intermediate_logits (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, config.num_labels)`): + Stacked intermediate logits (logits of each layer of the decoder). + intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`): + Stacked intermediate reference points (reference points of each layer of the decoder). + decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, num_queries, hidden_size)`. Hidden-states of the decoder at the output of each layer + plus the initial embedding outputs. + decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, num_queries, + num_queries)`. Attentions weights of the decoder, after the attention softmax, used to compute the weighted + average in the self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_queries, num_heads, 4, 4)`. + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the encoder at the output of each + layer plus the initial embedding outputs. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_queries, num_heads, 4, 4)`. + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): + Initial reference points sent through the Transformer decoder. + enc_topk_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): + Logits of predicted bounding boxes coordinates in the encoder. + enc_topk_bboxes (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): + Logits of predicted bounding boxes coordinates in the encoder. + enc_outputs_class (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): + Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are + picked as region proposals in the first stage. Output of bounding box binary classification (i.e. + foreground and background). + enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): + Logits of predicted bounding boxes coordinates in the first stage. + denoising_meta_values (`dict`): + Extra dictionary for the denoising related values + """ + + loss: Optional[torch.FloatTensor] = None + loss_dict: Optional[Dict] = None + logits: Optional[torch.FloatTensor] = None + pred_boxes: Optional[torch.FloatTensor] = None + auxiliary_outputs: Optional[List[Dict]] = None + last_hidden_state: Optional[torch.FloatTensor] = None + intermediate_hidden_states: Optional[torch.FloatTensor] = None + intermediate_logits: Optional[torch.FloatTensor] = None + intermediate_reference_points: Optional[torch.FloatTensor] = None + decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + init_reference_points: Optional[Tuple[torch.FloatTensor]] = None + enc_topk_logits: Optional[torch.FloatTensor] = None + enc_topk_bboxes: Optional[torch.FloatTensor] = None + enc_outputs_class: Optional[torch.FloatTensor] = None + enc_outputs_coord_logits: Optional[torch.FloatTensor] = None + denoising_meta_values: Optional[Dict] = None + + +def _get_clones(partial_module, N): + return nn.ModuleList([partial_module() for i in range(N)]) + + +# Copied from transformers.models.conditional_detr.modeling_conditional_detr.inverse_sigmoid +def inverse_sigmoid(x, eps=1e-5): + x = x.clamp(min=0, max=1) + x1 = x.clamp(min=eps) + x2 = (1 - x).clamp(min=eps) + return torch.log(x1 / x2) + + +# Copied from transformers.models.detr.modeling_detr.DetrFrozenBatchNorm2d with Detr->RTDetr +class RTDetrFrozenBatchNorm2d(nn.Module): + """ + BatchNorm2d where the batch statistics and the affine parameters are fixed. + + Copy-paste from torchvision.misc.ops with added eps before rqsrt, without which any other models than + torchvision.models.resnet[18,34,50,101] produce nans. + """ + + def __init__(self, n): + super().__init__() + self.register_buffer("weight", torch.ones(n)) + self.register_buffer("bias", torch.zeros(n)) + self.register_buffer("running_mean", torch.zeros(n)) + self.register_buffer("running_var", torch.ones(n)) + + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + num_batches_tracked_key = prefix + "num_batches_tracked" + if num_batches_tracked_key in state_dict: + del state_dict[num_batches_tracked_key] + + super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) + + def forward(self, x): + # move reshapes to the beginning + # to make it user-friendly + weight = self.weight.reshape(1, -1, 1, 1) + bias = self.bias.reshape(1, -1, 1, 1) + running_var = self.running_var.reshape(1, -1, 1, 1) + running_mean = self.running_mean.reshape(1, -1, 1, 1) + epsilon = 1e-5 + scale = weight * (running_var + epsilon).rsqrt() + bias = bias - running_mean * scale + return x * scale + bias + + +# Copied from transformers.models.detr.modeling_detr.replace_batch_norm with Detr->RTDetr +def replace_batch_norm(model): + r""" + Recursively replace all `torch.nn.BatchNorm2d` with `RTDetrFrozenBatchNorm2d`. + + Args: + model (torch.nn.Module): + input model + """ + for name, module in model.named_children(): + if isinstance(module, nn.BatchNorm2d): + new_module = RTDetrFrozenBatchNorm2d(module.num_features) + + if not module.weight.device == torch.device("meta"): + new_module.weight.data.copy_(module.weight) + new_module.bias.data.copy_(module.bias) + new_module.running_mean.data.copy_(module.running_mean) + new_module.running_var.data.copy_(module.running_var) + + model._modules[name] = new_module + + if len(list(module.children())) > 0: + replace_batch_norm(module) + + +def get_contrastive_denoising_training_group( + targets, + num_classes, + num_queries, + class_embed, + num_denoising_queries=100, + label_noise_ratio=0.5, + box_noise_scale=1.0, +): + """ + Creates a contrastive denoising training group using ground-truth samples. It adds noise to labels and boxes. + + Args: + targets (`List[dict]`): + The target objects, each containing 'class_labels' and 'boxes' for objects in an image. + num_classes (`int`): + Total number of classes in the dataset. + num_queries (`int`): + Number of query slots in the transformer. + class_embed (`callable`): + A function or a model layer to embed class labels. + num_denoising_queries (`int`, *optional*, defaults to 100): + Number of denoising queries. + label_noise_ratio (`float`, *optional*, defaults to 0.5): + Ratio of noise applied to labels. + box_noise_scale (`float`, *optional*, defaults to 1.0): + Scale of noise applied to bounding boxes. + Returns: + `tuple` comprising various elements: + - **input_query_class** (`torch.FloatTensor`) -- + Class queries with applied label noise. + - **input_query_bbox** (`torch.FloatTensor`) -- + Bounding box queries with applied box noise. + - **attn_mask** (`torch.FloatTensor`) -- + Attention mask for separating denoising and reconstruction queries. + - **denoising_meta_values** (`dict`) -- + Metadata including denoising positive indices, number of groups, and split sizes. + """ + + if num_denoising_queries <= 0: + return None, None, None, None + + num_ground_truths = [len(t["class_labels"]) for t in targets] + device = targets[0]["class_labels"].device + + max_gt_num = max(num_ground_truths) + if max_gt_num == 0: + return None, None, None, None + + num_groups_denoising_queries = num_denoising_queries // max_gt_num + num_groups_denoising_queries = 1 if num_groups_denoising_queries == 0 else num_groups_denoising_queries + # pad gt to max_num of a batch + batch_size = len(num_ground_truths) + + input_query_class = torch.full([batch_size, max_gt_num], num_classes, dtype=torch.int32, device=device) + input_query_bbox = torch.zeros([batch_size, max_gt_num, 4], device=device) + pad_gt_mask = torch.zeros([batch_size, max_gt_num], dtype=torch.bool, device=device) + + for i in range(batch_size): + num_gt = num_ground_truths[i] + if num_gt > 0: + input_query_class[i, :num_gt] = targets[i]["class_labels"] + input_query_bbox[i, :num_gt] = targets[i]["boxes"] + pad_gt_mask[i, :num_gt] = 1 + # each group has positive and negative queries. + input_query_class = input_query_class.tile([1, 2 * num_groups_denoising_queries]) + input_query_bbox = input_query_bbox.tile([1, 2 * num_groups_denoising_queries, 1]) + pad_gt_mask = pad_gt_mask.tile([1, 2 * num_groups_denoising_queries]) + # positive and negative mask + negative_gt_mask = torch.zeros([batch_size, max_gt_num * 2, 1], device=device) + negative_gt_mask[:, max_gt_num:] = 1 + negative_gt_mask = negative_gt_mask.tile([1, num_groups_denoising_queries, 1]) + positive_gt_mask = 1 - negative_gt_mask + # contrastive denoising training positive index + positive_gt_mask = positive_gt_mask.squeeze(-1) * pad_gt_mask + denoise_positive_idx = torch.nonzero(positive_gt_mask)[:, 1] + denoise_positive_idx = torch.split( + denoise_positive_idx, [n * num_groups_denoising_queries for n in num_ground_truths] + ) + # total denoising queries + num_denoising_queries = torch_int(max_gt_num * 2 * num_groups_denoising_queries) + + if label_noise_ratio > 0: + mask = torch.rand_like(input_query_class, dtype=torch.float) < (label_noise_ratio * 0.5) + # randomly put a new one here + new_label = torch.randint_like(mask, 0, num_classes, dtype=input_query_class.dtype) + input_query_class = torch.where(mask & pad_gt_mask, new_label, input_query_class) + + if box_noise_scale > 0: + known_bbox = center_to_corners_format(input_query_bbox) + diff = torch.tile(input_query_bbox[..., 2:] * 0.5, [1, 1, 2]) * box_noise_scale + rand_sign = torch.randint_like(input_query_bbox, 0, 2) * 2.0 - 1.0 + rand_part = torch.rand_like(input_query_bbox) + rand_part = (rand_part + 1.0) * negative_gt_mask + rand_part * (1 - negative_gt_mask) + rand_part *= rand_sign + known_bbox += rand_part * diff + known_bbox.clip_(min=0.0, max=1.0) + input_query_bbox = corners_to_center_format(known_bbox) + input_query_bbox = inverse_sigmoid(input_query_bbox) + + input_query_class = class_embed(input_query_class) + + target_size = num_denoising_queries + num_queries + attn_mask = torch.full([target_size, target_size], False, dtype=torch.bool, device=device) + # match query cannot see the reconstruction + attn_mask[num_denoising_queries:, :num_denoising_queries] = True + + # reconstructions cannot see each other + for i in range(num_groups_denoising_queries): + idx_block_start = max_gt_num * 2 * i + idx_block_end = max_gt_num * 2 * (i + 1) + attn_mask[idx_block_start:idx_block_end, :idx_block_start] = True + attn_mask[idx_block_start:idx_block_end, idx_block_end:num_denoising_queries] = True + + denoising_meta_values = { + "dn_positive_idx": denoise_positive_idx, + "dn_num_group": num_groups_denoising_queries, + "dn_num_split": [num_denoising_queries, num_queries], + } + + return input_query_class, input_query_bbox, attn_mask, denoising_meta_values + + +class RTDetrConvEncoder(nn.Module): + """ + Convolutional backbone using the modeling_rt_detr_resnet.py. + + nn.BatchNorm2d layers are replaced by RTDetrFrozenBatchNorm2d as defined above. + https://github.com/lyuwenyu/RT-DETR/blob/main/rtdetr_pytorch/src/nn/backbone/presnet.py#L142 + """ + + def __init__(self, config): + super().__init__() + + backbone = load_backbone(config) + + if config.freeze_backbone_batch_norms: + # replace batch norm by frozen batch norm + with torch.no_grad(): + replace_batch_norm(backbone) + self.model = backbone + self.intermediate_channel_sizes = self.model.channels + + def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor): + # send pixel_values through the model to get list of feature maps + features = self.model(pixel_values).feature_maps + + out = [] + for feature_map in features: + # downsample pixel_mask to match shape of corresponding feature_map + mask = nn.functional.interpolate(pixel_mask[None].float(), size=feature_map.shape[-2:]).to(torch.bool)[0] + out.append((feature_map, mask)) + return out + + +class RTDetrConvNormLayer(nn.Module): + def __init__(self, config, in_channels, out_channels, kernel_size, stride, padding=None, activation=None): + super().__init__() + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride, + padding=(kernel_size - 1) // 2 if padding is None else padding, + bias=False, + ) + self.norm = nn.BatchNorm2d(out_channels, config.batch_norm_eps) + self.activation = nn.Identity() if activation is None else ACT2CLS[activation]() + + def forward(self, hidden_state): + hidden_state = self.conv(hidden_state) + hidden_state = self.norm(hidden_state) + hidden_state = self.activation(hidden_state) + return hidden_state + + +class RTDetrEncoderLayer(nn.Module): + def __init__(self, config: RTDetrConfig): + super().__init__() + self.normalize_before = config.normalize_before + + # self-attention + self.self_attn = RTDetrMultiheadAttention( + embed_dim=config.encoder_hidden_dim, + num_heads=config.num_attention_heads, + dropout=config.dropout, + ) + self.self_attn_layer_norm = nn.LayerNorm(config.encoder_hidden_dim, eps=config.layer_norm_eps) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.encoder_activation_function] + self.activation_dropout = config.activation_dropout + self.fc1 = nn.Linear(config.encoder_hidden_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, config.encoder_hidden_dim) + self.final_layer_norm = nn.LayerNorm(config.encoder_hidden_dim, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + position_embeddings: Optional[torch.Tensor] = None, + output_attentions: bool = False, + **kwargs, + ): + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative + values. + position_embeddings (`torch.FloatTensor`, *optional*): + Object queries (also called content embeddings), to be added to the hidden states. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + if self.normalize_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_embeddings=position_embeddings, + output_attentions=output_attentions, + ) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + if not self.normalize_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + if self.normalize_before: + hidden_states = self.final_layer_norm(hidden_states) + residual = hidden_states + + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + + hidden_states = self.fc2(hidden_states) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + hidden_states = residual + hidden_states + if not self.normalize_before: + hidden_states = self.final_layer_norm(hidden_states) + + if self.training: + if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class RTDetrRepVggBlock(nn.Module): + """ + RepVGG architecture block introduced by the work "RepVGG: Making VGG-style ConvNets Great Again". + """ + + def __init__(self, config: RTDetrConfig): + super().__init__() + + activation = config.activation_function + hidden_channels = int(config.encoder_hidden_dim * config.hidden_expansion) + self.conv1 = RTDetrConvNormLayer(config, hidden_channels, hidden_channels, 3, 1, padding=1) + self.conv2 = RTDetrConvNormLayer(config, hidden_channels, hidden_channels, 1, 1, padding=0) + self.activation = nn.Identity() if activation is None else ACT2CLS[activation]() + + def forward(self, x): + y = self.conv1(x) + self.conv2(x) + return self.activation(y) + + +class RTDetrCSPRepLayer(nn.Module): + """ + Cross Stage Partial (CSP) network layer with RepVGG blocks. + """ + + def __init__(self, config: RTDetrConfig): + super().__init__() + + in_channels = config.encoder_hidden_dim * 2 + out_channels = config.encoder_hidden_dim + num_blocks = 3 + activation = config.activation_function + + hidden_channels = int(out_channels * config.hidden_expansion) + self.conv1 = RTDetrConvNormLayer(config, in_channels, hidden_channels, 1, 1, activation=activation) + self.conv2 = RTDetrConvNormLayer(config, in_channels, hidden_channels, 1, 1, activation=activation) + self.bottlenecks = nn.Sequential(*[RTDetrRepVggBlock(config) for _ in range(num_blocks)]) + if hidden_channels != out_channels: + self.conv3 = RTDetrConvNormLayer(config, hidden_channels, out_channels, 1, 1, activation=activation) + else: + self.conv3 = nn.Identity() + + def forward(self, hidden_state): + hidden_state_1 = self.conv1(hidden_state) + hidden_state_1 = self.bottlenecks(hidden_state_1) + hidden_state_2 = self.conv2(hidden_state) + return self.conv3(hidden_state_1 + hidden_state_2) + + +# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrMultiscaleDeformableAttention with DeformableDetr->RTDetr +class RTDetrMultiscaleDeformableAttention(nn.Module): + """ + Multiscale deformable attention as proposed in Deformable DETR. + """ + + def __init__(self, config: RTDetrConfig, num_heads: int, n_points: int): + super().__init__() + + self.attn = MultiScaleDeformableAttention() + + if config.d_model % num_heads != 0: + raise ValueError( + f"embed_dim (d_model) must be divisible by num_heads, but got {config.d_model} and {num_heads}" + ) + dim_per_head = config.d_model // num_heads + # check if dim_per_head is power of 2 + if not ((dim_per_head & (dim_per_head - 1) == 0) and dim_per_head != 0): + warnings.warn( + "You'd better set embed_dim (d_model) in RTDetrMultiscaleDeformableAttention to make the" + " dimension of each attention head a power of 2 which is more efficient in the authors' CUDA" + " implementation." + ) + + self.im2col_step = 64 + + self.d_model = config.d_model + self.n_levels = config.num_feature_levels + self.n_heads = num_heads + self.n_points = n_points + + self.sampling_offsets = nn.Linear(config.d_model, num_heads * self.n_levels * n_points * 2) + self.attention_weights = nn.Linear(config.d_model, num_heads * self.n_levels * n_points) + self.value_proj = nn.Linear(config.d_model, config.d_model) + self.output_proj = nn.Linear(config.d_model, config.d_model) + + self.disable_custom_kernels = config.disable_custom_kernels + + def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]): + return tensor if position_embeddings is None else tensor + position_embeddings + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states=None, + encoder_attention_mask=None, + position_embeddings: Optional[torch.Tensor] = None, + reference_points=None, + spatial_shapes=None, + spatial_shapes_list=None, + level_start_index=None, + output_attentions: bool = False, + ): + # add position embeddings to the hidden states before projecting to queries and keys + if position_embeddings is not None: + hidden_states = self.with_pos_embed(hidden_states, position_embeddings) + + batch_size, num_queries, _ = hidden_states.shape + batch_size, sequence_length, _ = encoder_hidden_states.shape + total_elements = sum(height * width for height, width in spatial_shapes_list) + if total_elements != sequence_length: + raise ValueError( + "Make sure to align the spatial shapes with the sequence length of the encoder hidden states" + ) + + value = self.value_proj(encoder_hidden_states) + if attention_mask is not None: + # we invert the attention_mask + value = value.masked_fill(~attention_mask[..., None], float(0)) + value = value.view(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads) + sampling_offsets = self.sampling_offsets(hidden_states).view( + batch_size, num_queries, self.n_heads, self.n_levels, self.n_points, 2 + ) + attention_weights = self.attention_weights(hidden_states).view( + batch_size, num_queries, self.n_heads, self.n_levels * self.n_points + ) + attention_weights = F.softmax(attention_weights, -1).view( + batch_size, num_queries, self.n_heads, self.n_levels, self.n_points + ) + # batch_size, num_queries, n_heads, n_levels, n_points, 2 + num_coordinates = reference_points.shape[-1] + if num_coordinates == 2: + offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1) + sampling_locations = ( + reference_points[:, :, None, :, None, :] + + sampling_offsets / offset_normalizer[None, None, None, :, None, :] + ) + elif num_coordinates == 4: + sampling_locations = ( + reference_points[:, :, None, :, None, :2] + + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 + ) + else: + raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}") + + output = self.attn( + value, + spatial_shapes, + spatial_shapes_list, + level_start_index, + sampling_locations, + attention_weights, + self.im2col_step, + ) + + output = self.output_proj(output) + + return output, attention_weights + + +class RTDetrMultiheadAttention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. + + Here, we add position embeddings to the queries and keys (as explained in the Deformable DETR paper). + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + bias: bool = True, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + if self.head_dim * num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def _reshape(self, tensor: torch.Tensor, seq_len: int, batch_size: int): + return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]): + return tensor if position_embeddings is None else tensor + position_embeddings + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_embeddings: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + batch_size, target_len, embed_dim = hidden_states.size() + # add position embeddings to the hidden states before projecting to queries and keys + if position_embeddings is not None: + hidden_states_original = hidden_states + hidden_states = self.with_pos_embed(hidden_states, position_embeddings) + + # get queries, keys and values + query_states = self.q_proj(hidden_states) * self.scaling + key_states = self._reshape(self.k_proj(hidden_states), -1, batch_size) + value_states = self._reshape(self.v_proj(hidden_states_original), -1, batch_size) + + proj_shape = (batch_size * self.num_heads, -1, self.head_dim) + query_states = self._reshape(query_states, target_len, batch_size).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + source_len = key_states.size(1) + + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (batch_size * self.num_heads, target_len, source_len): + raise ValueError( + f"Attention weights should be of size {(batch_size * self.num_heads, target_len, source_len)}, but is" + f" {attn_weights.size()}" + ) + + # expand attention_mask + if attention_mask is not None: + # [seq_len, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len] + attention_mask = attention_mask.expand(batch_size, 1, *attention_mask.size()) + + if attention_mask is not None: + if attention_mask.size() != (batch_size, 1, target_len, source_len): + raise ValueError( + f"Attention mask should be of size {(batch_size, 1, target_len, source_len)}, but is" + f" {attention_mask.size()}" + ) + attn_weights = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attention_mask + attn_weights = attn_weights.view(batch_size * self.num_heads, target_len, source_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attn_weights = attn_weights_reshaped.view(batch_size * self.num_heads, target_len, source_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (batch_size * self.num_heads, target_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(batch_size, self.num_heads, target_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(batch_size, self.num_heads, target_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(batch_size, target_len, embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped + + +class RTDetrDecoderLayer(nn.Module): + def __init__(self, config: RTDetrConfig): + super().__init__() + # self-attention + self.self_attn = RTDetrMultiheadAttention( + embed_dim=config.d_model, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + ) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.decoder_activation_function] + self.activation_dropout = config.activation_dropout + + self.self_attn_layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps) + # cross-attention + self.encoder_attn = RTDetrMultiscaleDeformableAttention( + config, + num_heads=config.decoder_attention_heads, + n_points=config.decoder_n_points, + ) + self.encoder_attn_layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps) + # feedforward neural networks + self.fc1 = nn.Linear(config.d_model, config.decoder_ffn_dim) + self.fc2 = nn.Linear(config.decoder_ffn_dim, config.d_model) + self.final_layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Optional[torch.Tensor] = None, + reference_points=None, + spatial_shapes=None, + spatial_shapes_list=None, + level_start_index=None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ): + """ + Args: + hidden_states (`torch.FloatTensor`): + Input to the layer of shape `(seq_len, batch, embed_dim)`. + position_embeddings (`torch.FloatTensor`, *optional*): + Position embeddings that are added to the queries and keys in the self-attention layer. + reference_points (`torch.FloatTensor`, *optional*): + Reference points. + spatial_shapes (`torch.LongTensor`, *optional*): + Spatial shapes. + level_start_index (`torch.LongTensor`, *optional*): + Level start index. + encoder_hidden_states (`torch.FloatTensor`): + cross attention input to the layer of shape `(seq_len, batch, embed_dim)` + encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size + `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative + values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=encoder_attention_mask, + position_embeddings=position_embeddings, + output_attentions=output_attentions, + ) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + second_residual = hidden_states + + # Cross-Attention + cross_attn_weights = None + hidden_states, cross_attn_weights = self.encoder_attn( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + position_embeddings=position_embeddings, + reference_points=reference_points, + spatial_shapes=spatial_shapes, + spatial_shapes_list=spatial_shapes_list, + level_start_index=level_start_index, + output_attentions=output_attentions, + ) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = second_residual + hidden_states + + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + # Fully Connected + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + return outputs + + +class RTDetrPreTrainedModel(PreTrainedModel): + config_class = RTDetrConfig + base_model_prefix = "rt_detr" + main_input_name = "pixel_values" + _no_split_modules = [r"RTDetrHybridEncoder", r"RTDetrDecoderLayer"] + + def _init_weights(self, module): + """Initalize the weights""" + if isinstance(module, (RTDetrForObjectDetection, RTDetrDecoder)): + if module.class_embed is not None: + for layer in module.class_embed: + prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1) + bias = float(-math.log((1 - prior_prob) / prior_prob)) + nn.init.xavier_uniform_(layer.weight) + nn.init.constant_(layer.bias, bias) + + if module.bbox_embed is not None: + for layer in module.bbox_embed: + nn.init.constant_(layer.layers[-1].weight, 0) + nn.init.constant_(layer.layers[-1].bias, 0) + + elif isinstance(module, RTDetrMultiscaleDeformableAttention): + nn.init.constant_(module.sampling_offsets.weight.data, 0.0) + default_dtype = torch.get_default_dtype() + thetas = torch.arange(module.n_heads, dtype=torch.int64).to(default_dtype) * ( + 2.0 * math.pi / module.n_heads + ) + grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) + grid_init = ( + (grid_init / grid_init.abs().max(-1, keepdim=True)[0]) + .view(module.n_heads, 1, 1, 2) + .repeat(1, module.n_levels, module.n_points, 1) + ) + for i in range(module.n_points): + grid_init[:, :, i, :] *= i + 1 + with torch.no_grad(): + module.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) + nn.init.constant_(module.attention_weights.weight.data, 0.0) + nn.init.constant_(module.attention_weights.bias.data, 0.0) + nn.init.xavier_uniform_(module.value_proj.weight.data) + nn.init.constant_(module.value_proj.bias.data, 0.0) + nn.init.xavier_uniform_(module.output_proj.weight.data) + nn.init.constant_(module.output_proj.bias.data, 0.0) + + elif isinstance(module, RTDetrModel): + prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1) + bias = float(-math.log((1 - prior_prob) / prior_prob)) + nn.init.xavier_uniform_(module.enc_score_head.weight) + nn.init.constant_(module.enc_score_head.bias, bias) + + elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + + if hasattr(module, "weight_embedding") and self.config.learn_initial_query: + nn.init.xavier_uniform_(module.weight_embedding.weight) + if hasattr(module, "denoising_class_embed") and self.config.num_denoising > 0: + nn.init.xavier_uniform_(module.denoising_class_embed.weight) + + +RTDETR_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`RTDetrConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +RTDETR_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`RTDetrImageProcessor.__call__`] for details. + pixel_mask (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*): + Mask to avoid performing attention on padding pixel values. Mask values selected in `[0, 1]`: + + - 1 for pixels that are real (i.e. **not masked**), + - 0 for pixels that are padding (i.e. **masked**). + + [What are attention masks?](../glossary#attention-mask) + encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): + Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of + hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you + can choose to directly pass a flattened representation of an image. + decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*): + Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an + embedded representation. + labels (`List[Dict]` of len `(batch_size,)`, *optional*): + Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the + following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch + respectively). The class labels themselves should be a `torch.LongTensor` of len `(number of bounding boxes + in the image,)` and the boxes a `torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class RTDetrEncoder(nn.Module): + def __init__(self, config: RTDetrConfig): + super().__init__() + + self.layers = nn.ModuleList([RTDetrEncoderLayer(config) for _ in range(config.encoder_layers)]) + + def forward(self, src, src_mask=None, pos_embed=None, output_attentions: bool = False) -> torch.Tensor: + hidden_states = src + for layer in self.layers: + hidden_states = layer( + hidden_states, + attention_mask=src_mask, + position_embeddings=pos_embed, + output_attentions=output_attentions, + ) + return hidden_states + + +class RTDetrHybridEncoder(nn.Module): + """ + Decoder consisting of a projection layer, a set of `RTDetrEncoder`, a top-down Feature Pyramid Network + (FPN) and a bottom-up Path Aggregation Network (PAN). More details on the paper: https://arxiv.org/abs/2304.08069 + + Args: + config: RTDetrConfig + """ + + def __init__(self, config: RTDetrConfig): + super().__init__() + self.config = config + self.in_channels = config.encoder_in_channels + self.feat_strides = config.feat_strides + self.encoder_hidden_dim = config.encoder_hidden_dim + self.encode_proj_layers = config.encode_proj_layers + self.positional_encoding_temperature = config.positional_encoding_temperature + self.eval_size = config.eval_size + self.out_channels = [self.encoder_hidden_dim for _ in self.in_channels] + self.out_strides = self.feat_strides + self.num_fpn_stages = len(self.in_channels) - 1 + self.num_pan_stages = len(self.in_channels) - 1 + activation = config.activation_function + + # encoder transformer + self.encoder = nn.ModuleList([RTDetrEncoder(config) for _ in range(len(self.encode_proj_layers))]) + + # top-down FPN + self.lateral_convs = nn.ModuleList() + self.fpn_blocks = nn.ModuleList() + for _ in range(self.num_fpn_stages): + lateral_conv = RTDetrConvNormLayer( + config, + in_channels=self.encoder_hidden_dim, + out_channels=self.encoder_hidden_dim, + kernel_size=1, + stride=1, + activation=activation, + ) + fpn_block = RTDetrCSPRepLayer(config) + self.lateral_convs.append(lateral_conv) + self.fpn_blocks.append(fpn_block) + + # bottom-up PAN + self.downsample_convs = nn.ModuleList() + self.pan_blocks = nn.ModuleList() + for _ in range(self.num_pan_stages): + downsample_conv = RTDetrConvNormLayer( + config, + in_channels=self.encoder_hidden_dim, + out_channels=self.encoder_hidden_dim, + kernel_size=3, + stride=2, + activation=activation, + ) + pan_block = RTDetrCSPRepLayer(config) + self.downsample_convs.append(downsample_conv) + self.pan_blocks.append(pan_block) + + @staticmethod + def build_2d_sincos_position_embedding( + width, height, embed_dim=256, temperature=10000.0, device="cpu", dtype=torch.float32 + ): + grid_w = torch.arange(torch_int(width), device=device).to(dtype) + grid_h = torch.arange(torch_int(height), device=device).to(dtype) + grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="ij") + if embed_dim % 4 != 0: + raise ValueError("Embed dimension must be divisible by 4 for 2D sin-cos position embedding") + pos_dim = embed_dim // 4 + omega = torch.arange(pos_dim, device=device).to(dtype) / pos_dim + omega = 1.0 / (temperature**omega) + + out_w = grid_w.flatten()[..., None] @ omega[None] + out_h = grid_h.flatten()[..., None] @ omega[None] + + return torch.concat([out_w.sin(), out_w.cos(), out_h.sin(), out_h.cos()], dim=1)[None, :, :] + + def forward( + self, + inputs_embeds=None, + attention_mask=None, + position_embeddings=None, + spatial_shapes=None, + level_start_index=None, + valid_ratios=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Flattened feature map (output of the backbone + projection layer) that is passed to the encoder. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding pixel features. Mask values selected in `[0, 1]`: + - 1 for pixel features that are real (i.e. **not masked**), + - 0 for pixel features that are padding (i.e. **masked**). + [What are attention masks?](../glossary#attention-mask) + position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Position embeddings that are added to the queries and keys in each self-attention layer. + spatial_shapes (`torch.LongTensor` of shape `(num_feature_levels, 2)`): + Spatial shapes of each feature map. + level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`): + Starting index of each feature map. + valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`): + Ratio of valid area in each feature level. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + hidden_states = inputs_embeds + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + # encoder + if self.config.encoder_layers > 0: + for i, enc_ind in enumerate(self.encode_proj_layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states[enc_ind],) + height, width = hidden_states[enc_ind].shape[2:] + # flatten [batch, channel, height, width] to [batch, height*width, channel] + src_flatten = hidden_states[enc_ind].flatten(2).permute(0, 2, 1) + if self.training or self.eval_size is None: + pos_embed = self.build_2d_sincos_position_embedding( + width, + height, + self.encoder_hidden_dim, + self.positional_encoding_temperature, + device=src_flatten.device, + dtype=src_flatten.dtype, + ) + else: + pos_embed = None + + layer_outputs = self.encoder[i]( + src_flatten, + pos_embed=pos_embed, + output_attentions=output_attentions, + ) + hidden_states[enc_ind] = ( + layer_outputs[0].permute(0, 2, 1).reshape(-1, self.encoder_hidden_dim, height, width).contiguous() + ) + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states[enc_ind],) + + # top-down FPN + fpn_feature_maps = [hidden_states[-1]] + for idx, (lateral_conv, fpn_block) in enumerate(zip(self.lateral_convs, self.fpn_blocks)): + backbone_feature_map = hidden_states[self.num_fpn_stages - idx - 1] + top_fpn_feature_map = fpn_feature_maps[-1] + # apply lateral block + top_fpn_feature_map = lateral_conv(top_fpn_feature_map) + fpn_feature_maps[-1] = top_fpn_feature_map + # apply fpn block + top_fpn_feature_map = F.interpolate(top_fpn_feature_map, scale_factor=2.0, mode="nearest") + fused_feature_map = torch.concat([top_fpn_feature_map, backbone_feature_map], dim=1) + new_fpn_feature_map = fpn_block(fused_feature_map) + fpn_feature_maps.append(new_fpn_feature_map) + + fpn_feature_maps = fpn_feature_maps[::-1] + + # bottom-up PAN + pan_feature_maps = [fpn_feature_maps[0]] + for idx, (downsample_conv, pan_block) in enumerate(zip(self.downsample_convs, self.pan_blocks)): + top_pan_feature_map = pan_feature_maps[-1] + fpn_feature_map = fpn_feature_maps[idx + 1] + downsampled_feature_map = downsample_conv(top_pan_feature_map) + fused_feature_map = torch.concat([downsampled_feature_map, fpn_feature_map], dim=1) + new_pan_feature_map = pan_block(fused_feature_map) + pan_feature_maps.append(new_pan_feature_map) + + if not return_dict: + return tuple(v for v in [pan_feature_maps, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=pan_feature_maps, hidden_states=encoder_states, attentions=all_attentions + ) + + +class RTDetrDecoder(RTDetrPreTrainedModel): + def __init__(self, config: RTDetrConfig): + super().__init__(config) + + self.dropout = config.dropout + self.layers = nn.ModuleList([RTDetrDecoderLayer(config) for _ in range(config.decoder_layers)]) + self.query_pos_head = RTDetrMLPPredictionHead(config, 4, 2 * config.d_model, config.d_model, num_layers=2) + + # hack implementation for iterative bounding box refinement and two-stage Deformable DETR + self.bbox_embed = None + self.class_embed = None + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + position_embeddings=None, + reference_points=None, + spatial_shapes=None, + spatial_shapes_list=None, + level_start_index=None, + valid_ratios=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`): + The query embeddings that are passed into the decoder. + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing cross-attention on padding pixel_values of the encoder. Mask values selected + in `[0, 1]`: + - 1 for pixels that are real (i.e. **not masked**), + - 0 for pixels that are padding (i.e. **masked**). + position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*): + Position embeddings that are added to the queries and keys in each self-attention layer. + reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)` is `as_two_stage` else `(batch_size, num_queries, 2)` or , *optional*): + Reference point in range `[0, 1]`, top-left (0,0), bottom-right (1, 1), including padding area. + spatial_shapes (`torch.FloatTensor` of shape `(num_feature_levels, 2)`): + Spatial shapes of the feature maps. + level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`, *optional*): + Indexes for the start of each feature level. In range `[0, sequence_length]`. + valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`, *optional*): + Ratio of valid area in each feature level. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if inputs_embeds is not None: + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + intermediate = () + intermediate_reference_points = () + intermediate_logits = () + + reference_points = F.sigmoid(reference_points) + + # https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/rtdetr_pytorch/src/zoo/rtdetr/rtdetr_decoder.py#L252 + for idx, decoder_layer in enumerate(self.layers): + reference_points_input = reference_points.unsqueeze(2) + position_embeddings = self.query_pos_head(reference_points) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + position_embeddings=position_embeddings, + encoder_hidden_states=encoder_hidden_states, + reference_points=reference_points_input, + spatial_shapes=spatial_shapes, + spatial_shapes_list=spatial_shapes_list, + level_start_index=level_start_index, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + # hack implementation for iterative bounding box refinement + if self.bbox_embed is not None: + tmp = self.bbox_embed[idx](hidden_states) + new_reference_points = F.sigmoid(tmp + inverse_sigmoid(reference_points)) + reference_points = new_reference_points.detach() + + intermediate += (hidden_states,) + intermediate_reference_points += ( + (new_reference_points,) if self.bbox_embed is not None else (reference_points,) + ) + + if self.class_embed is not None: + logits = self.class_embed[idx](hidden_states) + intermediate_logits += (logits,) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + # Keep batch_size as first dimension + intermediate = torch.stack(intermediate, dim=1) + intermediate_reference_points = torch.stack(intermediate_reference_points, dim=1) + if self.class_embed is not None: + intermediate_logits = torch.stack(intermediate_logits, dim=1) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + intermediate, + intermediate_logits, + intermediate_reference_points, + all_hidden_states, + all_self_attns, + all_cross_attentions, + ] + if v is not None + ) + return RTDetrDecoderOutput( + last_hidden_state=hidden_states, + intermediate_hidden_states=intermediate, + intermediate_logits=intermediate_logits, + intermediate_reference_points=intermediate_reference_points, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + +# taken from https://github.com/facebookresearch/detr/blob/master/models/detr.py +class RTDetrMLPPredictionHead(nn.Module): + """ + Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates, + height and width of a bounding box w.r.t. an image. + + Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py + Origin from https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/rtdetr_paddle/ppdet/modeling/transformers/utils.py#L453 + + """ + + def __init__(self, config, input_dim, d_model, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [d_model] * (num_layers - 1) + self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + + +@add_start_docstrings( + """ + RT-DETR Model (consisting of a backbone and encoder-decoder) outputting raw hidden states without any head on top. + """, + RTDETR_START_DOCSTRING, +) +class RTDetrModel(RTDetrPreTrainedModel): + def __init__(self, config: RTDetrConfig): + super().__init__(config) + + # Create backbone + self.backbone = RTDetrConvEncoder(config) + intermediate_channel_sizes = self.backbone.intermediate_channel_sizes + + # Create encoder input projection layers + # https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/rtdetr_pytorch/src/zoo/rtdetr/hybrid_encoder.py#L212 + num_backbone_outs = len(intermediate_channel_sizes) + encoder_input_proj_list = [] + for _ in range(num_backbone_outs): + in_channels = intermediate_channel_sizes[_] + encoder_input_proj_list.append( + nn.Sequential( + nn.Conv2d(in_channels, config.encoder_hidden_dim, kernel_size=1, bias=False), + nn.BatchNorm2d(config.encoder_hidden_dim), + ) + ) + self.encoder_input_proj = nn.ModuleList(encoder_input_proj_list) + + # Create encoder + self.encoder = RTDetrHybridEncoder(config) + + # denoising part + if config.num_denoising > 0: + self.denoising_class_embed = nn.Embedding( + config.num_labels + 1, config.d_model, padding_idx=config.num_labels + ) + + # decoder embedding + if config.learn_initial_query: + self.weight_embedding = nn.Embedding(config.num_queries, config.d_model) + + # encoder head + self.enc_output = nn.Sequential( + nn.Linear(config.d_model, config.d_model), + nn.LayerNorm(config.d_model, eps=config.layer_norm_eps), + ) + self.enc_score_head = nn.Linear(config.d_model, config.num_labels) + self.enc_bbox_head = RTDetrMLPPredictionHead(config, config.d_model, config.d_model, 4, num_layers=3) + + # init encoder output anchors and valid_mask + if config.anchor_image_size: + self.anchors, self.valid_mask = self.generate_anchors(dtype=self.dtype) + + # Create decoder input projection layers + # https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/rtdetr_pytorch/src/zoo/rtdetr/rtdetr_decoder.py#L412 + num_backbone_outs = len(config.decoder_in_channels) + decoder_input_proj_list = [] + for _ in range(num_backbone_outs): + in_channels = config.decoder_in_channels[_] + decoder_input_proj_list.append( + nn.Sequential( + nn.Conv2d(in_channels, config.d_model, kernel_size=1, bias=False), + nn.BatchNorm2d(config.d_model, config.batch_norm_eps), + ) + ) + for _ in range(config.num_feature_levels - num_backbone_outs): + decoder_input_proj_list.append( + nn.Sequential( + nn.Conv2d(in_channels, config.d_model, kernel_size=3, stride=2, padding=1, bias=False), + nn.BatchNorm2d(config.d_model, config.batch_norm_eps), + ) + ) + in_channels = config.d_model + self.decoder_input_proj = nn.ModuleList(decoder_input_proj_list) + + # decoder + self.decoder = RTDetrDecoder(config) + + self.post_init() + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + def freeze_backbone(self): + for param in self.backbone.parameters(): + param.requires_grad_(False) + + def unfreeze_backbone(self): + for param in self.backbone.parameters(): + param.requires_grad_(True) + + @compile_compatible_method_lru_cache(maxsize=32) + def generate_anchors(self, spatial_shapes=None, grid_size=0.05, device="cpu", dtype=torch.float32): + if spatial_shapes is None: + spatial_shapes = [ + [int(self.config.anchor_image_size[0] / s), int(self.config.anchor_image_size[1] / s)] + for s in self.config.feat_strides + ] + anchors = [] + for level, (height, width) in enumerate(spatial_shapes): + grid_y, grid_x = torch.meshgrid( + torch.arange(end=height, device=device).to(dtype), + torch.arange(end=width, device=device).to(dtype), + indexing="ij", + ) + grid_xy = torch.stack([grid_x, grid_y], -1) + grid_xy = grid_xy.unsqueeze(0) + 0.5 + grid_xy[..., 0] /= width + grid_xy[..., 1] /= height + wh = torch.ones_like(grid_xy) * grid_size * (2.0**level) + anchors.append(torch.concat([grid_xy, wh], -1).reshape(-1, height * width, 4)) + # define the valid range for anchor coordinates + eps = 1e-2 + anchors = torch.concat(anchors, 1) + valid_mask = ((anchors > eps) * (anchors < 1 - eps)).all(-1, keepdim=True) + anchors = torch.log(anchors / (1 - anchors)) + anchors = torch.where(valid_mask, anchors, torch.tensor(torch.finfo(dtype).max, dtype=dtype, device=device)) + + return anchors, valid_mask + + @add_start_docstrings_to_model_forward(RTDETR_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=RTDetrModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: torch.FloatTensor, + pixel_mask: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[List[dict]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.FloatTensor], RTDetrModelOutput]: + r""" + Returns: + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, RTDetrModel + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("PekingU/rtdetr_r50vd") + >>> model = RTDetrModel.from_pretrained("PekingU/rtdetr_r50vd") + + >>> inputs = image_processor(images=image, return_tensors="pt") + + >>> outputs = model(**inputs) + + >>> last_hidden_states = outputs.last_hidden_state + >>> list(last_hidden_states.shape) + [1, 300, 256] + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size, num_channels, height, width = pixel_values.shape + device = pixel_values.device + + if pixel_mask is None: + pixel_mask = torch.ones(((batch_size, height, width)), device=device) + + features = self.backbone(pixel_values, pixel_mask) + + proj_feats = [self.encoder_input_proj[level](source) for level, (source, mask) in enumerate(features)] + + if encoder_outputs is None: + encoder_outputs = self.encoder( + proj_feats, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if output_hidden_states else None, + attentions=encoder_outputs[2] + if len(encoder_outputs) > 2 + else encoder_outputs[1] + if output_attentions + else None, + ) + + # Equivalent to def _get_encoder_input + # https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/rtdetr_pytorch/src/zoo/rtdetr/rtdetr_decoder.py#L412 + sources = [] + for level, source in enumerate(encoder_outputs[0]): + sources.append(self.decoder_input_proj[level](source)) + + # Lowest resolution feature maps are obtained via 3x3 stride 2 convolutions on the final stage + if self.config.num_feature_levels > len(sources): + _len_sources = len(sources) + sources.append(self.decoder_input_proj[_len_sources](encoder_outputs[0])[-1]) + for i in range(_len_sources + 1, self.config.num_feature_levels): + sources.append(self.decoder_input_proj[i](encoder_outputs[0][-1])) + + # Prepare encoder inputs (by flattening) + source_flatten = [] + spatial_shapes_list = [] + spatial_shapes = torch.empty((len(sources), 2), device=device, dtype=torch.long) + for level, source in enumerate(sources): + height, width = source.shape[-2:] + spatial_shapes[level, 0] = height + spatial_shapes[level, 1] = width + spatial_shapes_list.append((height, width)) + source = source.flatten(2).transpose(1, 2) + source_flatten.append(source) + source_flatten = torch.cat(source_flatten, 1) + level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) + + # prepare denoising training + if self.training and self.config.num_denoising > 0 and labels is not None: + ( + denoising_class, + denoising_bbox_unact, + attention_mask, + denoising_meta_values, + ) = get_contrastive_denoising_training_group( + targets=labels, + num_classes=self.config.num_labels, + num_queries=self.config.num_queries, + class_embed=self.denoising_class_embed, + num_denoising_queries=self.config.num_denoising, + label_noise_ratio=self.config.label_noise_ratio, + box_noise_scale=self.config.box_noise_scale, + ) + else: + denoising_class, denoising_bbox_unact, attention_mask, denoising_meta_values = None, None, None, None + + batch_size = len(source_flatten) + device = source_flatten.device + dtype = source_flatten.dtype + + # prepare input for decoder + if self.training or self.config.anchor_image_size is None: + # Pass spatial_shapes as tuple to make it hashable and make sure + # lru_cache is working for generate_anchors() + spatial_shapes_tuple = tuple(spatial_shapes_list) + anchors, valid_mask = self.generate_anchors(spatial_shapes_tuple, device=device, dtype=dtype) + else: + anchors, valid_mask = self.anchors, self.valid_mask + anchors, valid_mask = anchors.to(device, dtype), valid_mask.to(device, dtype) + + # use the valid_mask to selectively retain values in the feature map where the mask is `True` + memory = valid_mask.to(source_flatten.dtype) * source_flatten + + output_memory = self.enc_output(memory) + + enc_outputs_class = self.enc_score_head(output_memory) + enc_outputs_coord_logits = self.enc_bbox_head(output_memory) + anchors + + _, topk_ind = torch.topk(enc_outputs_class.max(-1).values, self.config.num_queries, dim=1) + + reference_points_unact = enc_outputs_coord_logits.gather( + dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_coord_logits.shape[-1]) + ) + + enc_topk_bboxes = F.sigmoid(reference_points_unact) + if denoising_bbox_unact is not None: + reference_points_unact = torch.concat([denoising_bbox_unact, reference_points_unact], 1) + + enc_topk_logits = enc_outputs_class.gather( + dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_class.shape[-1]) + ) + + # extract region features + if self.config.learn_initial_query: + target = self.weight_embedding.tile([batch_size, 1, 1]) + else: + target = output_memory.gather(dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, output_memory.shape[-1])) + target = target.detach() + + if denoising_class is not None: + target = torch.concat([denoising_class, target], 1) + + init_reference_points = reference_points_unact.detach() + + # decoder + decoder_outputs = self.decoder( + inputs_embeds=target, + encoder_hidden_states=source_flatten, + encoder_attention_mask=attention_mask, + reference_points=init_reference_points, + spatial_shapes=spatial_shapes, + spatial_shapes_list=spatial_shapes_list, + level_start_index=level_start_index, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + enc_outputs = tuple( + value + for value in [enc_topk_logits, enc_topk_bboxes, enc_outputs_class, enc_outputs_coord_logits] + if value is not None + ) + dn_outputs = tuple(value if value is not None else None for value in [denoising_meta_values]) + tuple_outputs = decoder_outputs + encoder_outputs + (init_reference_points,) + enc_outputs + dn_outputs + + return tuple_outputs + + return RTDetrModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + intermediate_hidden_states=decoder_outputs.intermediate_hidden_states, + intermediate_logits=decoder_outputs.intermediate_logits, + intermediate_reference_points=decoder_outputs.intermediate_reference_points, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + init_reference_points=init_reference_points, + enc_topk_logits=enc_topk_logits, + enc_topk_bboxes=enc_topk_bboxes, + enc_outputs_class=enc_outputs_class, + enc_outputs_coord_logits=enc_outputs_coord_logits, + denoising_meta_values=denoising_meta_values, + ) + + +@add_start_docstrings( + """ + RT-DETR Model (consisting of a backbone and encoder-decoder) outputting bounding boxes and logits to be further + decoded into scores and classes. + """, + RTDETR_START_DOCSTRING, +) +class RTDetrForObjectDetection(RTDetrPreTrainedModel): + # When using clones, all layers > 0 will be clones, but layer 0 *is* required + _tied_weights_keys = ["bbox_embed", "class_embed"] + # We can't initialize the model on meta device as some weights are modified during the initialization + _no_split_modules = None + + def __init__(self, config: RTDetrConfig): + super().__init__(config) + + # RTDETR encoder-decoder model + self.model = RTDetrModel(config) + + # Detection heads on top + self.class_embed = partial(nn.Linear, config.d_model, config.num_labels) + self.bbox_embed = partial(RTDetrMLPPredictionHead, config, config.d_model, config.d_model, 4, num_layers=3) + + # if two-stage, the last class_embed and bbox_embed is for region proposal generation + num_pred = config.decoder_layers + if config.with_box_refine: + self.class_embed = _get_clones(self.class_embed, num_pred) + self.bbox_embed = _get_clones(self.bbox_embed, num_pred) + else: + self.class_embed = nn.ModuleList([self.class_embed() for _ in range(num_pred)]) + self.bbox_embed = nn.ModuleList([self.bbox_embed() for _ in range(num_pred)]) + + # hack implementation for iterative bounding box refinement + self.model.decoder.class_embed = self.class_embed + self.model.decoder.bbox_embed = self.bbox_embed + + # Initialize weights and apply final processing + self.post_init() + + @torch.jit.unused + def _set_aux_loss(self, outputs_class, outputs_coord): + # this is a workaround to make torchscript happy, as torchscript + # doesn't support dictionary with non-homogeneous values, such + # as a dict having both a Tensor and a list. + return [{"logits": a, "pred_boxes": b} for a, b in zip(outputs_class, outputs_coord)] + + @add_start_docstrings_to_model_forward(RTDETR_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=RTDetrObjectDetectionOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: torch.FloatTensor, + pixel_mask: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[List[dict]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **loss_kwargs, + ) -> Union[Tuple[torch.FloatTensor], RTDetrObjectDetectionOutput]: + r""" + labels (`List[Dict]` of len `(batch_size,)`, *optional*): + Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the + following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch + respectively). The class labels themselves should be a `torch.LongTensor` of len `(number of bounding boxes + in the image,)` and the boxes a `torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)`. + + Returns: + + Examples: + + ```python + >>> from transformers import RTDetrImageProcessor, RTDetrForObjectDetection + >>> from PIL import Image + >>> import requests + >>> import torch + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = RTDetrImageProcessor.from_pretrained("PekingU/rtdetr_r50vd") + >>> model = RTDetrForObjectDetection.from_pretrained("PekingU/rtdetr_r50vd") + + >>> # prepare image for the model + >>> inputs = image_processor(images=image, return_tensors="pt") + + >>> # forward pass + >>> outputs = model(**inputs) + + >>> logits = outputs.logits + >>> list(logits.shape) + [1, 300, 80] + + >>> boxes = outputs.pred_boxes + >>> list(boxes.shape) + [1, 300, 4] + + >>> # convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax) + >>> target_sizes = torch.tensor([image.size[::-1]]) + >>> results = image_processor.post_process_object_detection(outputs, threshold=0.9, target_sizes=target_sizes)[ + ... 0 + ... ] + + >>> for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): + ... box = [round(i, 2) for i in box.tolist()] + ... print( + ... f"Detected {model.config.id2label[label.item()]} with confidence " + ... f"{round(score.item(), 3)} at location {box}" + ... ) + Detected sofa with confidence 0.97 at location [0.14, 0.38, 640.13, 476.21] + Detected cat with confidence 0.96 at location [343.38, 24.28, 640.14, 371.5] + Detected cat with confidence 0.958 at location [13.23, 54.18, 318.98, 472.22] + Detected remote with confidence 0.951 at location [40.11, 73.44, 175.96, 118.48] + Detected remote with confidence 0.924 at location [333.73, 76.58, 369.97, 186.99] + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + pixel_values, + pixel_mask=pixel_mask, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + labels=labels, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + denoising_meta_values = ( + outputs.denoising_meta_values if return_dict else outputs[-1] if self.training else None + ) + + outputs_class = outputs.intermediate_logits if return_dict else outputs[2] + outputs_coord = outputs.intermediate_reference_points if return_dict else outputs[3] + + logits = outputs_class[:, -1] + pred_boxes = outputs_coord[:, -1] + + loss, loss_dict, auxiliary_outputs, enc_topk_logits, enc_topk_bboxes = None, None, None, None, None + if labels is not None: + enc_topk_logits = outputs.enc_topk_logits if return_dict else outputs[-5] + enc_topk_bboxes = outputs.enc_topk_bboxes if return_dict else outputs[-4] + loss, loss_dict, auxiliary_outputs = self.loss_function( + logits, + labels, + self.device, + pred_boxes, + self.config, + outputs_class, + outputs_coord, + enc_topk_logits=enc_topk_logits, + enc_topk_bboxes=enc_topk_bboxes, + denoising_meta_values=denoising_meta_values, + **loss_kwargs, + ) + + if not return_dict: + if auxiliary_outputs is not None: + output = (logits, pred_boxes) + (auxiliary_outputs,) + outputs + else: + output = (logits, pred_boxes) + outputs + return ((loss, loss_dict) + output) if loss is not None else output + + return RTDetrObjectDetectionOutput( + loss=loss, + loss_dict=loss_dict, + logits=logits, + pred_boxes=pred_boxes, + auxiliary_outputs=auxiliary_outputs, + last_hidden_state=outputs.last_hidden_state, + intermediate_hidden_states=outputs.intermediate_hidden_states, + intermediate_logits=outputs.intermediate_logits, + intermediate_reference_points=outputs.intermediate_reference_points, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + init_reference_points=outputs.init_reference_points, + enc_topk_logits=outputs.enc_topk_logits, + enc_topk_bboxes=outputs.enc_topk_bboxes, + enc_outputs_class=outputs.enc_outputs_class, + enc_outputs_coord_logits=outputs.enc_outputs_coord_logits, + denoising_meta_values=outputs.denoising_meta_values, + ) + + +__all__ = [ + "RTDetrForObjectDetection", + "RTDetrModel", + "RTDetrPreTrainedModel", +] diff --git a/docs/transformers/src/transformers/models/rt_detr/modeling_rt_detr_resnet.py b/docs/transformers/src/transformers/models/rt_detr/modeling_rt_detr_resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..72ef9a0ac4db8b83fc695c61ea4a5612d94d057c --- /dev/null +++ b/docs/transformers/src/transformers/models/rt_detr/modeling_rt_detr_resnet.py @@ -0,0 +1,440 @@ +# coding=utf-8 +# Copyright 2024 Microsoft Research, Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +PyTorch RTDetr specific ResNet model. The main difference between hugginface ResNet model is that this RTDetrResNet model forces to use shortcut at the first layer in the resnet-18/34 models. +See https://github.com/lyuwenyu/RT-DETR/blob/5b628eaa0a2fc25bdafec7e6148d5296b144af85/rtdetr_pytorch/src/nn/backbone/presnet.py#L126 for details. +""" + +import math +from typing import Optional + +from torch import Tensor, nn + +from ...activations import ACT2FN +from ...modeling_outputs import ( + BackboneOutput, + BaseModelOutputWithNoAttention, +) +from ...modeling_utils import PreTrainedModel +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from ...utils.backbone_utils import BackboneMixin +from .configuration_rt_detr_resnet import RTDetrResNetConfig + + +logger = logging.get_logger(__name__) + +# General docstring +_CONFIG_FOR_DOC = "RTDetrResNetConfig" + +# Base docstring +_CHECKPOINT_FOR_DOC = "microsoft/resnet-50" +_EXPECTED_OUTPUT_SHAPE = [1, 2048, 7, 7] + + +# Copied from transformers.models.resnet.modeling_resnet.ResNetConvLayer -> RTDetrResNetConvLayer +class RTDetrResNetConvLayer(nn.Module): + def __init__( + self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1, activation: str = "relu" + ): + super().__init__() + self.convolution = nn.Conv2d( + in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=kernel_size // 2, bias=False + ) + self.normalization = nn.BatchNorm2d(out_channels) + self.activation = ACT2FN[activation] if activation is not None else nn.Identity() + + def forward(self, input: Tensor) -> Tensor: + hidden_state = self.convolution(input) + hidden_state = self.normalization(hidden_state) + hidden_state = self.activation(hidden_state) + return hidden_state + + +class RTDetrResNetEmbeddings(nn.Module): + """ + ResNet Embeddings (stem) composed of a deep aggressive convolution. + """ + + def __init__(self, config: RTDetrResNetConfig): + super().__init__() + self.embedder = nn.Sequential( + *[ + RTDetrResNetConvLayer( + config.num_channels, + config.embedding_size // 2, + kernel_size=3, + stride=2, + activation=config.hidden_act, + ), + RTDetrResNetConvLayer( + config.embedding_size // 2, + config.embedding_size // 2, + kernel_size=3, + stride=1, + activation=config.hidden_act, + ), + RTDetrResNetConvLayer( + config.embedding_size // 2, + config.embedding_size, + kernel_size=3, + stride=1, + activation=config.hidden_act, + ), + ] + ) + self.pooler = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.num_channels = config.num_channels + + def forward(self, pixel_values: Tensor) -> Tensor: + num_channels = pixel_values.shape[1] + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + embedding = self.embedder(pixel_values) + embedding = self.pooler(embedding) + return embedding + + +# Copied from transformers.models.resnet.modeling_resnet.ResNetShortCut -> RTDetrResNetChortCut +class RTDetrResNetShortCut(nn.Module): + """ + ResNet shortcut, used to project the residual features to the correct size. If needed, it is also used to + downsample the input using `stride=2`. + """ + + def __init__(self, in_channels: int, out_channels: int, stride: int = 2): + super().__init__() + self.convolution = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False) + self.normalization = nn.BatchNorm2d(out_channels) + + def forward(self, input: Tensor) -> Tensor: + hidden_state = self.convolution(input) + hidden_state = self.normalization(hidden_state) + return hidden_state + + +class RTDetrResNetBasicLayer(nn.Module): + """ + A classic ResNet's residual layer composed by two `3x3` convolutions. + See https://github.com/lyuwenyu/RT-DETR/blob/5b628eaa0a2fc25bdafec7e6148d5296b144af85/rtdetr_pytorch/src/nn/backbone/presnet.py#L34. + """ + + def __init__( + self, + config: RTDetrResNetConfig, + in_channels: int, + out_channels: int, + stride: int = 1, + should_apply_shortcut: bool = False, + ): + super().__init__() + if in_channels != out_channels: + self.shortcut = ( + nn.Sequential( + *[nn.AvgPool2d(2, 2, 0, ceil_mode=True), RTDetrResNetShortCut(in_channels, out_channels, stride=1)] + ) + if should_apply_shortcut + else nn.Identity() + ) + else: + self.shortcut = ( + RTDetrResNetShortCut(in_channels, out_channels, stride=stride) + if should_apply_shortcut + else nn.Identity() + ) + self.layer = nn.Sequential( + RTDetrResNetConvLayer(in_channels, out_channels, stride=stride), + RTDetrResNetConvLayer(out_channels, out_channels, activation=None), + ) + self.activation = ACT2FN[config.hidden_act] + + def forward(self, hidden_state): + residual = hidden_state + hidden_state = self.layer(hidden_state) + residual = self.shortcut(residual) + hidden_state += residual + hidden_state = self.activation(hidden_state) + return hidden_state + + +class RTDetrResNetBottleNeckLayer(nn.Module): + """ + A classic RTDetrResNet's bottleneck layer composed by three `3x3` convolutions. + + The first `1x1` convolution reduces the input by a factor of `reduction` in order to make the second `3x3` + convolution faster. The last `1x1` convolution remaps the reduced features to `out_channels`. If + `downsample_in_bottleneck` is true, downsample will be in the first layer instead of the second layer. + """ + + def __init__( + self, + config: RTDetrResNetConfig, + in_channels: int, + out_channels: int, + stride: int = 1, + ): + super().__init__() + reduction = 4 + should_apply_shortcut = in_channels != out_channels or stride != 1 + reduces_channels = out_channels // reduction + if stride == 2: + self.shortcut = nn.Sequential( + *[ + nn.AvgPool2d(2, 2, 0, ceil_mode=True), + RTDetrResNetShortCut(in_channels, out_channels, stride=1) + if should_apply_shortcut + else nn.Identity(), + ] + ) + else: + self.shortcut = ( + RTDetrResNetShortCut(in_channels, out_channels, stride=stride) + if should_apply_shortcut + else nn.Identity() + ) + self.layer = nn.Sequential( + RTDetrResNetConvLayer( + in_channels, reduces_channels, kernel_size=1, stride=stride if config.downsample_in_bottleneck else 1 + ), + RTDetrResNetConvLayer( + reduces_channels, reduces_channels, stride=stride if not config.downsample_in_bottleneck else 1 + ), + RTDetrResNetConvLayer(reduces_channels, out_channels, kernel_size=1, activation=None), + ) + self.activation = ACT2FN[config.hidden_act] + + def forward(self, hidden_state): + residual = hidden_state + hidden_state = self.layer(hidden_state) + residual = self.shortcut(residual) + hidden_state += residual + hidden_state = self.activation(hidden_state) + return hidden_state + + +class RTDetrResNetStage(nn.Module): + """ + A RTDetrResNet stage composed by stacked layers. + """ + + def __init__( + self, + config: RTDetrResNetConfig, + in_channels: int, + out_channels: int, + stride: int = 2, + depth: int = 2, + ): + super().__init__() + + layer = RTDetrResNetBottleNeckLayer if config.layer_type == "bottleneck" else RTDetrResNetBasicLayer + + if config.layer_type == "bottleneck": + first_layer = layer( + config, + in_channels, + out_channels, + stride=stride, + ) + else: + first_layer = layer(config, in_channels, out_channels, stride=stride, should_apply_shortcut=True) + self.layers = nn.Sequential( + first_layer, *[layer(config, out_channels, out_channels) for _ in range(depth - 1)] + ) + + def forward(self, input: Tensor) -> Tensor: + hidden_state = input + for layer in self.layers: + hidden_state = layer(hidden_state) + return hidden_state + + +# Copied from transformers.models.resnet.modeling_resnet.ResNetEncoder with ResNet->RTDetrResNet +class RTDetrResNetEncoder(nn.Module): + def __init__(self, config: RTDetrResNetConfig): + super().__init__() + self.stages = nn.ModuleList([]) + # based on `downsample_in_first_stage` the first layer of the first stage may or may not downsample the input + self.stages.append( + RTDetrResNetStage( + config, + config.embedding_size, + config.hidden_sizes[0], + stride=2 if config.downsample_in_first_stage else 1, + depth=config.depths[0], + ) + ) + in_out_channels = zip(config.hidden_sizes, config.hidden_sizes[1:]) + for (in_channels, out_channels), depth in zip(in_out_channels, config.depths[1:]): + self.stages.append(RTDetrResNetStage(config, in_channels, out_channels, depth=depth)) + + def forward( + self, hidden_state: Tensor, output_hidden_states: bool = False, return_dict: bool = True + ) -> BaseModelOutputWithNoAttention: + hidden_states = () if output_hidden_states else None + + for stage_module in self.stages: + if output_hidden_states: + hidden_states = hidden_states + (hidden_state,) + + hidden_state = stage_module(hidden_state) + + if output_hidden_states: + hidden_states = hidden_states + (hidden_state,) + + if not return_dict: + return tuple(v for v in [hidden_state, hidden_states] if v is not None) + + return BaseModelOutputWithNoAttention( + last_hidden_state=hidden_state, + hidden_states=hidden_states, + ) + + +# Copied from transformers.models.resnet.modeling_resnet.ResNetPreTrainedModel with ResNet->RTDetrResNet +class RTDetrResNetPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = RTDetrResNetConfig + base_model_prefix = "resnet" + main_input_name = "pixel_values" + _no_split_modules = ["RTDetrResNetConvLayer", "RTDetrResNetShortCut"] + + def _init_weights(self, module): + if isinstance(module, nn.Conv2d): + nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") + # copied from the `reset_parameters` method of `class Linear(Module)` in `torch`. + elif isinstance(module, nn.Linear): + nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5)) + if module.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + nn.init.uniform_(module.bias, -bound, bound) + elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(module.weight, 1) + nn.init.constant_(module.bias, 0) + + +RTDETR_RESNET_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it + as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`RTDetrResNetConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +RTDETR_RESNET_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See + [`RTDetrImageProcessor.__call__`] for details. + + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + """ + ResNet backbone, to be used with frameworks like RTDETR. + """, + RTDETR_RESNET_START_DOCSTRING, +) +class RTDetrResNetBackbone(RTDetrResNetPreTrainedModel, BackboneMixin): + def __init__(self, config): + super().__init__(config) + super()._init_backbone(config) + + self.num_features = [config.embedding_size] + config.hidden_sizes + self.embedder = RTDetrResNetEmbeddings(config) + self.encoder = RTDetrResNetEncoder(config) + + # initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(RTDETR_RESNET_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, pixel_values: Tensor, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None + ) -> BackboneOutput: + """ + Returns: + + Examples: + + ```python + >>> from transformers import RTDetrResNetConfig, RTDetrResNetBackbone + >>> import torch + + >>> config = RTDetrResNetConfig() + >>> model = RTDetrResNetBackbone(config) + + >>> pixel_values = torch.randn(1, 3, 224, 224) + + >>> with torch.no_grad(): + ... outputs = model(pixel_values) + + >>> feature_maps = outputs.feature_maps + >>> list(feature_maps[-1].shape) + [1, 2048, 7, 7] + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + embedding_output = self.embedder(pixel_values) + + outputs = self.encoder(embedding_output, output_hidden_states=True, return_dict=True) + + hidden_states = outputs.hidden_states + + feature_maps = () + for idx, stage in enumerate(self.stage_names): + if stage in self.out_features: + feature_maps += (hidden_states[idx],) + + if not return_dict: + output = (feature_maps,) + if output_hidden_states: + output += (outputs.hidden_states,) + return output + + return BackboneOutput( + feature_maps=feature_maps, + hidden_states=outputs.hidden_states if output_hidden_states else None, + attentions=None, + ) + + +__all__ = [ + "RTDetrResNetBackbone", + "RTDetrResNetPreTrainedModel", +] diff --git a/docs/transformers/src/transformers/models/rt_detr/modular_rt_detr.py b/docs/transformers/src/transformers/models/rt_detr/modular_rt_detr.py new file mode 100644 index 0000000000000000000000000000000000000000..fd9913e97e945cc67b881acba0d5054e5ce60dae --- /dev/null +++ b/docs/transformers/src/transformers/models/rt_detr/modular_rt_detr.py @@ -0,0 +1,410 @@ +import pathlib +from typing import Dict, List, Optional, Tuple, Union + +from transformers.models.detr.image_processing_detr_fast import ( + DetrFastImageProcessorKwargs, + DetrImageProcessorFast, +) + +from ...image_processing_utils import BatchFeature +from ...image_processing_utils_fast import ( + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS, + BaseImageProcessorFast, + SizeDict, + add_start_docstrings, + get_max_height_width, +) +from ...image_transforms import center_to_corners_format +from ...image_utils import ( + IMAGENET_DEFAULT_MEAN, + IMAGENET_DEFAULT_STD, + AnnotationFormat, + AnnotationType, + ChannelDimension, + ImageInput, + PILImageResampling, + get_image_size, + validate_annotations, +) +from ...processing_utils import Unpack +from ...utils import ( + TensorType, + is_torch_available, + is_torchvision_available, + is_torchvision_v2_available, + logging, + requires_backends, +) + + +if is_torch_available(): + import torch + + +if is_torchvision_v2_available(): + from torchvision.transforms.v2 import functional as F +elif is_torchvision_available(): + from torchvision.transforms import functional as F + + +logger = logging.get_logger(__name__) + +SUPPORTED_ANNOTATION_FORMATS = (AnnotationFormat.COCO_DETECTION,) + + +def prepare_coco_detection_annotation( + image, + target, + return_segmentation_masks: bool = False, + input_data_format: Optional[Union[ChannelDimension, str]] = None, +): + """ + Convert the target in COCO format into the format expected by RT-DETR. + """ + image_height, image_width = image.size()[-2:] + + image_id = target["image_id"] + image_id = torch.as_tensor([image_id], dtype=torch.int64, device=image.device) + + # Get all COCO annotations for the given image. + annotations = target["annotations"] + classes = [] + area = [] + boxes = [] + keypoints = [] + for obj in annotations: + if "iscrowd" not in obj or obj["iscrowd"] == 0: + classes.append(obj["category_id"]) + area.append(obj["area"]) + boxes.append(obj["bbox"]) + if "keypoints" in obj: + keypoints.append(obj["keypoints"]) + + classes = torch.as_tensor(classes, dtype=torch.int64, device=image.device) + area = torch.as_tensor(area, dtype=torch.float32, device=image.device) + iscrowd = torch.zeros_like(classes, dtype=torch.int64, device=image.device) + # guard against no boxes via resizing + boxes = torch.as_tensor(boxes, dtype=torch.float32, device=image.device).reshape(-1, 4) + boxes[:, 2:] += boxes[:, :2] + boxes[:, 0::2] = boxes[:, 0::2].clip(min=0, max=image_width) + boxes[:, 1::2] = boxes[:, 1::2].clip(min=0, max=image_height) + + keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) + + new_target = { + "image_id": image_id, + "class_labels": classes[keep], + "boxes": boxes[keep], + "area": area[keep], + "iscrowd": iscrowd[keep], + "orig_size": torch.as_tensor([int(image_height), int(image_width)], dtype=torch.int64, device=image.device), + } + + if keypoints: + keypoints = torch.as_tensor(keypoints, dtype=torch.float32, device=image.device) + # Apply the keep mask here to filter the relevant annotations + keypoints = keypoints[keep] + num_keypoints = keypoints.shape[0] + keypoints = keypoints.reshape((-1, 3)) if num_keypoints else keypoints + new_target["keypoints"] = keypoints + + return new_target + + +class RTDetrFastImageProcessorKwargs(DetrFastImageProcessorKwargs): + pass + + +class RTDetrImageProcessorFast(DetrImageProcessorFast, BaseImageProcessorFast): + resample = PILImageResampling.BILINEAR + image_mean = IMAGENET_DEFAULT_MEAN + image_std = IMAGENET_DEFAULT_STD + format = AnnotationFormat.COCO_DETECTION + do_convert_annotations = True + do_resize = True + do_rescale = True + do_normalize = False + do_pad = False + size = {"height": 640, "width": 640} + default_to_square = False + model_input_names = ["pixel_values", "pixel_mask"] + valid_kwargs = RTDetrFastImageProcessorKwargs + + def __init__(self, **kwargs: Unpack[RTDetrFastImageProcessorKwargs]) -> None: + # Backwards compatibility + do_convert_annotations = kwargs.get("do_convert_annotations", None) + do_normalize = kwargs.get("do_normalize", None) + if do_convert_annotations is None and getattr(self, "do_convert_annotations", None) is None: + self.do_convert_annotations = do_normalize if do_normalize is not None else self.do_normalize + + BaseImageProcessorFast.__init__(**kwargs) + + @add_start_docstrings( + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS, + """ + annotations (`AnnotationType` or `List[AnnotationType]`, *optional*): + List of annotations associated with the image or batch of images. If annotation is for object + detection, the annotations should be a dictionary with the following keys: + - "image_id" (`int`): The image id. + - "annotations" (`List[Dict]`): List of annotations for an image. Each annotation should be a + dictionary. An image can have no annotations, in which case the list should be empty. + If annotation is for segmentation, the annotations should be a dictionary with the following keys: + - "image_id" (`int`): The image id. + - "segments_info" (`List[Dict]`): List of segments for an image. Each segment should be a dictionary. + An image can have no segments, in which case the list should be empty. + - "file_name" (`str`): The file name of the image. + format (`str`, *optional*, defaults to `AnnotationFormat.COCO_DETECTION`): + Data format of the annotations. One of "coco_detection" or "coco_panoptic". + do_convert_annotations (`bool`, *optional*, defaults to `True`): + Controls whether to convert the annotations to the format expected by the DETR model. Converts the + bounding boxes to the format `(center_x, center_y, width, height)` and in the range `[0, 1]`. + Can be overridden by the `do_convert_annotations` parameter in the `preprocess` method. + do_pad (`bool`, *optional*, defaults to `True`): + Controls whether to pad the image. Can be overridden by the `do_pad` parameter in the `preprocess` + method. If `True`, padding will be applied to the bottom and right of the image with zeros. + If `pad_size` is provided, the image will be padded to the specified dimensions. + Otherwise, the image will be padded to the maximum height and width of the batch. + pad_size (`Dict[str, int]`, *optional*): + The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size + provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest + height and width in the batch. + return_segmentation_masks (`bool`, *optional*, defaults to `False`): + Whether to return segmentation masks. + masks_path (`str` or `pathlib.Path`, *optional*): + Path to the directory containing the segmentation masks. + """, + ) + def preprocess( + self, + images: ImageInput, + annotations: Optional[Union[AnnotationType, List[AnnotationType]]] = None, + masks_path: Optional[Union[str, pathlib.Path]] = None, + **kwargs: Unpack[RTDetrFastImageProcessorKwargs], + ) -> BatchFeature: + return BaseImageProcessorFast().preprocess(images, annotations=annotations, masks_path=masks_path, **kwargs) + + def prepare_annotation( + self, + image: torch.Tensor, + target: Dict, + format: Optional[AnnotationFormat] = None, + return_segmentation_masks: Optional[bool] = None, + masks_path: Optional[Union[str, pathlib.Path]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> Dict: + format = format if format is not None else self.format + + if format == AnnotationFormat.COCO_DETECTION: + return_segmentation_masks = False if return_segmentation_masks is None else return_segmentation_masks + target = prepare_coco_detection_annotation( + image, target, return_segmentation_masks, input_data_format=input_data_format + ) + else: + raise ValueError(f"Format {format} is not supported.") + return target + + def _preprocess( + self, + images: List["torch.Tensor"], + annotations: Optional[Union[AnnotationType, List[AnnotationType]]], + return_segmentation_masks: bool, + masks_path: Optional[Union[str, pathlib.Path]], + do_resize: bool, + size: SizeDict, + interpolation: Optional["F.InterpolationMode"], + do_center_crop: bool, + crop_size: SizeDict, + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + do_convert_annotations: bool, + image_mean: Optional[Union[float, List[float]]], + image_std: Optional[Union[float, List[float]]], + do_pad: bool, + pad_size: Optional[Dict[str, int]], + format: Optional[Union[str, AnnotationFormat]], + return_tensors: Optional[Union[str, TensorType]], + ) -> BatchFeature: + """ + Preprocess an image or a batch of images so that it can be used by the model. + """ + + if annotations is not None and isinstance(annotations, dict): + annotations = [annotations] + + if annotations is not None and len(images) != len(annotations): + raise ValueError( + f"The number of images ({len(images)}) and annotations ({len(annotations)}) do not match." + ) + + format = AnnotationFormat(format) + if annotations is not None: + validate_annotations(format, SUPPORTED_ANNOTATION_FORMATS, annotations) + + data = {} + processed_images = [] + processed_annotations = [] + pixel_masks = [] # Initialize pixel_masks here + for image, annotation in zip(images, annotations if annotations is not None else [None] * len(images)): + # prepare (COCO annotations as a list of Dict -> DETR target as a single Dict per image) + if annotations is not None: + annotation = self.prepare_annotation( + image, + annotation, + format, + return_segmentation_masks=return_segmentation_masks, + masks_path=masks_path, + input_data_format=ChannelDimension.FIRST, + ) + + if do_resize: + resized_image = self.resize(image, size=size, interpolation=interpolation) + if annotations is not None: + annotation = self.resize_annotation( + annotation, + orig_size=image.size()[-2:], + target_size=resized_image.size()[-2:], + ) + image = resized_image + # Fused rescale and normalize + image = self.rescale_and_normalize(image, do_rescale, rescale_factor, do_normalize, image_mean, image_std) + if do_convert_annotations and annotations is not None: + annotation = self.normalize_annotation(annotation, get_image_size(image, ChannelDimension.FIRST)) + + processed_images.append(image) + processed_annotations.append(annotation) + images = processed_images + annotations = processed_annotations if annotations is not None else None + + if do_pad: + # depends on all resized image shapes so we need another loop + if pad_size is not None: + padded_size = (pad_size["height"], pad_size["width"]) + else: + padded_size = get_max_height_width(images) + + padded_images = [] + padded_annotations = [] + for image, annotation in zip(images, annotations if annotations is not None else [None] * len(images)): + # Pads images and returns their mask: {'pixel_values': ..., 'pixel_mask': ...} + if padded_size == image.size()[-2:]: + padded_images.append(image) + pixel_masks.append(torch.ones(padded_size, dtype=torch.int64, device=image.device)) + padded_annotations.append(annotation) + continue + image, pixel_mask, annotation = self.pad( + image, padded_size, annotation=annotation, update_bboxes=do_convert_annotations + ) + padded_images.append(image) + padded_annotations.append(annotation) + pixel_masks.append(pixel_mask) + images = padded_images + annotations = padded_annotations if annotations is not None else None + data.update({"pixel_mask": torch.stack(pixel_masks, dim=0)}) + + data.update({"pixel_values": torch.stack(images, dim=0)}) + encoded_inputs = BatchFeature(data, tensor_type=return_tensors) + if annotations is not None: + encoded_inputs["labels"] = [ + BatchFeature(annotation, tensor_type=return_tensors) for annotation in annotations + ] + return encoded_inputs + + def post_process_object_detection( + self, + outputs, + threshold: float = 0.5, + target_sizes: Union[TensorType, List[Tuple]] = None, + use_focal_loss: bool = True, + ): + """ + Converts the raw output of [`DetrForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y, + bottom_right_x, bottom_right_y) format. Only supports PyTorch. + + Args: + outputs ([`DetrObjectDetectionOutput`]): + Raw outputs of the model. + threshold (`float`, *optional*, defaults to 0.5): + Score threshold to keep object detection predictions. + target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*): + Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size + `(height, width)` of each image in the batch. If unset, predictions will not be resized. + use_focal_loss (`bool` defaults to `True`): + Variable informing if the focal loss was used to predict the outputs. If `True`, a sigmoid is applied + to compute the scores of each detection, otherwise, a softmax function is used. + + Returns: + `List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image + in the batch as predicted by the model. + """ + requires_backends(self, ["torch"]) + out_logits, out_bbox = outputs.logits, outputs.pred_boxes + # convert from relative cxcywh to absolute xyxy + boxes = center_to_corners_format(out_bbox) + if target_sizes is not None: + if len(out_logits) != len(target_sizes): + raise ValueError( + "Make sure that you pass in as many target sizes as the batch dimension of the logits" + ) + if isinstance(target_sizes, List): + img_h, img_w = torch.as_tensor(target_sizes).unbind(1) + else: + img_h, img_w = target_sizes.unbind(1) + scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device) + boxes = boxes * scale_fct[:, None, :] + + num_top_queries = out_logits.shape[1] + num_classes = out_logits.shape[2] + + if use_focal_loss: + scores = torch.nn.functional.sigmoid(out_logits) + scores, index = torch.topk(scores.flatten(1), num_top_queries, axis=-1) + labels = index % num_classes + index = index // num_classes + boxes = boxes.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, boxes.shape[-1])) + else: + scores = torch.nn.functional.softmax(out_logits)[:, :, :-1] + scores, labels = scores.max(dim=-1) + if scores.shape[1] > num_top_queries: + scores, index = torch.topk(scores, num_top_queries, dim=-1) + labels = torch.gather(labels, dim=1, index=index) + boxes = torch.gather(boxes, dim=1, index=index.unsqueeze(-1).tile(1, 1, boxes.shape[-1])) + + results = [] + for score, label, box in zip(scores, labels, boxes): + results.append( + { + "scores": score[score > threshold], + "labels": label[score > threshold], + "boxes": box[score > threshold], + } + ) + + return results + + def from_dict(): + raise NotImplementedError("No need to override this method for RT-DETR yet.") + + def post_process(): + raise NotImplementedError("Post-processing is not implemented for RT-DETR yet.") + + def post_process_segmentation(): + raise NotImplementedError("Segmentation post-processing is not implemented for RT-DETR yet.") + + def post_process_instance(): + raise NotImplementedError("Instance post-processing is not implemented for RT-DETR yet.") + + def post_process_panoptic(): + raise NotImplementedError("Panoptic post-processing is not implemented for RT-DETR yet.") + + def post_process_instance_segmentation(): + raise NotImplementedError("Segmentation post-processing is not implemented for RT-DETR yet.") + + def post_process_semantic_segmentation(): + raise NotImplementedError("Semantic segmentation post-processing is not implemented for RT-DETR yet.") + + def post_process_panoptic_segmentation(): + raise NotImplementedError("Panoptic segmentation post-processing is not implemented for RT-DETR yet.") + + +__all__ = ["RTDetrImageProcessorFast"] diff --git a/docs/transformers/src/transformers/models/rt_detr_v2/__init__.py b/docs/transformers/src/transformers/models/rt_detr_v2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..efc3772ca6d015271102638c3948958777498fd9 --- /dev/null +++ b/docs/transformers/src/transformers/models/rt_detr_v2/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_rt_detr_v2 import * + from .modeling_rt_detr_v2 import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/rt_detr_v2/configuration_rt_detr_v2.py b/docs/transformers/src/transformers/models/rt_detr_v2/configuration_rt_detr_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..1a16e226c893857282d25042da7b7d891e97eed2 --- /dev/null +++ b/docs/transformers/src/transformers/models/rt_detr_v2/configuration_rt_detr_v2.py @@ -0,0 +1,379 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/rt_detr_v2/modular_rt_detr_v2.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_rt_detr_v2.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 Baidu Inc and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ...utils.backbone_utils import verify_backbone_config_arguments +from ..auto import CONFIG_MAPPING + + +logger = logging.get_logger(__name__) + + +class RTDetrV2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`RTDetrV2Model`]. It is used to instantiate a + RT-DETR model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the RT-DETR architecture. + + e.g. [PekingU/rtdetr_r18vd](https://huggingface.co/PekingU/rtdetr_r18vd) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + initializer_range (`float`, *optional*, defaults to 0.01): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + initializer_bias_prior_prob (`float`, *optional*): + The prior probability used by the bias initializer to initialize biases for `enc_score_head` and `class_embed`. + If `None`, `prior_prob` computed as `prior_prob = 1 / (num_labels + 1)` while initializing model weights. + layer_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the layer normalization layers. + batch_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the batch normalization layers. + backbone_config (`Dict`, *optional*, defaults to `RTDetrV2ResNetConfig()`): + The configuration of the backbone model. + backbone (`str`, *optional*): + Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this + will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone` + is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights. + use_pretrained_backbone (`bool`, *optional*, defaults to `False`): + Whether to use pretrained weights for the backbone. + use_timm_backbone (`bool`, *optional*, defaults to `False`): + Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers + library. + freeze_backbone_batch_norms (`bool`, *optional*, defaults to `True`): + Whether to freeze the batch normalization layers in the backbone. + backbone_kwargs (`dict`, *optional*): + Keyword arguments to be passed to AutoBackbone when loading from a checkpoint + e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set. + encoder_hidden_dim (`int`, *optional*, defaults to 256): + Dimension of the layers in hybrid encoder. + encoder_in_channels (`list`, *optional*, defaults to `[512, 1024, 2048]`): + Multi level features input for encoder. + feat_strides (`List[int]`, *optional*, defaults to `[8, 16, 32]`): + Strides used in each feature map. + encoder_layers (`int`, *optional*, defaults to 1): + Total of layers to be used by the encoder. + encoder_ffn_dim (`int`, *optional*, defaults to 1024): + Dimension of the "intermediate" (often named feed-forward) layer in decoder. + encoder_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer encoder. + dropout (`float`, *optional*, defaults to 0.0): + The ratio for all dropout layers. + activation_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for activations inside the fully connected layer. + encode_proj_layers (`List[int]`, *optional*, defaults to `[2]`): + Indexes of the projected layers to be used in the encoder. + positional_encoding_temperature (`int`, *optional*, defaults to 10000): + The temperature parameter used to create the positional encodings. + encoder_activation_function (`str`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + activation_function (`str`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the general layer. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + eval_size (`Tuple[int, int]`, *optional*): + Height and width used to compute the effective height and width of the position embeddings after taking + into account the stride. + normalize_before (`bool`, *optional*, defaults to `False`): + Determine whether to apply layer normalization in the transformer encoder layer before self-attention and + feed-forward modules. + hidden_expansion (`float`, *optional*, defaults to 1.0): + Expansion ratio to enlarge the dimension size of RepVGGBlock and CSPRepLayer. + d_model (`int`, *optional*, defaults to 256): + Dimension of the layers exclude hybrid encoder. + num_queries (`int`, *optional*, defaults to 300): + Number of object queries. + decoder_in_channels (`list`, *optional*, defaults to `[256, 256, 256]`): + Multi level features dimension for decoder + decoder_ffn_dim (`int`, *optional*, defaults to 1024): + Dimension of the "intermediate" (often named feed-forward) layer in decoder. + num_feature_levels (`int`, *optional*, defaults to 3): + The number of input feature levels. + decoder_n_points (`int`, *optional*, defaults to 4): + The number of sampled keys in each feature level for each attention head in the decoder. + decoder_layers (`int`, *optional*, defaults to 6): + Number of decoder layers. + decoder_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer decoder. + decoder_activation_function (`str`, *optional*, defaults to `"relu"`): + The non-linear activation function (function or string) in the decoder. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + num_denoising (`int`, *optional*, defaults to 100): + The total number of denoising tasks or queries to be used for contrastive denoising. + label_noise_ratio (`float`, *optional*, defaults to 0.5): + The fraction of denoising labels to which random noise should be added. + box_noise_scale (`float`, *optional*, defaults to 1.0): + Scale or magnitude of noise to be added to the bounding boxes. + learn_initial_query (`bool`, *optional*, defaults to `False`): + Indicates whether the initial query embeddings for the decoder should be learned during training + anchor_image_size (`Tuple[int, int]`, *optional*): + Height and width of the input image used during evaluation to generate the bounding box anchors. If None, automatic generate anchor is applied. + with_box_refine (`bool`, *optional*, defaults to `True`): + Whether to apply iterative bounding box refinement, where each decoder layer refines the bounding boxes + based on the predictions from the previous layer. + is_encoder_decoder (`bool`, *optional*, defaults to `True`): + Whether the architecture has an encoder decoder structure. + matcher_alpha (`float`, *optional*, defaults to 0.25): + Parameter alpha used by the Hungarian Matcher. + matcher_gamma (`float`, *optional*, defaults to 2.0): + Parameter gamma used by the Hungarian Matcher. + matcher_class_cost (`float`, *optional*, defaults to 2.0): + The relative weight of the class loss used by the Hungarian Matcher. + matcher_bbox_cost (`float`, *optional*, defaults to 5.0): + The relative weight of the bounding box loss used by the Hungarian Matcher. + matcher_giou_cost (`float`, *optional*, defaults to 2.0): + The relative weight of the giou loss of used by the Hungarian Matcher. + use_focal_loss (`bool`, *optional*, defaults to `True`): + Parameter informing if focal loss should be used. + auxiliary_loss (`bool`, *optional*, defaults to `True`): + Whether auxiliary decoding losses (loss at each decoder layer) are to be used. + focal_loss_alpha (`float`, *optional*, defaults to 0.75): + Parameter alpha used to compute the focal loss. + focal_loss_gamma (`float`, *optional*, defaults to 2.0): + Parameter gamma used to compute the focal loss. + weight_loss_vfl (`float`, *optional*, defaults to 1.0): + Relative weight of the varifocal loss in the object detection loss. + weight_loss_bbox (`float`, *optional*, defaults to 5.0): + Relative weight of the L1 bounding box loss in the object detection loss. + weight_loss_giou (`float`, *optional*, defaults to 2.0): + Relative weight of the generalized IoU loss in the object detection loss. + eos_coefficient (`float`, *optional*, defaults to 0.0001): + Relative classification weight of the 'no-object' class in the object detection loss. + decoder_n_levels (`int`, *optional*, defaults to 3): + The number of feature levels used by the decoder. + decoder_offset_scale (`float`, *optional*, defaults to 0.5): + Scaling factor applied to the attention offsets in the decoder. + decoder_method (`str`, *optional*, defaults to `"default"`): + The method to use for the decoder: `"default"` or `"discrete"`. + + Examples: + + ```python + >>> from transformers import RTDetrV2Config, RTDetrV2Model + + >>> # Initializing a RT-DETR configuration + >>> configuration = RTDetrV2Config() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = RTDetrV2Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "rt_detr_v2" + layer_types = ["basic", "bottleneck"] + attribute_map = { + "hidden_size": "d_model", + "num_attention_heads": "encoder_attention_heads", + } + + def __init__( + self, + initializer_range=0.01, + initializer_bias_prior_prob=None, + layer_norm_eps=1e-5, + batch_norm_eps=1e-5, + # backbone + backbone_config=None, + backbone=None, + use_pretrained_backbone=False, + use_timm_backbone=False, + freeze_backbone_batch_norms=True, + backbone_kwargs=None, + # encoder HybridEncoder + encoder_hidden_dim=256, + encoder_in_channels=[512, 1024, 2048], + feat_strides=[8, 16, 32], + encoder_layers=1, + encoder_ffn_dim=1024, + encoder_attention_heads=8, + dropout=0.0, + activation_dropout=0.0, + encode_proj_layers=[2], + positional_encoding_temperature=10000, + encoder_activation_function="gelu", + activation_function="silu", + eval_size=None, + normalize_before=False, + hidden_expansion=1.0, + # decoder RTDetrV2Transformer + d_model=256, + num_queries=300, + decoder_in_channels=[256, 256, 256], + decoder_ffn_dim=1024, + num_feature_levels=3, + decoder_n_points=4, + decoder_layers=6, + decoder_attention_heads=8, + decoder_activation_function="relu", + attention_dropout=0.0, + num_denoising=100, + label_noise_ratio=0.5, + box_noise_scale=1.0, + learn_initial_query=False, + anchor_image_size=None, + with_box_refine=True, + is_encoder_decoder=True, + # Loss + matcher_alpha=0.25, + matcher_gamma=2.0, + matcher_class_cost=2.0, + matcher_bbox_cost=5.0, + matcher_giou_cost=2.0, + use_focal_loss=True, + auxiliary_loss=True, + focal_loss_alpha=0.75, + focal_loss_gamma=2.0, + weight_loss_vfl=1.0, + weight_loss_bbox=5.0, + weight_loss_giou=2.0, + eos_coefficient=1e-4, + decoder_n_levels=3, # default value + decoder_offset_scale=0.5, # default value + decoder_method="default", + **kwargs, + ): + super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs) + self.initializer_range = initializer_range + self.initializer_bias_prior_prob = initializer_bias_prior_prob + self.layer_norm_eps = layer_norm_eps + self.batch_norm_eps = batch_norm_eps + # backbone + if backbone_config is None and backbone is None: + logger.info( + "`backbone_config` and `backbone` are `None`. Initializing the config with the default `RTDetrV2-ResNet` backbone." + ) + backbone_model_type = "rt_detr_resnet" + config_class = CONFIG_MAPPING[backbone_model_type] + # this will map it to RTDetrResNetConfig + # note: we can instead create RTDetrV2ResNetConfig but it will be exactly the same as V1 + # and we would need to create RTDetrV2ResNetModel + backbone_config = config_class( + num_channels=3, + embedding_size=64, + hidden_sizes=[256, 512, 1024, 2048], + depths=[3, 4, 6, 3], + layer_type="bottleneck", + hidden_act="relu", + downsample_in_first_stage=False, + downsample_in_bottleneck=False, + out_features=None, + out_indices=[2, 3, 4], + ) + elif isinstance(backbone_config, dict): + backbone_model_type = backbone_config.pop("model_type") + config_class = CONFIG_MAPPING[backbone_model_type] + backbone_config = config_class.from_dict(backbone_config) + + verify_backbone_config_arguments( + use_timm_backbone=use_timm_backbone, + use_pretrained_backbone=use_pretrained_backbone, + backbone=backbone, + backbone_config=backbone_config, + backbone_kwargs=backbone_kwargs, + ) + + self.backbone_config = backbone_config + self.backbone = backbone + self.use_pretrained_backbone = use_pretrained_backbone + self.use_timm_backbone = use_timm_backbone + self.freeze_backbone_batch_norms = freeze_backbone_batch_norms + self.backbone_kwargs = backbone_kwargs + # encoder + self.encoder_hidden_dim = encoder_hidden_dim + self.encoder_in_channels = encoder_in_channels + self.feat_strides = feat_strides + self.encoder_ffn_dim = encoder_ffn_dim + self.dropout = dropout + self.activation_dropout = activation_dropout + self.encode_proj_layers = encode_proj_layers + self.encoder_layers = encoder_layers + self.positional_encoding_temperature = positional_encoding_temperature + self.eval_size = eval_size + self.normalize_before = normalize_before + self.encoder_activation_function = encoder_activation_function + self.activation_function = activation_function + self.hidden_expansion = hidden_expansion + self.num_queries = num_queries + self.decoder_ffn_dim = decoder_ffn_dim + self.decoder_in_channels = decoder_in_channels + self.num_feature_levels = num_feature_levels + self.decoder_n_points = decoder_n_points + self.decoder_layers = decoder_layers + self.decoder_attention_heads = decoder_attention_heads + self.decoder_activation_function = decoder_activation_function + self.attention_dropout = attention_dropout + self.num_denoising = num_denoising + self.label_noise_ratio = label_noise_ratio + self.box_noise_scale = box_noise_scale + self.learn_initial_query = learn_initial_query + self.anchor_image_size = anchor_image_size + self.auxiliary_loss = auxiliary_loss + self.with_box_refine = with_box_refine + # Loss + self.matcher_alpha = matcher_alpha + self.matcher_gamma = matcher_gamma + self.matcher_class_cost = matcher_class_cost + self.matcher_bbox_cost = matcher_bbox_cost + self.matcher_giou_cost = matcher_giou_cost + self.use_focal_loss = use_focal_loss + self.focal_loss_alpha = focal_loss_alpha + self.focal_loss_gamma = focal_loss_gamma + self.weight_loss_vfl = weight_loss_vfl + self.weight_loss_bbox = weight_loss_bbox + self.weight_loss_giou = weight_loss_giou + self.eos_coefficient = eos_coefficient + + if not hasattr(self, "d_model"): + self.d_model = d_model + + if not hasattr(self, "encoder_attention_heads"): + self.encoder_attention_heads = encoder_attention_heads + # add the new attributes with the given values or defaults + self.decoder_n_levels = decoder_n_levels + self.decoder_offset_scale = decoder_offset_scale + self.decoder_method = decoder_method + + @classmethod + def from_backbone_configs(cls, backbone_config: PretrainedConfig, **kwargs): + """Instantiate a [`RTDetrV2Config`] (or a derived class) from a pre-trained backbone model configuration and DETR model + configuration. + + Args: + backbone_config ([`PretrainedConfig`]): + The backbone configuration. + + Returns: + [`RTDetrV2Config`]: An instance of a configuration object + """ + return cls( + backbone_config=backbone_config, + **kwargs, + ) + + +__all__ = ["RTDetrV2Config"] diff --git a/docs/transformers/src/transformers/models/rt_detr_v2/convert_rt_detr_v2_weights_to_hf.py b/docs/transformers/src/transformers/models/rt_detr_v2/convert_rt_detr_v2_weights_to_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..51372b74e4261c646be19198018837b9eaf8c742 --- /dev/null +++ b/docs/transformers/src/transformers/models/rt_detr_v2/convert_rt_detr_v2_weights_to_hf.py @@ -0,0 +1,363 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert RT Detr V2 checkpoints with Timm backbone""" + +import argparse +import json +import re +from pathlib import Path + +import requests +import torch +from huggingface_hub import hf_hub_download +from PIL import Image +from torchvision import transforms + +from transformers import RTDetrImageProcessor, RTDetrV2Config, RTDetrV2ForObjectDetection +from transformers.utils import logging + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + + +def get_rt_detr_v2_config(model_name: str) -> RTDetrV2Config: + config = RTDetrV2Config() + + config.num_labels = 80 + repo_id = "huggingface/label-files" + filename = "coco-detection-mmdet-id2label.json" + id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r")) + id2label = {int(k): v for k, v in id2label.items()} + config.id2label = id2label + config.label2id = {v: k for k, v in id2label.items()} + + if model_name == "rtdetr_v2_r18vd": + config.backbone_config.hidden_sizes = [64, 128, 256, 512] + config.backbone_config.depths = [2, 2, 2, 2] + config.backbone_config.layer_type = "basic" + config.encoder_in_channels = [128, 256, 512] + config.hidden_expansion = 0.5 + config.decoder_layers = 3 + elif model_name == "rtdetr_v2_r34vd": + config.backbone_config.hidden_sizes = [64, 128, 256, 512] + config.backbone_config.depths = [3, 4, 6, 3] + config.backbone_config.layer_type = "basic" + config.encoder_in_channels = [128, 256, 512] + config.hidden_expansion = 0.5 + config.decoder_layers = 4 + # TODO: check this not working + elif model_name == "rtdetr_v2_r50vd_m": + config.hidden_expansion = 0.5 + elif model_name == "rtdetr_v2_r50vd": + pass + elif model_name == "rtdetr_v2_r101vd": + config.backbone_config.depths = [3, 4, 23, 3] + config.encoder_ffn_dim = 2048 + config.encoder_hidden_dim = 384 + config.decoder_in_channels = [384, 384, 384] + + return config + + +# Define a mapping from original keys to converted keys using regex +ORIGINAL_TO_CONVERTED_KEY_MAPPING = { + r"backbone.conv1.conv1_1.conv.weight": r"model.backbone.model.embedder.embedder.0.convolution.weight", + r"backbone.conv1.conv1_1.norm.(weight|bias|running_mean|running_var)": r"model.backbone.model.embedder.embedder.0.normalization.\1", + r"backbone.conv1.conv1_2.conv.weight": r"model.backbone.model.embedder.embedder.1.convolution.weight", + r"backbone.conv1.conv1_2.norm.(weight|bias|running_mean|running_var)": r"model.backbone.model.embedder.embedder.1.normalization.\1", + r"backbone.conv1.conv1_3.conv.weight": r"model.backbone.model.embedder.embedder.2.convolution.weight", + r"backbone.conv1.conv1_3.norm.(weight|bias|running_mean|running_var)": r"model.backbone.model.embedder.embedder.2.normalization.\1", + r"backbone.res_layers.(\d+).blocks.(\d+).branch2a.conv.weight": r"model.backbone.model.encoder.stages.\1.layers.\2.layer.0.convolution.weight", + r"backbone.res_layers.(\d+).blocks.(\d+).branch2a.norm.(weight|bias|running_mean|running_var)": r"model.backbone.model.encoder.stages.\1.layers.\2.layer.0.normalization.\3", + r"backbone.res_layers.(\d+).blocks.(\d+).branch2b.conv.weight": r"model.backbone.model.encoder.stages.\1.layers.\2.layer.1.convolution.weight", + r"backbone.res_layers.(\d+).blocks.(\d+).branch2b.norm.(weight|bias|running_mean|running_var)": r"model.backbone.model.encoder.stages.\1.layers.\2.layer.1.normalization.\3", + r"backbone.res_layers.(\d+).blocks.(\d+).branch2c.conv.weight": r"model.backbone.model.encoder.stages.\1.layers.\2.layer.2.convolution.weight", + r"backbone.res_layers.(\d+).blocks.(\d+).branch2c.norm.(weight|bias|running_mean|running_var)": r"model.backbone.model.encoder.stages.\1.layers.\2.layer.2.normalization.\3", + r"encoder.encoder.(\d+).layers.0.self_attn.out_proj.weight": r"model.encoder.encoder.\1.layers.0.self_attn.out_proj.weight", + r"encoder.encoder.(\d+).layers.0.self_attn.out_proj.bias": r"model.encoder.encoder.\1.layers.0.self_attn.out_proj.bias", + r"encoder.encoder.(\d+).layers.0.linear1.weight": r"model.encoder.encoder.\1.layers.0.fc1.weight", + r"encoder.encoder.(\d+).layers.0.linear1.bias": r"model.encoder.encoder.\1.layers.0.fc1.bias", + r"encoder.encoder.(\d+).layers.0.linear2.weight": r"model.encoder.encoder.\1.layers.0.fc2.weight", + r"encoder.encoder.(\d+).layers.0.linear2.bias": r"model.encoder.encoder.\1.layers.0.fc2.bias", + r"encoder.encoder.(\d+).layers.0.norm1.weight": r"model.encoder.encoder.\1.layers.0.self_attn_layer_norm.weight", + r"encoder.encoder.(\d+).layers.0.norm1.bias": r"model.encoder.encoder.\1.layers.0.self_attn_layer_norm.bias", + r"encoder.encoder.(\d+).layers.0.norm2.weight": r"model.encoder.encoder.\1.layers.0.final_layer_norm.weight", + r"encoder.encoder.(\d+).layers.0.norm2.bias": r"model.encoder.encoder.\1.layers.0.final_layer_norm.bias", + r"encoder.input_proj.(\d+).conv.weight": r"model.encoder_input_proj.\1.0.weight", + r"encoder.input_proj.(\d+).norm.(.*)": r"model.encoder_input_proj.\1.1.\2", + r"encoder.fpn_blocks.(\d+).conv(\d+).conv.weight": r"model.encoder.fpn_blocks.\1.conv\2.conv.weight", + # r"encoder.fpn_blocks.(\d+).conv(\d+).norm.(.*)": r"model.encoder.fpn_blocks.\1.conv\2.norm.\3", + r"encoder.fpn_blocks.(\d+).conv(\d+).norm.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_blocks.\1.conv\2.norm.\3", + r"encoder.lateral_convs.(\d+).conv.weight": r"model.encoder.lateral_convs.\1.conv.weight", + r"encoder.lateral_convs.(\d+).norm.(.*)": r"model.encoder.lateral_convs.\1.norm.\2", + r"encoder.fpn_blocks.(\d+).bottlenecks.(\d+).conv(\d+).conv.weight": r"model.encoder.fpn_blocks.\1.bottlenecks.\2.conv\3.conv.weight", + r"encoder.fpn_blocks.(\d+).bottlenecks.(\d+).conv(\d+).norm.(\w+)": r"model.encoder.fpn_blocks.\1.bottlenecks.\2.conv\3.norm.\4", + r"encoder.pan_blocks.(\d+).conv(\d+).conv.weight": r"model.encoder.pan_blocks.\1.conv\2.conv.weight", + r"encoder.pan_blocks.(\d+).conv(\d+).norm.(weight|bias|running_mean|running_var)": r"model.encoder.pan_blocks.\1.conv\2.norm.\3", + r"encoder.pan_blocks.(\d+).bottlenecks.(\d+).conv(\d+).conv.weight": r"model.encoder.pan_blocks.\1.bottlenecks.\2.conv\3.conv.weight", + r"encoder.pan_blocks.(\d+).bottlenecks.(\d+).conv(\d+).norm.(weight|bias|running_mean|running_var)": r"model.encoder.pan_blocks.\1.bottlenecks.\2.conv\3.norm.\4", + r"encoder.downsample_convs.(\d+).conv.weight": r"model.encoder.downsample_convs.\1.conv.weight", + r"encoder.downsample_convs.(\d+).norm.(weight|bias|running_mean|running_var)": r"model.encoder.downsample_convs.\1.norm.\2", + r"decoder.decoder.layers.(\d+).self_attn.out_proj.weight": r"model.decoder.layers.\1.self_attn.out_proj.weight", + r"decoder.decoder.layers.(\d+).self_attn.out_proj.bias": r"model.decoder.layers.\1.self_attn.out_proj.bias", + r"decoder.decoder.layers.(\d+).cross_attn.sampling_offsets.weight": r"model.decoder.layers.\1.encoder_attn.sampling_offsets.weight", + r"decoder.decoder.layers.(\d+).cross_attn.sampling_offsets.bias": r"model.decoder.layers.\1.encoder_attn.sampling_offsets.bias", + r"decoder.decoder.layers.(\d+).cross_attn.attention_weights.weight": r"model.decoder.layers.\1.encoder_attn.attention_weights.weight", + r"decoder.decoder.layers.(\d+).cross_attn.attention_weights.bias": r"model.decoder.layers.\1.encoder_attn.attention_weights.bias", + r"decoder.decoder.layers.(\d+).cross_attn.value_proj.weight": r"model.decoder.layers.\1.encoder_attn.value_proj.weight", + r"decoder.decoder.layers.(\d+).cross_attn.value_proj.bias": r"model.decoder.layers.\1.encoder_attn.value_proj.bias", + r"decoder.decoder.layers.(\d+).cross_attn.output_proj.weight": r"model.decoder.layers.\1.encoder_attn.output_proj.weight", + r"decoder.decoder.layers.(\d+).cross_attn.output_proj.bias": r"model.decoder.layers.\1.encoder_attn.output_proj.bias", + r"decoder.decoder.layers.(\d+).norm1.weight": r"model.decoder.layers.\1.self_attn_layer_norm.weight", + r"decoder.decoder.layers.(\d+).norm1.bias": r"model.decoder.layers.\1.self_attn_layer_norm.bias", + r"decoder.decoder.layers.(\d+).norm2.weight": r"model.decoder.layers.\1.encoder_attn_layer_norm.weight", + r"decoder.decoder.layers.(\d+).norm2.bias": r"model.decoder.layers.\1.encoder_attn_layer_norm.bias", + r"decoder.decoder.layers.(\d+).linear1.weight": r"model.decoder.layers.\1.fc1.weight", + r"decoder.decoder.layers.(\d+).linear1.bias": r"model.decoder.layers.\1.fc1.bias", + r"decoder.decoder.layers.(\d+).linear2.weight": r"model.decoder.layers.\1.fc2.weight", + r"decoder.decoder.layers.(\d+).linear2.bias": r"model.decoder.layers.\1.fc2.bias", + r"decoder.decoder.layers.(\d+).norm3.weight": r"model.decoder.layers.\1.final_layer_norm.weight", + r"decoder.decoder.layers.(\d+).norm3.bias": r"model.decoder.layers.\1.final_layer_norm.bias", + r"decoder.decoder.layers.(\d+).cross_attn.num_points_scale": r"model.decoder.layers.\1.encoder_attn.n_points_scale", + r"decoder.dec_score_head.(\d+).weight": r"model.decoder.class_embed.\1.weight", + r"decoder.dec_score_head.(\d+).bias": r"model.decoder.class_embed.\1.bias", + r"decoder.dec_bbox_head.(\d+).layers.(\d+).(weight|bias)": r"model.decoder.bbox_embed.\1.layers.\2.\3", + r"decoder.denoising_class_embed.weight": r"model.denoising_class_embed.weight", + r"decoder.query_pos_head.layers.0.weight": r"model.decoder.query_pos_head.layers.0.weight", + r"decoder.query_pos_head.layers.0.bias": r"model.decoder.query_pos_head.layers.0.bias", + r"decoder.query_pos_head.layers.1.weight": r"model.decoder.query_pos_head.layers.1.weight", + r"decoder.query_pos_head.layers.1.bias": r"model.decoder.query_pos_head.layers.1.bias", + r"decoder.enc_output.proj.weight": r"model.enc_output.0.weight", + r"decoder.enc_output.proj.bias": r"model.enc_output.0.bias", + r"decoder.enc_output.norm.weight": r"model.enc_output.1.weight", + r"decoder.enc_output.norm.bias": r"model.enc_output.1.bias", + r"decoder.enc_score_head.weight": r"model.enc_score_head.weight", + r"decoder.enc_score_head.bias": r"model.enc_score_head.bias", + r"decoder.enc_bbox_head.layers.(\d+).(weight|bias)": r"model.enc_bbox_head.layers.\1.\2", + r"backbone.res_layers.0.blocks.0.short.conv.weight": r"model.backbone.model.encoder.stages.0.layers.0.shortcut.convolution.weight", + r"backbone.res_layers.0.blocks.0.short.norm.(weight|bias|running_mean|running_var)": r"model.backbone.model.encoder.stages.0.layers.0.shortcut.normalization.\1", + r"backbone.res_layers.(\d+).blocks.0.short.conv.conv.weight": r"model.backbone.model.encoder.stages.\1.layers.0.shortcut.1.convolution.weight", + r"backbone.res_layers.(\d+).blocks.0.short.conv.norm.(\w+)": r"model.backbone.model.encoder.stages.\1.layers.0.shortcut.1.normalization.\2", + # Mapping for subsequent blocks in other stages + r"backbone.res_layers.(\d+).blocks.0.short.conv.weight": r"model.backbone.model.encoder.stages.\1.layers.0.shortcut.1.convolution.weight", + r"backbone.res_layers.(\d+).blocks.0.short.norm.(weight|bias|running_mean|running_var)": r"model.backbone.model.encoder.stages.\1.layers.0.shortcut.1.normalization.\2", + r"decoder.input_proj.(\d+).conv.weight": r"model.decoder_input_proj.\1.0.weight", + r"decoder.input_proj.(\d+).norm.(.*)": r"model.decoder_input_proj.\1.1.\2", +} + + +def convert_old_keys_to_new_keys(state_dict_keys: dict = None): + # Use the mapping to rename keys + for original_key, converted_key in ORIGINAL_TO_CONVERTED_KEY_MAPPING.items(): + for key in list(state_dict_keys.keys()): + new_key = re.sub(original_key, converted_key, key) + if new_key != key: + state_dict_keys[new_key] = state_dict_keys.pop(key) + + return state_dict_keys + + +def read_in_q_k_v(state_dict, config): + prefix = "" + encoder_hidden_dim = config.encoder_hidden_dim + + # first: transformer encoder + for i in range(config.encoder_layers): + # read in weights + bias of input projection layer (in PyTorch's MultiHeadAttention, this is a single matrix + bias) + in_proj_weight = state_dict.pop(f"{prefix}encoder.encoder.{i}.layers.0.self_attn.in_proj_weight") + in_proj_bias = state_dict.pop(f"{prefix}encoder.encoder.{i}.layers.0.self_attn.in_proj_bias") + # next, add query, keys and values (in that order) to the state dict + state_dict[f"model.encoder.encoder.{i}.layers.0.self_attn.q_proj.weight"] = in_proj_weight[ + :encoder_hidden_dim, : + ] + state_dict[f"model.encoder.encoder.{i}.layers.0.self_attn.q_proj.bias"] = in_proj_bias[:encoder_hidden_dim] + state_dict[f"model.encoder.encoder.{i}.layers.0.self_attn.k_proj.weight"] = in_proj_weight[ + encoder_hidden_dim : 2 * encoder_hidden_dim, : + ] + state_dict[f"model.encoder.encoder.{i}.layers.0.self_attn.k_proj.bias"] = in_proj_bias[ + encoder_hidden_dim : 2 * encoder_hidden_dim + ] + state_dict[f"model.encoder.encoder.{i}.layers.0.self_attn.v_proj.weight"] = in_proj_weight[ + -encoder_hidden_dim:, : + ] + state_dict[f"model.encoder.encoder.{i}.layers.0.self_attn.v_proj.bias"] = in_proj_bias[-encoder_hidden_dim:] + # next: transformer decoder (which is a bit more complex because it also includes cross-attention) + for i in range(config.decoder_layers): + # read in weights + bias of input projection layer of self-attention + in_proj_weight = state_dict.pop(f"{prefix}decoder.decoder.layers.{i}.self_attn.in_proj_weight") + in_proj_bias = state_dict.pop(f"{prefix}decoder.decoder.layers.{i}.self_attn.in_proj_bias") + # next, add query, keys and values (in that order) to the state dict + state_dict[f"model.decoder.layers.{i}.self_attn.q_proj.weight"] = in_proj_weight[:256, :] + state_dict[f"model.decoder.layers.{i}.self_attn.q_proj.bias"] = in_proj_bias[:256] + state_dict[f"model.decoder.layers.{i}.self_attn.k_proj.weight"] = in_proj_weight[256:512, :] + state_dict[f"model.decoder.layers.{i}.self_attn.k_proj.bias"] = in_proj_bias[256:512] + state_dict[f"model.decoder.layers.{i}.self_attn.v_proj.weight"] = in_proj_weight[-256:, :] + state_dict[f"model.decoder.layers.{i}.self_attn.v_proj.bias"] = in_proj_bias[-256:] + + +# We will verify our results on an image of cute cats +def prepare_img(): + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + im = Image.open(requests.get(url, stream=True).raw) + + return im + + +@torch.no_grad() +def write_model_and_image_processor(model_name, output_dir, push_to_hub, repo_id): + """ + Copy/paste/tweak model's weights to our RTDETR structure. + """ + + # load default config + config = get_rt_detr_v2_config(model_name) + + # load original model from torch hub + model_name_to_checkpoint_url = { + "rtdetr_v2_r18vd": "https://github.com/lyuwenyu/storage/releases/download/v0.2/rtdetrv2_r18vd_120e_coco_rerun_48.1.pth", + "rtdetr_v2_r34vd": "https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetrv2_r34vd_120e_coco_ema.pth", + "rtdetr_v2_r50vd": "https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetrv2_r50vd_6x_coco_ema.pth", + "rtdetr_v2_r101vd": "https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetrv2_r101vd_6x_coco_from_paddle.pth", + } + logger.info(f"Converting model {model_name}...") + state_dict = torch.hub.load_state_dict_from_url(model_name_to_checkpoint_url[model_name], map_location="cpu")[ + "ema" + ]["module"] + # rename keys + state_dict = convert_old_keys_to_new_keys(state_dict) + for key in state_dict.copy().keys(): + if key.endswith("num_batches_tracked"): + del state_dict[key] + # query, key and value matrices need special treatment + read_in_q_k_v(state_dict, config) + # important: we need to prepend a prefix to each of the base model keys as the head models use different attributes for them + for key in state_dict.copy().keys(): + if key.endswith("num_batches_tracked"): + del state_dict[key] + # for two_stage + if "bbox_embed" in key or ("class_embed" in key and "denoising_" not in key): + state_dict[key.split("model.decoder.")[-1]] = state_dict[key] + + # no need in ckpt + del state_dict["decoder.anchors"] + del state_dict["decoder.valid_mask"] + # finally, create HuggingFace model and load state dict + model = RTDetrV2ForObjectDetection(config) + model.load_state_dict(state_dict) + model.eval() + + # load image processor + image_processor = RTDetrImageProcessor() + + # prepare image + img = prepare_img() + + # preprocess image + transformations = transforms.Compose( + [ + transforms.Resize([640, 640], interpolation=transforms.InterpolationMode.BILINEAR), + transforms.ToTensor(), + ] + ) + original_pixel_values = transformations(img).unsqueeze(0) # insert batch dimension + + encoding = image_processor(images=img, return_tensors="pt") + pixel_values = encoding["pixel_values"] + + assert torch.allclose(original_pixel_values, pixel_values) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model.to(device) + pixel_values = pixel_values.to(device) + + # Pass image by the model + with torch.no_grad(): + outputs = model(pixel_values) + + if model_name == "rtdetr_v2_r18vd": + expected_slice_logits = torch.tensor( + [[-3.7045, -5.1913, -6.1787], [-4.0106, -9.3450, -5.2043], [-4.1287, -4.7463, -5.8634]] + ) + expected_slice_boxes = torch.tensor( + [[0.2582, 0.5497, 0.4764], [0.1684, 0.1985, 0.2120], [0.7665, 0.4146, 0.4669]] + ) + elif model_name == "rtdetr_v2_r34vd": + expected_slice_logits = torch.tensor( + [[-4.6108, -5.9453, -3.8505], [-3.8702, -6.1136, -5.5677], [-3.7790, -6.4538, -5.9449]] + ) + expected_slice_boxes = torch.tensor( + [[0.1691, 0.1984, 0.2118], [0.2594, 0.5506, 0.4736], [0.7669, 0.4136, 0.4654]] + ) + elif model_name == "rtdetr_v2_r50vd": + expected_slice_logits = torch.tensor( + [[-4.7881, -4.6754, -6.1624], [-5.4441, -6.6486, -4.3840], [-3.5455, -4.9318, -6.3544]] + ) + expected_slice_boxes = torch.tensor( + [[0.2588, 0.5487, 0.4747], [0.5497, 0.2760, 0.0573], [0.7688, 0.4133, 0.4634]] + ) + elif model_name == "rtdetr_v2_r101vd": + expected_slice_logits = torch.tensor( + [[-4.6162, -4.9189, -4.6656], [-4.4701, -4.4997, -4.9659], [-5.6641, -7.9000, -5.0725]] + ) + expected_slice_boxes = torch.tensor( + [[0.7707, 0.4124, 0.4585], [0.2589, 0.5492, 0.4735], [0.1688, 0.1993, 0.2108]] + ) + else: + raise ValueError(f"Unknown rt_detr_v2_name: {model_name}") + assert torch.allclose(outputs.logits[0, :3, :3], expected_slice_logits.to(outputs.logits.device), atol=1e-4) + assert torch.allclose(outputs.pred_boxes[0, :3, :3], expected_slice_boxes.to(outputs.pred_boxes.device), atol=1e-3) + + if output_dir is not None: + Path(output_dir).mkdir(exist_ok=True) + print(f"Saving model {model_name} to {output_dir}") + model.save_pretrained(output_dir) + print(f"Saving image processor to {output_dir}") + image_processor.save_pretrained(output_dir) + + if push_to_hub: + # Upload model, image processor and config to the hub + logger.info("Uploading PyTorch model and image processor to the hub...") + config.push_to_hub( + repo_id=repo_id, + commit_message="Add config from convert_rt_detr_v2_original_pytorch_checkpoint_to_pytorch.py", + ) + model.push_to_hub( + repo_id=repo_id, + commit_message="Add model from convert_rt_detr_v2_original_pytorch_checkpoint_to_pytorch.py", + ) + image_processor.push_to_hub( + repo_id=repo_id, + commit_message="Add image processor from convert_rt_detr_v2_original_pytorch_checkpoint_to_pytorch.py", + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_name", + default="rtdetr_v2_r18vd", + type=str, + help="model_name of the checkpoint you'd like to convert.", + ) + parser.add_argument("--output_dir", default=None, type=str, help="Location to write HF model and image processor") + parser.add_argument("--push_to_hub", action="store_true", help="Whether to push the model to the hub or not.") + parser.add_argument( + "--repo_id", + type=str, + help="repo_id where the model will be pushed to.", + ) + args = parser.parse_args() + write_model_and_image_processor(args.model_name, args.output_dir, args.push_to_hub, args.repo_id) diff --git a/docs/transformers/src/transformers/models/rt_detr_v2/modeling_rt_detr_v2.py b/docs/transformers/src/transformers/models/rt_detr_v2/modeling_rt_detr_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..9505c3fdd35b8fd6035b6d1db6364441b522fa2c --- /dev/null +++ b/docs/transformers/src/transformers/models/rt_detr_v2/modeling_rt_detr_v2.py @@ -0,0 +1,2076 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/rt_detr_v2/modular_rt_detr_v2.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_rt_detr_v2.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 Baidu Inc and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +import warnings +from dataclasses import dataclass +from functools import partial +from typing import Dict, List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from torch import Tensor, nn + +from ...activations import ACT2CLS, ACT2FN +from ...image_transforms import center_to_corners_format, corners_to_center_format +from ...modeling_outputs import BaseModelOutput +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import compile_compatible_method_lru_cache +from ...utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_torchdynamo_compiling, + replace_return_docstrings, + torch_int, +) +from ...utils.backbone_utils import load_backbone +from .configuration_rt_detr_v2 import RTDetrV2Config + + +_CONFIG_FOR_DOC = "RTDetrV2Config" + + +def multi_scale_deformable_attention_v2( + value: Tensor, + value_spatial_shapes: Tensor, + sampling_locations: Tensor, + attention_weights: Tensor, + num_points_list: List[int], + method="default", +) -> Tensor: + batch_size, _, num_heads, hidden_dim = value.shape + _, num_queries, num_heads, num_levels, num_points = sampling_locations.shape + value_list = ( + value.permute(0, 2, 3, 1) + .flatten(0, 1) + .split([height * width for height, width in value_spatial_shapes], dim=-1) + ) + # sampling_offsets [8, 480, 8, 12, 2] + if method == "default": + sampling_grids = 2 * sampling_locations - 1 + elif method == "discrete": + sampling_grids = sampling_locations + sampling_grids = sampling_grids.permute(0, 2, 1, 3, 4).flatten(0, 1) + sampling_grids = sampling_grids.split(num_points_list, dim=-2) + sampling_value_list = [] + for level_id, (height, width) in enumerate(value_spatial_shapes): + # batch_size, height*width, num_heads, hidden_dim + # -> batch_size, height*width, num_heads*hidden_dim + # -> batch_size, num_heads*hidden_dim, height*width + # -> batch_size*num_heads, hidden_dim, height, width + value_l_ = value_list[level_id].reshape(batch_size * num_heads, hidden_dim, height, width) + # batch_size, num_queries, num_heads, num_points, 2 + # -> batch_size, num_heads, num_queries, num_points, 2 + # -> batch_size*num_heads, num_queries, num_points, 2 + sampling_grid_l_ = sampling_grids[level_id] + # batch_size*num_heads, hidden_dim, num_queries, num_points + if method == "default": + sampling_value_l_ = nn.functional.grid_sample( + value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False + ) + elif method == "discrete": + sampling_coord = (sampling_grid_l_ * torch.tensor([[width, height]], device=value.device) + 0.5).to( + torch.int64 + ) + + # Separate clamping for x and y coordinates + sampling_coord_x = sampling_coord[..., 0].clamp(0, width - 1) + sampling_coord_y = sampling_coord[..., 1].clamp(0, height - 1) + + # Combine the clamped coordinates + sampling_coord = torch.stack([sampling_coord_x, sampling_coord_y], dim=-1) + sampling_coord = sampling_coord.reshape(batch_size * num_heads, num_queries * num_points_list[level_id], 2) + sampling_idx = ( + torch.arange(sampling_coord.shape[0], device=value.device) + .unsqueeze(-1) + .repeat(1, sampling_coord.shape[1]) + ) + sampling_value_l_ = value_l_[sampling_idx, :, sampling_coord[..., 1], sampling_coord[..., 0]] + sampling_value_l_ = sampling_value_l_.permute(0, 2, 1).reshape( + batch_size * num_heads, hidden_dim, num_queries, num_points_list[level_id] + ) + sampling_value_list.append(sampling_value_l_) + # (batch_size, num_queries, num_heads, num_levels, num_points) + # -> (batch_size, num_heads, num_queries, num_levels, num_points) + # -> (batch_size, num_heads, 1, num_queries, num_levels*num_points) + attention_weights = attention_weights.permute(0, 2, 1, 3).reshape( + batch_size * num_heads, 1, num_queries, sum(num_points_list) + ) + output = ( + (torch.concat(sampling_value_list, dim=-1) * attention_weights) + .sum(-1) + .view(batch_size, num_heads * hidden_dim, num_queries) + ) + return output.transpose(1, 2).contiguous() + + +# the main change +class RTDetrV2MultiscaleDeformableAttention(nn.Module): + """ + RTDetrV2 version of multiscale deformable attention, extending the base implementation + with improved offset handling and initialization. + """ + + def __init__(self, config: RTDetrV2Config): + super().__init__() + num_heads = config.decoder_attention_heads + n_points = config.decoder_n_points + + if config.d_model % num_heads != 0: + raise ValueError( + f"embed_dim (d_model) must be divisible by num_heads, but got {config.d_model} and {num_heads}" + ) + dim_per_head = config.d_model // num_heads + # check if dim_per_head is power of 2 + if not ((dim_per_head & (dim_per_head - 1) == 0) and dim_per_head != 0): + warnings.warn( + "You'd better set embed_dim (d_model) in RTDetrV2MultiscaleDeformableAttention to make the" + " dimension of each attention head a power of 2 which is more efficient in the authors' CUDA" + " implementation." + ) + + self.im2col_step = 64 + + self.d_model = config.d_model + + # V2-specific attributes + self.n_levels = config.decoder_n_levels + self.n_heads = num_heads + self.n_points = n_points + + self.sampling_offsets = nn.Linear(config.d_model, num_heads * self.n_levels * n_points * 2) + self.attention_weights = nn.Linear(config.d_model, num_heads * self.n_levels * n_points) + self.value_proj = nn.Linear(config.d_model, config.d_model) + self.output_proj = nn.Linear(config.d_model, config.d_model) + + self.offset_scale = config.decoder_offset_scale + self.method = config.decoder_method + + # Initialize n_points list and scale + n_points_list = [self.n_points for _ in range(self.n_levels)] + self.n_points_list = n_points_list + n_points_scale = [1 / n for n in n_points_list for _ in range(n)] + self.register_buffer("n_points_scale", torch.tensor(n_points_scale, dtype=torch.float32)) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states=None, + encoder_attention_mask=None, + position_embeddings: Optional[torch.Tensor] = None, + reference_points=None, + spatial_shapes=None, + spatial_shapes_list=None, + level_start_index=None, + output_attentions: bool = False, + ): + # Process inputs up to sampling locations calculation using parent class logic + if position_embeddings is not None: + hidden_states = hidden_states + position_embeddings + + batch_size, num_queries, _ = hidden_states.shape + batch_size, sequence_length, _ = encoder_hidden_states.shape + if not is_torchdynamo_compiling() and (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() != sequence_length: + raise ValueError( + "Make sure to align the spatial shapes with the sequence length of the encoder hidden states" + ) + + value = self.value_proj(encoder_hidden_states) + if attention_mask is not None: + value = value.masked_fill(~attention_mask[..., None], float(0)) + value = value.view(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads) + + # V2-specific sampling offsets shape + sampling_offsets = self.sampling_offsets(hidden_states).view( + batch_size, num_queries, self.n_heads, self.n_levels * self.n_points, 2 + ) + + attention_weights = self.attention_weights(hidden_states).view( + batch_size, num_queries, self.n_heads, self.n_levels * self.n_points + ) + attention_weights = F.softmax(attention_weights, -1) + + # V2-specific sampling locations calculation + if reference_points.shape[-1] == 2: + offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1) + sampling_locations = ( + reference_points[:, :, None, :, None, :] + + sampling_offsets / offset_normalizer[None, None, None, :, None, :] + ) + elif reference_points.shape[-1] == 4: + n_points_scale = self.n_points_scale.to(dtype=hidden_states.dtype).unsqueeze(-1) + offset = sampling_offsets * n_points_scale * reference_points[:, :, None, :, 2:] * self.offset_scale + sampling_locations = reference_points[:, :, None, :, :2] + offset + else: + raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}") + + # V2-specific attention implementation choice + output = multi_scale_deformable_attention_v2( + value, spatial_shapes_list, sampling_locations, attention_weights, self.n_points_list, self.method + ) + + output = self.output_proj(output) + return output, attention_weights + + +class RTDetrV2MultiheadAttention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. + + Here, we add position embeddings to the queries and keys (as explained in the Deformable DETR paper). + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + bias: bool = True, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + if self.head_dim * num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def _reshape(self, tensor: torch.Tensor, seq_len: int, batch_size: int): + return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]): + return tensor if position_embeddings is None else tensor + position_embeddings + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_embeddings: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + batch_size, target_len, embed_dim = hidden_states.size() + # add position embeddings to the hidden states before projecting to queries and keys + if position_embeddings is not None: + hidden_states_original = hidden_states + hidden_states = self.with_pos_embed(hidden_states, position_embeddings) + + # get queries, keys and values + query_states = self.q_proj(hidden_states) * self.scaling + key_states = self._reshape(self.k_proj(hidden_states), -1, batch_size) + value_states = self._reshape(self.v_proj(hidden_states_original), -1, batch_size) + + proj_shape = (batch_size * self.num_heads, -1, self.head_dim) + query_states = self._reshape(query_states, target_len, batch_size).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + source_len = key_states.size(1) + + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (batch_size * self.num_heads, target_len, source_len): + raise ValueError( + f"Attention weights should be of size {(batch_size * self.num_heads, target_len, source_len)}, but is" + f" {attn_weights.size()}" + ) + + # expand attention_mask + if attention_mask is not None: + # [seq_len, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len] + attention_mask = attention_mask.expand(batch_size, 1, *attention_mask.size()) + + if attention_mask is not None: + if attention_mask.size() != (batch_size, 1, target_len, source_len): + raise ValueError( + f"Attention mask should be of size {(batch_size, 1, target_len, source_len)}, but is" + f" {attention_mask.size()}" + ) + attn_weights = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attention_mask + attn_weights = attn_weights.view(batch_size * self.num_heads, target_len, source_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attn_weights = attn_weights_reshaped.view(batch_size * self.num_heads, target_len, source_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (batch_size * self.num_heads, target_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(batch_size, self.num_heads, target_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(batch_size, self.num_heads, target_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(batch_size, target_len, embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped + + +class RTDetrV2DecoderLayer(nn.Module): + def __init__(self, config: RTDetrV2Config): + super().__init__() + # self-attention + self.self_attn = RTDetrV2MultiheadAttention( + embed_dim=config.d_model, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + ) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.decoder_activation_function] + self.activation_dropout = config.activation_dropout + + self.self_attn_layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps) + # override only the encoder attention module with v2 version + self.encoder_attn = RTDetrV2MultiscaleDeformableAttention(config) + self.encoder_attn_layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps) + # feedforward neural networks + self.fc1 = nn.Linear(config.d_model, config.decoder_ffn_dim) + self.fc2 = nn.Linear(config.decoder_ffn_dim, config.d_model) + self.final_layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Optional[torch.Tensor] = None, + reference_points=None, + spatial_shapes=None, + spatial_shapes_list=None, + level_start_index=None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ): + """ + Args: + hidden_states (`torch.FloatTensor`): + Input to the layer of shape `(seq_len, batch, embed_dim)`. + position_embeddings (`torch.FloatTensor`, *optional*): + Position embeddings that are added to the queries and keys in the self-attention layer. + reference_points (`torch.FloatTensor`, *optional*): + Reference points. + spatial_shapes (`torch.LongTensor`, *optional*): + Spatial shapes. + level_start_index (`torch.LongTensor`, *optional*): + Level start index. + encoder_hidden_states (`torch.FloatTensor`): + cross attention input to the layer of shape `(seq_len, batch, embed_dim)` + encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size + `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative + values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=encoder_attention_mask, + position_embeddings=position_embeddings, + output_attentions=output_attentions, + ) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + second_residual = hidden_states + + # Cross-Attention + cross_attn_weights = None + hidden_states, cross_attn_weights = self.encoder_attn( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + position_embeddings=position_embeddings, + reference_points=reference_points, + spatial_shapes=spatial_shapes, + spatial_shapes_list=spatial_shapes_list, + level_start_index=level_start_index, + output_attentions=output_attentions, + ) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = second_residual + hidden_states + + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + # Fully Connected + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + return outputs + + +@dataclass +class RTDetrV2DecoderOutput(ModelOutput): + """ + Base class for outputs of the RTDetrV2Decoder. This class adds two attributes to + BaseModelOutputWithCrossAttentions, namely: + - a stacked tensor of intermediate decoder hidden states (i.e. the output of each decoder layer) + - a stacked tensor of intermediate reference points. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`): + Stacked intermediate hidden states (output of each layer of the decoder). + intermediate_logits (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, config.num_labels)`): + Stacked intermediate logits (logits of each layer of the decoder). + intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, hidden_size)`): + Stacked intermediate reference points (reference points of each layer of the decoder). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer + plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in + the self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax, + used to compute the weighted average in the cross-attention heads. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + intermediate_hidden_states: Optional[torch.FloatTensor] = None + intermediate_logits: Optional[torch.FloatTensor] = None + intermediate_reference_points: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class RTDetrV2ModelOutput(ModelOutput): + """ + Base class for outputs of the RT-DETR encoder-decoder model. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the decoder of the model. + intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`): + Stacked intermediate hidden states (output of each layer of the decoder). + intermediate_logits (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, config.num_labels)`): + Stacked intermediate logits (logits of each layer of the decoder). + intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`): + Stacked intermediate reference points (reference points of each layer of the decoder). + decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, num_queries, hidden_size)`. Hidden-states of the decoder at the output of each layer + plus the initial embedding outputs. + decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, num_queries, + num_queries)`. Attentions weights of the decoder, after the attention softmax, used to compute the weighted + average in the self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_queries, num_heads, 4, 4)`. + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the encoder at the output of each + layer plus the initial embedding outputs. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_queries, num_heads, 4, 4)`. + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): + Initial reference points sent through the Transformer decoder. + enc_topk_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`): + Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are + picked as region proposals in the encoder stage. Output of bounding box binary classification (i.e. + foreground and background). + enc_topk_bboxes (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`): + Logits of predicted bounding boxes coordinates in the encoder stage. + enc_outputs_class (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): + Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are + picked as region proposals in the first stage. Output of bounding box binary classification (i.e. + foreground and background). + enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): + Logits of predicted bounding boxes coordinates in the first stage. + denoising_meta_values (`dict`): + Extra dictionary for the denoising related values + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + intermediate_hidden_states: Optional[torch.FloatTensor] = None + intermediate_logits: Optional[torch.FloatTensor] = None + intermediate_reference_points: Optional[torch.FloatTensor] = None + decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + init_reference_points: Optional[torch.FloatTensor] = None + enc_topk_logits: Optional[torch.FloatTensor] = None + enc_topk_bboxes: Optional[torch.FloatTensor] = None + enc_outputs_class: Optional[torch.FloatTensor] = None + enc_outputs_coord_logits: Optional[torch.FloatTensor] = None + denoising_meta_values: Optional[Dict] = None + + +@dataclass +class RTDetrV2ObjectDetectionOutput(ModelOutput): + """ + Output type of [`RTDetrV2ForObjectDetection`]. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)): + Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a + bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized + scale-invariant IoU loss. + loss_dict (`Dict`, *optional*): + A dictionary containing the individual losses. Useful for logging. + logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes + 1)`): + Classification logits (including no-object) for all queries. + pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): + Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These + values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding + possible padding). You can use [`~RTDetrV2ImageProcessor.post_process_object_detection`] to retrieve the + unnormalized (absolute) bounding boxes. + auxiliary_outputs (`list[Dict]`, *optional*): + Optional, only returned when auxiliary losses are activated (i.e. `config.auxiliary_loss` is set to `True`) + and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and + `pred_boxes`) for each decoder layer. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the decoder of the model. + intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`): + Stacked intermediate hidden states (output of each layer of the decoder). + intermediate_logits (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, config.num_labels)`): + Stacked intermediate logits (logits of each layer of the decoder). + intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`): + Stacked intermediate reference points (reference points of each layer of the decoder). + decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, num_queries, hidden_size)`. Hidden-states of the decoder at the output of each layer + plus the initial embedding outputs. + decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, num_queries, + num_queries)`. Attentions weights of the decoder, after the attention softmax, used to compute the weighted + average in the self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_queries, num_heads, 4, 4)`. + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the encoder at the output of each + layer plus the initial embedding outputs. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_queries, num_heads, 4, 4)`. + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): + Initial reference points sent through the Transformer decoder. + enc_topk_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): + Logits of predicted bounding boxes coordinates in the encoder. + enc_topk_bboxes (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): + Logits of predicted bounding boxes coordinates in the encoder. + enc_outputs_class (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): + Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are + picked as region proposals in the first stage. Output of bounding box binary classification (i.e. + foreground and background). + enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): + Logits of predicted bounding boxes coordinates in the first stage. + denoising_meta_values (`dict`): + Extra dictionary for the denoising related values + """ + + loss: Optional[torch.FloatTensor] = None + loss_dict: Optional[Dict] = None + logits: Optional[torch.FloatTensor] = None + pred_boxes: Optional[torch.FloatTensor] = None + auxiliary_outputs: Optional[List[Dict]] = None + last_hidden_state: Optional[torch.FloatTensor] = None + intermediate_hidden_states: Optional[torch.FloatTensor] = None + intermediate_logits: Optional[torch.FloatTensor] = None + intermediate_reference_points: Optional[torch.FloatTensor] = None + decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + init_reference_points: Optional[Tuple[torch.FloatTensor]] = None + enc_topk_logits: Optional[torch.FloatTensor] = None + enc_topk_bboxes: Optional[torch.FloatTensor] = None + enc_outputs_class: Optional[torch.FloatTensor] = None + enc_outputs_coord_logits: Optional[torch.FloatTensor] = None + denoising_meta_values: Optional[Dict] = None + + +class RTDetrV2FrozenBatchNorm2d(nn.Module): + """ + BatchNorm2d where the batch statistics and the affine parameters are fixed. + + Copy-paste from torchvision.misc.ops with added eps before rqsrt, without which any other models than + torchvision.models.resnet[18,34,50,101] produce nans. + """ + + def __init__(self, n): + super().__init__() + self.register_buffer("weight", torch.ones(n)) + self.register_buffer("bias", torch.zeros(n)) + self.register_buffer("running_mean", torch.zeros(n)) + self.register_buffer("running_var", torch.ones(n)) + + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + num_batches_tracked_key = prefix + "num_batches_tracked" + if num_batches_tracked_key in state_dict: + del state_dict[num_batches_tracked_key] + + super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) + + def forward(self, x): + # move reshapes to the beginning + # to make it user-friendly + weight = self.weight.reshape(1, -1, 1, 1) + bias = self.bias.reshape(1, -1, 1, 1) + running_var = self.running_var.reshape(1, -1, 1, 1) + running_mean = self.running_mean.reshape(1, -1, 1, 1) + epsilon = 1e-5 + scale = weight * (running_var + epsilon).rsqrt() + bias = bias - running_mean * scale + return x * scale + bias + + +def replace_batch_norm(model): + r""" + Recursively replace all `torch.nn.BatchNorm2d` with `RTDetrV2FrozenBatchNorm2d`. + + Args: + model (torch.nn.Module): + input model + """ + for name, module in model.named_children(): + if isinstance(module, nn.BatchNorm2d): + new_module = RTDetrV2FrozenBatchNorm2d(module.num_features) + + if not module.weight.device == torch.device("meta"): + new_module.weight.data.copy_(module.weight) + new_module.bias.data.copy_(module.bias) + new_module.running_mean.data.copy_(module.running_mean) + new_module.running_var.data.copy_(module.running_var) + + model._modules[name] = new_module + + if len(list(module.children())) > 0: + replace_batch_norm(module) + + +class RTDetrV2ConvEncoder(nn.Module): + """ + Convolutional backbone using the modeling_rt_detr_v2_resnet.py. + + nn.BatchNorm2d layers are replaced by RTDetrV2FrozenBatchNorm2d as defined above. + https://github.com/lyuwenyu/RT-DETR/blob/main/RTDetrV2_pytorch/src/nn/backbone/presnet.py#L142 + """ + + def __init__(self, config): + super().__init__() + + backbone = load_backbone(config) + + if config.freeze_backbone_batch_norms: + # replace batch norm by frozen batch norm + with torch.no_grad(): + replace_batch_norm(backbone) + self.model = backbone + self.intermediate_channel_sizes = self.model.channels + + def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor): + # send pixel_values through the model to get list of feature maps + features = self.model(pixel_values).feature_maps + + out = [] + for feature_map in features: + # downsample pixel_mask to match shape of corresponding feature_map + mask = nn.functional.interpolate(pixel_mask[None].float(), size=feature_map.shape[-2:]).to(torch.bool)[0] + out.append((feature_map, mask)) + return out + + +class RTDetrV2ConvNormLayer(nn.Module): + def __init__(self, config, in_channels, out_channels, kernel_size, stride, padding=None, activation=None): + super().__init__() + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride, + padding=(kernel_size - 1) // 2 if padding is None else padding, + bias=False, + ) + self.norm = nn.BatchNorm2d(out_channels, config.batch_norm_eps) + self.activation = nn.Identity() if activation is None else ACT2CLS[activation]() + + def forward(self, hidden_state): + hidden_state = self.conv(hidden_state) + hidden_state = self.norm(hidden_state) + hidden_state = self.activation(hidden_state) + return hidden_state + + +class RTDetrV2EncoderLayer(nn.Module): + def __init__(self, config: RTDetrV2Config): + super().__init__() + self.normalize_before = config.normalize_before + + # self-attention + self.self_attn = RTDetrV2MultiheadAttention( + embed_dim=config.encoder_hidden_dim, + num_heads=config.num_attention_heads, + dropout=config.dropout, + ) + self.self_attn_layer_norm = nn.LayerNorm(config.encoder_hidden_dim, eps=config.layer_norm_eps) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.encoder_activation_function] + self.activation_dropout = config.activation_dropout + self.fc1 = nn.Linear(config.encoder_hidden_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, config.encoder_hidden_dim) + self.final_layer_norm = nn.LayerNorm(config.encoder_hidden_dim, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + position_embeddings: Optional[torch.Tensor] = None, + output_attentions: bool = False, + **kwargs, + ): + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative + values. + position_embeddings (`torch.FloatTensor`, *optional*): + Object queries (also called content embeddings), to be added to the hidden states. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + if self.normalize_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_embeddings=position_embeddings, + output_attentions=output_attentions, + ) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + if not self.normalize_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + if self.normalize_before: + hidden_states = self.final_layer_norm(hidden_states) + residual = hidden_states + + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + + hidden_states = self.fc2(hidden_states) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + hidden_states = residual + hidden_states + if not self.normalize_before: + hidden_states = self.final_layer_norm(hidden_states) + + if self.training: + if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class RTDetrV2RepVggBlock(nn.Module): + """ + RepVGG architecture block introduced by the work "RepVGG: Making VGG-style ConvNets Great Again". + """ + + def __init__(self, config: RTDetrV2Config): + super().__init__() + + activation = config.activation_function + hidden_channels = int(config.encoder_hidden_dim * config.hidden_expansion) + self.conv1 = RTDetrV2ConvNormLayer(config, hidden_channels, hidden_channels, 3, 1, padding=1) + self.conv2 = RTDetrV2ConvNormLayer(config, hidden_channels, hidden_channels, 1, 1, padding=0) + self.activation = nn.Identity() if activation is None else ACT2CLS[activation]() + + def forward(self, x): + y = self.conv1(x) + self.conv2(x) + return self.activation(y) + + +class RTDetrV2CSPRepLayer(nn.Module): + """ + Cross Stage Partial (CSP) network layer with RepVGG blocks. + """ + + def __init__(self, config: RTDetrV2Config): + super().__init__() + + in_channels = config.encoder_hidden_dim * 2 + out_channels = config.encoder_hidden_dim + num_blocks = 3 + activation = config.activation_function + + hidden_channels = int(out_channels * config.hidden_expansion) + self.conv1 = RTDetrV2ConvNormLayer(config, in_channels, hidden_channels, 1, 1, activation=activation) + self.conv2 = RTDetrV2ConvNormLayer(config, in_channels, hidden_channels, 1, 1, activation=activation) + self.bottlenecks = nn.Sequential(*[RTDetrV2RepVggBlock(config) for _ in range(num_blocks)]) + if hidden_channels != out_channels: + self.conv3 = RTDetrV2ConvNormLayer(config, hidden_channels, out_channels, 1, 1, activation=activation) + else: + self.conv3 = nn.Identity() + + def forward(self, hidden_state): + hidden_state_1 = self.conv1(hidden_state) + hidden_state_1 = self.bottlenecks(hidden_state_1) + hidden_state_2 = self.conv2(hidden_state) + return self.conv3(hidden_state_1 + hidden_state_2) + + +class RTDetrV2Encoder(nn.Module): + def __init__(self, config: RTDetrV2Config): + super().__init__() + + self.layers = nn.ModuleList([RTDetrV2EncoderLayer(config) for _ in range(config.encoder_layers)]) + + def forward(self, src, src_mask=None, pos_embed=None, output_attentions: bool = False) -> torch.Tensor: + hidden_states = src + for layer in self.layers: + hidden_states = layer( + hidden_states, + attention_mask=src_mask, + position_embeddings=pos_embed, + output_attentions=output_attentions, + ) + return hidden_states + + +class RTDetrV2HybridEncoder(nn.Module): + """ + Decoder consisting of a projection layer, a set of `RTDetrV2Encoder`, a top-down Feature Pyramid Network + (FPN) and a bottom-up Path Aggregation Network (PAN). More details on the paper: https://arxiv.org/abs/2304.08069 + + Args: + config: RTDetrV2Config + """ + + def __init__(self, config: RTDetrV2Config): + super().__init__() + self.config = config + self.in_channels = config.encoder_in_channels + self.feat_strides = config.feat_strides + self.encoder_hidden_dim = config.encoder_hidden_dim + self.encode_proj_layers = config.encode_proj_layers + self.positional_encoding_temperature = config.positional_encoding_temperature + self.eval_size = config.eval_size + self.out_channels = [self.encoder_hidden_dim for _ in self.in_channels] + self.out_strides = self.feat_strides + self.num_fpn_stages = len(self.in_channels) - 1 + self.num_pan_stages = len(self.in_channels) - 1 + activation = config.activation_function + + # encoder transformer + self.encoder = nn.ModuleList([RTDetrV2Encoder(config) for _ in range(len(self.encode_proj_layers))]) + + # top-down FPN + self.lateral_convs = nn.ModuleList() + self.fpn_blocks = nn.ModuleList() + for _ in range(self.num_fpn_stages): + lateral_conv = RTDetrV2ConvNormLayer( + config, + in_channels=self.encoder_hidden_dim, + out_channels=self.encoder_hidden_dim, + kernel_size=1, + stride=1, + activation=activation, + ) + fpn_block = RTDetrV2CSPRepLayer(config) + self.lateral_convs.append(lateral_conv) + self.fpn_blocks.append(fpn_block) + + # bottom-up PAN + self.downsample_convs = nn.ModuleList() + self.pan_blocks = nn.ModuleList() + for _ in range(self.num_pan_stages): + downsample_conv = RTDetrV2ConvNormLayer( + config, + in_channels=self.encoder_hidden_dim, + out_channels=self.encoder_hidden_dim, + kernel_size=3, + stride=2, + activation=activation, + ) + pan_block = RTDetrV2CSPRepLayer(config) + self.downsample_convs.append(downsample_conv) + self.pan_blocks.append(pan_block) + + @staticmethod + def build_2d_sincos_position_embedding( + width, height, embed_dim=256, temperature=10000.0, device="cpu", dtype=torch.float32 + ): + grid_w = torch.arange(torch_int(width), device=device).to(dtype) + grid_h = torch.arange(torch_int(height), device=device).to(dtype) + grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="ij") + if embed_dim % 4 != 0: + raise ValueError("Embed dimension must be divisible by 4 for 2D sin-cos position embedding") + pos_dim = embed_dim // 4 + omega = torch.arange(pos_dim, device=device).to(dtype) / pos_dim + omega = 1.0 / (temperature**omega) + + out_w = grid_w.flatten()[..., None] @ omega[None] + out_h = grid_h.flatten()[..., None] @ omega[None] + + return torch.concat([out_w.sin(), out_w.cos(), out_h.sin(), out_h.cos()], dim=1)[None, :, :] + + def forward( + self, + inputs_embeds=None, + attention_mask=None, + position_embeddings=None, + spatial_shapes=None, + level_start_index=None, + valid_ratios=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Flattened feature map (output of the backbone + projection layer) that is passed to the encoder. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding pixel features. Mask values selected in `[0, 1]`: + - 1 for pixel features that are real (i.e. **not masked**), + - 0 for pixel features that are padding (i.e. **masked**). + [What are attention masks?](../glossary#attention-mask) + position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Position embeddings that are added to the queries and keys in each self-attention layer. + spatial_shapes (`torch.LongTensor` of shape `(num_feature_levels, 2)`): + Spatial shapes of each feature map. + level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`): + Starting index of each feature map. + valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`): + Ratio of valid area in each feature level. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + hidden_states = inputs_embeds + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + # encoder + if self.config.encoder_layers > 0: + for i, enc_ind in enumerate(self.encode_proj_layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states[enc_ind],) + height, width = hidden_states[enc_ind].shape[2:] + # flatten [batch, channel, height, width] to [batch, height*width, channel] + src_flatten = hidden_states[enc_ind].flatten(2).permute(0, 2, 1) + if self.training or self.eval_size is None: + pos_embed = self.build_2d_sincos_position_embedding( + width, + height, + self.encoder_hidden_dim, + self.positional_encoding_temperature, + device=src_flatten.device, + dtype=src_flatten.dtype, + ) + else: + pos_embed = None + + layer_outputs = self.encoder[i]( + src_flatten, + pos_embed=pos_embed, + output_attentions=output_attentions, + ) + hidden_states[enc_ind] = ( + layer_outputs[0].permute(0, 2, 1).reshape(-1, self.encoder_hidden_dim, height, width).contiguous() + ) + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states[enc_ind],) + + # top-down FPN + fpn_feature_maps = [hidden_states[-1]] + for idx, (lateral_conv, fpn_block) in enumerate(zip(self.lateral_convs, self.fpn_blocks)): + backbone_feature_map = hidden_states[self.num_fpn_stages - idx - 1] + top_fpn_feature_map = fpn_feature_maps[-1] + # apply lateral block + top_fpn_feature_map = lateral_conv(top_fpn_feature_map) + fpn_feature_maps[-1] = top_fpn_feature_map + # apply fpn block + top_fpn_feature_map = F.interpolate(top_fpn_feature_map, scale_factor=2.0, mode="nearest") + fused_feature_map = torch.concat([top_fpn_feature_map, backbone_feature_map], dim=1) + new_fpn_feature_map = fpn_block(fused_feature_map) + fpn_feature_maps.append(new_fpn_feature_map) + + fpn_feature_maps = fpn_feature_maps[::-1] + + # bottom-up PAN + pan_feature_maps = [fpn_feature_maps[0]] + for idx, (downsample_conv, pan_block) in enumerate(zip(self.downsample_convs, self.pan_blocks)): + top_pan_feature_map = pan_feature_maps[-1] + fpn_feature_map = fpn_feature_maps[idx + 1] + downsampled_feature_map = downsample_conv(top_pan_feature_map) + fused_feature_map = torch.concat([downsampled_feature_map, fpn_feature_map], dim=1) + new_pan_feature_map = pan_block(fused_feature_map) + pan_feature_maps.append(new_pan_feature_map) + + if not return_dict: + return tuple(v for v in [pan_feature_maps, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=pan_feature_maps, hidden_states=encoder_states, attentions=all_attentions + ) + + +def inverse_sigmoid(x, eps=1e-5): + x = x.clamp(min=0, max=1) + x1 = x.clamp(min=eps) + x2 = (1 - x).clamp(min=eps) + return torch.log(x1 / x2) + + +def get_contrastive_denoising_training_group( + targets, + num_classes, + num_queries, + class_embed, + num_denoising_queries=100, + label_noise_ratio=0.5, + box_noise_scale=1.0, +): + """ + Creates a contrastive denoising training group using ground-truth samples. It adds noise to labels and boxes. + + Args: + targets (`List[dict]`): + The target objects, each containing 'class_labels' and 'boxes' for objects in an image. + num_classes (`int`): + Total number of classes in the dataset. + num_queries (`int`): + Number of query slots in the transformer. + class_embed (`callable`): + A function or a model layer to embed class labels. + num_denoising_queries (`int`, *optional*, defaults to 100): + Number of denoising queries. + label_noise_ratio (`float`, *optional*, defaults to 0.5): + Ratio of noise applied to labels. + box_noise_scale (`float`, *optional*, defaults to 1.0): + Scale of noise applied to bounding boxes. + Returns: + `tuple` comprising various elements: + - **input_query_class** (`torch.FloatTensor`) -- + Class queries with applied label noise. + - **input_query_bbox** (`torch.FloatTensor`) -- + Bounding box queries with applied box noise. + - **attn_mask** (`torch.FloatTensor`) -- + Attention mask for separating denoising and reconstruction queries. + - **denoising_meta_values** (`dict`) -- + Metadata including denoising positive indices, number of groups, and split sizes. + """ + + if num_denoising_queries <= 0: + return None, None, None, None + + num_ground_truths = [len(t["class_labels"]) for t in targets] + device = targets[0]["class_labels"].device + + max_gt_num = max(num_ground_truths) + if max_gt_num == 0: + return None, None, None, None + + num_groups_denoising_queries = num_denoising_queries // max_gt_num + num_groups_denoising_queries = 1 if num_groups_denoising_queries == 0 else num_groups_denoising_queries + # pad gt to max_num of a batch + batch_size = len(num_ground_truths) + + input_query_class = torch.full([batch_size, max_gt_num], num_classes, dtype=torch.int32, device=device) + input_query_bbox = torch.zeros([batch_size, max_gt_num, 4], device=device) + pad_gt_mask = torch.zeros([batch_size, max_gt_num], dtype=torch.bool, device=device) + + for i in range(batch_size): + num_gt = num_ground_truths[i] + if num_gt > 0: + input_query_class[i, :num_gt] = targets[i]["class_labels"] + input_query_bbox[i, :num_gt] = targets[i]["boxes"] + pad_gt_mask[i, :num_gt] = 1 + # each group has positive and negative queries. + input_query_class = input_query_class.tile([1, 2 * num_groups_denoising_queries]) + input_query_bbox = input_query_bbox.tile([1, 2 * num_groups_denoising_queries, 1]) + pad_gt_mask = pad_gt_mask.tile([1, 2 * num_groups_denoising_queries]) + # positive and negative mask + negative_gt_mask = torch.zeros([batch_size, max_gt_num * 2, 1], device=device) + negative_gt_mask[:, max_gt_num:] = 1 + negative_gt_mask = negative_gt_mask.tile([1, num_groups_denoising_queries, 1]) + positive_gt_mask = 1 - negative_gt_mask + # contrastive denoising training positive index + positive_gt_mask = positive_gt_mask.squeeze(-1) * pad_gt_mask + denoise_positive_idx = torch.nonzero(positive_gt_mask)[:, 1] + denoise_positive_idx = torch.split( + denoise_positive_idx, [n * num_groups_denoising_queries for n in num_ground_truths] + ) + # total denoising queries + num_denoising_queries = torch_int(max_gt_num * 2 * num_groups_denoising_queries) + + if label_noise_ratio > 0: + mask = torch.rand_like(input_query_class, dtype=torch.float) < (label_noise_ratio * 0.5) + # randomly put a new one here + new_label = torch.randint_like(mask, 0, num_classes, dtype=input_query_class.dtype) + input_query_class = torch.where(mask & pad_gt_mask, new_label, input_query_class) + + if box_noise_scale > 0: + known_bbox = center_to_corners_format(input_query_bbox) + diff = torch.tile(input_query_bbox[..., 2:] * 0.5, [1, 1, 2]) * box_noise_scale + rand_sign = torch.randint_like(input_query_bbox, 0, 2) * 2.0 - 1.0 + rand_part = torch.rand_like(input_query_bbox) + rand_part = (rand_part + 1.0) * negative_gt_mask + rand_part * (1 - negative_gt_mask) + rand_part *= rand_sign + known_bbox += rand_part * diff + known_bbox.clip_(min=0.0, max=1.0) + input_query_bbox = corners_to_center_format(known_bbox) + input_query_bbox = inverse_sigmoid(input_query_bbox) + + input_query_class = class_embed(input_query_class) + + target_size = num_denoising_queries + num_queries + attn_mask = torch.full([target_size, target_size], False, dtype=torch.bool, device=device) + # match query cannot see the reconstruction + attn_mask[num_denoising_queries:, :num_denoising_queries] = True + + # reconstructions cannot see each other + for i in range(num_groups_denoising_queries): + idx_block_start = max_gt_num * 2 * i + idx_block_end = max_gt_num * 2 * (i + 1) + attn_mask[idx_block_start:idx_block_end, :idx_block_start] = True + attn_mask[idx_block_start:idx_block_end, idx_block_end:num_denoising_queries] = True + + denoising_meta_values = { + "dn_positive_idx": denoise_positive_idx, + "dn_num_group": num_groups_denoising_queries, + "dn_num_split": [num_denoising_queries, num_queries], + } + + return input_query_class, input_query_bbox, attn_mask, denoising_meta_values + + +RTDetrV2_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`RTDetrV2Config`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +RTDetrV2_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`RTDetrV2ImageProcessor.__call__`] for details. + pixel_mask (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*): + Mask to avoid performing attention on padding pixel values. Mask values selected in `[0, 1]`: + + - 1 for pixels that are real (i.e. **not masked**), + - 0 for pixels that are padding (i.e. **masked**). + + [What are attention masks?](../glossary#attention-mask) + encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): + Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of + hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you + can choose to directly pass a flattened representation of an image. + decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*): + Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an + embedded representation. + labels (`List[Dict]` of len `(batch_size,)`, *optional*): + Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the + following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch + respectively). The class labels themselves should be a `torch.LongTensor` of len `(number of bounding boxes + in the image,)` and the boxes a `torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +def _get_clones(partial_module, N): + return nn.ModuleList([partial_module() for i in range(N)]) + + +class RTDetrV2PreTrainedModel(PreTrainedModel): + config_class = RTDetrV2Config + base_model_prefix = "rt_detr_v2" + main_input_name = "pixel_values" + _no_split_modules = [r"RTDetrV2HybridEncoder", r"RTDetrV2DecoderLayer"] + + def _init_weights(self, module): + """Initalize the weights""" + if isinstance(module, (RTDetrV2ForObjectDetection, RTDetrV2Decoder)): + if module.class_embed is not None: + for layer in module.class_embed: + prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1) + bias = float(-math.log((1 - prior_prob) / prior_prob)) + nn.init.xavier_uniform_(layer.weight) + nn.init.constant_(layer.bias, bias) + + if module.bbox_embed is not None: + for layer in module.bbox_embed: + nn.init.constant_(layer.layers[-1].weight, 0) + nn.init.constant_(layer.layers[-1].bias, 0) + + elif isinstance(module, RTDetrV2MultiscaleDeformableAttention): + nn.init.constant_(module.sampling_offsets.weight.data, 0.0) + default_dtype = torch.get_default_dtype() + thetas = torch.arange(module.n_heads, dtype=torch.int64).to(default_dtype) * ( + 2.0 * math.pi / module.n_heads + ) + grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) + grid_init = ( + (grid_init / grid_init.abs().max(-1, keepdim=True)[0]) + .view(module.n_heads, 1, 1, 2) + .repeat(1, module.n_levels, module.n_points, 1) + ) + for i in range(module.n_points): + grid_init[:, :, i, :] *= i + 1 + with torch.no_grad(): + module.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) + nn.init.constant_(module.attention_weights.weight.data, 0.0) + nn.init.constant_(module.attention_weights.bias.data, 0.0) + nn.init.xavier_uniform_(module.value_proj.weight.data) + nn.init.constant_(module.value_proj.bias.data, 0.0) + nn.init.xavier_uniform_(module.output_proj.weight.data) + nn.init.constant_(module.output_proj.bias.data, 0.0) + + elif isinstance(module, RTDetrV2Model): + prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1) + bias = float(-math.log((1 - prior_prob) / prior_prob)) + nn.init.xavier_uniform_(module.enc_score_head.weight) + nn.init.constant_(module.enc_score_head.bias, bias) + + elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + + if hasattr(module, "weight_embedding") and self.config.learn_initial_query: + nn.init.xavier_uniform_(module.weight_embedding.weight) + if hasattr(module, "denoising_class_embed") and self.config.num_denoising > 0: + nn.init.xavier_uniform_(module.denoising_class_embed.weight) + + +class RTDetrV2Decoder(RTDetrV2PreTrainedModel): + def __init__(self, config: RTDetrV2Config): + super().__init__(config) + + self.dropout = config.dropout + self.layers = nn.ModuleList([RTDetrV2DecoderLayer(config) for _ in range(config.decoder_layers)]) + self.query_pos_head = RTDetrV2MLPPredictionHead(config, 4, 2 * config.d_model, config.d_model, num_layers=2) + + # hack implementation for iterative bounding box refinement and two-stage Deformable DETR + self.bbox_embed = None + self.class_embed = None + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + position_embeddings=None, + reference_points=None, + spatial_shapes=None, + spatial_shapes_list=None, + level_start_index=None, + valid_ratios=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`): + The query embeddings that are passed into the decoder. + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing cross-attention on padding pixel_values of the encoder. Mask values selected + in `[0, 1]`: + - 1 for pixels that are real (i.e. **not masked**), + - 0 for pixels that are padding (i.e. **masked**). + position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*): + Position embeddings that are added to the queries and keys in each self-attention layer. + reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)` is `as_two_stage` else `(batch_size, num_queries, 2)` or , *optional*): + Reference point in range `[0, 1]`, top-left (0,0), bottom-right (1, 1), including padding area. + spatial_shapes (`torch.FloatTensor` of shape `(num_feature_levels, 2)`): + Spatial shapes of the feature maps. + level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`, *optional*): + Indexes for the start of each feature level. In range `[0, sequence_length]`. + valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`, *optional*): + Ratio of valid area in each feature level. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if inputs_embeds is not None: + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + intermediate = () + intermediate_reference_points = () + intermediate_logits = () + + reference_points = F.sigmoid(reference_points) + + # https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/RTDetrV2_pytorch/src/zoo/RTDetrV2/RTDetrV2_decoder.py#L252 + for idx, decoder_layer in enumerate(self.layers): + reference_points_input = reference_points.unsqueeze(2) + position_embeddings = self.query_pos_head(reference_points) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + position_embeddings=position_embeddings, + encoder_hidden_states=encoder_hidden_states, + reference_points=reference_points_input, + spatial_shapes=spatial_shapes, + spatial_shapes_list=spatial_shapes_list, + level_start_index=level_start_index, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + # hack implementation for iterative bounding box refinement + if self.bbox_embed is not None: + tmp = self.bbox_embed[idx](hidden_states) + new_reference_points = F.sigmoid(tmp + inverse_sigmoid(reference_points)) + reference_points = new_reference_points.detach() + + intermediate += (hidden_states,) + intermediate_reference_points += ( + (new_reference_points,) if self.bbox_embed is not None else (reference_points,) + ) + + if self.class_embed is not None: + logits = self.class_embed[idx](hidden_states) + intermediate_logits += (logits,) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + # Keep batch_size as first dimension + intermediate = torch.stack(intermediate, dim=1) + intermediate_reference_points = torch.stack(intermediate_reference_points, dim=1) + if self.class_embed is not None: + intermediate_logits = torch.stack(intermediate_logits, dim=1) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + intermediate, + intermediate_logits, + intermediate_reference_points, + all_hidden_states, + all_self_attns, + all_cross_attentions, + ] + if v is not None + ) + return RTDetrV2DecoderOutput( + last_hidden_state=hidden_states, + intermediate_hidden_states=intermediate, + intermediate_logits=intermediate_logits, + intermediate_reference_points=intermediate_reference_points, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + +@add_start_docstrings( + """ + RT-DETR Model (consisting of a backbone and encoder-decoder) outputting raw hidden states without any head on top. + """, + RTDetrV2_START_DOCSTRING, +) +class RTDetrV2Model(RTDetrV2PreTrainedModel): + def __init__(self, config: RTDetrV2Config): + super().__init__(config) + + # Create backbone + self.backbone = RTDetrV2ConvEncoder(config) + intermediate_channel_sizes = self.backbone.intermediate_channel_sizes + + # Create encoder input projection layers + # https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/RTDetrV2_pytorch/src/zoo/RTDetrV2/hybrid_encoder.py#L212 + num_backbone_outs = len(intermediate_channel_sizes) + encoder_input_proj_list = [] + for _ in range(num_backbone_outs): + in_channels = intermediate_channel_sizes[_] + encoder_input_proj_list.append( + nn.Sequential( + nn.Conv2d(in_channels, config.encoder_hidden_dim, kernel_size=1, bias=False), + nn.BatchNorm2d(config.encoder_hidden_dim), + ) + ) + self.encoder_input_proj = nn.ModuleList(encoder_input_proj_list) + + # Create encoder + self.encoder = RTDetrV2HybridEncoder(config) + + # denoising part + if config.num_denoising > 0: + self.denoising_class_embed = nn.Embedding( + config.num_labels + 1, config.d_model, padding_idx=config.num_labels + ) + + # decoder embedding + if config.learn_initial_query: + self.weight_embedding = nn.Embedding(config.num_queries, config.d_model) + + # encoder head + self.enc_output = nn.Sequential( + nn.Linear(config.d_model, config.d_model), + nn.LayerNorm(config.d_model, eps=config.layer_norm_eps), + ) + self.enc_score_head = nn.Linear(config.d_model, config.num_labels) + self.enc_bbox_head = RTDetrV2MLPPredictionHead(config, config.d_model, config.d_model, 4, num_layers=3) + + # init encoder output anchors and valid_mask + if config.anchor_image_size: + self.anchors, self.valid_mask = self.generate_anchors(dtype=self.dtype) + + # Create decoder input projection layers + # https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/RTDetrV2_pytorch/src/zoo/RTDetrV2/RTDetrV2_decoder.py#L412 + num_backbone_outs = len(config.decoder_in_channels) + decoder_input_proj_list = [] + for _ in range(num_backbone_outs): + in_channels = config.decoder_in_channels[_] + decoder_input_proj_list.append( + nn.Sequential( + nn.Conv2d(in_channels, config.d_model, kernel_size=1, bias=False), + nn.BatchNorm2d(config.d_model, config.batch_norm_eps), + ) + ) + for _ in range(config.num_feature_levels - num_backbone_outs): + decoder_input_proj_list.append( + nn.Sequential( + nn.Conv2d(in_channels, config.d_model, kernel_size=3, stride=2, padding=1, bias=False), + nn.BatchNorm2d(config.d_model, config.batch_norm_eps), + ) + ) + in_channels = config.d_model + self.decoder_input_proj = nn.ModuleList(decoder_input_proj_list) + # decoder + self.decoder = RTDetrV2Decoder(config) + + self.post_init() + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + def freeze_backbone(self): + for param in self.backbone.parameters(): + param.requires_grad_(False) + + def unfreeze_backbone(self): + for param in self.backbone.parameters(): + param.requires_grad_(True) + + @compile_compatible_method_lru_cache(maxsize=32) + def generate_anchors(self, spatial_shapes=None, grid_size=0.05, device="cpu", dtype=torch.float32): + if spatial_shapes is None: + spatial_shapes = [ + [int(self.config.anchor_image_size[0] / s), int(self.config.anchor_image_size[1] / s)] + for s in self.config.feat_strides + ] + anchors = [] + for level, (height, width) in enumerate(spatial_shapes): + grid_y, grid_x = torch.meshgrid( + torch.arange(end=height, device=device).to(dtype), + torch.arange(end=width, device=device).to(dtype), + indexing="ij", + ) + grid_xy = torch.stack([grid_x, grid_y], -1) + grid_xy = grid_xy.unsqueeze(0) + 0.5 + grid_xy[..., 0] /= width + grid_xy[..., 1] /= height + wh = torch.ones_like(grid_xy) * grid_size * (2.0**level) + anchors.append(torch.concat([grid_xy, wh], -1).reshape(-1, height * width, 4)) + # define the valid range for anchor coordinates + eps = 1e-2 + anchors = torch.concat(anchors, 1) + valid_mask = ((anchors > eps) * (anchors < 1 - eps)).all(-1, keepdim=True) + anchors = torch.log(anchors / (1 - anchors)) + anchors = torch.where(valid_mask, anchors, torch.tensor(torch.finfo(dtype).max, dtype=dtype, device=device)) + + return anchors, valid_mask + + @add_start_docstrings_to_model_forward(RTDetrV2_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=RTDetrV2ModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: torch.FloatTensor, + pixel_mask: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[List[dict]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.FloatTensor], RTDetrV2ModelOutput]: + r""" + Returns: + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, RTDetrV2Model + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("PekingU/RTDetrV2_r50vd") + >>> model = RTDetrV2Model.from_pretrained("PekingU/RTDetrV2_r50vd") + + >>> inputs = image_processor(images=image, return_tensors="pt") + + >>> outputs = model(**inputs) + + >>> last_hidden_states = outputs.last_hidden_state + >>> list(last_hidden_states.shape) + [1, 300, 256] + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size, num_channels, height, width = pixel_values.shape + device = pixel_values.device + + if pixel_mask is None: + pixel_mask = torch.ones(((batch_size, height, width)), device=device) + + features = self.backbone(pixel_values, pixel_mask) + + proj_feats = [self.encoder_input_proj[level](source) for level, (source, mask) in enumerate(features)] + + if encoder_outputs is None: + encoder_outputs = self.encoder( + proj_feats, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if output_hidden_states else None, + attentions=encoder_outputs[2] + if len(encoder_outputs) > 2 + else encoder_outputs[1] + if output_attentions + else None, + ) + + # Equivalent to def _get_encoder_input + # https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/RTDetrV2_pytorch/src/zoo/RTDetrV2/RTDetrV2_decoder.py#L412 + sources = [] + for level, source in enumerate(encoder_outputs[0]): + sources.append(self.decoder_input_proj[level](source)) + + # Lowest resolution feature maps are obtained via 3x3 stride 2 convolutions on the final stage + if self.config.num_feature_levels > len(sources): + _len_sources = len(sources) + sources.append(self.decoder_input_proj[_len_sources](encoder_outputs[0])[-1]) + for i in range(_len_sources + 1, self.config.num_feature_levels): + sources.append(self.decoder_input_proj[i](encoder_outputs[0][-1])) + + # Prepare encoder inputs (by flattening) + source_flatten = [] + spatial_shapes_list = [] + spatial_shapes = torch.empty((len(sources), 2), device=device, dtype=torch.long) + for level, source in enumerate(sources): + height, width = source.shape[-2:] + spatial_shapes[level, 0] = height + spatial_shapes[level, 1] = width + spatial_shapes_list.append((height, width)) + source = source.flatten(2).transpose(1, 2) + source_flatten.append(source) + source_flatten = torch.cat(source_flatten, 1) + level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) + + # prepare denoising training + if self.training and self.config.num_denoising > 0 and labels is not None: + ( + denoising_class, + denoising_bbox_unact, + attention_mask, + denoising_meta_values, + ) = get_contrastive_denoising_training_group( + targets=labels, + num_classes=self.config.num_labels, + num_queries=self.config.num_queries, + class_embed=self.denoising_class_embed, + num_denoising_queries=self.config.num_denoising, + label_noise_ratio=self.config.label_noise_ratio, + box_noise_scale=self.config.box_noise_scale, + ) + else: + denoising_class, denoising_bbox_unact, attention_mask, denoising_meta_values = None, None, None, None + + batch_size = len(source_flatten) + device = source_flatten.device + dtype = source_flatten.dtype + + # prepare input for decoder + if self.training or self.config.anchor_image_size is None: + # Pass spatial_shapes as tuple to make it hashable and make sure + # lru_cache is working for generate_anchors() + spatial_shapes_tuple = tuple(spatial_shapes_list) + anchors, valid_mask = self.generate_anchors(spatial_shapes_tuple, device=device, dtype=dtype) + else: + anchors, valid_mask = self.anchors, self.valid_mask + anchors, valid_mask = anchors.to(device, dtype), valid_mask.to(device, dtype) + + # use the valid_mask to selectively retain values in the feature map where the mask is `True` + memory = valid_mask.to(source_flatten.dtype) * source_flatten + + output_memory = self.enc_output(memory) + + enc_outputs_class = self.enc_score_head(output_memory) + enc_outputs_coord_logits = self.enc_bbox_head(output_memory) + anchors + + _, topk_ind = torch.topk(enc_outputs_class.max(-1).values, self.config.num_queries, dim=1) + + reference_points_unact = enc_outputs_coord_logits.gather( + dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_coord_logits.shape[-1]) + ) + + enc_topk_bboxes = F.sigmoid(reference_points_unact) + if denoising_bbox_unact is not None: + reference_points_unact = torch.concat([denoising_bbox_unact, reference_points_unact], 1) + + enc_topk_logits = enc_outputs_class.gather( + dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_class.shape[-1]) + ) + + # extract region features + if self.config.learn_initial_query: + target = self.weight_embedding.tile([batch_size, 1, 1]) + else: + target = output_memory.gather(dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, output_memory.shape[-1])) + target = target.detach() + + if denoising_class is not None: + target = torch.concat([denoising_class, target], 1) + + init_reference_points = reference_points_unact.detach() + + # decoder + decoder_outputs = self.decoder( + inputs_embeds=target, + encoder_hidden_states=source_flatten, + encoder_attention_mask=attention_mask, + reference_points=init_reference_points, + spatial_shapes=spatial_shapes, + spatial_shapes_list=spatial_shapes_list, + level_start_index=level_start_index, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + enc_outputs = tuple( + value + for value in [enc_topk_logits, enc_topk_bboxes, enc_outputs_class, enc_outputs_coord_logits] + if value is not None + ) + dn_outputs = tuple(value if value is not None else None for value in [denoising_meta_values]) + tuple_outputs = decoder_outputs + encoder_outputs + (init_reference_points,) + enc_outputs + dn_outputs + + return tuple_outputs + + return RTDetrV2ModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + intermediate_hidden_states=decoder_outputs.intermediate_hidden_states, + intermediate_logits=decoder_outputs.intermediate_logits, + intermediate_reference_points=decoder_outputs.intermediate_reference_points, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + init_reference_points=init_reference_points, + enc_topk_logits=enc_topk_logits, + enc_topk_bboxes=enc_topk_bboxes, + enc_outputs_class=enc_outputs_class, + enc_outputs_coord_logits=enc_outputs_coord_logits, + denoising_meta_values=denoising_meta_values, + ) + + +# taken from https://github.com/facebookresearch/detr/blob/master/models/detr.py +class RTDetrV2MLPPredictionHead(nn.Module): + """ + Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates, + height and width of a bounding box w.r.t. an image. + + Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py + Origin from https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/RTDetrV2_paddle/ppdet/modeling/transformers/utils.py#L453 + + """ + + def __init__(self, config, input_dim, d_model, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [d_model] * (num_layers - 1) + self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + + +@add_start_docstrings( + """ + RT-DETR Model (consisting of a backbone and encoder-decoder) outputting bounding boxes and logits to be further + decoded into scores and classes. + """, + RTDetrV2_START_DOCSTRING, +) +class RTDetrV2ForObjectDetection(RTDetrV2PreTrainedModel): + # When using clones, all layers > 0 will be clones, but layer 0 *is* required + _tied_weights_keys = ["bbox_embed", "class_embed"] + # We can't initialize the model on meta device as some weights are modified during the initialization + _no_split_modules = None + + def __init__(self, config: RTDetrV2Config): + super().__init__(config) + # RTDETR encoder-decoder model + self.model = RTDetrV2Model(config) + + # Detection heads on top + class_embed = partial(nn.Linear, config.d_model, config.num_labels) + bbox_embed = partial(RTDetrV2MLPPredictionHead, config, config.d_model, config.d_model, 4, num_layers=3) + + self.class_embed = nn.ModuleList([class_embed() for _ in range(config.decoder_layers)]) + self.bbox_embed = nn.ModuleList([bbox_embed() for _ in range(config.decoder_layers)]) + + self.model.decoder.class_embed = self.class_embed + self.model.decoder.bbox_embed = self.bbox_embed + + # Initialize weights and apply final processing + self.post_init() + + @torch.jit.unused + def _set_aux_loss(self, outputs_class, outputs_coord): + # this is a workaround to make torchscript happy, as torchscript + # doesn't support dictionary with non-homogeneous values, such + # as a dict having both a Tensor and a list. + return [{"logits": a, "pred_boxes": b} for a, b in zip(outputs_class, outputs_coord)] + + @add_start_docstrings_to_model_forward(RTDetrV2_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=RTDetrV2ObjectDetectionOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: torch.FloatTensor, + pixel_mask: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[List[dict]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **loss_kwargs, + ) -> Union[Tuple[torch.FloatTensor], RTDetrV2ObjectDetectionOutput]: + r""" + labels (`List[Dict]` of len `(batch_size,)`, *optional*): + Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the + following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch + respectively). The class labels themselves should be a `torch.LongTensor` of len `(number of bounding boxes + in the image,)` and the boxes a `torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)`. + + Returns: + + Examples: + + ```python + >>> from transformers import RTDetrV2ImageProcessor, RTDetrV2ForObjectDetection + >>> from PIL import Image + >>> import requests + >>> import torch + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = RTDetrV2ImageProcessor.from_pretrained("PekingU/RTDetrV2_r50vd") + >>> model = RTDetrV2ForObjectDetection.from_pretrained("PekingU/RTDetrV2_r50vd") + + >>> # prepare image for the model + >>> inputs = image_processor(images=image, return_tensors="pt") + + >>> # forward pass + >>> outputs = model(**inputs) + + >>> logits = outputs.logits + >>> list(logits.shape) + [1, 300, 80] + + >>> boxes = outputs.pred_boxes + >>> list(boxes.shape) + [1, 300, 4] + + >>> # convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax) + >>> target_sizes = torch.tensor([image.size[::-1]]) + >>> results = image_processor.post_process_object_detection(outputs, threshold=0.9, target_sizes=target_sizes)[ + ... 0 + ... ] + + >>> for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): + ... box = [round(i, 2) for i in box.tolist()] + ... print( + ... f"Detected {model.config.id2label[label.item()]} with confidence " + ... f"{round(score.item(), 3)} at location {box}" + ... ) + Detected sofa with confidence 0.97 at location [0.14, 0.38, 640.13, 476.21] + Detected cat with confidence 0.96 at location [343.38, 24.28, 640.14, 371.5] + Detected cat with confidence 0.958 at location [13.23, 54.18, 318.98, 472.22] + Detected remote with confidence 0.951 at location [40.11, 73.44, 175.96, 118.48] + Detected remote with confidence 0.924 at location [333.73, 76.58, 369.97, 186.99] + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + pixel_values, + pixel_mask=pixel_mask, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + labels=labels, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + denoising_meta_values = ( + outputs.denoising_meta_values if return_dict else outputs[-1] if self.training else None + ) + + outputs_class = outputs.intermediate_logits if return_dict else outputs[2] + outputs_coord = outputs.intermediate_reference_points if return_dict else outputs[3] + + logits = outputs_class[:, -1] + pred_boxes = outputs_coord[:, -1] + + loss, loss_dict, auxiliary_outputs, enc_topk_logits, enc_topk_bboxes = None, None, None, None, None + if labels is not None: + enc_topk_logits = outputs.enc_topk_logits if return_dict else outputs[-5] + enc_topk_bboxes = outputs.enc_topk_bboxes if return_dict else outputs[-4] + loss, loss_dict, auxiliary_outputs = self.loss_function( + logits, + labels, + self.device, + pred_boxes, + self.config, + outputs_class, + outputs_coord, + enc_topk_logits=enc_topk_logits, + enc_topk_bboxes=enc_topk_bboxes, + denoising_meta_values=denoising_meta_values, + **loss_kwargs, + ) + + if not return_dict: + if auxiliary_outputs is not None: + output = (logits, pred_boxes) + (auxiliary_outputs,) + outputs + else: + output = (logits, pred_boxes) + outputs + return ((loss, loss_dict) + output) if loss is not None else output + + return RTDetrV2ObjectDetectionOutput( + loss=loss, + loss_dict=loss_dict, + logits=logits, + pred_boxes=pred_boxes, + auxiliary_outputs=auxiliary_outputs, + last_hidden_state=outputs.last_hidden_state, + intermediate_hidden_states=outputs.intermediate_hidden_states, + intermediate_logits=outputs.intermediate_logits, + intermediate_reference_points=outputs.intermediate_reference_points, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + init_reference_points=outputs.init_reference_points, + enc_topk_logits=outputs.enc_topk_logits, + enc_topk_bboxes=outputs.enc_topk_bboxes, + enc_outputs_class=outputs.enc_outputs_class, + enc_outputs_coord_logits=outputs.enc_outputs_coord_logits, + denoising_meta_values=outputs.denoising_meta_values, + ) + + +__all__ = ["RTDetrV2Model", "RTDetrV2PreTrainedModel", "RTDetrV2ForObjectDetection"] diff --git a/docs/transformers/src/transformers/models/rt_detr_v2/modular_rt_detr_v2.py b/docs/transformers/src/transformers/models/rt_detr_v2/modular_rt_detr_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..dfd2a4613192c4cc062f3c6063aeb01f1623aec2 --- /dev/null +++ b/docs/transformers/src/transformers/models/rt_detr_v2/modular_rt_detr_v2.py @@ -0,0 +1,628 @@ +# coding=utf-8 +# Copyright 2025 Baidu Inc and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import warnings +from functools import partial +from typing import List, Optional + +import torch +import torch.nn.functional as F +from torch import Tensor, nn + +from ...configuration_utils import PretrainedConfig +from ...utils import is_torchdynamo_compiling, logging +from ...utils.backbone_utils import ( + verify_backbone_config_arguments, +) +from ..auto import CONFIG_MAPPING +from ..rt_detr.modeling_rt_detr import ( + RTDetrDecoder, + RTDetrDecoderLayer, + RTDetrForObjectDetection, + RTDetrMLPPredictionHead, + RTDetrModel, + RTDetrPreTrainedModel, +) + + +logger = logging.get_logger(__name__) + + +class RTDetrV2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`RTDetrV2Model`]. It is used to instantiate a + RT-DETR model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the RT-DETR architecture. + + e.g. [PekingU/rtdetr_r18vd](https://huggingface.co/PekingU/rtdetr_r18vd) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + initializer_range (`float`, *optional*, defaults to 0.01): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + initializer_bias_prior_prob (`float`, *optional*): + The prior probability used by the bias initializer to initialize biases for `enc_score_head` and `class_embed`. + If `None`, `prior_prob` computed as `prior_prob = 1 / (num_labels + 1)` while initializing model weights. + layer_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the layer normalization layers. + batch_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the batch normalization layers. + backbone_config (`Dict`, *optional*, defaults to `RTDetrV2ResNetConfig()`): + The configuration of the backbone model. + backbone (`str`, *optional*): + Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this + will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone` + is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights. + use_pretrained_backbone (`bool`, *optional*, defaults to `False`): + Whether to use pretrained weights for the backbone. + use_timm_backbone (`bool`, *optional*, defaults to `False`): + Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers + library. + freeze_backbone_batch_norms (`bool`, *optional*, defaults to `True`): + Whether to freeze the batch normalization layers in the backbone. + backbone_kwargs (`dict`, *optional*): + Keyword arguments to be passed to AutoBackbone when loading from a checkpoint + e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set. + encoder_hidden_dim (`int`, *optional*, defaults to 256): + Dimension of the layers in hybrid encoder. + encoder_in_channels (`list`, *optional*, defaults to `[512, 1024, 2048]`): + Multi level features input for encoder. + feat_strides (`List[int]`, *optional*, defaults to `[8, 16, 32]`): + Strides used in each feature map. + encoder_layers (`int`, *optional*, defaults to 1): + Total of layers to be used by the encoder. + encoder_ffn_dim (`int`, *optional*, defaults to 1024): + Dimension of the "intermediate" (often named feed-forward) layer in decoder. + encoder_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer encoder. + dropout (`float`, *optional*, defaults to 0.0): + The ratio for all dropout layers. + activation_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for activations inside the fully connected layer. + encode_proj_layers (`List[int]`, *optional*, defaults to `[2]`): + Indexes of the projected layers to be used in the encoder. + positional_encoding_temperature (`int`, *optional*, defaults to 10000): + The temperature parameter used to create the positional encodings. + encoder_activation_function (`str`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + activation_function (`str`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the general layer. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + eval_size (`Tuple[int, int]`, *optional*): + Height and width used to compute the effective height and width of the position embeddings after taking + into account the stride. + normalize_before (`bool`, *optional*, defaults to `False`): + Determine whether to apply layer normalization in the transformer encoder layer before self-attention and + feed-forward modules. + hidden_expansion (`float`, *optional*, defaults to 1.0): + Expansion ratio to enlarge the dimension size of RepVGGBlock and CSPRepLayer. + d_model (`int`, *optional*, defaults to 256): + Dimension of the layers exclude hybrid encoder. + num_queries (`int`, *optional*, defaults to 300): + Number of object queries. + decoder_in_channels (`list`, *optional*, defaults to `[256, 256, 256]`): + Multi level features dimension for decoder + decoder_ffn_dim (`int`, *optional*, defaults to 1024): + Dimension of the "intermediate" (often named feed-forward) layer in decoder. + num_feature_levels (`int`, *optional*, defaults to 3): + The number of input feature levels. + decoder_n_points (`int`, *optional*, defaults to 4): + The number of sampled keys in each feature level for each attention head in the decoder. + decoder_layers (`int`, *optional*, defaults to 6): + Number of decoder layers. + decoder_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer decoder. + decoder_activation_function (`str`, *optional*, defaults to `"relu"`): + The non-linear activation function (function or string) in the decoder. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + num_denoising (`int`, *optional*, defaults to 100): + The total number of denoising tasks or queries to be used for contrastive denoising. + label_noise_ratio (`float`, *optional*, defaults to 0.5): + The fraction of denoising labels to which random noise should be added. + box_noise_scale (`float`, *optional*, defaults to 1.0): + Scale or magnitude of noise to be added to the bounding boxes. + learn_initial_query (`bool`, *optional*, defaults to `False`): + Indicates whether the initial query embeddings for the decoder should be learned during training + anchor_image_size (`Tuple[int, int]`, *optional*): + Height and width of the input image used during evaluation to generate the bounding box anchors. If None, automatic generate anchor is applied. + with_box_refine (`bool`, *optional*, defaults to `True`): + Whether to apply iterative bounding box refinement, where each decoder layer refines the bounding boxes + based on the predictions from the previous layer. + is_encoder_decoder (`bool`, *optional*, defaults to `True`): + Whether the architecture has an encoder decoder structure. + matcher_alpha (`float`, *optional*, defaults to 0.25): + Parameter alpha used by the Hungarian Matcher. + matcher_gamma (`float`, *optional*, defaults to 2.0): + Parameter gamma used by the Hungarian Matcher. + matcher_class_cost (`float`, *optional*, defaults to 2.0): + The relative weight of the class loss used by the Hungarian Matcher. + matcher_bbox_cost (`float`, *optional*, defaults to 5.0): + The relative weight of the bounding box loss used by the Hungarian Matcher. + matcher_giou_cost (`float`, *optional*, defaults to 2.0): + The relative weight of the giou loss of used by the Hungarian Matcher. + use_focal_loss (`bool`, *optional*, defaults to `True`): + Parameter informing if focal loss should be used. + auxiliary_loss (`bool`, *optional*, defaults to `True`): + Whether auxiliary decoding losses (loss at each decoder layer) are to be used. + focal_loss_alpha (`float`, *optional*, defaults to 0.75): + Parameter alpha used to compute the focal loss. + focal_loss_gamma (`float`, *optional*, defaults to 2.0): + Parameter gamma used to compute the focal loss. + weight_loss_vfl (`float`, *optional*, defaults to 1.0): + Relative weight of the varifocal loss in the object detection loss. + weight_loss_bbox (`float`, *optional*, defaults to 5.0): + Relative weight of the L1 bounding box loss in the object detection loss. + weight_loss_giou (`float`, *optional*, defaults to 2.0): + Relative weight of the generalized IoU loss in the object detection loss. + eos_coefficient (`float`, *optional*, defaults to 0.0001): + Relative classification weight of the 'no-object' class in the object detection loss. + decoder_n_levels (`int`, *optional*, defaults to 3): + The number of feature levels used by the decoder. + decoder_offset_scale (`float`, *optional*, defaults to 0.5): + Scaling factor applied to the attention offsets in the decoder. + decoder_method (`str`, *optional*, defaults to `"default"`): + The method to use for the decoder: `"default"` or `"discrete"`. + + Examples: + + ```python + >>> from transformers import RTDetrV2Config, RTDetrV2Model + + >>> # Initializing a RT-DETR configuration + >>> configuration = RTDetrV2Config() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = RTDetrV2Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "rt_detr_v2" + layer_types = ["basic", "bottleneck"] + attribute_map = { + "hidden_size": "d_model", + "num_attention_heads": "encoder_attention_heads", + } + + def __init__( + self, + initializer_range=0.01, + initializer_bias_prior_prob=None, + layer_norm_eps=1e-5, + batch_norm_eps=1e-5, + # backbone + backbone_config=None, + backbone=None, + use_pretrained_backbone=False, + use_timm_backbone=False, + freeze_backbone_batch_norms=True, + backbone_kwargs=None, + # encoder HybridEncoder + encoder_hidden_dim=256, + encoder_in_channels=[512, 1024, 2048], + feat_strides=[8, 16, 32], + encoder_layers=1, + encoder_ffn_dim=1024, + encoder_attention_heads=8, + dropout=0.0, + activation_dropout=0.0, + encode_proj_layers=[2], + positional_encoding_temperature=10000, + encoder_activation_function="gelu", + activation_function="silu", + eval_size=None, + normalize_before=False, + hidden_expansion=1.0, + # decoder RTDetrV2Transformer + d_model=256, + num_queries=300, + decoder_in_channels=[256, 256, 256], + decoder_ffn_dim=1024, + num_feature_levels=3, + decoder_n_points=4, + decoder_layers=6, + decoder_attention_heads=8, + decoder_activation_function="relu", + attention_dropout=0.0, + num_denoising=100, + label_noise_ratio=0.5, + box_noise_scale=1.0, + learn_initial_query=False, + anchor_image_size=None, + with_box_refine=True, + is_encoder_decoder=True, + # Loss + matcher_alpha=0.25, + matcher_gamma=2.0, + matcher_class_cost=2.0, + matcher_bbox_cost=5.0, + matcher_giou_cost=2.0, + use_focal_loss=True, + auxiliary_loss=True, + focal_loss_alpha=0.75, + focal_loss_gamma=2.0, + weight_loss_vfl=1.0, + weight_loss_bbox=5.0, + weight_loss_giou=2.0, + eos_coefficient=1e-4, + decoder_n_levels=3, # default value + decoder_offset_scale=0.5, # default value + decoder_method="default", + **kwargs, + ): + super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs) + self.initializer_range = initializer_range + self.initializer_bias_prior_prob = initializer_bias_prior_prob + self.layer_norm_eps = layer_norm_eps + self.batch_norm_eps = batch_norm_eps + # backbone + if backbone_config is None and backbone is None: + logger.info( + "`backbone_config` and `backbone` are `None`. Initializing the config with the default `RTDetrV2-ResNet` backbone." + ) + backbone_model_type = "rt_detr_resnet" + config_class = CONFIG_MAPPING[backbone_model_type] + # this will map it to RTDetrResNetConfig + # note: we can instead create RTDetrV2ResNetConfig but it will be exactly the same as V1 + # and we would need to create RTDetrV2ResNetModel + backbone_config = config_class( + num_channels=3, + embedding_size=64, + hidden_sizes=[256, 512, 1024, 2048], + depths=[3, 4, 6, 3], + layer_type="bottleneck", + hidden_act="relu", + downsample_in_first_stage=False, + downsample_in_bottleneck=False, + out_features=None, + out_indices=[2, 3, 4], + ) + elif isinstance(backbone_config, dict): + backbone_model_type = backbone_config.pop("model_type") + config_class = CONFIG_MAPPING[backbone_model_type] + backbone_config = config_class.from_dict(backbone_config) + + verify_backbone_config_arguments( + use_timm_backbone=use_timm_backbone, + use_pretrained_backbone=use_pretrained_backbone, + backbone=backbone, + backbone_config=backbone_config, + backbone_kwargs=backbone_kwargs, + ) + + self.backbone_config = backbone_config + self.backbone = backbone + self.use_pretrained_backbone = use_pretrained_backbone + self.use_timm_backbone = use_timm_backbone + self.freeze_backbone_batch_norms = freeze_backbone_batch_norms + self.backbone_kwargs = backbone_kwargs + # encoder + self.encoder_hidden_dim = encoder_hidden_dim + self.encoder_in_channels = encoder_in_channels + self.feat_strides = feat_strides + self.encoder_ffn_dim = encoder_ffn_dim + self.dropout = dropout + self.activation_dropout = activation_dropout + self.encode_proj_layers = encode_proj_layers + self.encoder_layers = encoder_layers + self.positional_encoding_temperature = positional_encoding_temperature + self.eval_size = eval_size + self.normalize_before = normalize_before + self.encoder_activation_function = encoder_activation_function + self.activation_function = activation_function + self.hidden_expansion = hidden_expansion + self.num_queries = num_queries + self.decoder_ffn_dim = decoder_ffn_dim + self.decoder_in_channels = decoder_in_channels + self.num_feature_levels = num_feature_levels + self.decoder_n_points = decoder_n_points + self.decoder_layers = decoder_layers + self.decoder_attention_heads = decoder_attention_heads + self.decoder_activation_function = decoder_activation_function + self.attention_dropout = attention_dropout + self.num_denoising = num_denoising + self.label_noise_ratio = label_noise_ratio + self.box_noise_scale = box_noise_scale + self.learn_initial_query = learn_initial_query + self.anchor_image_size = anchor_image_size + self.auxiliary_loss = auxiliary_loss + self.with_box_refine = with_box_refine + # Loss + self.matcher_alpha = matcher_alpha + self.matcher_gamma = matcher_gamma + self.matcher_class_cost = matcher_class_cost + self.matcher_bbox_cost = matcher_bbox_cost + self.matcher_giou_cost = matcher_giou_cost + self.use_focal_loss = use_focal_loss + self.focal_loss_alpha = focal_loss_alpha + self.focal_loss_gamma = focal_loss_gamma + self.weight_loss_vfl = weight_loss_vfl + self.weight_loss_bbox = weight_loss_bbox + self.weight_loss_giou = weight_loss_giou + self.eos_coefficient = eos_coefficient + + if not hasattr(self, "d_model"): + self.d_model = d_model + + if not hasattr(self, "encoder_attention_heads"): + self.encoder_attention_heads = encoder_attention_heads + # add the new attributes with the given values or defaults + self.decoder_n_levels = decoder_n_levels + self.decoder_offset_scale = decoder_offset_scale + self.decoder_method = decoder_method + + @classmethod + def from_backbone_configs(cls, backbone_config: PretrainedConfig, **kwargs): + """Instantiate a [`RTDetrV2Config`] (or a derived class) from a pre-trained backbone model configuration and DETR model + configuration. + + Args: + backbone_config ([`PretrainedConfig`]): + The backbone configuration. + + Returns: + [`RTDetrV2Config`]: An instance of a configuration object + """ + return cls( + backbone_config=backbone_config, + **kwargs, + ) + + +def multi_scale_deformable_attention_v2( + value: Tensor, + value_spatial_shapes: Tensor, + sampling_locations: Tensor, + attention_weights: Tensor, + num_points_list: List[int], + method="default", +) -> Tensor: + batch_size, _, num_heads, hidden_dim = value.shape + _, num_queries, num_heads, num_levels, num_points = sampling_locations.shape + value_list = ( + value.permute(0, 2, 3, 1) + .flatten(0, 1) + .split([height * width for height, width in value_spatial_shapes], dim=-1) + ) + # sampling_offsets [8, 480, 8, 12, 2] + if method == "default": + sampling_grids = 2 * sampling_locations - 1 + elif method == "discrete": + sampling_grids = sampling_locations + sampling_grids = sampling_grids.permute(0, 2, 1, 3, 4).flatten(0, 1) + sampling_grids = sampling_grids.split(num_points_list, dim=-2) + sampling_value_list = [] + for level_id, (height, width) in enumerate(value_spatial_shapes): + # batch_size, height*width, num_heads, hidden_dim + # -> batch_size, height*width, num_heads*hidden_dim + # -> batch_size, num_heads*hidden_dim, height*width + # -> batch_size*num_heads, hidden_dim, height, width + value_l_ = value_list[level_id].reshape(batch_size * num_heads, hidden_dim, height, width) + # batch_size, num_queries, num_heads, num_points, 2 + # -> batch_size, num_heads, num_queries, num_points, 2 + # -> batch_size*num_heads, num_queries, num_points, 2 + sampling_grid_l_ = sampling_grids[level_id] + # batch_size*num_heads, hidden_dim, num_queries, num_points + if method == "default": + sampling_value_l_ = nn.functional.grid_sample( + value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False + ) + elif method == "discrete": + sampling_coord = (sampling_grid_l_ * torch.tensor([[width, height]], device=value.device) + 0.5).to( + torch.int64 + ) + + # Separate clamping for x and y coordinates + sampling_coord_x = sampling_coord[..., 0].clamp(0, width - 1) + sampling_coord_y = sampling_coord[..., 1].clamp(0, height - 1) + + # Combine the clamped coordinates + sampling_coord = torch.stack([sampling_coord_x, sampling_coord_y], dim=-1) + sampling_coord = sampling_coord.reshape(batch_size * num_heads, num_queries * num_points_list[level_id], 2) + sampling_idx = ( + torch.arange(sampling_coord.shape[0], device=value.device) + .unsqueeze(-1) + .repeat(1, sampling_coord.shape[1]) + ) + sampling_value_l_ = value_l_[sampling_idx, :, sampling_coord[..., 1], sampling_coord[..., 0]] + sampling_value_l_ = sampling_value_l_.permute(0, 2, 1).reshape( + batch_size * num_heads, hidden_dim, num_queries, num_points_list[level_id] + ) + sampling_value_list.append(sampling_value_l_) + # (batch_size, num_queries, num_heads, num_levels, num_points) + # -> (batch_size, num_heads, num_queries, num_levels, num_points) + # -> (batch_size, num_heads, 1, num_queries, num_levels*num_points) + attention_weights = attention_weights.permute(0, 2, 1, 3).reshape( + batch_size * num_heads, 1, num_queries, sum(num_points_list) + ) + output = ( + (torch.concat(sampling_value_list, dim=-1) * attention_weights) + .sum(-1) + .view(batch_size, num_heads * hidden_dim, num_queries) + ) + return output.transpose(1, 2).contiguous() + + +# the main change +class RTDetrV2MultiscaleDeformableAttention(nn.Module): + """ + RTDetrV2 version of multiscale deformable attention, extending the base implementation + with improved offset handling and initialization. + """ + + def __init__(self, config: RTDetrV2Config): + super().__init__() + num_heads = config.decoder_attention_heads + n_points = config.decoder_n_points + + if config.d_model % num_heads != 0: + raise ValueError( + f"embed_dim (d_model) must be divisible by num_heads, but got {config.d_model} and {num_heads}" + ) + dim_per_head = config.d_model // num_heads + # check if dim_per_head is power of 2 + if not ((dim_per_head & (dim_per_head - 1) == 0) and dim_per_head != 0): + warnings.warn( + "You'd better set embed_dim (d_model) in RTDetrV2MultiscaleDeformableAttention to make the" + " dimension of each attention head a power of 2 which is more efficient in the authors' CUDA" + " implementation." + ) + + self.im2col_step = 64 + + self.d_model = config.d_model + + # V2-specific attributes + self.n_levels = config.decoder_n_levels + self.n_heads = num_heads + self.n_points = n_points + + self.sampling_offsets = nn.Linear(config.d_model, num_heads * self.n_levels * n_points * 2) + self.attention_weights = nn.Linear(config.d_model, num_heads * self.n_levels * n_points) + self.value_proj = nn.Linear(config.d_model, config.d_model) + self.output_proj = nn.Linear(config.d_model, config.d_model) + + self.offset_scale = config.decoder_offset_scale + self.method = config.decoder_method + + # Initialize n_points list and scale + n_points_list = [self.n_points for _ in range(self.n_levels)] + self.n_points_list = n_points_list + n_points_scale = [1 / n for n in n_points_list for _ in range(n)] + self.register_buffer("n_points_scale", torch.tensor(n_points_scale, dtype=torch.float32)) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states=None, + encoder_attention_mask=None, + position_embeddings: Optional[torch.Tensor] = None, + reference_points=None, + spatial_shapes=None, + spatial_shapes_list=None, + level_start_index=None, + output_attentions: bool = False, + ): + # Process inputs up to sampling locations calculation using parent class logic + if position_embeddings is not None: + hidden_states = hidden_states + position_embeddings + + batch_size, num_queries, _ = hidden_states.shape + batch_size, sequence_length, _ = encoder_hidden_states.shape + if not is_torchdynamo_compiling() and (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() != sequence_length: + raise ValueError( + "Make sure to align the spatial shapes with the sequence length of the encoder hidden states" + ) + + value = self.value_proj(encoder_hidden_states) + if attention_mask is not None: + value = value.masked_fill(~attention_mask[..., None], float(0)) + value = value.view(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads) + + # V2-specific sampling offsets shape + sampling_offsets = self.sampling_offsets(hidden_states).view( + batch_size, num_queries, self.n_heads, self.n_levels * self.n_points, 2 + ) + + attention_weights = self.attention_weights(hidden_states).view( + batch_size, num_queries, self.n_heads, self.n_levels * self.n_points + ) + attention_weights = F.softmax(attention_weights, -1) + + # V2-specific sampling locations calculation + if reference_points.shape[-1] == 2: + offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1) + sampling_locations = ( + reference_points[:, :, None, :, None, :] + + sampling_offsets / offset_normalizer[None, None, None, :, None, :] + ) + elif reference_points.shape[-1] == 4: + n_points_scale = self.n_points_scale.to(dtype=hidden_states.dtype).unsqueeze(-1) + offset = sampling_offsets * n_points_scale * reference_points[:, :, None, :, 2:] * self.offset_scale + sampling_locations = reference_points[:, :, None, :, :2] + offset + else: + raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}") + + # V2-specific attention implementation choice + output = multi_scale_deformable_attention_v2( + value, spatial_shapes_list, sampling_locations, attention_weights, self.n_points_list, self.method + ) + + output = self.output_proj(output) + return output, attention_weights + + +class RTDetrV2DecoderLayer(RTDetrDecoderLayer): + def __init__(self, config: RTDetrV2Config): + # initialize parent class + super().__init__(config) + # override only the encoder attention module with v2 version + self.encoder_attn = RTDetrV2MultiscaleDeformableAttention(config) + + +class RTDetrV2PreTrainedModel(RTDetrPreTrainedModel): + pass + + +class RTDetrV2Decoder(RTDetrDecoder): + def __init__(self, config: RTDetrV2Config): + super().__init__(config) + self.layers = nn.ModuleList([RTDetrV2DecoderLayer(config) for _ in range(config.decoder_layers)]) + + +class RTDetrV2Model(RTDetrModel): + def __init__(self, config: RTDetrV2Config): + super().__init__(config) + # decoder + self.decoder = RTDetrV2Decoder(config) + + +class RTDetrV2MLPPredictionHead(RTDetrMLPPredictionHead): + pass + + +class RTDetrV2ForObjectDetection(RTDetrForObjectDetection, RTDetrV2PreTrainedModel): + def __init__(self, config: RTDetrV2Config): + RTDetrV2PreTrainedModel.__init__(config) + # RTDETR encoder-decoder model + self.model = RTDetrV2Model(config) + + # Detection heads on top + class_embed = partial(nn.Linear, config.d_model, config.num_labels) + bbox_embed = partial(RTDetrV2MLPPredictionHead, config, config.d_model, config.d_model, 4, num_layers=3) + + self.class_embed = nn.ModuleList([class_embed() for _ in range(config.decoder_layers)]) + self.bbox_embed = nn.ModuleList([bbox_embed() for _ in range(config.decoder_layers)]) + + self.model.decoder.class_embed = self.class_embed + self.model.decoder.bbox_embed = self.bbox_embed + + # Initialize weights and apply final processing + self.post_init() + + +__all__ = [ + "RTDetrV2Config", + "RTDetrV2Model", + "RTDetrV2PreTrainedModel", + "RTDetrV2ForObjectDetection", +] diff --git a/docs/transformers/src/transformers/models/rwkv/__init__.py b/docs/transformers/src/transformers/models/rwkv/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cbfdb86827a0f5551fc99ad84edf2c2684444e96 --- /dev/null +++ b/docs/transformers/src/transformers/models/rwkv/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_rwkv import * + from .modeling_rwkv import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/rwkv/configuration_rwkv.py b/docs/transformers/src/transformers/models/rwkv/configuration_rwkv.py new file mode 100644 index 0000000000000000000000000000000000000000..90c5cf7e1c8b3bd6e6be973d6b3b7d058daf9848 --- /dev/null +++ b/docs/transformers/src/transformers/models/rwkv/configuration_rwkv.py @@ -0,0 +1,120 @@ +# coding=utf-8 +# Copyright 2023 The OpenAI Team Authors and HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""RWKV configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class RwkvConfig(PretrainedConfig): + """ + This is the configuration class to store the configuration of a [`RwkvModel`]. It is used to instantiate a RWKV + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the RWVK-4 + [RWKV/rwkv-4-169m-pile](https://huggingface.co/RWKV/rwkv-4-169m-pile) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50277): + Vocabulary size of the RWKV model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`RwkvModel`]. + context_length (`int`, *optional*, defaults to 1024): + The maximum sequence length that this model can be used with in a single forward (using it in RNN mode + lets use any sequence length). + hidden_size (`int`, *optional*, defaults to 4096): + Dimensionality of the embeddings and hidden states. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the model. + attention_hidden_size (`int`, *optional*): + Dimensionality of the attention hidden states. Will default to `hidden_size` if unset. + intermediate_size (`int`, *optional*): + Dimensionality of the inner feed-forward layers. Will default to 4 times `hidden_size` if unset. + layer_norm_epsilon (`float`, *optional*, defaults to 1e-05): + The epsilon to use in the layer normalization layers. + bos_token_id (`int`, *optional*, defaults to 0): + The id of the beginning of sentence token in the vocabulary. Defaults to 0 as RWKV uses the same tokenizer + as GPTNeoX. + eos_token_id (`int`, *optional*, defaults to 0): + The id of the end of sentence token in the vocabulary. Defaults to 0 as RWKV uses the same tokenizer as + GPTNeoX. + rescale_every (`int`, *optional*, defaults to 6): + At inference, the hidden states (and weights of the correponding output layers) are divided by 2 every + `rescale_every` layer. If set to 0 or a negative number, no rescale is done. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether or not to tie the word embeddings with the input token embeddings. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last state. + + + Example: + + ```python + >>> from transformers import RwkvConfig, RwkvModel + + >>> # Initializing a Rwkv configuration + >>> configuration = RwkvConfig() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = RwkvModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "rwkv" + attribute_map = {"max_position_embeddings": "context_length"} + + def __init__( + self, + vocab_size=50277, + context_length=1024, + hidden_size=4096, + num_hidden_layers=32, + attention_hidden_size=None, + intermediate_size=None, + layer_norm_epsilon=1e-5, + bos_token_id=0, + eos_token_id=0, + rescale_every=6, + tie_word_embeddings=False, + use_cache=True, + **kwargs, + ): + self.vocab_size = vocab_size + self.context_length = context_length + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.attention_hidden_size = attention_hidden_size if attention_hidden_size is not None else hidden_size + self.intermediate_size = intermediate_size if intermediate_size is not None else 4 * hidden_size + self.layer_norm_epsilon = layer_norm_epsilon + self.rescale_every = rescale_every + self.use_cache = use_cache + + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + + super().__init__( + tie_word_embeddings=tie_word_embeddings, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs + ) + + +__all__ = ["RwkvConfig"] diff --git a/docs/transformers/src/transformers/models/rwkv/convert_rwkv_checkpoint_to_hf.py b/docs/transformers/src/transformers/models/rwkv/convert_rwkv_checkpoint_to_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..87d35db223632d868ef3727fc6db7d5bb192ae2d --- /dev/null +++ b/docs/transformers/src/transformers/models/rwkv/convert_rwkv_checkpoint_to_hf.py @@ -0,0 +1,209 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert a RWKV checkpoint from BlinkDL to the Hugging Face format.""" + +import argparse +import gc +import json +import os +import re + +import torch +from huggingface_hub import hf_hub_download, split_torch_state_dict_into_shards + +from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerFast, RwkvConfig +from transformers.modeling_utils import WEIGHTS_INDEX_NAME + + +NUM_HIDDEN_LAYERS_MAPPING = { + "169M": 12, + "430M": 24, + "1B5": 24, + "3B": 32, + "7B": 32, + "14B": 40, +} + +HIDEN_SIZE_MAPPING = { + "169M": 768, + "430M": 1024, + "1B5": 2048, + "3B": 2560, + "7B": 4096, + "14B": 5120, +} + + +def convert_state_dict(state_dict): + state_dict_keys = list(state_dict.keys()) + for name in state_dict_keys: + weight = state_dict.pop(name) + # emb -> embedding + if name.startswith("emb."): + name = name.replace("emb.", "embeddings.") + # ln_0 -> pre_ln (only present at block 0) + if name.startswith("blocks.0.ln0"): + name = name.replace("blocks.0.ln0", "blocks.0.pre_ln") + # att -> attention + name = re.sub(r"blocks\.(\d+)\.att", r"blocks.\1.attention", name) + # ffn -> feed_forward + name = re.sub(r"blocks\.(\d+)\.ffn", r"blocks.\1.feed_forward", name) + # time_mix_k -> time_mix_key and reshape + if name.endswith(".time_mix_k"): + name = name.replace(".time_mix_k", ".time_mix_key") + # time_mix_v -> time_mix_value and reshape + if name.endswith(".time_mix_v"): + name = name.replace(".time_mix_v", ".time_mix_value") + # time_mix_r -> time_mix_key and reshape + if name.endswith(".time_mix_r"): + name = name.replace(".time_mix_r", ".time_mix_receptance") + + if name != "head.weight": + name = "rwkv." + name + + state_dict[name] = weight + return state_dict + + +def convert_rmkv_checkpoint_to_hf_format( + repo_id, checkpoint_file, output_dir, size=None, tokenizer_file=None, push_to_hub=False, model_name=None +): + # 1. If possible, build the tokenizer. + if tokenizer_file is None: + print("No `--tokenizer_file` provided, we will use the default tokenizer.") + vocab_size = 50277 + tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") + else: + tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_file) + vocab_size = len(tokenizer) + tokenizer.save_pretrained(output_dir) + + # 2. Build the config + possible_sizes = list(NUM_HIDDEN_LAYERS_MAPPING.keys()) + if size is None: + # Try to infer size from the checkpoint name + for candidate in possible_sizes: + if candidate in checkpoint_file: + size = candidate + break + if size is None: + raise ValueError("Could not infer the size, please provide it with the `--size` argument.") + if size not in possible_sizes: + raise ValueError(f"`size` should be one of {possible_sizes}, got {size}.") + + config = RwkvConfig( + vocab_size=vocab_size, + num_hidden_layers=NUM_HIDDEN_LAYERS_MAPPING[size], + hidden_size=HIDEN_SIZE_MAPPING[size], + ) + config.save_pretrained(output_dir) + + # 3. Download model file then convert state_dict + model_file = hf_hub_download(repo_id, checkpoint_file) + state_dict = torch.load(model_file, map_location="cpu", weights_only=True) + state_dict = convert_state_dict(state_dict) + + # 4. Split in shards and save + state_dict_split = split_torch_state_dict_into_shards(state_dict) + shards = index = None + for tensors in state_dict_split.filename_to_tensors.values(): + shards = {tensor: state_dict[tensor] for tensor in tensors} + if state_dict_split.is_sharded: + index = { + "metadata": state_dict_split.metadata, + "weight_map": state_dict_split.tensor_to_filename, + } + + for shard_file, shard in shards.items(): + torch.save(shard, os.path.join(output_dir, shard_file)) + + if index is not None: + save_index_file = os.path.join(output_dir, WEIGHTS_INDEX_NAME) + # Save the index as well + with open(save_index_file, "w", encoding="utf-8") as f: + content = json.dumps(index, indent=2, sort_keys=True) + "\n" + f.write(content) + + # 5. Clean up shards (for some reason the file PyTorch saves take the same space as the whole state_dict + print( + "Cleaning up shards. This may error with an OOM error, it this is the case don't worry you still have converted the model." + ) + shard_files = list(shards.keys()) + + del state_dict + del shards + gc.collect() + + for shard_file in shard_files: + state_dict = torch.load(os.path.join(output_dir, shard_file), weights_only=True) + torch.save({k: v.cpu().clone() for k, v in state_dict.items()}, os.path.join(output_dir, shard_file)) + + del state_dict + gc.collect() + + if push_to_hub: + if model_name is None: + raise ValueError("Please provide a `model_name` to push the model to the Hub.") + model = AutoModelForCausalLM.from_pretrained(output_dir) + model.push_to_hub(model_name, max_shard_size="2GB") + tokenizer.push_to_hub(model_name) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--repo_id", default=None, type=str, required=True, help="Repo ID from which to pull the checkpoint." + ) + parser.add_argument( + "--checkpoint_file", default=None, type=str, required=True, help="Name of the checkpoint file in the repo." + ) + parser.add_argument( + "--output_dir", default=None, type=str, required=True, help="Where to save the converted model." + ) + parser.add_argument( + "--tokenizer_file", + default=None, + type=str, + help="Path to the tokenizer file to use (if not provided, only the model is converted).", + ) + parser.add_argument( + "--size", + default=None, + type=str, + help="Size of the model. Will be inferred from the `checkpoint_file` if not passed.", + ) + parser.add_argument( + "--push_to_hub", + action="store_true", + help="Push to the Hub the converted model.", + ) + parser.add_argument( + "--model_name", + default=None, + type=str, + help="Name of the pushed model on the Hub, including the username / organization.", + ) + + args = parser.parse_args() + convert_rmkv_checkpoint_to_hf_format( + args.repo_id, + args.checkpoint_file, + args.output_dir, + size=args.size, + tokenizer_file=args.tokenizer_file, + push_to_hub=args.push_to_hub, + model_name=args.model_name, + ) diff --git a/docs/transformers/src/transformers/models/rwkv/modeling_rwkv.py b/docs/transformers/src/transformers/models/rwkv/modeling_rwkv.py new file mode 100644 index 0000000000000000000000000000000000000000..6274464309e3726e548622c584e354e129921e03 --- /dev/null +++ b/docs/transformers/src/transformers/models/rwkv/modeling_rwkv.py @@ -0,0 +1,850 @@ +# coding=utf-8 +# Copyright 2023 Bo Peng and HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch RWKV model.""" + +import math +from dataclasses import dataclass +from pathlib import Path +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn + +from ...generation import GenerationMixin +from ...modeling_utils import PreTrainedModel +from ...utils import ( + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_bitsandbytes_available, + is_ninja_available, + is_torch_cuda_available, + logging, +) +from .configuration_rwkv import RwkvConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "RWKV/rwkv-4-169m-pile" +_CONFIG_FOR_DOC = "RwkvConfig" + + +rwkv_cuda_kernel = None + + +def load_wkv_cuda_kernel(context_length): + from torch.utils.cpp_extension import load as load_kernel + + global rwkv_cuda_kernel + + kernel_folder = Path(__file__).resolve().parent.parent.parent / "kernels" / "rwkv" + cuda_kernel_files = [kernel_folder / f for f in ["wkv_op.cpp", "wkv_cuda.cu", "wkv_cuda_bf16.cu"]] + + # Only load the kernel if it's not been loaded yet or if we changed the context length + if rwkv_cuda_kernel is not None and rwkv_cuda_kernel.max_seq_length == context_length: + return + + logger.info(f"Loading CUDA kernel for RWKV at context length of {context_length}.") + + flags = [ + "-res-usage", + "--maxrregcount 60", + "--use_fast_math", + "-O3", + "-Xptxas -O3", + "--extra-device-vectorization", + f"-DTmax={context_length}", + ] + rwkv_cuda_kernel = load_kernel( + name=f"wkv_{context_length}", + sources=cuda_kernel_files, + verbose=(logging.get_verbosity() == logging.DEBUG), + extra_cuda_cflags=flags, + ) + rwkv_cuda_kernel.max_seq_length = context_length + + +class RwkvLinearAttention(torch.autograd.Function): + @staticmethod + def forward(ctx, time_decay, time_first, key, value, state=None, return_state=False): + batch_size, seq_len, hidden_size = key.size() + if seq_len > rwkv_cuda_kernel.max_seq_length: + raise ValueError( + f"Cannot process a batch with {seq_len} tokens at the same time, use a maximum of " + f"{rwkv_cuda_kernel.max_seq_length} with this model." + ) + if batch_size * hidden_size % min(hidden_size, 32) != 0: + raise ValueError( + f"The product of batch size ({batch_size}) and hidden size ({hidden_size}) needs to be a round " + f"multiple of {min(hidden_size, 32)}." + ) + + ctx.input_dtype = key.dtype + + if ( + time_decay.device.type != "cuda" + or time_first.device.type != "cuda" + or key.device.type != "cuda" + or value.device.type != "cuda" + ): + raise ValueError("Calling the CUDA kernel for wkv attention requires all tensors to be on CUDA devices.") + + time_decay = -torch.exp(time_decay.float().contiguous()) + if key.dtype == torch.float16: + time_first = time_first.float() + key = key.float() + value = value.float() + time_first = time_first.contiguous() + key = key.contiguous() + value = value.contiguous() + # The CUDA kernel will fill this tensor. + output = torch.empty_like(key, memory_format=torch.contiguous_format) + if return_state or state is not None: + if state is None: + state = torch.zeros( + batch_size, + hidden_size, + 3, + dtype=torch.float32, + device=key.device, + memory_format=torch.contiguous_format, + ) + state[:, :, 2] -= 1e38 + else: + state = torch.cat([s.unsqueeze(2) for s in state], dim=2).contiguous() + if key.dtype == torch.bfloat16: + forward_func = rwkv_cuda_kernel.forward_with_state_bf16 + else: + forward_func = rwkv_cuda_kernel.forward_with_state + forward_func(time_decay, time_first, key, value, output, state) + else: + forward_func = rwkv_cuda_kernel.forward_bf16 if key.dtype == torch.bfloat16 else rwkv_cuda_kernel.forward + forward_func(time_decay, time_first, key, value, output) + + ctx.save_for_backward(time_decay, time_first, key, value, output) + + if state is not None: + state = [s.squeeze(2) for s in torch.chunk(state, 3, dim=2)] + + return output.to(ctx.input_dtype), state + + @staticmethod + # g stands for grad + def backward(ctx, g_output, g_state=None): + input_dtype = ctx.input_dtype + + time_decay, time_first, key, value, output = ctx.saved_tensors + # The CUDA kernel will fill those tensors. + g_time_decay = torch.empty_like( + time_decay, + memory_format=torch.contiguous_format, + dtype=torch.bfloat16 if input_dtype == torch.bfloat16 else torch.float32, + ) + g_time_first = torch.empty_like(time_first, memory_format=torch.contiguous_format) + g_key = torch.empty_like(key, memory_format=torch.contiguous_format) + g_value = torch.empty_like(value, memory_format=torch.contiguous_format) + + if input_dtype == torch.float16: + g_output = g_output.float() + backward_func = rwkv_cuda_kernel.backward_bf16 if input_dtype == torch.bfloat16 else rwkv_cuda_kernel.backward + backward_func( + time_decay, + time_first, + key, + value, + output, + g_output.contiguous(), + g_time_decay, + g_time_first, + g_key, + g_value, + ) + + return ( + g_time_decay.to(input_dtype), + g_time_first.to(input_dtype), + g_key.to(input_dtype), + g_value.to(input_dtype), + None, + None, + ) + + +def rwkv_linear_attention_cpu(time_decay, time_first, key, value, state=None, return_state=False): + # For CPU fallback. Will be slower and probably take more memory than the custom CUDA kernel if not executed + # within a torch.no_grad. + _, seq_length, _ = key.size() + output = torch.zeros_like(key) + + if state is None: + num_state = torch.zeros_like(key[:, 0], dtype=torch.float32) + den_state = torch.zeros_like(key[:, 0], dtype=torch.float32) + max_state = torch.zeros_like(key[:, 0], dtype=torch.float32) - 1e38 + else: + num_state, den_state, max_state = state + # For numerical stability + # real_numerator_state = num_state * torch.exp(max_state) + # real_denominator_state = den_state * torch.exp(max_state) + + time_decay = -torch.exp(time_decay) + + for current_index in range(seq_length): + current_key = key[:, current_index].float() + current_value = value[:, current_index] + + # wkv computation at time t + max_for_output = torch.maximum(max_state, current_key + time_first) + e1 = torch.exp(max_state - max_for_output) + e2 = torch.exp(current_key + time_first - max_for_output) + numerator = e1 * num_state + e2 * current_value + denominator = e1 * den_state + e2 + output[:, current_index] = (numerator / denominator).to(output.dtype) + + # Update state for next iteration + max_for_state = torch.maximum(max_state + time_decay, current_key) + e1 = torch.exp(max_state + time_decay - max_for_state) + e2 = torch.exp(current_key - max_for_state) + num_state = e1 * num_state + e2 * current_value + den_state = e1 * den_state + e2 + max_state = max_for_state + + if return_state or state is not None: + state = [num_state, den_state, max_state] + + return output, state + + +def rwkv_linear_attention(time_decay, time_first, key, value, state=None, return_state=False): + no_cuda = any(t.device.type != "cuda" for t in [time_decay, time_first, key, value]) + # Launching the CUDA kernel for just one token will actually be slower (there is no for loop in the CPU version + # in this case). + one_token = key.size(1) == 1 + if rwkv_cuda_kernel is None or no_cuda or one_token: + return rwkv_linear_attention_cpu(time_decay, time_first, key, value, state=state, return_state=return_state) + else: + return RwkvLinearAttention.apply(time_decay, time_first, key, value, state, return_state) + + +class RwkvSelfAttention(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.config = config + kernel_loaded = rwkv_cuda_kernel is not None and rwkv_cuda_kernel.max_seq_length == config.context_length + if is_ninja_available() and is_torch_cuda_available() and not kernel_loaded: + try: + load_wkv_cuda_kernel(config.context_length) + except Exception: + logger.info("Could not load the custom CUDA kernel for RWKV attention.") + self.layer_id = layer_id + hidden_size = config.hidden_size + attention_hidden_size = ( + config.attention_hidden_size if config.attention_hidden_size is not None else hidden_size + ) + self.attention_hidden_size = attention_hidden_size + + self.time_decay = nn.Parameter(torch.empty(attention_hidden_size)) + self.time_first = nn.Parameter(torch.empty(attention_hidden_size)) + + self.time_mix_key = nn.Parameter(torch.empty(1, 1, hidden_size)) + self.time_mix_value = nn.Parameter(torch.empty(1, 1, hidden_size)) + self.time_mix_receptance = nn.Parameter(torch.empty(1, 1, hidden_size)) + + self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) + self.key = nn.Linear(hidden_size, attention_hidden_size, bias=False) + self.value = nn.Linear(hidden_size, attention_hidden_size, bias=False) + self.receptance = nn.Linear(hidden_size, attention_hidden_size, bias=False) + self.output = nn.Linear(attention_hidden_size, hidden_size, bias=False) + + # TODO: maybe jit, otherwise move inside forward + def extract_key_value(self, hidden, state=None): + # Mix hidden with the previous timestep to produce key, value, receptance + if hidden.size(1) == 1 and state is not None: + shifted = state[1][:, :, self.layer_id] + else: + shifted = self.time_shift(hidden) + if state is not None: + shifted[:, 0] = state[1][:, :, self.layer_id] + key = hidden * self.time_mix_key + shifted * (1 - self.time_mix_key) + value = hidden * self.time_mix_value + shifted * (1 - self.time_mix_value) + receptance = hidden * self.time_mix_receptance + shifted * (1 - self.time_mix_receptance) + + key = self.key(key) + value = self.value(value) + receptance = torch.sigmoid(self.receptance(receptance)) + if state is not None: + state[1][:, :, self.layer_id] = hidden[:, -1] + return receptance, key, value, state + + def forward(self, hidden, state=None, use_cache=False): + receptance, key, value, state = self.extract_key_value(hidden, state=state) + layer_state = tuple(s[:, :, self.layer_id] for s in state[2:]) if state is not None else None + rwkv, layer_state = rwkv_linear_attention( + self.time_decay, + self.time_first, + key, + value, + state=layer_state, + return_state=use_cache, + ) + + if layer_state is not None: + state[2][:, :, self.layer_id] = layer_state[0] + state[3][:, :, self.layer_id] = layer_state[1] + state[4][:, :, self.layer_id] = layer_state[2] + + return self.output(receptance * rwkv), state + + +class RwkvFeedForward(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.config = config + self.layer_id = layer_id + hidden_size = config.hidden_size + intermediate_size = ( + config.intermediate_size if config.intermediate_size is not None else 4 * config.hidden_size + ) + + self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) + self.time_mix_key = nn.Parameter(torch.empty(1, 1, hidden_size)) + self.time_mix_receptance = nn.Parameter(torch.empty(1, 1, hidden_size)) + + self.key = nn.Linear(hidden_size, intermediate_size, bias=False) + self.receptance = nn.Linear(hidden_size, hidden_size, bias=False) + self.value = nn.Linear(intermediate_size, hidden_size, bias=False) + + def forward(self, hidden, state=None): + if hidden.size(1) == 1 and state is not None: + shifted = state[0][:, :, self.layer_id] + else: + shifted = self.time_shift(hidden) + if state is not None: + shifted[:, 0] = state[0][:, :, self.layer_id] + key = hidden * self.time_mix_key + shifted * (1 - self.time_mix_key) + receptance = hidden * self.time_mix_receptance + shifted * (1 - self.time_mix_receptance) + + key = torch.square(torch.relu(self.key(key))) + value = self.value(key) + receptance = torch.sigmoid(self.receptance(receptance)) + + if state is not None: + state[0][:, :, self.layer_id] = hidden[:, -1] + + return receptance * value, state + + +class RwkvBlock(nn.Module): + def __init__(self, config, layer_id): + super().__init__() + self.config = config + self.layer_id = layer_id + + if layer_id == 0: + self.pre_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) + + self.ln1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) + self.ln2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) + + self.attention = RwkvSelfAttention(config, layer_id) + self.feed_forward = RwkvFeedForward(config, layer_id) + + def forward(self, hidden, state=None, use_cache=False, output_attentions=False): + if self.layer_id == 0: + hidden = self.pre_ln(hidden) + + attention, state = self.attention(self.ln1(hidden), state=state, use_cache=use_cache) + hidden = hidden + attention + + feed_forward, state = self.feed_forward(self.ln2(hidden), state=state) + hidden = hidden + feed_forward + + outputs = (hidden, state) + if output_attentions: + outputs += (attention,) + else: + outputs += (None,) + + return outputs + + +class RwkvPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = RwkvConfig + base_model_prefix = "rwkv" + _no_split_modules = ["RwkvBlock"] + _keep_in_fp32_modules = ["time_decay", "time_first"] + supports_gradient_checkpointing = True + _is_stateful = True + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, RwkvSelfAttention): + layer_id = module.layer_id + num_hidden_layers = module.config.num_hidden_layers + hidden_size = module.config.hidden_size + attention_hidden_size = module.attention_hidden_size + + ratio_0_to_1 = layer_id / (num_hidden_layers - 1) # 0 to 1 + ratio_1_to_almost0 = 1.0 - (layer_id / num_hidden_layers) # 1 to ~0 + + time_weight = torch.tensor( + [i / hidden_size for i in range(hidden_size)], + dtype=module.time_mix_key.dtype, + device=module.time_mix_key.device, + ) + time_weight = time_weight[None, None, :] + + decay_speed = [ + -5 + 8 * (h / (attention_hidden_size - 1)) ** (0.7 + 1.3 * ratio_0_to_1) + for h in range(attention_hidden_size) + ] + decay_speed = torch.tensor(decay_speed, dtype=module.time_decay.dtype, device=module.time_decay.device) + zigzag = ( + torch.tensor( + [(i + 1) % 3 - 1 for i in range(attention_hidden_size)], + dtype=module.time_first.dtype, + device=module.time_first.device, + ) + * 0.5 + ) + + with torch.no_grad(): + module.time_decay.data = decay_speed + module.time_first.data = torch.ones_like(module.time_first * math.log(0.3) + zigzag) + + module.time_mix_key.data = torch.pow(time_weight, ratio_1_to_almost0) + module.time_mix_value.data = torch.pow(time_weight, ratio_1_to_almost0) + 0.3 * ratio_0_to_1 + module.time_mix_receptance.data = torch.pow(time_weight, 0.5 * ratio_1_to_almost0) + elif isinstance(module, RwkvFeedForward): + layer_id = module.layer_id + num_hidden_layers = module.config.num_hidden_layers + hidden_size = module.config.hidden_size + + ratio_1_to_almost0 = 1.0 - (layer_id / num_hidden_layers) # 1 to ~0 + + time_weight = torch.tensor( + [i / hidden_size for i in range(hidden_size)], + dtype=module.time_mix_key.dtype, + device=module.time_mix_key.device, + ) + time_weight = time_weight[None, None, :] + + with torch.no_grad(): + module.time_mix_key.data = torch.pow(time_weight, ratio_1_to_almost0) + module.time_mix_receptance.data = torch.pow(time_weight, ratio_1_to_almost0) + + +@dataclass +class RwkvOutput(ModelOutput): + """ + Class for the RWKV model outputs. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + state (list of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`): + The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to + avoid providing the old `input_ids`. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + state: Optional[List[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class RwkvCausalLMOutput(ModelOutput): + """ + Base class for causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + state (list of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`): + The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to + avoid providing the old `input_ids`. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + state: Optional[List[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +RWKV_START_DOCSTRING = r""" + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`RwkvConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +RWKV_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): + `input_ids_length` = `sequence_length` if `past_key_values` is `None` else + `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input + sequence tokens in the vocabulary. + + If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as + `input_ids`. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + This is currently not used by `RwkvModel`, but will be supported in the future. + + [What are attention masks?](../glossary#attention-mask) + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + state (tuple of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`, *optional*): + If passed along, the model uses the previous state in all the blocks (which will give the output for the + `input_ids` provided as if the model add `state_input_ids + input_ids` as context). + use_cache (`bool`, *optional*): + If set to `True`, the last state is returned and can be used to quickly generate the next logits. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare RWKV Model transformer outputting raw hidden-states without any specific head on top.", + RWKV_START_DOCSTRING, +) +class RwkvModel(RwkvPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size) + self.blocks = nn.ModuleList([RwkvBlock(config, layer_id=idx) for idx in range(config.num_hidden_layers)]) + self.ln_out = nn.LayerNorm(config.hidden_size) + + self.layers_are_rescaled = False + + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, new_embeddings): + self.embeddings = new_embeddings + + @add_start_docstrings_to_model_forward(RWKV_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=RwkvOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, # noqa + inputs_embeds: Optional[torch.FloatTensor] = None, + state: Optional[List[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, RwkvOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if attention_mask is not None: + logger.warning_once("`attention_mask` was passed, but it is unused in this model.") + + if self.training == self.layers_are_rescaled: + self._rescale_layers() + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is None and inputs_embeds is None: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + + if use_cache and state is None: + shape = (inputs_embeds.size(0), self.config.hidden_size, self.config.num_hidden_layers) + state = [ + torch.zeros( + *shape, dtype=inputs_embeds.dtype if i <= 1 else torch.float32, device=inputs_embeds.device + ) + for i in range(5) + ] + state[4] -= 1e30 + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + hidden_states = inputs_embeds + + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + for idx, block in enumerate(self.blocks): + if self.gradient_checkpointing and self.training: + hidden_states, state, attentions = self._gradient_checkpointing_func( + block.__call__, hidden_states, state, use_cache, output_attentions + ) + else: + hidden_states, state, attentions = block( + hidden_states, state=state, use_cache=use_cache, output_attentions=output_attentions + ) + + if ( + self.layers_are_rescaled + and self.config.rescale_every > 0 + and (idx + 1) % self.config.rescale_every == 0 + ): + hidden_states = hidden_states / 2 + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if output_attentions: + all_self_attentions = all_self_attentions + (attentions,) + + hidden_states = self.ln_out(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(x for x in [hidden_states, state, all_hidden_states, all_self_attentions] if x is not None) + + return RwkvOutput( + last_hidden_state=hidden_states, + state=state, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + def _rescale_layers(self): + # Layers should be rescaled for inference only. + if self.layers_are_rescaled == (not self.training): + return + if self.config.rescale_every > 0: + with torch.no_grad(): + for block_id, block in enumerate(self.blocks): + if self.training: + block.attention.output.weight.mul_(2 ** int(block_id // self.config.rescale_every)) + block.feed_forward.value.weight.mul_(2 ** int(block_id // self.config.rescale_every)) + else: + # Deal with quantization statistics + if hasattr(block.attention.output.weight, "SCB"): + block.attention.output.weight.SCB.div_(2 ** int(block_id // self.config.rescale_every)) + block.feed_forward.value.weight.SCB.div_(2 ** int(block_id // self.config.rescale_every)) + elif hasattr(block.attention.output.weight, "quant_state"): + self._bnb_4bit_dequantize_and_rescale(block.attention.output, block_id) + self._bnb_4bit_dequantize_and_rescale(block.feed_forward.value, block_id) + else: + block.attention.output.weight.div_(2 ** int(block_id // self.config.rescale_every)) + block.feed_forward.value.weight.div_(2 ** int(block_id // self.config.rescale_every)) + + self.layers_are_rescaled = not self.training + + def _bnb_4bit_dequantize_and_rescale(self, target_layer, block_id): + r""" + Perform the dequantization and rescaling of the weights of a given layer. After that operation the layer will + be quantized again. + """ + if not is_bitsandbytes_available(): + raise ImportError("Please install bitsandbytes to use this method.") + import bitsandbytes as bnb + + dequant_weights = bnb.functional.dequantize_4bit(target_layer.weight.data, target_layer.weight.quant_state) + + dequant_weights.div_(2 ** int(block_id // self.config.rescale_every)) + + # re-quantize the model: + # we need to put it first on CPU then back to the device + # this will create an overhead :/ + # We set requires_grad=False as we cannot compute gradients on top of 4bit parameters anyway and to avoid + # bugs with bnb + quant_weight = bnb.nn.Params4bit(dequant_weights.to("cpu"), requires_grad=False).to(dequant_weights.device) + setattr(target_layer, "weight", quant_weight) + + +@add_start_docstrings( + """ + The RWKV Model transformer with a language modeling head on top (linear layer with weights tied to the input + embeddings). + """, + RWKV_START_DOCSTRING, +) +class RwkvForCausalLM(RwkvPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["head.weight"] + + def __init__(self, config): + super().__init__(config) + self.rwkv = RwkvModel(config) + self.head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.head + + def set_output_embeddings(self, new_embeddings): + self.head = new_embeddings + + def prepare_inputs_for_generation(self, input_ids, state=None, inputs_embeds=None, use_cache=None, **kwargs): + # Overwritten -- this model uses `state`, but doesn't have a cache (`past_key_values`) + + # only last token for inputs_ids if the state is passed along. + if state is not None: + input_ids = input_ids[:, -1].unsqueeze(-1) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and state is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs["state"] = state + model_inputs["use_cache"] = use_cache + return model_inputs + + @add_start_docstrings_to_model_forward(RWKV_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=RwkvCausalLMOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, # noqa + inputs_embeds: Optional[torch.FloatTensor] = None, + state: Optional[List[torch.FloatTensor]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple, RwkvCausalLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + rwkv_outputs = self.rwkv( + input_ids, + inputs_embeds=inputs_embeds, + state=state, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = rwkv_outputs[0] + + logits = self.head(hidden_states) + + loss = None + if labels is not None: + loss = self.loss_function( + logits, + labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + if not return_dict: + output = (logits,) + rwkv_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return RwkvCausalLMOutput( + loss=loss, + logits=logits, + state=rwkv_outputs.state, + hidden_states=rwkv_outputs.hidden_states, + attentions=rwkv_outputs.attentions, + ) + + +__all__ = ["RwkvForCausalLM", "RwkvModel", "RwkvPreTrainedModel"] diff --git a/docs/transformers/src/transformers/models/sam/__init__.py b/docs/transformers/src/transformers/models/sam/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..68da4037a351b308e7301835795b6c68946b6f63 --- /dev/null +++ b/docs/transformers/src/transformers/models/sam/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_sam import * + from .image_processing_sam import * + from .modeling_sam import * + from .modeling_tf_sam import * + from .processing_sam import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/sam/configuration_sam.py b/docs/transformers/src/transformers/models/sam/configuration_sam.py new file mode 100644 index 0000000000000000000000000000000000000000..0f4a2b0893abde34374c5c0acc694c20a3de3083 --- /dev/null +++ b/docs/transformers/src/transformers/models/sam/configuration_sam.py @@ -0,0 +1,337 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""SAM model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class SamPromptEncoderConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`SamPromptEncoder`]. The [`SamPromptEncoder`] + module is used to encode the input 2D points and bounding boxes. Instantiating a configuration defaults will yield + a similar configuration to that of the SAM-vit-h + [facebook/sam-vit-huge](https://huggingface.co/facebook/sam-vit-huge) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 256): + Dimensionality of the hidden states. + image_size (`int`, *optional*, defaults to 1024): + The expected output resolution of the image. + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch. + mask_input_channels (`int`, *optional*, defaults to 16): + The number of channels to be fed to the `MaskDecoder` module. + num_point_embeddings (`int`, *optional*, defaults to 4): + The number of point embeddings to be used. + hidden_act (`str`, *optional*, defaults to `"gelu"`): + The non-linear activation function in the encoder and pooler. + """ + + base_config_key = "prompt_encoder_config" + + def __init__( + self, + hidden_size=256, + image_size=1024, + patch_size=16, + mask_input_channels=16, + num_point_embeddings=4, + hidden_act="gelu", + layer_norm_eps=1e-6, + **kwargs, + ): + super().__init__(**kwargs) + self.hidden_size = hidden_size + self.image_size = image_size + self.patch_size = patch_size + self.image_embedding_size = image_size // patch_size + self.mask_input_channels = mask_input_channels + self.num_point_embeddings = num_point_embeddings + self.hidden_act = hidden_act + self.layer_norm_eps = layer_norm_eps + + +class SamMaskDecoderConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`SamMaskDecoder`]. It is used to instantiate a SAM + mask decoder to the specified arguments, defining the model architecture. Instantiating a configuration defaults + will yield a similar configuration to that of the SAM-vit-h + [facebook/sam-vit-huge](https://huggingface.co/facebook/sam-vit-huge) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 256): + Dimensionality of the hidden states. + hidden_act (`str`, *optional*, defaults to `"relu"`): + The non-linear activation function used inside the `SamMaskDecoder` module. + mlp_dim (`int`, *optional*, defaults to 2048): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 2): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer encoder. + attention_downsample_rate (`int`, *optional*, defaults to 2): + The downsampling rate of the attention layer. + num_multimask_outputs (`int`, *optional*, defaults to 3): + The number of outputs from the `SamMaskDecoder` module. In the Segment Anything paper, this is set to 3. + iou_head_depth (`int`, *optional*, defaults to 3): + The number of layers in the IoU head module. + iou_head_hidden_dim (`int`, *optional*, defaults to 256): + The dimensionality of the hidden states in the IoU head module. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + + """ + + base_config_key = "mask_decoder_config" + + def __init__( + self, + hidden_size=256, + hidden_act="relu", + mlp_dim=2048, + num_hidden_layers=2, + num_attention_heads=8, + attention_downsample_rate=2, + num_multimask_outputs=3, + iou_head_depth=3, + iou_head_hidden_dim=256, + layer_norm_eps=1e-6, + **kwargs, + ): + super().__init__(**kwargs) + self.hidden_size = hidden_size + self.hidden_act = hidden_act + self.mlp_dim = mlp_dim + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.attention_downsample_rate = attention_downsample_rate + self.num_multimask_outputs = num_multimask_outputs + self.iou_head_depth = iou_head_depth + self.iou_head_hidden_dim = iou_head_hidden_dim + self.layer_norm_eps = layer_norm_eps + + +class SamVisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`SamVisionModel`]. It is used to instantiate a SAM + vision encoder according to the specified arguments, defining the model architecture. Instantiating a configuration + defaults will yield a similar configuration to that of the SAM ViT-h + [facebook/sam-vit-huge](https://huggingface.co/facebook/sam-vit-huge) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + output_channels (`int`, *optional*, defaults to 256): + Dimensionality of the output channels in the Patch Encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + num_channels (`int`, *optional*, defaults to 3): + Number of channels in the input image. + image_size (`int`, *optional*, defaults to 1024): + Expected resolution. Target size of the resized input image. + patch_size (`int`, *optional*, defaults to 16): + Size of the patches to be extracted from the input image. + hidden_act (`str`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 1e-10): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + qkv_bias (`bool`, *optional*, defaults to `True`): + Whether to add a bias to query, key, value projections. + mlp_ratio (`float`, *optional*, defaults to 4.0): + Ratio of mlp hidden dim to embedding dim. + use_abs_pos (`bool`, *optional*, defaults to `True`): + Whether to use absolute position embedding. + use_rel_pos (`bool`, *optional*, defaults to `True`): + Whether to use relative position embedding. + window_size (`int`, *optional*, defaults to 14): + Window size for relative position. + global_attn_indexes (`List[int]`, *optional*, defaults to `[2, 5, 8, 11]`): + The indexes of the global attention layers. + num_pos_feats (`int`, *optional*, defaults to 128): + The dimensionality of the position embedding. + mlp_dim (`int`, *optional*): + The dimensionality of the MLP layer in the Transformer encoder. If `None`, defaults to `mlp_ratio * + hidden_size`. + + Example: + + ```python + >>> from transformers import ( + ... SamVisionConfig, + ... SamVisionModel, + ... ) + + >>> # Initializing a SamVisionConfig with `"facebook/sam-vit-huge"` style configuration + >>> configuration = SamVisionConfig() + + >>> # Initializing a SamVisionModel (with random weights) from the `"facebook/sam-vit-huge"` style configuration + >>> model = SamVisionModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + base_config_key = "vision_config" + model_type = "sam_vision_model" + + def __init__( + self, + hidden_size=768, + output_channels=256, + num_hidden_layers=12, + num_attention_heads=12, + num_channels=3, + image_size=1024, + patch_size=16, + hidden_act="gelu", + layer_norm_eps=1e-06, + attention_dropout=0.0, + initializer_range=1e-10, + qkv_bias=True, + mlp_ratio=4.0, + use_abs_pos=True, + use_rel_pos=True, + window_size=14, + global_attn_indexes=[2, 5, 8, 11], + num_pos_feats=128, + mlp_dim=None, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.output_channels = output_channels + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.image_size = image_size + self.patch_size = patch_size + self.hidden_act = hidden_act + self.layer_norm_eps = layer_norm_eps + self.attention_dropout = attention_dropout + self.initializer_range = initializer_range + self.qkv_bias = qkv_bias + self.mlp_ratio = mlp_ratio + self.use_abs_pos = use_abs_pos + self.use_rel_pos = use_rel_pos + self.window_size = window_size + self.global_attn_indexes = global_attn_indexes + self.num_pos_feats = num_pos_feats + self.mlp_dim = int(hidden_size * mlp_ratio) if mlp_dim is None else mlp_dim + + +class SamConfig(PretrainedConfig): + r""" + [`SamConfig`] is the configuration class to store the configuration of a [`SamModel`]. It is used to instantiate a + SAM model according to the specified arguments, defining the vision model, prompt-encoder model and mask decoder + configs. Instantiating a configuration with the defaults will yield a similar configuration to that of the + SAM-ViT-H [facebook/sam-vit-huge](https://huggingface.co/facebook/sam-vit-huge) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vision_config (Union[`dict`, `SamVisionConfig`], *optional*): + Dictionary of configuration options used to initialize [`SamVisionConfig`]. + prompt_encoder_config (Union[`dict`, `SamPromptEncoderConfig`], *optional*): + Dictionary of configuration options used to initialize [`SamPromptEncoderConfig`]. + mask_decoder_config (Union[`dict`, `SamMaskDecoderConfig`], *optional*): + Dictionary of configuration options used to initialize [`SamMaskDecoderConfig`]. + + kwargs (*optional*): + Dictionary of keyword arguments. + + Example: + + ```python + >>> from transformers import ( + ... SamVisionConfig, + ... SamPromptEncoderConfig, + ... SamMaskDecoderConfig, + ... SamModel, + ... ) + + >>> # Initializing a SamConfig with `"facebook/sam-vit-huge"` style configuration + >>> configuration = SamConfig() + + >>> # Initializing a SamModel (with random weights) from the `"facebook/sam-vit-huge"` style configuration + >>> model = SamModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + + >>> # We can also initialize a SamConfig from a SamVisionConfig, SamPromptEncoderConfig, and SamMaskDecoderConfig + + >>> # Initializing SAM vision, SAM Q-Former and language model configurations + >>> vision_config = SamVisionConfig() + >>> prompt_encoder_config = SamPromptEncoderConfig() + >>> mask_decoder_config = SamMaskDecoderConfig() + + >>> config = SamConfig(vision_config, prompt_encoder_config, mask_decoder_config) + ```""" + + model_type = "sam" + sub_configs = { + "prompt_encoder_config": SamPromptEncoderConfig, + "mask_decoder_config": SamMaskDecoderConfig, + "vision_config": SamVisionConfig, + } + + def __init__( + self, + vision_config=None, + prompt_encoder_config=None, + mask_decoder_config=None, + initializer_range=0.02, + **kwargs, + ): + super().__init__(**kwargs) + vision_config = vision_config if vision_config is not None else {} + prompt_encoder_config = prompt_encoder_config if prompt_encoder_config is not None else {} + mask_decoder_config = mask_decoder_config if mask_decoder_config is not None else {} + + if isinstance(vision_config, SamVisionConfig): + vision_config = vision_config.to_dict() + if isinstance(prompt_encoder_config, SamPromptEncoderConfig): + prompt_encoder_config = prompt_encoder_config.to_dict() + if isinstance(mask_decoder_config, SamMaskDecoderConfig): + mask_decoder_config = mask_decoder_config.to_dict() + + self.vision_config = SamVisionConfig(**vision_config) + self.prompt_encoder_config = SamPromptEncoderConfig(**prompt_encoder_config) + self.mask_decoder_config = SamMaskDecoderConfig(**mask_decoder_config) + self.initializer_range = initializer_range + + +__all__ = ["SamConfig", "SamMaskDecoderConfig", "SamPromptEncoderConfig", "SamVisionConfig"] diff --git a/docs/transformers/src/transformers/models/sam/convert_sam_to_hf.py b/docs/transformers/src/transformers/models/sam/convert_sam_to_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..76d8884d951521062ea9f969a1407fa0f0e649ef --- /dev/null +++ b/docs/transformers/src/transformers/models/sam/convert_sam_to_hf.py @@ -0,0 +1,251 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Convert SAM checkpoints from the original repository. + +URL: https://github.com/facebookresearch/segment-anything. + +Also supports converting the SlimSAM checkpoints from https://github.com/czg1225/SlimSAM/tree/master. +""" + +import argparse +import re + +import numpy as np +import requests +import torch +from huggingface_hub import hf_hub_download +from PIL import Image + +from transformers import ( + SamConfig, + SamImageProcessor, + SamModel, + SamProcessor, + SamVisionConfig, +) + + +def get_config(model_name): + if "slimsam-50" in model_name: + vision_config = SamVisionConfig( + hidden_size=384, + mlp_dim=1536, + num_hidden_layers=12, + num_attention_heads=12, + global_attn_indexes=[2, 5, 8, 11], + ) + elif "slimsam-77" in model_name: + vision_config = SamVisionConfig( + hidden_size=168, + mlp_dim=696, + num_hidden_layers=12, + num_attention_heads=12, + global_attn_indexes=[2, 5, 8, 11], + ) + elif "sam_vit_b" in model_name: + vision_config = SamVisionConfig() + elif "sam_vit_l" in model_name: + vision_config = SamVisionConfig( + hidden_size=1024, + num_hidden_layers=24, + num_attention_heads=16, + global_attn_indexes=[5, 11, 17, 23], + ) + elif "sam_vit_h" in model_name: + vision_config = SamVisionConfig( + hidden_size=1280, + num_hidden_layers=32, + num_attention_heads=16, + global_attn_indexes=[7, 15, 23, 31], + ) + + config = SamConfig( + vision_config=vision_config, + ) + + return config + + +KEYS_TO_MODIFY_MAPPING = { + "iou_prediction_head.layers.0": "iou_prediction_head.proj_in", + "iou_prediction_head.layers.1": "iou_prediction_head.layers.0", + "iou_prediction_head.layers.2": "iou_prediction_head.proj_out", + "mask_decoder.output_upscaling.0": "mask_decoder.upscale_conv1", + "mask_decoder.output_upscaling.1": "mask_decoder.upscale_layer_norm", + "mask_decoder.output_upscaling.3": "mask_decoder.upscale_conv2", + "mask_downscaling.0": "mask_embed.conv1", + "mask_downscaling.1": "mask_embed.layer_norm1", + "mask_downscaling.3": "mask_embed.conv2", + "mask_downscaling.4": "mask_embed.layer_norm2", + "mask_downscaling.6": "mask_embed.conv3", + "point_embeddings": "point_embed", + "pe_layer.positional_encoding_gaussian_matrix": "shared_embedding.positional_embedding", + "image_encoder": "vision_encoder", + "neck.0": "neck.conv1", + "neck.1": "neck.layer_norm1", + "neck.2": "neck.conv2", + "neck.3": "neck.layer_norm2", + "patch_embed.proj": "patch_embed.projection", + ".norm": ".layer_norm", + "blocks": "layers", +} + + +def replace_keys(state_dict): + model_state_dict = {} + state_dict.pop("pixel_mean", None) + state_dict.pop("pixel_std", None) + + output_hypernetworks_mlps_pattern = r".*.output_hypernetworks_mlps.(\d+).layers.(\d+).*" + + for key, value in state_dict.items(): + for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items(): + if key_to_modify in key: + key = key.replace(key_to_modify, new_key) + + if re.match(output_hypernetworks_mlps_pattern, key): + layer_nb = int(re.match(output_hypernetworks_mlps_pattern, key).group(2)) + if layer_nb == 0: + key = key.replace("layers.0", "proj_in") + elif layer_nb == 1: + key = key.replace("layers.1", "layers.0") + elif layer_nb == 2: + key = key.replace("layers.2", "proj_out") + + model_state_dict[key] = value + + model_state_dict["shared_image_embedding.positional_embedding"] = model_state_dict[ + "prompt_encoder.shared_embedding.positional_embedding" + ] + + return model_state_dict + + +def convert_sam_checkpoint(model_name, checkpoint_path, pytorch_dump_folder, push_to_hub): + config = get_config(model_name) + + state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True) + state_dict = replace_keys(state_dict) + + image_processor = SamImageProcessor() + processor = SamProcessor(image_processor=image_processor) + hf_model = SamModel(config) + hf_model.eval() + + device = "cuda" if torch.cuda.is_available() else "cpu" + + hf_model.load_state_dict(state_dict) + hf_model = hf_model.to(device) + + img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png" + raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB") + + input_points = [[[500, 375]]] + input_labels = [[1]] + + inputs = processor(images=np.array(raw_image), return_tensors="pt").to(device) + + with torch.no_grad(): + output = hf_model(**inputs) + scores = output.iou_scores.squeeze() + + if model_name == "sam_vit_b_01ec64": + inputs = processor( + images=np.array(raw_image), input_points=input_points, input_labels=input_labels, return_tensors="pt" + ).to(device) + + with torch.no_grad(): + output = hf_model(**inputs) + scores = output.iou_scores.squeeze() + + elif model_name == "sam_vit_h_4b8939": + inputs = processor( + images=np.array(raw_image), input_points=input_points, input_labels=input_labels, return_tensors="pt" + ).to(device) + + with torch.no_grad(): + output = hf_model(**inputs) + scores = output.iou_scores.squeeze() + + assert scores[-1].item() == 0.9712603092193604 + + input_boxes = ((75, 275, 1725, 850),) + + inputs = processor(images=np.array(raw_image), input_boxes=input_boxes, return_tensors="pt").to(device) + + with torch.no_grad(): + output = hf_model(**inputs) + scores = output.iou_scores.squeeze() + + assert scores[-1].item() == 0.8686015605926514 + + # Test with 2 points and 1 image. + input_points = [[[400, 650], [800, 650]]] + input_labels = [[1, 1]] + + inputs = processor( + images=np.array(raw_image), input_points=input_points, input_labels=input_labels, return_tensors="pt" + ).to(device) + + with torch.no_grad(): + output = hf_model(**inputs) + scores = output.iou_scores.squeeze() + + assert scores[-1].item() == 0.9936047792434692 + + if pytorch_dump_folder is not None: + processor.save_pretrained(pytorch_dump_folder) + hf_model.save_pretrained(pytorch_dump_folder) + + if push_to_hub: + repo_id = f"nielsr/{model_name}" if "slimsam" in model_name else f"meta/{model_name}" + processor.push_to_hub(repo_id) + hf_model.push_to_hub(repo_id) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + choices = ["sam_vit_b_01ec64", "sam_vit_h_4b8939", "sam_vit_l_0b3195", "slimsam-50-uniform", "slimsam-77-uniform"] + parser.add_argument( + "--model_name", + default="sam_vit_h_4b8939", + choices=choices, + type=str, + help="Name of the original model to convert", + ) + parser.add_argument( + "--checkpoint_path", + type=str, + required=False, + help="Path to the original checkpoint", + ) + parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.") + parser.add_argument( + "--push_to_hub", + action="store_true", + help="Whether to push the model and processor to the hub after converting", + ) + + args = parser.parse_args() + + if "slimsam" in args.model_name: + checkpoint_path = args.checkpoint_path + if checkpoint_path is None: + raise ValueError("You need to provide a checkpoint path for SlimSAM models.") + else: + checkpoint_path = hf_hub_download("ybelkada/segment-anything", f"checkpoints/{args.model_name}.pth") + + convert_sam_checkpoint(args.model_name, checkpoint_path, args.pytorch_dump_folder_path, args.push_to_hub) diff --git a/docs/transformers/src/transformers/models/sam/image_processing_sam.py b/docs/transformers/src/transformers/models/sam/image_processing_sam.py new file mode 100644 index 0000000000000000000000000000000000000000..5d98674f73e68e2a63c51c367965bb0bb59128f0 --- /dev/null +++ b/docs/transformers/src/transformers/models/sam/image_processing_sam.py @@ -0,0 +1,1494 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for SAM.""" + +import math +from copy import deepcopy +from itertools import product +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import convert_to_rgb, pad, resize, to_channel_dimension_format +from ...image_utils import ( + IMAGENET_DEFAULT_MEAN, + IMAGENET_DEFAULT_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + get_image_size, + infer_channel_dimension_format, + is_scaled_image, + make_list_of_images, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...utils import ( + TensorType, + filter_out_non_signature_kwargs, + is_tf_available, + is_torch_available, + is_torchvision_available, + logging, + requires_backends, +) + + +if is_torch_available(): + import torch + import torch.nn.functional as F + +if is_torchvision_available(): + from torchvision.ops.boxes import batched_nms + +if is_tf_available(): + import tensorflow as tf + from tensorflow.experimental import numpy as tnp + + from ...tf_utils import flatten, shape_list + +logger = logging.get_logger(__name__) + + +class SamImageProcessor(BaseImageProcessor): + r""" + Constructs a SAM image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the + `do_resize` parameter in the `preprocess` method. + size (`dict`, *optional*, defaults to `{"longest_edge": 1024}`): + Size of the output image after resizing. Resizes the longest edge of the image to match + `size["longest_edge"]` while maintaining the aspect ratio. Can be overridden by the `size` parameter in the + `preprocess` method. + mask_size (`dict`, *optional*, defaults to `{"longest_edge": 256}`): + Size of the output segmentation map after resizing. Resizes the longest edge of the image to match + `size["longest_edge"]` while maintaining the aspect ratio. Can be overridden by the `mask_size` parameter + in the `preprocess` method. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`): + Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the + `preprocess` method. + do_rescale (`bool`, *optional*, defaults to `True`): + Wwhether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the + `do_rescale` parameter in the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Only has an effect if `do_rescale` is set to `True`. Can be + overridden by the `rescale_factor` parameter in the `preprocess` method. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess` + method. Can be overridden by the `do_normalize` parameter in the `preprocess` method. + image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_DEFAULT_MEAN`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be + overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_DEFAULT_STD`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + Can be overridden by the `image_std` parameter in the `preprocess` method. + do_pad (`bool`, *optional*, defaults to `True`): + Whether to pad the image to the specified `pad_size`. Can be overridden by the `do_pad` parameter in the + `preprocess` method. + pad_size (`dict`, *optional*, defaults to `{"height": 1024, "width": 1024}`): + Size of the output image after padding. Can be overridden by the `pad_size` parameter in the `preprocess` + method. + mask_pad_size (`dict`, *optional*, defaults to `{"height": 256, "width": 256}`): + Size of the output segmentation map after padding. Can be overridden by the `mask_pad_size` parameter in + the `preprocess` method. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize: bool = True, + size: Dict[str, int] = None, + mask_size: Dict[str, int] = None, + resample: PILImageResampling = PILImageResampling.BILINEAR, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_pad: bool = True, + pad_size: Optional[int] = None, + mask_pad_size: Optional[int] = None, + do_convert_rgb: bool = True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"longest_edge": 1024} + size = get_size_dict(max_size=size, default_to_square=False) if not isinstance(size, dict) else size + + pad_size = pad_size if pad_size is not None else {"height": 1024, "width": 1024} + pad_size = get_size_dict(pad_size, default_to_square=True) + + mask_size = mask_size if mask_size is not None else {"longest_edge": 256} + mask_size = ( + get_size_dict(max_size=mask_size, default_to_square=False) + if not isinstance(mask_size, dict) + else mask_size + ) + + mask_pad_size = mask_pad_size if mask_pad_size is not None else {"height": 256, "width": 256} + mask_pad_size = get_size_dict(mask_pad_size, default_to_square=True) + + self.do_resize = do_resize + self.size = size + self.mask_size = mask_size + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN + self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD + self.do_pad = do_pad + self.pad_size = pad_size + self.mask_pad_size = mask_pad_size + self.do_convert_rgb = do_convert_rgb + + def pad_image( + self, + image: np.ndarray, + pad_size: Dict[str, int], + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Pad an image to `(pad_size["height"], pad_size["width"])` with zeros to the right and bottom. + + Args: + image (`np.ndarray`): + Image to pad. + pad_size (`Dict[str, int]`): + Size of the output image after padding. + data_format (`str` or `ChannelDimension`, *optional*): + The data format of the image. Can be either "channels_first" or "channels_last". If `None`, the + `data_format` of the `image` will be used. + input_data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + output_height, output_width = pad_size["height"], pad_size["width"] + input_height, input_width = get_image_size(image, channel_dim=input_data_format) + + pad_width = output_width - input_width + pad_height = output_height - input_height + + padded_image = pad( + image, + ((0, pad_height), (0, pad_width)), + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + return padded_image + + def _get_preprocess_shape(self, old_shape: Tuple[int, int], longest_edge: int): + """ + Compute the output size given input size and target long side length. + """ + oldh, oldw = old_shape + scale = longest_edge * 1.0 / max(oldh, oldw) + newh, neww = oldh * scale, oldw * scale + newh = int(newh + 0.5) + neww = int(neww + 0.5) + return (newh, neww) + + def resize( + self, + image: np.ndarray, + size: Dict[str, int], + resample: PILImageResampling = PILImageResampling.BICUBIC, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize an image to `(size["height"], size["width"])`. + + Args: + image (`np.ndarray`): + Image to resize. + size (`Dict[str, int]`): + Dictionary in the format `{"longest_edge": int}` specifying the size of the output image. The longest + edge of the image will be resized to the specified size, while the other edge will be resized to + maintain the aspect ratio. + resample: + `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`. + data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the output image. If unset, the channel dimension format of the input + image is used. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + + Returns: + `np.ndarray`: The resized image. + """ + size = get_size_dict(size) + if "longest_edge" not in size: + raise ValueError(f"The `size` dictionary must contain the key `longest_edge`. Got {size.keys()}") + input_size = get_image_size(image, channel_dim=input_data_format) + output_height, output_width = self._get_preprocess_shape(input_size, size["longest_edge"]) + return resize( + image, + size=(output_height, output_width), + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + def _preprocess( + self, + image: ImageInput, + do_resize: bool, + do_rescale: bool, + do_normalize: bool, + size: Optional[Dict[str, int]] = None, + resample: PILImageResampling = None, + rescale_factor: Optional[float] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_pad: Optional[bool] = None, + pad_size: Optional[Dict[str, int]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + if do_resize: + image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format) + reshaped_input_size = get_image_size(image, channel_dim=input_data_format) + + if do_rescale: + image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + + if do_normalize: + image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) + + if do_pad: + image = self.pad_image(image=image, pad_size=pad_size, input_data_format=input_data_format) + + return image, reshaped_input_size + + def _preprocess_image( + self, + image: ImageInput, + do_resize: Optional[bool] = None, + size: Dict[str, int] = None, + resample: PILImageResampling = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_pad: Optional[bool] = None, + pad_size: Optional[Dict[str, int]] = None, + do_convert_rgb: Optional[bool] = None, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> Tuple[np.ndarray, Tuple[int, int], Tuple[int, int]]: + # PIL RGBA images are converted to RGB + if do_convert_rgb: + image = convert_to_rgb(image) + + # All transformations expect numpy arrays. + image = to_numpy_array(image) + + if do_rescale and is_scaled_image(image): + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + input_data_format = infer_channel_dimension_format(image) + + original_size = get_image_size(image, channel_dim=input_data_format) + + image, reshaped_input_size = self._preprocess( + image=image, + do_resize=do_resize, + size=size, + resample=resample, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_pad=do_pad, + pad_size=pad_size, + input_data_format=input_data_format, + ) + + if data_format is not None: + image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + + return image, original_size, reshaped_input_size + + def _preprocess_mask( + self, + segmentation_map: ImageInput, + do_resize: Optional[bool] = None, + mask_size: Dict[str, int] = None, + do_pad: Optional[bool] = None, + mask_pad_size: Optional[Dict[str, int]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> np.ndarray: + segmentation_map = to_numpy_array(segmentation_map) + + # Add channel dimension if missing - needed for certain transformations + if segmentation_map.ndim == 2: + added_channel_dim = True + segmentation_map = segmentation_map[None, ...] + input_data_format = ChannelDimension.FIRST + else: + added_channel_dim = False + if input_data_format is None: + input_data_format = infer_channel_dimension_format(segmentation_map, num_channels=1) + + original_size = get_image_size(segmentation_map, channel_dim=input_data_format) + + segmentation_map, _ = self._preprocess( + image=segmentation_map, + do_resize=do_resize, + size=mask_size, + resample=PILImageResampling.NEAREST, + do_rescale=False, + do_normalize=False, + do_pad=do_pad, + pad_size=mask_pad_size, + input_data_format=input_data_format, + ) + + # Remove extra channel dimension if added for processing + if added_channel_dim: + segmentation_map = segmentation_map.squeeze(0) + segmentation_map = segmentation_map.astype(np.int64) + + return segmentation_map, original_size + + @filter_out_non_signature_kwargs() + def preprocess( + self, + images: ImageInput, + segmentation_maps: Optional[ImageInput] = None, + do_resize: Optional[bool] = None, + size: Optional[Dict[str, int]] = None, + mask_size: Optional[Dict[str, int]] = None, + resample: Optional["PILImageResampling"] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[Union[int, float]] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_pad: Optional[bool] = None, + pad_size: Optional[Dict[str, int]] = None, + mask_pad_size: Optional[Dict[str, int]] = None, + do_convert_rgb: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: ChannelDimension = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + segmentation_maps (`ImageInput`, *optional*): + Segmentation map to preprocess. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Controls the size of the image after `resize`. The longest edge of the image is resized to + `size["longest_edge"]` whilst preserving the aspect ratio. + mask_size (`Dict[str, int]`, *optional*, defaults to `self.mask_size`): + Controls the size of the segmentation map after `resize`. The longest edge of the image is resized to + `size["longest_edge"]` whilst preserving the aspect ratio. + resample (`PILImageResampling`, *optional*, defaults to `self.resample`): + `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image pixel values by rescaling factor. + rescale_factor (`int` or `float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to apply to the image pixel values. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean to normalize the image by if `do_normalize` is set to `True`. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to normalize the image by if `do_normalize` is set to `True`. + do_pad (`bool`, *optional*, defaults to `self.do_pad`): + Whether to pad the image. + pad_size (`Dict[str, int]`, *optional*, defaults to `self.pad_size`): + Controls the size of the padding applied to the image. The image is padded to `pad_size["height"]` and + `pad_size["width"]` if `do_pad` is set to `True`. + mask_pad_size (`Dict[str, int]`, *optional*, defaults to `self.mask_pad_size`): + Controls the size of the padding applied to the segmentation map. The image is padded to + `mask_pad_size["height"]` and `mask_pad_size["width"]` if `do_pad` is set to `True`. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + do_resize = do_resize if do_resize is not None else self.do_resize + size = size if size is not None else self.size + size = get_size_dict(max_size=size, default_to_square=False) if not isinstance(size, dict) else size + mask_size = mask_size if mask_size is not None else self.mask_size + mask_size = ( + get_size_dict(max_size=mask_size, default_to_square=False) + if not isinstance(mask_size, dict) + else mask_size + ) + resample = resample if resample is not None else self.resample + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + do_pad = do_pad if do_pad is not None else self.do_pad + pad_size = pad_size if pad_size is not None else self.pad_size + pad_size = get_size_dict(pad_size, default_to_square=True) + mask_pad_size = mask_pad_size if mask_pad_size is not None else self.mask_pad_size + mask_pad_size = get_size_dict(mask_pad_size, default_to_square=True) + do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + + images = make_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + if segmentation_maps is not None: + segmentation_maps = make_list_of_images(segmentation_maps, expected_ndims=2) + + if not valid_images(segmentation_maps): + raise ValueError( + "Invalid segmentation map type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_pad=do_pad, + size_divisibility=pad_size, # Here _preprocess needs do_pad and pad_size. + do_resize=do_resize, + size=size, + resample=resample, + ) + + images, original_sizes, reshaped_input_sizes = zip( + *( + self._preprocess_image( + image=img, + do_resize=do_resize, + size=size, + resample=resample, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_pad=do_pad, + pad_size=pad_size, + do_convert_rgb=do_convert_rgb, + data_format=data_format, + input_data_format=input_data_format, + ) + for img in images + ) + ) + + data = { + "pixel_values": images, + "original_sizes": original_sizes, + "reshaped_input_sizes": reshaped_input_sizes, + } + + if segmentation_maps is not None: + segmentation_maps, original_mask_sizes = zip( + *( + self._preprocess_mask( + segmentation_map=mask, + do_resize=do_resize, + mask_size=mask_size, + do_pad=do_pad, + mask_pad_size=mask_pad_size, + input_data_format=input_data_format, + ) + for mask in segmentation_maps + ) + ) + + # masks should start out the same size as input images + assert all( + original_im_size == original_mask_size + for original_im_size, original_mask_size in zip(original_sizes, original_mask_sizes) + ), "Segmentation maps should be the same size as input images." + + data["labels"] = segmentation_maps + + return BatchFeature(data=data, tensor_type=return_tensors) + + def post_process_masks( + self, + masks, + original_sizes, + reshaped_input_sizes, + mask_threshold=0.0, + binarize=True, + pad_size=None, + return_tensors="pt", + ): + """ + Remove padding and upscale masks to the original image size. + + Args: + masks (`Union[List[torch.Tensor], List[np.ndarray], List[tf.Tensor]]`): + Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format. + original_sizes (`Union[torch.Tensor, tf.Tensor, List[Tuple[int,int]]]`): + The original sizes of each image before it was resized to the model's expected input shape, in (height, + width) format. + reshaped_input_sizes (`Union[torch.Tensor, tf.Tensor, List[Tuple[int,int]]]`): + The size of each image as it is fed to the model, in (height, width) format. Used to remove padding. + mask_threshold (`float`, *optional*, defaults to 0.0): + The threshold to use for binarizing the masks. + binarize (`bool`, *optional*, defaults to `True`): + Whether to binarize the masks. + pad_size (`int`, *optional*, defaults to `self.pad_size`): + The target size the images were padded to before being passed to the model. If None, the target size is + assumed to be the processor's `pad_size`. + return_tensors (`str`, *optional*, defaults to `"pt"`): + If `"pt"`, return PyTorch tensors. If `"tf"`, return TensorFlow tensors. + Returns: + (`Union[torch.Tensor, tf.Tensor]`): Batched masks in batch_size, num_channels, height, width) format, where + (height, width) is given by original_size. + """ + if return_tensors == "pt": + return self._post_process_masks_pt( + masks=masks, + original_sizes=original_sizes, + reshaped_input_sizes=reshaped_input_sizes, + mask_threshold=mask_threshold, + binarize=binarize, + pad_size=pad_size, + ) + elif return_tensors == "tf": + return self._post_process_masks_tf( + masks=masks, + original_sizes=original_sizes, + reshaped_input_sizes=reshaped_input_sizes, + mask_threshold=mask_threshold, + binarize=binarize, + pad_size=pad_size, + ) + else: + raise ValueError("return_tensors must be either 'pt' or 'tf'") + + def _post_process_masks_pt( + self, masks, original_sizes, reshaped_input_sizes, mask_threshold=0.0, binarize=True, pad_size=None + ): + """ + Remove padding and upscale masks to the original image size. + + Args: + masks (`Union[List[torch.Tensor], List[np.ndarray]]`): + Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format. + original_sizes (`Union[torch.Tensor, List[Tuple[int,int]]]`): + The original sizes of each image before it was resized to the model's expected input shape, in (height, + width) format. + reshaped_input_sizes (`Union[torch.Tensor, List[Tuple[int,int]]]`): + The size of each image as it is fed to the model, in (height, width) format. Used to remove padding. + mask_threshold (`float`, *optional*, defaults to 0.0): + The threshold to use for binarizing the masks. + binarize (`bool`, *optional*, defaults to `True`): + Whether to binarize the masks. + pad_size (`int`, *optional*, defaults to `self.pad_size`): + The target size the images were padded to before being passed to the model. If None, the target size is + assumed to be the processor's `pad_size`. + Returns: + (`torch.Tensor`): Batched masks in batch_size, num_channels, height, width) format, where (height, width) + is given by original_size. + """ + requires_backends(self, ["torch"]) + pad_size = self.pad_size if pad_size is None else pad_size + target_image_size = (pad_size["height"], pad_size["width"]) + if isinstance(original_sizes, (torch.Tensor, np.ndarray)): + original_sizes = original_sizes.tolist() + if isinstance(reshaped_input_sizes, (torch.Tensor, np.ndarray)): + reshaped_input_sizes = reshaped_input_sizes.tolist() + output_masks = [] + for i, original_size in enumerate(original_sizes): + if isinstance(masks[i], np.ndarray): + masks[i] = torch.from_numpy(masks[i]) + elif not isinstance(masks[i], torch.Tensor): + raise ValueError("Input masks should be a list of `torch.tensors` or a list of `np.ndarray`") + interpolated_mask = F.interpolate(masks[i], target_image_size, mode="bilinear", align_corners=False) + interpolated_mask = interpolated_mask[..., : reshaped_input_sizes[i][0], : reshaped_input_sizes[i][1]] + interpolated_mask = F.interpolate(interpolated_mask, original_size, mode="bilinear", align_corners=False) + if binarize: + interpolated_mask = interpolated_mask > mask_threshold + output_masks.append(interpolated_mask) + + return output_masks + + def _post_process_masks_tf( + self, masks, original_sizes, reshaped_input_sizes, mask_threshold=0.0, binarize=True, pad_size=None + ): + """ + Remove padding and upscale masks to the original image size. + + Args: + masks (`tf.Tensor`): + Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format. + original_sizes (`tf.Tensor`): + The original size of the images before resizing for input to the model, in (height, width) format. + reshaped_input_sizes (`tf.Tensor`): + The size of the image input to the model, in (height, width) format. Used to remove padding. + mask_threshold (`float`, *optional*, defaults to 0.0): + The threshold to use for binarizing the masks. + binarize (`bool`, *optional*, defaults to `True`): + Whether to binarize the masks. + pad_size (`int`, *optional*, defaults to `self.pad_size`): + The target size the images were padded to before being passed to the model. If None, the target size is + assumed to be the processor's `pad_size`. + Returns: + (`tf.Tensor`): Batched masks in batch_size, num_channels, height, width) format, where (height, width) is + given by original_size. + """ + requires_backends(self, ["tf"]) + pad_size = self.pad_size if pad_size is None else pad_size + target_image_size = (pad_size["height"], pad_size["width"]) + + output_masks = [] + for i, original_size in enumerate(original_sizes): + # tf.image expects NHWC, we transpose the NCHW inputs for it + mask = tf.transpose(masks[i], perm=[0, 2, 3, 1]) + interpolated_mask = tf.image.resize(mask, target_image_size, method="bilinear") + interpolated_mask = interpolated_mask[:, : reshaped_input_sizes[i][0], : reshaped_input_sizes[i][1], :] + interpolated_mask = tf.image.resize(interpolated_mask, original_size, method="bilinear") + if binarize: + interpolated_mask = interpolated_mask > mask_threshold + # And then we transpose them back at the end + output_masks.append(tf.transpose(interpolated_mask, perm=[0, 3, 1, 2])) + + return output_masks + + def post_process_for_mask_generation( + self, all_masks, all_scores, all_boxes, crops_nms_thresh, return_tensors="pt" + ): + """ + Post processes mask that are generated by calling the Non Maximum Suppression algorithm on the predicted masks. + + Args: + all_masks (`Union[List[torch.Tensor], List[tf.Tensor]]`): + List of all predicted segmentation masks + all_scores (`Union[List[torch.Tensor], List[tf.Tensor]]`): + List of all predicted iou scores + all_boxes (`Union[List[torch.Tensor], List[tf.Tensor]]`): + List of all bounding boxes of the predicted masks + crops_nms_thresh (`float`): + Threshold for NMS (Non Maximum Suppression) algorithm. + return_tensors (`str`, *optional*, defaults to `pt`): + If `pt`, returns `torch.Tensor`. If `tf`, returns `tf.Tensor`. + """ + if return_tensors == "pt": + return _postprocess_for_mg(all_masks, all_scores, all_boxes, crops_nms_thresh) + elif return_tensors == "tf": + return _postprocess_for_mg_tf(all_masks, all_scores, all_boxes, crops_nms_thresh) + + def generate_crop_boxes( + self, + image, + target_size, + crop_n_layers: int = 0, + overlap_ratio: float = 512 / 1500, + points_per_crop: Optional[int] = 32, + crop_n_points_downscale_factor: Optional[List[int]] = 1, + device: Optional["torch.device"] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + return_tensors: str = "pt", + ): + """ + Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer. + + Args: + image (`np.array`): + Input original image + target_size (`int`): + Target size of the resized image + crop_n_layers (`int`, *optional*, defaults to 0): + If >0, mask prediction will be run again on crops of the image. Sets the number of layers to run, where + each layer has 2**i_layer number of image crops. + overlap_ratio (`float`, *optional*, defaults to 512/1500): + Sets the degree to which crops overlap. In the first crop layer, crops will overlap by this fraction of + the image length. Later layers with more crops scale down this overlap. + points_per_crop (`int`, *optional*, defaults to 32): + Number of points to sample from each crop. + crop_n_points_downscale_factor (`List[int]`, *optional*, defaults to 1): + The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n. + device (`torch.device`, *optional*, defaults to None): + Device to use for the computation. If None, cpu will be used. + input_data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + return_tensors (`str`, *optional*, defaults to `pt`): + If `pt`, returns `torch.Tensor`. If `tf`, returns `tf.Tensor`. + """ + crop_boxes, points_per_crop, cropped_images, input_labels = _generate_crop_boxes( + image, + target_size, + crop_n_layers, + overlap_ratio, + points_per_crop, + crop_n_points_downscale_factor, + input_data_format, + ) + if return_tensors == "pt": + if device is None: + device = torch.device("cpu") + crop_boxes = torch.tensor(crop_boxes, device=device) + points_per_crop = torch.tensor(points_per_crop, device=device) + # cropped_images stays as np + input_labels = torch.tensor(input_labels, device=device) + + elif return_tensors == "tf": + if device is not None: + raise ValueError("device is not a supported argument when return_tensors is tf!") + crop_boxes = tf.convert_to_tensor(crop_boxes) + points_per_crop = tf.convert_to_tensor(points_per_crop) + # cropped_images stays as np + input_labels = tf.convert_to_tensor(input_labels) + else: + raise ValueError("return_tensors must be either 'pt' or 'tf'.") + return crop_boxes, points_per_crop, cropped_images, input_labels + + def filter_masks( + self, + masks, + iou_scores, + original_size, + cropped_box_image, + pred_iou_thresh=0.88, + stability_score_thresh=0.95, + mask_threshold=0, + stability_score_offset=1, + return_tensors="pt", + ): + """ + Filters the predicted masks by selecting only the ones that meets several criteria. The first criterion being + that the iou scores needs to be greater than `pred_iou_thresh`. The second criterion is that the stability + score needs to be greater than `stability_score_thresh`. The method also converts the predicted masks to + bounding boxes and pad the predicted masks if necessary. + + Args: + masks (`Union[torch.Tensor, tf.Tensor]`): + Input masks. + iou_scores (`Union[torch.Tensor, tf.Tensor]`): + List of IoU scores. + original_size (`Tuple[int,int]`): + Size of the orginal image. + cropped_box_image (`np.array`): + The cropped image. + pred_iou_thresh (`float`, *optional*, defaults to 0.88): + The threshold for the iou scores. + stability_score_thresh (`float`, *optional*, defaults to 0.95): + The threshold for the stability score. + mask_threshold (`float`, *optional*, defaults to 0): + The threshold for the predicted masks. + stability_score_offset (`float`, *optional*, defaults to 1): + The offset for the stability score used in the `_compute_stability_score` method. + return_tensors (`str`, *optional*, defaults to `pt`): + If `pt`, returns `torch.Tensor`. If `tf`, returns `tf.Tensor`. + """ + if return_tensors == "pt": + return self._filter_masks_pt( + masks=masks, + iou_scores=iou_scores, + original_size=original_size, + cropped_box_image=cropped_box_image, + pred_iou_thresh=pred_iou_thresh, + stability_score_thresh=stability_score_thresh, + mask_threshold=mask_threshold, + stability_score_offset=stability_score_offset, + ) + elif return_tensors == "tf": + return self._filter_masks_tf( + masks=masks, + iou_scores=iou_scores, + original_size=original_size, + cropped_box_image=cropped_box_image, + pred_iou_thresh=pred_iou_thresh, + stability_score_thresh=stability_score_thresh, + mask_threshold=mask_threshold, + stability_score_offset=stability_score_offset, + ) + + def _filter_masks_pt( + self, + masks, + iou_scores, + original_size, + cropped_box_image, + pred_iou_thresh=0.88, + stability_score_thresh=0.95, + mask_threshold=0, + stability_score_offset=1, + ): + """ + Filters the predicted masks by selecting only the ones that meets several criteria. The first criterion being + that the iou scores needs to be greater than `pred_iou_thresh`. The second criterion is that the stability + score needs to be greater than `stability_score_thresh`. The method also converts the predicted masks to + bounding boxes and pad the predicted masks if necessary. + + Args: + masks (`torch.Tensor`): + Input masks. + iou_scores (`torch.Tensor`): + List of IoU scores. + original_size (`Tuple[int,int]`): + Size of the orginal image. + cropped_box_image (`np.array`): + The cropped image. + pred_iou_thresh (`float`, *optional*, defaults to 0.88): + The threshold for the iou scores. + stability_score_thresh (`float`, *optional*, defaults to 0.95): + The threshold for the stability score. + mask_threshold (`float`, *optional*, defaults to 0): + The threshold for the predicted masks. + stability_score_offset (`float`, *optional*, defaults to 1): + The offset for the stability score used in the `_compute_stability_score` method. + + """ + requires_backends(self, ["torch"]) + original_height, original_width = original_size + iou_scores = iou_scores.flatten(0, 1) + masks = masks.flatten(0, 1) + + if masks.shape[0] != iou_scores.shape[0]: + raise ValueError("masks and iou_scores must have the same batch size.") + + if masks.device != iou_scores.device: + iou_scores = iou_scores.to(masks.device) + + batch_size = masks.shape[0] + + keep_mask = torch.ones(batch_size, dtype=torch.bool, device=masks.device) + + if pred_iou_thresh > 0.0: + keep_mask = keep_mask & (iou_scores > pred_iou_thresh) + + # compute stability score + if stability_score_thresh > 0.0: + stability_scores = _compute_stability_score_pt(masks, mask_threshold, stability_score_offset) + keep_mask = keep_mask & (stability_scores > stability_score_thresh) + + scores = iou_scores[keep_mask] + masks = masks[keep_mask] + + # binarize masks + masks = masks > mask_threshold + converted_boxes = _batched_mask_to_box(masks) + + keep_mask = ~_is_box_near_crop_edge( + converted_boxes, cropped_box_image, [0, 0, original_width, original_height] + ) + + scores = scores[keep_mask] + masks = masks[keep_mask] + converted_boxes = converted_boxes[keep_mask] + + masks = _pad_masks(masks, cropped_box_image, original_height, original_width) + # conversion to rle is necessary to run non-maximum suppresion + masks = _mask_to_rle_pytorch(masks) + + return masks, scores, converted_boxes + + def _filter_masks_tf( + self, + masks, + iou_scores, + original_size, + cropped_box_image, + pred_iou_thresh=0.88, + stability_score_thresh=0.95, + mask_threshold=0, + stability_score_offset=1, + ): + """ + Filters the predicted masks by selecting only the ones that meets several criteria. The first criterion being + that the iou scores needs to be greater than `pred_iou_thresh`. The second criterion is that the stability + score needs to be greater than `stability_score_thresh`. The method also converts the predicted masks to + bounding boxes and pad the predicted masks if necessary. + + Args: + masks (`tf.Tensor`): + Input masks. + iou_scores (`tf.Tensor`): + List of IoU scores. + original_size (`Tuple[int,int]`): + Size of the orginal image. + cropped_box_image (`np.array`): + The cropped image. + pred_iou_thresh (`float`, *optional*, defaults to 0.88): + The threshold for the iou scores. + stability_score_thresh (`float`, *optional*, defaults to 0.95): + The threshold for the stability score. + mask_threshold (`float`, *optional*, defaults to 0): + The threshold for the predicted masks. + stability_score_offset (`float`, *optional*, defaults to 1): + The offset for the stability score used in the `_compute_stability_score` method. + + """ + requires_backends(self, ["tf"]) + original_height, original_width = original_size + iou_scores = tf.reshape(iou_scores, [iou_scores.shape[0] * iou_scores.shape[1], iou_scores.shape[2:]]) + masks = tf.reshape(masks, [masks.shape[0] * masks.shape[1], masks.shape[2:]]) + + if masks.shape[0] != iou_scores.shape[0]: + raise ValueError("masks and iou_scores must have the same batch size.") + + batch_size = masks.shape[0] + + keep_mask = tf.ones(batch_size, dtype=tf.bool) + + if pred_iou_thresh > 0.0: + keep_mask = keep_mask & (iou_scores > pred_iou_thresh) + + # compute stability score + if stability_score_thresh > 0.0: + stability_scores = _compute_stability_score_tf(masks, mask_threshold, stability_score_offset) + keep_mask = keep_mask & (stability_scores > stability_score_thresh) + + scores = iou_scores[keep_mask] + masks = masks[keep_mask] + + # binarize masks + masks = masks > mask_threshold + converted_boxes = _batched_mask_to_box_tf(masks) + + keep_mask = ~_is_box_near_crop_edge_tf( + converted_boxes, cropped_box_image, [0, 0, original_width, original_height] + ) + + scores = scores[keep_mask] + masks = masks[keep_mask] + converted_boxes = converted_boxes[keep_mask] + + masks = _pad_masks_tf(masks, cropped_box_image, original_height, original_width) + # conversion to rle is necessary to run non-maximum suppresion + masks = _mask_to_rle_tf(masks) + + return masks, scores, converted_boxes + + +def _compute_stability_score_pt(masks: "torch.Tensor", mask_threshold: float, stability_score_offset: int): + # One mask is always contained inside the other. + # Save memory by preventing unnecesary cast to torch.int64 + intersections = ( + (masks > (mask_threshold + stability_score_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32) + ) + unions = (masks > (mask_threshold - stability_score_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32) + stability_scores = intersections / unions + return stability_scores + + +def _compute_stability_score_tf(masks: "tf.Tensor", mask_threshold: float, stability_score_offset: int): + # Torch does Py3-style division but TF does floor division with ints. We cast to float32 in TF to make sure + # we get the right division results. + intersections = tf.count_nonzero( + masks > (mask_threshold + stability_score_offset), axis=[-1, -2], dtype=tf.float32 + ) + unions = tf.count_nonzero(masks > (mask_threshold - stability_score_offset), axis=[-1, -2], dtype=tf.float32) + stability_scores = intersections / unions + return stability_scores + + +def _build_point_grid(n_per_side: int) -> np.ndarray: + """Generates a 2D grid of points evenly spaced in [0,1]x[0,1].""" + offset = 1 / (2 * n_per_side) + points_one_side = np.linspace(offset, 1 - offset, n_per_side) + points_x = np.tile(points_one_side[None, :], (n_per_side, 1)) + points_y = np.tile(points_one_side[:, None], (1, n_per_side)) + points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2) + return points + + +def _normalize_coordinates( + target_size: int, coords: np.ndarray, original_size: Tuple[int, int], is_bounding_box=False +) -> np.ndarray: + """ + Expects a numpy array of length 2 in the final dimension. Requires the original image size in (height, width) + format. + """ + old_height, old_width = original_size + + scale = target_size * 1.0 / max(old_height, old_width) + new_height, new_width = old_height * scale, old_width * scale + new_width = int(new_width + 0.5) + new_height = int(new_height + 0.5) + + coords = deepcopy(coords).astype(float) + + if is_bounding_box: + coords = coords.reshape(-1, 2, 2) + + coords[..., 0] = coords[..., 0] * (new_width / old_width) + coords[..., 1] = coords[..., 1] * (new_height / old_height) + + if is_bounding_box: + coords = coords.reshape(-1, 4) + + return coords + + +def _generate_crop_boxes( + image, + target_size: int, # Is it tuple here? + crop_n_layers: int = 0, + overlap_ratio: float = 512 / 1500, + points_per_crop: Optional[int] = 32, + crop_n_points_downscale_factor: Optional[List[int]] = 1, + input_data_format: Optional[Union[str, ChannelDimension]] = None, +) -> Tuple[List[List[int]], List[int]]: + """ + Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer. + + Args: + image (Union[`numpy.ndarray`, `PIL.Image`, `torch.Tensor`]): + Image to generate crops for. + target_size (`int`): + Size of the smallest crop. + crop_n_layers (`int`, *optional*): + If `crops_n_layers>0`, mask prediction will be run again on crops of the image. Sets the number of layers + to run, where each layer has 2**i_layer number of image crops. + overlap_ratio (`int`, *optional*): + Sets the degree to which crops overlap. In the first crop layer, crops will overlap by this fraction of the + image length. Later layers with more crops scale down this overlap. + points_per_crop (`int`, *optional*): + Number of points to sample per crop. + crop_n_points_downscale_factor (`int`, *optional*): + The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n. + input_data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + + if isinstance(image, list): + raise ValueError("Only one image is allowed for crop generation.") + image = to_numpy_array(image) + original_size = get_image_size(image, input_data_format) + + points_grid = [] + for i in range(crop_n_layers + 1): + n_points = int(points_per_crop / (crop_n_points_downscale_factor**i)) + points_grid.append(_build_point_grid(n_points)) + + crop_boxes, layer_idxs = _generate_per_layer_crops(crop_n_layers, overlap_ratio, original_size) + + cropped_images, point_grid_per_crop = _generate_crop_images( + crop_boxes, image, points_grid, layer_idxs, target_size, original_size, input_data_format + ) + crop_boxes = np.array(crop_boxes) + crop_boxes = crop_boxes.astype(np.float32) + points_per_crop = np.array([point_grid_per_crop]) + points_per_crop = np.transpose(points_per_crop, axes=(0, 2, 1, 3)) + + input_labels = np.ones_like(points_per_crop[:, :, :, 0], dtype=np.int64) + + return crop_boxes, points_per_crop, cropped_images, input_labels + + +def _generate_per_layer_crops(crop_n_layers, overlap_ratio, original_size): + """ + Generates 2 ** (layers idx + 1) crops for each crop_n_layers. Crops are in the XYWH format : The XYWH format + consists of the following required indices: + - X: X coordinate of the top left of the bounding box + - Y: Y coordinate of the top left of the bounding box + - W: width of the bounding box + - H: height of the bounding box + """ + crop_boxes, layer_idxs = [], [] + im_height, im_width = original_size + short_side = min(im_height, im_width) + + # Original image + crop_boxes.append([0, 0, im_width, im_height]) + layer_idxs.append(0) + for i_layer in range(crop_n_layers): + n_crops_per_side = 2 ** (i_layer + 1) + overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side)) + + crop_width = int(math.ceil((overlap * (n_crops_per_side - 1) + im_width) / n_crops_per_side)) + crop_height = int(math.ceil((overlap * (n_crops_per_side - 1) + im_height) / n_crops_per_side)) + + crop_box_x0 = [int((crop_width - overlap) * i) for i in range(n_crops_per_side)] + crop_box_y0 = [int((crop_height - overlap) * i) for i in range(n_crops_per_side)] + + for left, top in product(crop_box_x0, crop_box_y0): + box = [left, top, min(left + crop_width, im_width), min(top + crop_height, im_height)] + crop_boxes.append(box) + layer_idxs.append(i_layer + 1) + + return crop_boxes, layer_idxs + + +def _generate_crop_images( + crop_boxes, image, points_grid, layer_idxs, target_size, original_size, input_data_format=None +): + """ + Takes as an input bounding boxes that are used to crop the image. Based in the crops, the corresponding points are + also passed. + """ + cropped_images = [] + total_points_per_crop = [] + for i, crop_box in enumerate(crop_boxes): + left, top, right, bottom = crop_box + + channel_dim = infer_channel_dimension_format(image, input_data_format) + if channel_dim == ChannelDimension.LAST: + cropped_im = image[top:bottom, left:right, :] + else: + cropped_im = image[:, top:bottom, left:right] + + cropped_images.append(cropped_im) + + cropped_im_size = get_image_size(cropped_im, channel_dim) + points_scale = np.array(cropped_im_size)[None, ::-1] + + points = points_grid[layer_idxs[i]] * points_scale + normalized_points = _normalize_coordinates(target_size, points, original_size) + total_points_per_crop.append(normalized_points) + + return cropped_images, total_points_per_crop + + +def _pad_masks(masks, crop_box: List[int], orig_height: int, orig_width: int): + left, top, right, bottom = crop_box + if left == 0 and top == 0 and right == orig_width and bottom == orig_height: + return masks + # Coordinate transform masks + pad_x, pad_y = orig_width - (right - left), orig_height - (bottom - top) + pad = (left, pad_x - left, top, pad_y - top) + return torch.nn.functional.pad(masks, pad, value=0) + + +def _pad_masks_tf(masks, crop_box: List[int], orig_height: int, orig_width: int): + left, top, right, bottom = crop_box + if left == 0 and top == 0 and right == orig_width and bottom == orig_height: + return masks + # Coordinate transform masks + pad_x, pad_y = orig_width - (right - left), orig_height - (bottom - top) + pad = (left, pad_x - left, top, pad_y - top) + return tf.pad(masks, pad, constant_values=0) + + +def _is_box_near_crop_edge(boxes, crop_box, orig_box, atol=20.0): + """Filter masks at the edge of a crop, but not at the edge of the original image.""" + crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device) + orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device) + + left, top, _, _ = crop_box + offset = torch.tensor([[left, top, left, top]], device=boxes.device) + # Check if boxes has a channel dimension + if len(boxes.shape) == 3: + offset = offset.unsqueeze(1) + boxes = (boxes + offset).float() + + near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0) + near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0) + near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge) + return torch.any(near_crop_edge, dim=1) + + +def _is_box_near_crop_edge_tf(boxes, crop_box, orig_box, atol=20.0): + """Filter masks at the edge of a crop, but not at the edge of the original image.""" + crop_box_tf = tf.convert_to_tensor(crop_box, dtype=tf.float32) + orig_box_tf = tf.convert_to_tensor(orig_box, dtype=tf.float32) + + left, top, _, _ = crop_box + offset = tf.convert_to_tensor([[left, top, left, top]]) + # Check if boxes has a channel dimension + if len(boxes.shape) == 3: + offset = tf.expand_dims(offset, 1) + boxes = tf.cast(boxes + offset, tf.float32) + + near_crop_edge = tnp.isclose(boxes, crop_box_tf[None, :], atol=atol, rtol=0) + near_image_edge = tnp.isclose(boxes, orig_box_tf[None, :], atol=atol, rtol=0) + near_crop_edge = tf.math.logical_and(near_crop_edge, ~near_image_edge) + return tf.reduce_any(near_crop_edge, axis=1) + + +def _batched_mask_to_box(masks: "torch.Tensor"): + """ + Computes the bounding boxes around the given input masks. The bounding boxes are in the XYXY format which + corresponds the following required indices: + - LEFT: left hand side of the bounding box + - TOP: top of the bounding box + - RIGHT: right of the bounding box + - BOTTOM: bottom of the bounding box + + Return [0,0,0,0] for an empty mask. For input shape channel_1 x channel_2 x ... x height x width, the output shape + is channel_1 x channel_2 x ... x 4. + + Args: + - masks (`torch.Tensor` of shape `(batch, nb_mask, height, width)`) + """ + # torch.max below raises an error on empty inputs, just skip in this case + + if torch.numel(masks) == 0: + return torch.zeros(*masks.shape[:-2], 4, device=masks.device) + + # Normalize shape to Cxheightxwidth + shape = masks.shape + height, width = shape[-2:] + + # Get top and bottom edges + in_height, _ = torch.max(masks, dim=-1) + in_height_coords = in_height * torch.arange(height, device=in_height.device)[None, :] + bottom_edges, _ = torch.max(in_height_coords, dim=-1) + in_height_coords = in_height_coords + height * (~in_height) + top_edges, _ = torch.min(in_height_coords, dim=-1) + + # Get left and right edges + in_width, _ = torch.max(masks, dim=-2) + in_width_coords = in_width * torch.arange(width, device=in_width.device)[None, :] + right_edges, _ = torch.max(in_width_coords, dim=-1) + in_width_coords = in_width_coords + width * (~in_width) + left_edges, _ = torch.min(in_width_coords, dim=-1) + + # If the mask is empty the right edge will be to the left of the left edge. + # Replace these boxes with [0, 0, 0, 0] + empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) + out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1) + out = out * (~empty_filter).unsqueeze(-1) + + # Return to original shape + out = out.reshape(*shape[:-2], 4) + return out + + +def _batched_mask_to_box_tf(masks: "tf.Tensor"): + """ + Computes the bounding boxes around the given input masks. The bounding boxes are in the XYXY format which + corresponds the following required indices: + - LEFT: left hand side of the bounding box + - TOP: top of the bounding box + - RIGHT: right of the bounding box + - BOTTOM: bottom of the bounding box + + Return [0,0,0,0] for an empty mask. For input shape channel_1 x channel_2 x ... x height x width, the output shape + is channel_1 x channel_2 x ... x 4. + + Args: + - masks (`tf.Tensor` of shape `(batch, nb_mask, height, width)`) + """ + + if tf.size(masks) == 0: + return tf.zeros([*masks.shape[:-2], 4]) + + # Normalize shape to Cxheightxwidth + shape = shape_list(masks) + height, width = shape[-2:] + + # Get top and bottom edges + in_height = tf.reduce_max(masks, axis=-1) + in_height_coords = in_height * tf.range(height)[None, :] + bottom_edges = tf.reduce_max(in_height_coords, axis=-1) + in_height_coords = in_height_coords + height * (~in_height) + top_edges = tf.reduce_min(in_height_coords, axis=-1) + + # Get left and right edges + in_width, _ = tf.reduce_max(masks, axis=-2) + in_width_coords = in_width * tf.range(width)[None, :] + right_edges, _ = tf.reduce_max(in_width_coords, axis=-1) + in_width_coords = in_width_coords + width * (~in_width) + left_edges, _ = tf.reduce_min(in_width_coords, axis=-1) + + # If the mask is empty the right edge will be to the left of the left edge. + # Replace these boxes with [0, 0, 0, 0] + empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) + out = tf.stack([left_edges, top_edges, right_edges, bottom_edges], axis=-1) + out = out * tf.expand_dims(~empty_filter, -1) + + # Return to original shape + out = tf.reshape(out, *shape[:-2], 4) + return out + + +def _mask_to_rle_pytorch(input_mask: "torch.Tensor"): + """ + Encodes masks the run-length encoding (RLE), in the format expected by pycoco tools. + """ + # Put in fortran order and flatten height and width + batch_size, height, width = input_mask.shape + input_mask = input_mask.permute(0, 2, 1).flatten(1) + + # Compute change indices + diff = input_mask[:, 1:] ^ input_mask[:, :-1] + change_indices = diff.nonzero() + + # Encode run length + out = [] + for i in range(batch_size): + cur_idxs = change_indices[change_indices[:, 0] == i, 1] + 1 + if len(cur_idxs) == 0: + # No changes => either all 0 or all 1 + # If the entire mask is 0, RLE is [height*width] or if the entire mask is 1, RLE is [0, height*width]. + if input_mask[i, 0] == 0: + out.append({"size": [height, width], "counts": [height * width]}) + else: + out.append({"size": [height, width], "counts": [0, height * width]}) + continue + btw_idxs = cur_idxs[1:] - cur_idxs[:-1] + counts = [] if input_mask[i, 0] == 0 else [0] + counts += [cur_idxs[0].item()] + btw_idxs.tolist() + [height * width - cur_idxs[-1].item()] + out.append({"size": [height, width], "counts": counts}) + return out + + +def _mask_to_rle_tf(input_mask: "tf.Tensor"): + """ + Encodes masks the run-length encoding (RLE), in the format expected by pycoco tools. + """ + # Put in fortran order and flatten height and width + batch_size, height, width = input_mask.shape + input_mask = flatten(tf.transpose(input_mask, perm=(0, 2, 1)), 1) + + # Compute change indices + diff = input_mask[:, 1:] ^ input_mask[:, :-1] + change_indices = tf.where(diff) + + # Encode run length + out = [] + for i in range(batch_size): + cur_idxs = change_indices[change_indices[:, 0] == i][:, 1] + 1 + if len(cur_idxs) == 0: + # No changes => either all 0 or all 1 + # If the entire mask is 0, RLE is [height*width] or if the entire mask is 1, RLE is [0, height*width]. + if input_mask[i, 0] == 0: + out.append({"size": [height, width], "counts": [height * width]}) + else: + out.append({"size": [height, width], "counts": [0, height * width]}) + continue + btw_idxs = cur_idxs[1:] - cur_idxs[:-1] + counts = [] if input_mask[i, 0] == 0 else [0] + counts += ( + [cur_idxs[0].numpy().item()] + btw_idxs.numpy().tolist() + [height * width - cur_idxs[-1].numpy().item()] + ) + out.append({"size": [height, width], "counts": counts}) + return out + + +def _rle_to_mask(rle: Dict[str, Any]) -> np.ndarray: + """Compute a binary mask from an uncompressed RLE.""" + height, width = rle["size"] + mask = np.empty(height * width, dtype=bool) + idx = 0 + parity = False + for count in rle["counts"]: + mask[idx : idx + count] = parity + idx += count + parity = not parity + mask = mask.reshape(width, height) + return mask.transpose() # Reshape to original shape + + +def _postprocess_for_mg(rle_masks, iou_scores, mask_boxes, amg_crops_nms_thresh=0.7): + """ + Perform NMS (Non Maximum Suppression) on the outputs. + + Args: + rle_masks (`torch.Tensor`): + binary masks in the RLE format + iou_scores (`torch.Tensor` of shape (nb_masks, 1)): + iou_scores predicted by the model + mask_boxes (`torch.Tensor`): + The bounding boxes corresponding to segmentation masks + amg_crops_nms_thresh (`float`, *optional*, defaults to 0.7): + NMS threshold. + """ + keep_by_nms = batched_nms( + boxes=mask_boxes.float(), + scores=iou_scores, + idxs=torch.zeros(mask_boxes.shape[0]), + iou_threshold=amg_crops_nms_thresh, + ) + + iou_scores = iou_scores[keep_by_nms] + rle_masks = [rle_masks[i] for i in keep_by_nms] + mask_boxes = mask_boxes[keep_by_nms] + masks = [_rle_to_mask(rle) for rle in rle_masks] + + return masks, iou_scores, rle_masks, mask_boxes + + +def _postprocess_for_mg_tf(rle_masks, iou_scores, mask_boxes, amg_crops_nms_thresh=0.7): + """ + Perform NMS (Non Maximum Suppression) on the outputs. + + Args: + rle_masks (`tf.Tensor`): + binary masks in the RLE format + iou_scores (`tf.Tensor` of shape (nb_masks, 1)): + iou_scores predicted by the model + mask_boxes (`tf.Tensor`): + The bounding boxes corresponding to segmentation masks + amg_crops_nms_thresh (`float`, *optional*, defaults to 0.7): + NMS threshold. + """ + keep_by_nms = tf.image.combined_non_max_suppression( + boxes=mask_boxes.float(), + scores=iou_scores, + idxs=torch.zeros(mask_boxes.shape[0]), + iou_threshold=amg_crops_nms_thresh, + ) + + iou_scores = iou_scores[keep_by_nms] + rle_masks = [rle_masks[i] for i in keep_by_nms] + mask_boxes = mask_boxes[keep_by_nms] + masks = [_rle_to_mask(rle) for rle in rle_masks] + + return masks, iou_scores, rle_masks, mask_boxes + + +__all__ = ["SamImageProcessor"] diff --git a/docs/transformers/src/transformers/models/sam/modeling_sam.py b/docs/transformers/src/transformers/models/sam/modeling_sam.py new file mode 100644 index 0000000000000000000000000000000000000000..549eb8a317ed9c034952cf5140c2b14e089f41e8 --- /dev/null +++ b/docs/transformers/src/transformers/models/sam/modeling_sam.py @@ -0,0 +1,1579 @@ +# coding=utf-8 +# Copyright 2023 The Meta AI Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch SAM model.""" + +import collections +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import Tensor, nn + +from ...activations import ACT2FN +from ...modeling_outputs import BaseModelOutput +from ...modeling_utils import PreTrainedModel +from ...utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + can_return_tuple, + logging, + replace_return_docstrings, +) +from .configuration_sam import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "SamConfig" +_CHECKPOINT_FOR_DOC = "facebook/sam-vit-huge" + + +@dataclass +class SamVisionEncoderOutput(ModelOutput): + """ + Base class for sam vision model's outputs that also contains image embeddings obtained by applying the projection + layer to the pooler_output. + + Args: + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The image embeddings obtained by applying the projection layer to the pooler_output. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + image_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class SamImageSegmentationOutput(ModelOutput): + """ + Base class for Segment-Anything model's output + + Args: + iou_scores (`torch.FloatTensor` of shape `(batch_size, num_masks)`): + The iou scores of the predicted masks. + pred_masks (`torch.FloatTensor` of shape `(batch_size, num_masks, height, width)`): + The predicted low resolutions masks. Needs to be post-processed by the processor + vision_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the vision model at the output of each layer plus the optional initial embedding outputs. + vision_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + mask_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + iou_scores: Optional[torch.FloatTensor] = None + pred_masks: Optional[torch.FloatTensor] = None + vision_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + vision_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + mask_decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +class SamPatchEmbeddings(nn.Module): + """ + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__(self, config): + super().__init__() + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.hidden_size + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + + self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) + + def forward(self, pixel_values): + batch_size, num_channels, height, width = pixel_values.shape + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + if height != self.image_size[0] or width != self.image_size[1]: + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})." + ) + embeddings = self.projection(pixel_values).permute(0, 2, 3, 1) + return embeddings + + +class SamMLPBlock(nn.Module): + def __init__(self, config): + super().__init__() + self.lin1 = nn.Linear(config.hidden_size, config.mlp_dim) + self.lin2 = nn.Linear(config.mlp_dim, config.hidden_size) + self.act = ACT2FN[config.hidden_act] + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.lin1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.lin2(hidden_states) + return hidden_states + + +# Copied from transformers.models.convnext.modeling_convnext.ConvNextLayerNorm with ConvNext->Sam +class SamLayerNorm(nn.Module): + r"""LayerNorm that supports two data formats: channels_last (default) or channels_first. + The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height, + width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width). + """ + + def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): + super().__init__() + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.eps = eps + self.data_format = data_format + if self.data_format not in ["channels_last", "channels_first"]: + raise NotImplementedError(f"Unsupported data format: {self.data_format}") + self.normalized_shape = (normalized_shape,) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.data_format == "channels_last": + x = torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + elif self.data_format == "channels_first": + input_dtype = x.dtype + x = x.float() + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = x.to(dtype=input_dtype) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + + +class SamAttention(nn.Module): + """ + SAM's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and + values. + """ + + def __init__(self, config, downsample_rate=None): + super().__init__() + self.hidden_size = config.hidden_size + + downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate + + self.internal_dim = config.hidden_size // downsample_rate + self.num_attention_heads = config.num_attention_heads + if self.internal_dim % config.num_attention_heads != 0: + raise ValueError("num_attention_heads must divide hidden_size.") + + self.q_proj = nn.Linear(self.hidden_size, self.internal_dim) + self.k_proj = nn.Linear(self.hidden_size, self.internal_dim) + self.v_proj = nn.Linear(self.hidden_size, self.internal_dim) + self.out_proj = nn.Linear(self.internal_dim, self.hidden_size) + + def _separate_heads(self, hidden_states: Tensor, num_attention_heads: int) -> Tensor: + batch, point_batch_size, n_tokens, channel = hidden_states.shape + c_per_head = channel // num_attention_heads + hidden_states = hidden_states.reshape(batch * point_batch_size, n_tokens, num_attention_heads, c_per_head) + return hidden_states.transpose(1, 2) + + def _recombine_heads(self, hidden_states: Tensor, point_batch_size: int) -> Tensor: + batch, n_heads, n_tokens, c_per_head = hidden_states.shape + hidden_states = hidden_states.transpose(1, 2) + return hidden_states.reshape(batch // point_batch_size, point_batch_size, n_tokens, n_heads * c_per_head) + + def forward( + self, query: Tensor, key: Tensor, value: Tensor, attention_similarity: Optional[Tensor] = None + ) -> Tensor: + # Input projections + query = self.q_proj(query) + key = self.k_proj(key) + value = self.v_proj(value) + + point_batch_size = query.shape[1] + # Separate into heads + query = self._separate_heads(query, self.num_attention_heads) + key = self._separate_heads(key, self.num_attention_heads) + value = self._separate_heads(value, self.num_attention_heads) + + # SamAttention + _, _, _, c_per_head = query.shape + attn = query @ key.permute(0, 1, 3, 2) # batch_size * point_batch_size x N_heads x N_tokens x N_tokens + attn = attn / (c_per_head**0.5) + attn = torch.softmax(attn, dim=-1) + + if attention_similarity is not None: + attn = attn + attention_similarity + attn = torch.softmax(attn, dim=-1) + + # Get output + out = attn @ value + out = self._recombine_heads(out, point_batch_size) + out = self.out_proj(out) + + return out + + +class SamSdpaAttention(SamAttention): + """ + SAM's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and + values. Using SDPA instead of the default attention. + """ + + def __init__(self, config, downsample_rate=None): + super().__init__(config, downsample_rate) + + def forward( + self, query: Tensor, key: Tensor, value: Tensor, attention_similarity: Optional[Tensor] = None + ) -> Tensor: + # Input projections + query = self.q_proj(query) + key = self.k_proj(key) + value = self.v_proj(value) + + point_batch_size = query.shape[1] + # Separate into heads + query = self._separate_heads(query, self.num_attention_heads) + key = self._separate_heads(key, self.num_attention_heads) + value = self._separate_heads(value, self.num_attention_heads) + + # Scaled dot product attention + attn_mask = None + if attention_similarity is not None: + attn_mask = attention_similarity.unsqueeze(1).expand(-1, self.num_attention_heads, -1, -1) + + out = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask) + + # Get output + out = self._recombine_heads(out, point_batch_size) + out = self.out_proj(out) + + return out + + +SAM_ATTENTION_CLASSES = { + "eager": SamAttention, + "sdpa": SamSdpaAttention, +} + + +class SamTwoWayAttentionBlock(nn.Module): + def __init__(self, config, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False): + """ + A transformer block with four layers: + (1) self-attention of sparse inputs (2) cross attention of sparse inputs -> dense inputs (3) mlp block on + sparse inputs (4) cross attention of dense inputs -> sparse inputs + + Arguments: + config (`SamMaskDecoderConfig`): + The configuration file used to instantiate the block + attention_downsample_rate (*optionalk*, int, defaults to 2): + The downsample ratio of the block used to reduce the inner dim of the attention. + skip_first_layer_pe (*optional*, bool, defaults to `False`): + Whether or not to skip the addition of the query_point_embedding on the first layer. + """ + super().__init__() + + self.hidden_size = config.hidden_size + self.layer_norm_eps = config.layer_norm_eps + + self.self_attn = SAM_ATTENTION_CLASSES[config._attn_implementation](config, downsample_rate=1) + self.layer_norm1 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) + + self.cross_attn_token_to_image = SAM_ATTENTION_CLASSES[config._attn_implementation]( + config, downsample_rate=attention_downsample_rate + ) + self.layer_norm2 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) + + self.mlp = SamMLPBlock(config) + self.layer_norm3 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) + + self.layer_norm4 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) + self.cross_attn_image_to_token = SAM_ATTENTION_CLASSES[config._attn_implementation]( + config, downsample_rate=attention_downsample_rate + ) + self.skip_first_layer_pe = skip_first_layer_pe + + def forward( + self, + queries: Tensor, + keys: Tensor, + query_point_embedding: Tensor, + key_point_embedding: Tensor, + attention_similarity: Tensor, + output_attentions: bool = False, + ): + # Self attention block + if self.skip_first_layer_pe: + queries = self.self_attn(query=queries, key=queries, value=queries) + else: + query = queries + query_point_embedding + attn_out = self.self_attn(query=query, key=query, value=queries) + queries = queries + attn_out + queries = self.layer_norm1(queries) + + # Cross attention block, tokens attending to image embedding + query = queries + query_point_embedding + key = keys + key_point_embedding + + attn_out = self.cross_attn_token_to_image( + query=query, key=key, value=keys, attention_similarity=attention_similarity + ) + queries = queries + attn_out + + queries = self.layer_norm2(queries) + + # MLP block + mlp_out = self.mlp(queries) + queries = queries + mlp_out + queries = self.layer_norm3(queries) + + # Cross attention block, image embedding attending to tokens + query = queries + query_point_embedding + key = keys + key_point_embedding + + attn_out = self.cross_attn_image_to_token(query=key, key=query, value=queries) + keys = keys + attn_out + + keys = self.layer_norm4(keys) + + outputs = (queries, keys) + + if output_attentions: + outputs = outputs + (attn_out,) + else: + outputs = outputs + (None,) + + return outputs + + +class SamTwoWayTransformer(nn.Module): + def __init__(self, config: SamMaskDecoderConfig): + super().__init__() + self.config = config + + self.num_hidden_layers = config.num_hidden_layers + self.layers = nn.ModuleList() + + for i in range(self.num_hidden_layers): + self.layers.append(SamTwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0))) + + self.final_attn_token_to_image = SAM_ATTENTION_CLASSES[config._attn_implementation](config) + self.layer_norm_final_attn = nn.LayerNorm(config.hidden_size) + + def forward( + self, + point_embeddings: Tensor, + image_embeddings: Tensor, + image_positional_embeddings: Tensor, + attention_similarity: Tensor, + target_embedding=None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + all_attentions = () + + if image_embeddings is None: + raise ValueError("You have to specify an image_embedding") + + image_embeddings = image_embeddings.flatten(2).permute(0, 2, 1).unsqueeze(1) + image_positional_embeddings = image_positional_embeddings.flatten(2).permute(0, 2, 1).unsqueeze(1) + + # Prepare queries + queries = point_embeddings + keys = image_embeddings + + # Apply transformer blocks and final layernorm + for layer in self.layers: + if target_embedding is not None: + queries += target_embedding + + queries, keys, attention_outputs = layer( + queries=queries, + keys=keys, + query_point_embedding=point_embeddings, + key_point_embedding=image_positional_embeddings, + attention_similarity=attention_similarity, + output_attentions=output_attentions, + ) + + if output_attentions: + all_attentions = all_attentions + (attention_outputs,) + + # Apply the final attenion layer from the points to the image + query = queries + point_embeddings + key = keys + image_positional_embeddings + + attn_out = self.final_attn_token_to_image(query=query, key=key, value=keys) + + queries = queries + attn_out + queries = self.layer_norm_final_attn(queries) + return queries, keys, all_attentions + + +class SamFeedForward(nn.Module): + def __init__( + self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, sigmoid_output: bool = False + ): + super().__init__() + self.num_layers = num_layers + self.activation = nn.ReLU() + self.proj_in = nn.Linear(input_dim, hidden_dim) + self.proj_out = nn.Linear(hidden_dim, output_dim) + self.layers = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(num_layers - 2)]) + self.sigmoid_output = sigmoid_output + + def forward(self, hidden_states): + hidden_states = self.proj_in(hidden_states) + hidden_states = self.activation(hidden_states) + for layer in self.layers: + hidden_states = self.activation(layer(hidden_states)) + + hidden_states = self.proj_out(hidden_states) + if self.sigmoid_output: + hidden_states = F.sigmoid(hidden_states) + return hidden_states + + +class SamMaskDecoder(nn.Module): + def __init__(self, config: SamMaskDecoderConfig): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + + self.num_multimask_outputs = config.num_multimask_outputs + self.num_mask_tokens = config.num_multimask_outputs + 1 + + self.iou_token = nn.Embedding(1, self.hidden_size) + self.mask_tokens = nn.Embedding(self.num_mask_tokens, self.hidden_size) + + self.transformer = SamTwoWayTransformer(config) + + # should we create a new class for this? + self.upscale_conv1 = nn.ConvTranspose2d(self.hidden_size, self.hidden_size // 4, kernel_size=2, stride=2) + self.upscale_conv2 = nn.ConvTranspose2d(self.hidden_size // 4, self.hidden_size // 8, kernel_size=2, stride=2) + self.upscale_layer_norm = SamLayerNorm(self.hidden_size // 4, data_format="channels_first") + self.activation = nn.GELU() + + mlps_list = [] + for _ in range(self.num_mask_tokens): + mlps_list += [SamFeedForward(self.hidden_size, self.hidden_size, self.hidden_size // 8, 3)] + self.output_hypernetworks_mlps = nn.ModuleList(mlps_list) + + self.iou_prediction_head = SamFeedForward( + self.hidden_size, config.iou_head_hidden_dim, self.num_mask_tokens, config.iou_head_depth + ) + + def forward( + self, + image_embeddings: torch.Tensor, + image_positional_embeddings: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + multimask_output: bool, + output_attentions: Optional[bool] = None, + attention_similarity: Optional[torch.Tensor] = None, + target_embedding: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Predict masks given image and prompt embeddings. + + Args: + image_embeddings (`torch.Tensor`): + the embeddings from the image encoder + image_positional_embedding (`torch.Tensor`): + positional encoding with the shape of image_embeddings + sparse_prompt_embeddings (`torch.Tensor`): + The embeddings of the points and boxes + dense_prompt_embeddings (`torch.Tensor`): + the embeddings of the mask inputs + multimask_output (bool): + Whether to return multiple masks or a single mask. + output_attentions (bool, *optional*): + Whether or not to return the attentions tensors of all attention layers. + """ + batch_size, num_channels, height, width = image_embeddings.shape + point_batch_size = sparse_prompt_embeddings.shape[1] + # Concatenate output tokens + output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) + output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1) + + if sparse_prompt_embeddings.sum().item() != 0: + tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=2) + else: + tokens = output_tokens + point_embeddings = tokens.to(self.iou_token.weight.dtype) + + # Expand per-image data in batch direction to be per-point + image_embeddings = image_embeddings + dense_prompt_embeddings + image_embeddings = image_embeddings.repeat_interleave(point_batch_size, 0) + image_positional_embeddings = image_positional_embeddings.repeat_interleave(point_batch_size, 0) + + # Run the transformer, image_positional_embedding are consumed + point_embedding, image_embeddings, attentions = self.transformer( + point_embeddings=point_embeddings, + image_embeddings=image_embeddings, + image_positional_embeddings=image_positional_embeddings, + attention_similarity=attention_similarity, + target_embedding=target_embedding, + output_attentions=output_attentions, + ) + iou_token_out = point_embedding[:, :, 0, :] + mask_tokens_out = point_embedding[:, :, 1 : (1 + self.num_mask_tokens), :] + + # Upscale mask embeddings and predict masks using the mask tokens + image_embeddings = image_embeddings.transpose(2, 3).reshape( + batch_size * point_batch_size, num_channels, height, width + ) + + upscaled_embedding = self.upscale_conv1(image_embeddings) + upscaled_embedding = self.activation(self.upscale_layer_norm(upscaled_embedding)) + upscaled_embedding = self.activation(self.upscale_conv2(upscaled_embedding)) + + hyper_in_list = [] + for i in range(self.num_mask_tokens): + current_mlp = self.output_hypernetworks_mlps[i] + hyper_in_list += [current_mlp(mask_tokens_out[:, :, i, :])] + hyper_in = torch.stack(hyper_in_list, dim=2) + + _, num_channels, height, width = upscaled_embedding.shape + upscaled_embedding = upscaled_embedding.reshape(batch_size, point_batch_size, num_channels, height * width) + masks = (hyper_in @ upscaled_embedding).reshape(batch_size, point_batch_size, -1, height, width) + + # Generate mask quality predictions + iou_pred = self.iou_prediction_head(iou_token_out) + + # Select the correct mask or masks for output + if multimask_output: + mask_slice = slice(1, None) + else: + mask_slice = slice(0, 1) + masks = masks[:, :, mask_slice, :, :] + iou_pred = iou_pred[:, :, mask_slice] + + outputs = (masks, iou_pred) + + if output_attentions: + outputs = outputs + (attentions,) + else: + outputs = outputs + (None,) + + return outputs + + +class SamPositionalEmbedding(nn.Module): + def __init__(self, config): + super().__init__() + self.scale = config.hidden_size // 2 + self.register_buffer("positional_embedding", self.scale * torch.randn((2, config.num_pos_feats))) + + def forward(self, input_coords, input_shape=None): + """Positionally encode points that are normalized to [0,1].""" + coordinates = input_coords.clone() + + if input_shape is not None: + coordinates[:, :, :, 0] = coordinates[:, :, :, 0] / input_shape[1] + coordinates[:, :, :, 1] = coordinates[:, :, :, 1] / input_shape[0] + + # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coordinates = 2 * coordinates - 1 + coordinates = coordinates.to(self.positional_embedding.dtype) + coordinates = coordinates @ self.positional_embedding + coordinates = 2 * np.pi * coordinates + # outputs d_1 x ... x d_n x channel shape + return torch.cat([torch.sin(coordinates), torch.cos(coordinates)], dim=-1) + + +class SamMaskEmbedding(nn.Module): + def __init__(self, config: SamPromptEncoderConfig): + super().__init__() + self.mask_input_channels = config.mask_input_channels // 4 + self.activation = ACT2FN[config.hidden_act] + self.conv1 = nn.Conv2d(1, self.mask_input_channels, kernel_size=2, stride=2) + self.conv2 = nn.Conv2d(self.mask_input_channels, config.mask_input_channels, kernel_size=2, stride=2) + self.conv3 = nn.Conv2d(config.mask_input_channels, config.hidden_size, kernel_size=1) + self.layer_norm1 = SamLayerNorm( + self.mask_input_channels, eps=config.layer_norm_eps, data_format="channels_first" + ) + self.layer_norm2 = SamLayerNorm( + self.mask_input_channels * 4, eps=config.layer_norm_eps, data_format="channels_first" + ) + + def forward(self, masks): + hidden_states = self.conv1(masks) + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.activation(hidden_states) + + hidden_states = self.conv2(hidden_states) + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.activation(hidden_states) + dense_embeddings = self.conv3(hidden_states) + return dense_embeddings + + +class SamPromptEncoder(nn.Module): + def __init__(self, config: SamPromptEncoderConfig): + super().__init__() + self.shared_embedding = SamPositionalEmbedding(config.vision_config) + config = config.prompt_encoder_config + self.mask_embed = SamMaskEmbedding(config) + self.no_mask_embed = nn.Embedding(1, config.hidden_size) + + self.image_embedding_size = (config.image_embedding_size, config.image_embedding_size) + self.input_image_size = config.image_size + + self.point_embed = nn.ModuleList( + [nn.Embedding(1, config.hidden_size) for i in range(config.num_point_embeddings)] + ) + self.hidden_size = config.hidden_size + self.not_a_point_embed = nn.Embedding(1, config.hidden_size) + + def _embed_points(self, points: torch.Tensor, labels: torch.Tensor, pad: bool) -> torch.Tensor: + """Embeds point prompts.""" + points = points + 0.5 # Shift to center of pixel + if pad: + target_point_shape = (points.shape[0], points.shape[1], 1, points.shape[-1]) + target_labels_shape = (points.shape[0], points.shape[1], 1) + padding_point = torch.zeros(target_point_shape, device=points.device) + padding_label = -torch.ones(target_labels_shape, device=labels.device) + points = torch.cat([points, padding_point], dim=2) + labels = torch.cat([labels, padding_label], dim=2) + input_shape = (self.input_image_size, self.input_image_size) + point_embedding = self.shared_embedding(points, input_shape) + + # torch.where and expanding the labels tensor is required by the ONNX export + point_embedding = torch.where(labels[..., None] == -1, self.not_a_point_embed.weight, point_embedding) + + # This is required for the ONNX export. The dtype, device need to be explicitely + # specificed as otherwise torch.onnx.export interprets as double + point_embedding = torch.where( + labels[..., None] != -10, + point_embedding, + torch.tensor(0.0, dtype=point_embedding.dtype, device=point_embedding.device), + ) + + point_embedding = torch.where( + (labels == 0)[:, :, :, None], + point_embedding + self.point_embed[0].weight[None, None, :, :], + point_embedding, + ) + + point_embedding = torch.where( + (labels == 1)[:, :, :, None], + point_embedding + self.point_embed[1].weight[None, None, :, :], + point_embedding, + ) + + return point_embedding + + def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: + """Embeds box prompts.""" + boxes = boxes + 0.5 # Shift to center of pixel + batch_size, nb_boxes = boxes.shape[:2] + coords = boxes.reshape(batch_size, nb_boxes, 2, 2) + input_shape = (self.input_image_size, self.input_image_size) + corner_embedding = self.shared_embedding(coords, input_shape) + corner_embedding[:, :, 0, :] += self.point_embed[2].weight + corner_embedding[:, :, 1, :] += self.point_embed[3].weight + return corner_embedding + + def forward( + self, + input_points: Optional[Tuple[torch.Tensor, torch.Tensor]], + input_labels: Optional[torch.Tensor], + input_boxes: Optional[torch.Tensor], + input_masks: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Embeds different types of prompts, returning both sparse and dense embeddings. + + Args: + points (`torch.Tensor`, *optional*): + point coordinates and labels to embed. + boxes (`torch.Tensor`, *optional*): + boxes to embed + masks (`torch.Tensor`, *optional*): + masks to embed + """ + sparse_embeddings = None + batch_size = 1 + target_device = self.shared_embedding.positional_embedding.device + if input_points is not None: + batch_size, point_batch_size = input_points.shape[:2] + if input_labels is None: + raise ValueError("If points are provided, labels must also be provided.") + point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None)) + sparse_embeddings = point_embeddings + if input_boxes is not None: + batch_size = input_boxes.shape[0] + box_embeddings = self._embed_boxes(input_boxes) + if sparse_embeddings is None: + sparse_embeddings = box_embeddings + else: + sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=2) + if input_masks is not None: + dense_embeddings = self.mask_embed(input_masks) + else: + dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( + batch_size, -1, self.image_embedding_size[0], self.image_embedding_size[1] + ) + + if sparse_embeddings is None: + sparse_embeddings = torch.zeros((batch_size, 1, 1, self.hidden_size), device=target_device) + + return sparse_embeddings, dense_embeddings + + +class SamVisionAttention(nn.Module): + """Multi-head Attention block with relative position embeddings.""" + + def __init__(self, config, window_size): + super().__init__() + input_size = ( + (config.image_size // config.patch_size, config.image_size // config.patch_size) + if window_size == 0 + else (window_size, window_size) + ) + + self.num_attention_heads = config.num_attention_heads + head_dim = config.hidden_size // config.num_attention_heads + self.scale = head_dim**-0.5 + self.dropout = config.attention_dropout + + self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.qkv_bias) + self.proj = nn.Linear(config.hidden_size, config.hidden_size) + + self.use_rel_pos = config.use_rel_pos + if self.use_rel_pos: + if input_size is None: + raise ValueError("Input size must be provided if using relative positional encoding.") + + # initialize relative positional embeddings + self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) + + def get_rel_pos(self, q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + + Args: + q_size (int): + size of the query. + k_size (int): + size of key k. + rel_pos (`torch.Tensor`): + relative position embeddings (L, channel). + + Returns: + Extracted positional embeddings according to relative positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=max_rel_dist, + mode="linear", + ) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + + def get_decomposed_rel_pos( + self, + query: torch.Tensor, + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, + q_size: Tuple[int, int], + k_size: Tuple[int, int], + ) -> torch.Tensor: + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py + + Args: + query (`torch.Tensor`): + query q in the attention layer with shape (batch_size, query_height * query_width, channel). + rel_pos_h (`torch.Tensor`): + relative position embeddings (Lh, channel) for height axis. + rel_pos_w (`torch.Tensor`): + relative position embeddings (Lw, channel) for width axis. + q_size (tuple): + spatial sequence size of query q with (query_height, query_width). + k_size (tuple): + spatial sequence size of key k with (key_height, key_width). + + Returns: + decomposed_rel_pos (`torch.Tensor`): + decomposed relative position embeddings. + """ + query_height, query_width = q_size + key_height, key_width = k_size + relative_position_height = self.get_rel_pos(query_height, key_height, rel_pos_h) + relative_position_width = self.get_rel_pos(query_width, key_width, rel_pos_w) + + batch_size, _, dim = query.shape + reshaped_query = query.reshape(batch_size, query_height, query_width, dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height) + rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width) + + decomposed_rel_pos = rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] + + return decomposed_rel_pos + + def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor: + batch_size, height, width, _ = hidden_states.shape + # qkv with shape (3, batch_size, nHead, height * width, channel) + qkv = ( + self.qkv(hidden_states) + .reshape(batch_size, height * width, 3, self.num_attention_heads, -1) + .permute(2, 0, 3, 1, 4) + ) + # q, k, v with shape (batch_size * nHead, height * width, channel) + query, key, value = qkv.reshape(3, batch_size * self.num_attention_heads, height * width, -1).unbind(0) + + attn_weights = (query * self.scale) @ key.transpose(-2, -1) + + if self.use_rel_pos: + decomposed_rel_pos = self.get_decomposed_rel_pos( + query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width) + ) + decomposed_rel_pos = decomposed_rel_pos.reshape_as(attn_weights) + attn_weights = attn_weights + decomposed_rel_pos + + attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype) + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1) + attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1) + + attn_output = self.proj(attn_output) + + if output_attentions: + outputs = (attn_output, attn_weights) + else: + outputs = (attn_output, None) + + return outputs + + +class SamVisionSdpaAttention(SamVisionAttention): + """ + Multi-head Attention block with relative position embeddings. + Using SDPA instead of the default attention. + """ + + def __init__(self, config, window_size): + super().__init__(config, window_size) + + def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor: + if output_attentions: + logger.warning_once( + "`SamVisionSdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support " + "`output_attentions=True`. Falling back to the manual attention implementation, but " + "specifying the manual implementation will be required from Transformers version v5.0.0 onwards. " + 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + output_attentions=output_attentions, + ) + + batch_size, height, width, _ = hidden_states.shape + # qkv with shape (3, B, nHead, H * W, C) + qkv = ( + self.qkv(hidden_states) + .reshape(batch_size, height * width, 3, self.num_attention_heads, -1) + .permute(2, 0, 3, 1, 4) + ) + # q, k, v with shape (B * nHead, H * W, C) + query, key, value = qkv.reshape(3, batch_size * self.num_attention_heads, height * width, -1).unbind(0) + + attn_bias = None + if self.use_rel_pos: + decomposed_rel_pos = self.get_decomposed_rel_pos( + query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width) + ) + decomposed_rel_pos = decomposed_rel_pos.reshape( + batch_size, self.num_attention_heads, height * width, height * width + ) + attn_bias = decomposed_rel_pos + + query = query.view(batch_size, self.num_attention_heads, height * width, -1) + key = key.view(batch_size, self.num_attention_heads, height * width, -1) + value = value.view(batch_size, self.num_attention_heads, height * width, -1) + + attn_output = torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=attn_bias) + + attn_output = ( + attn_output.view(batch_size, self.num_attention_heads, height, width, -1) + .permute(0, 2, 3, 1, 4) + .reshape(batch_size, height, width, -1) + ) + + attn_output = self.proj(attn_output) + + return attn_output, None + + +SAM_VISION_ATTENTION_CLASSES = { + "eager": SamVisionAttention, + "sdpa": SamVisionSdpaAttention, +} + + +class SamVisionLayer(nn.Module): + def __init__(self, config, window_size): + super().__init__() + self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.attn = SAM_VISION_ATTENTION_CLASSES[config._attn_implementation](config, window_size) + self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.mlp = SamMLPBlock(config) + self.window_size = window_size + + def window_partition(self, hidden_states: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: + """ + Args: + Partition into non-overlapping windows with padding if needed. + hidden_states (tensor): input tokens with [batch_size, height, width, channel]. window_size (int): window + size. + + Returns: + windows: windows after partition with [batch_size * num_windows, window_size, window_size, channel]. + (pad_height, pad_width): padded height and width before partition + """ + batch_size, height, width, channel = hidden_states.shape + + pad_h = (window_size - height % window_size) % window_size + pad_w = (window_size - width % window_size) % window_size + hidden_states = F.pad(hidden_states, (0, 0, 0, pad_w, 0, pad_h)) + pad_height, pad_width = height + pad_h, width + pad_w + + hidden_states = hidden_states.reshape( + batch_size, pad_height // window_size, window_size, pad_width // window_size, window_size, channel + ) + windows = hidden_states.permute(0, 1, 3, 2, 4, 5).contiguous().reshape(-1, window_size, window_size, channel) + return windows, (pad_height, pad_width) + + def window_unpartition( + self, windows: torch.Tensor, window_size: int, padding_shape: Tuple[int, int], original_shape: Tuple[int, int] + ) -> torch.Tensor: + """ + Args: + Window unpartition into original sequences and removing padding. + hidden_states (tensor): + input tokens with [batch_size * num_windows, window_size, window_size, channel]. + window_size (int): + window size. + padding_shape (Tuple): + padded height and width (pad_height, pad_width). + original_shape (Tuple): original height and width (height, width) before padding. + + Returns: + hidden_states: unpartitioned sequences with [batch_size, height, width, channel]. + """ + pad_height, pad_width = padding_shape + height, width = original_shape + batch_size = windows.shape[0] // (pad_height * pad_width // window_size // window_size) + hidden_states = windows.reshape( + batch_size, pad_height // window_size, pad_width // window_size, window_size, window_size, -1 + ) + hidden_states = ( + hidden_states.permute(0, 1, 3, 2, 4, 5).contiguous().reshape(batch_size, pad_height, pad_width, -1) + ) + + hidden_states = hidden_states[:, :height, :width, :].contiguous() + return hidden_states + + def forward( + self, + hidden_states: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor]: + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + # Window partition + if self.window_size > 0: + height, width = hidden_states.shape[1], hidden_states.shape[2] + hidden_states, padding_shape = self.window_partition(hidden_states, self.window_size) + + hidden_states, attn_weights = self.attn( + hidden_states=hidden_states, + output_attentions=output_attentions, + ) + # Reverse window partition + if self.window_size > 0: + hidden_states = self.window_unpartition(hidden_states, self.window_size, padding_shape, (height, width)) + + hidden_states = residual + hidden_states + layernorm_output = self.layer_norm2(hidden_states) + hidden_states = hidden_states + self.mlp(layernorm_output) + + outputs = (hidden_states,) + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class SamVisionNeck(nn.Module): + def __init__(self, config: SamVisionConfig): + super().__init__() + self.config = config + + self.conv1 = nn.Conv2d(config.hidden_size, config.output_channels, kernel_size=1, bias=False) + self.layer_norm1 = SamLayerNorm(config.output_channels, data_format="channels_first") + self.conv2 = nn.Conv2d(config.output_channels, config.output_channels, kernel_size=3, padding=1, bias=False) + self.layer_norm2 = SamLayerNorm(config.output_channels, data_format="channels_first") + + def forward(self, hidden_states): + hidden_states = hidden_states.permute(0, 3, 1, 2) + hidden_states = self.conv1(hidden_states) + hidden_states = self.layer_norm1(hidden_states) + + hidden_states = self.conv2(hidden_states) + hidden_states = self.layer_norm2(hidden_states) + return hidden_states + + +class SamVisionEncoder(nn.Module): + def __init__(self, config: SamVisionConfig): + super().__init__() + self.config = config + self.image_size = config.image_size + + self.patch_embed = SamPatchEmbeddings(config) + + self.pos_embed = None + if config.use_abs_pos: + # Initialize absolute positional embedding with pretrain image size. + self.pos_embed = nn.Parameter( + torch.zeros( + 1, + config.image_size // config.patch_size, + config.image_size // config.patch_size, + config.hidden_size, + ) + ) + + self.layers = nn.ModuleList() + for i in range(config.num_hidden_layers): + layer = SamVisionLayer( + config, + window_size=config.window_size if i not in config.global_attn_indexes else 0, + ) + self.layers.append(layer) + + self.neck = SamVisionNeck(config) + + self.gradient_checkpointing = False + + def get_input_embeddings(self): + return self.patch_embed + + @can_return_tuple + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> SamVisionEncoderOutput: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + hidden_states = self.patch_embed(pixel_values) + if self.pos_embed is not None: + hidden_states = hidden_states + self.pos_embed + + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, layer_module in enumerate(self.layers): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.__call__, + hidden_states, + ) + else: + layer_outputs = layer_module(hidden_states, output_attentions=output_attentions) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + hidden_states = self.neck(hidden_states) + + return SamVisionEncoderOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class SamPreTrainedModel(PreTrainedModel): + config_class = SamConfig + base_model_prefix = "sam" + main_input_name = "pixel_values" + _no_split_modules = ["SamVisionAttention"] + supports_gradient_checkpointing = True + _supports_sdpa = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, (SamLayerNorm, nn.LayerNorm)): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + elif isinstance(module, SamVisionAttention): + if module.use_rel_pos: + module.rel_pos_h.data.zero_() + module.rel_pos_w.data.zero_() + + +SAM_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`SamConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +SAM_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`SamProcessor`]. See [`SamProcessor.__call__`] for + details. + input_points (`torch.FloatTensor` of shape `(batch_size, num_points, 2)`): + Input 2D spatial points, this is used by the prompt encoder to encode the prompt. Generally yields to much + better results. The points can be obtained by passing a list of list of list to the processor that will + create corresponding `torch` tensors of dimension 4. The first dimension is the image batch size, the + second dimension is the point batch size (i.e. how many segmentation masks do we want the model to predict + per input point), the third dimension is the number of points per segmentation mask (it is possible to pass + multiple points for a single mask), and the last dimension is the x (vertical) and y (horizontal) + coordinates of the point. If a different number of points is passed either for each image, or for each + mask, the processor will create "PAD" points that will correspond to the (0, 0) coordinate, and the + computation of the embedding will be skipped for these points using the labels. + input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points)`): + Input labels for the points, this is used by the prompt encoder to encode the prompt. According to the + official implementation, there are 3 types of labels + + - `1`: the point is a point that contains the object of interest + - `0`: the point is a point that does not contain the object of interest + - `-1`: the point corresponds to the background + + We added the label: + + - `-10`: the point is a padding point, thus should be ignored by the prompt encoder + + The padding labels should be automatically done by the processor. + input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes, 4)`): + Input boxes for the points, this is used by the prompt encoder to encode the prompt. Generally yields to + much better generated masks. The boxes can be obtained by passing a list of list of list to the processor, + that will generate a `torch` tensor, with each dimension corresponding respectively to the image batch + size, the number of boxes per image and the coordinates of the top left and botton right point of the box. + In the order (`x1`, `y1`, `x2`, `y2`): + + - `x1`: the x coordinate of the top left point of the input box + - `y1`: the y coordinate of the top left point of the input box + - `x2`: the x coordinate of the bottom right point of the input box + - `y2`: the y coordinate of the bottom right point of the input box + + input_masks (`torch.FloatTensor` of shape `(batch_size, image_size, image_size)`): + SAM model also accepts segmentation masks as input. The mask will be embedded by the prompt encoder to + generate a corresponding embedding, that will be fed later on to the mask decoder. These masks needs to be + manually fed by the user, and they need to be of shape (`batch_size`, `image_size`, `image_size`). + + image_embeddings (`torch.FloatTensor` of shape `(batch_size, output_channels, window_size, window_size)`): + Image embeddings, this is used by the mask decder to generate masks and iou scores. For more memory + efficient computation, users can first retrieve the image embeddings using the `get_image_embeddings` + method, and then feed them to the `forward` method instead of feeding the `pixel_values`. + multimask_output (`bool`, *optional*): + In the original implementation and paper, the model always outputs 3 masks per image (or per point / per + bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the + "best" mask, by specifying `multimask_output=False`. + attention_similarity (`torch.FloatTensor`, *optional*): + Attention similarity tensor, to be provided to the mask decoder for target-guided attention in case the + model is used for personalization as introduced in [PerSAM](https://arxiv.org/abs/2305.03048). + target_embedding (`torch.FloatTensor`, *optional*): + Embedding of the target concept, to be provided to the mask decoder for target-semantic prompting in case + the model is used for personalization as introduced in [PerSAM](https://arxiv.org/abs/2305.03048). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +SAM_VISION_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`SamProcessor`]. See [`SamProcessor.__call__`] for + details. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + """The vision model from Sam without any head or projection on top.""", + SAM_START_DOCSTRING, +) +class SamVisionModel(SamPreTrainedModel): + config_class = SamVisionConfig + main_input_name = "pixel_values" + + def __init__(self, config: SamVisionConfig): + super().__init__(config) + self.vision_encoder = SamVisionEncoder(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.vision_encoder.patch_embed + + @add_start_docstrings_to_model_forward(SAM_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=SamVisionEncoderOutput, config_class=SamVisionConfig) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SamVisionEncoderOutput]: + r""" + Returns: + + """ + return self.vision_encoder( + pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +@add_start_docstrings( + "Segment Anything Model (SAM) for generating segmentation masks, given an input image and ", + " optional 2D location and bounding boxes.", + SAM_START_DOCSTRING, +) +class SamModel(SamPreTrainedModel): + _tied_weights_keys = ["prompt_encoder.shared_embedding.positional_embedding"] + # need to be ignored, as it's a buffer and will not be correctly detected as tied weight + _keys_to_ignore_on_load_missing = ["prompt_encoder.shared_embedding.positional_embedding"] + + def __init__(self, config): + super().__init__(config) + self.shared_image_embedding = SamPositionalEmbedding(config.vision_config) + + self.vision_encoder = SamVisionEncoder(config.vision_config) + self.prompt_encoder = SamPromptEncoder(config) + self.mask_decoder = SamMaskDecoder(config.mask_decoder_config) + + self.post_init() + + def _tie_weights(self): + self.prompt_encoder.shared_embedding.positional_embedding.data = ( + self.shared_image_embedding.positional_embedding.data + ) + + def get_input_embeddings(self): + return self.vision_encoder.get_input_embeddings() + + def get_image_wide_positional_embeddings(self): + size = self.config.prompt_encoder_config.image_embedding_size + target_device = self.shared_image_embedding.positional_embedding.device + target_dtype = self.shared_image_embedding.positional_embedding.dtype + grid = torch.ones((size, size), device=target_device, dtype=target_dtype) + y_embed = grid.cumsum(dim=0) - 0.5 + x_embed = grid.cumsum(dim=1) - 0.5 + y_embed = y_embed / size + x_embed = x_embed / size + + positional_embedding = self.shared_image_embedding(torch.stack([x_embed, y_embed], dim=-1)) + return positional_embedding.permute(2, 0, 1).unsqueeze(0) # channel x height x width + + @torch.no_grad() + def get_image_embeddings( + self, + pixel_values, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ): + r""" + Returns the image embeddings by passing the pixel values through the vision encoder. + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Input pixel values + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. + """ + vision_output = self.vision_encoder( + pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + image_embeddings = vision_output[0] + return image_embeddings + + @torch.no_grad() + def get_prompt_embeddings( + self, + input_points: Optional[torch.FloatTensor] = None, + input_labels: Optional[torch.LongTensor] = None, + input_boxes: Optional[torch.FloatTensor] = None, + input_masks: Optional[torch.LongTensor] = None, + ): + r""" + Returns the prompt embeddings by passing the input points, labels, boxes and masks through the prompt encoder. + + Args: + input_points (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_points_per_image, 2)`): + Optional input points for the prompt encoder. The padding of the point is automatically done by the + processor. `point_batch_size` refers to the number of masks that we want the model to predict per + point. The model will output `point_batch_size` times 3 masks in total. + input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points_per_image)`): + Optional input labels for the prompt encoder. The padding of the labels is automatically done by the + processor, or can be fed by the user. + input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes_per_image, 4)`): + Optional input boxes for the prompt encoder. The padding of the boxes is automatically done by the + processor. users can also pass manually the input boxes. + input_masks (`torch.LongTensor` of shape `(batch_size, image_size, image_size)`): + Optional input masks for the prompt encoder. + """ + prompt_output = self.prompt_encoder( + input_points=input_points, + input_labels=input_labels, + input_boxes=input_boxes, + input_masks=input_masks, + ) + return prompt_output + + @can_return_tuple + @add_start_docstrings_to_model_forward(SAM_INPUTS_DOCSTRING) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + input_points: Optional[torch.FloatTensor] = None, + input_labels: Optional[torch.LongTensor] = None, + input_boxes: Optional[torch.FloatTensor] = None, + input_masks: Optional[torch.LongTensor] = None, + image_embeddings: Optional[torch.FloatTensor] = None, + multimask_output: bool = True, + attention_similarity: Optional[torch.FloatTensor] = None, + target_embedding: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + **kwargs, + ) -> SamImageSegmentationOutput: + r""" + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoModel, AutoProcessor + + >>> model = AutoModel.from_pretrained("facebook/sam-vit-base") + >>> processor = AutoProcessor.from_pretrained("facebook/sam-vit-base") + + >>> img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/sam-car.png" + >>> raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB") + >>> input_points = [[[400, 650]]] # 2D location of a window on the car + >>> inputs = processor(images=raw_image, input_points=input_points, return_tensors="pt") + + >>> # Get segmentation mask + >>> outputs = model(**inputs) + + >>> # Postprocess masks + >>> masks = processor.post_process_masks( + ... outputs.pred_masks, inputs["original_sizes"], inputs["reshaped_input_sizes"] + ... ) + ``` + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + if pixel_values is None and image_embeddings is None: + raise ValueError("Either pixel_values or image_embeddings must be provided.") + + if pixel_values is not None and image_embeddings is not None: + raise ValueError("Only one of pixel_values and image_embeddings can be provided.") + + if input_points is not None and len(input_points.shape) != 4: + raise ValueError( + "The input_points must be a 4D tensor. Of shape `batch_size`, `point_batch_size`, `nb_points_per_image`, `2`.", + " got {}.".format(input_points.shape), + ) + if input_boxes is not None and len(input_boxes.shape) != 3: + raise ValueError( + "The input_points must be a 3D tensor. Of shape `batch_size`, `nb_boxes`, `4`.", + " got {}.".format(input_boxes.shape), + ) + if input_points is not None and input_boxes is not None: + point_batch_size = input_points.shape[1] + box_batch_size = input_boxes.shape[1] + if point_batch_size != box_batch_size: + raise ValueError( + "You should provide as many bounding boxes as input points per box. Got {} and {}.".format( + point_batch_size, box_batch_size + ) + ) + + image_positional_embeddings = self.get_image_wide_positional_embeddings() + # repeat with batch size + batch_size = pixel_values.shape[0] if pixel_values is not None else image_embeddings.shape[0] + image_positional_embeddings = image_positional_embeddings.repeat(batch_size, 1, 1, 1) + + vision_attentions = None + vision_hidden_states = None + + if pixel_values is not None: + vision_outputs: SamVisionEncoderOutput = self.vision_encoder( + pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + image_embeddings = vision_outputs.last_hidden_state + + if output_hidden_states: + vision_hidden_states = vision_outputs.hidden_states + if output_attentions: + vision_attentions = vision_outputs.attentions + + if input_points is not None and input_labels is None: + input_labels = torch.ones_like(input_points[:, :, :, 0], dtype=torch.int, device=input_points.device) + + if input_points is not None and image_embeddings.shape[0] != input_points.shape[0]: + raise ValueError( + "The batch size of the image embeddings and the input points must be the same. ", + "Got {} and {} respectively.".format(image_embeddings.shape[0], input_points.shape[0]), + " if you want to pass multiple points for the same image, make sure that you passed ", + " input_points of shape (batch_size, point_batch_size, num_points_per_image, 3) and ", + " input_labels of shape (batch_size, point_batch_size, num_points_per_image)", + ) + + sparse_embeddings, dense_embeddings = self.prompt_encoder( + input_points=input_points, + input_labels=input_labels, + input_boxes=input_boxes, + input_masks=input_masks, + ) + + low_res_masks, iou_predictions, mask_decoder_attentions = self.mask_decoder( + image_embeddings=image_embeddings, + image_positional_embeddings=image_positional_embeddings, + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + attention_similarity=attention_similarity, + target_embedding=target_embedding, + output_attentions=output_attentions, + ) + + return SamImageSegmentationOutput( + iou_scores=iou_predictions, + pred_masks=low_res_masks, + vision_hidden_states=vision_hidden_states, + vision_attentions=vision_attentions, + mask_decoder_attentions=mask_decoder_attentions, + ) + + +__all__ = ["SamVisionModel", "SamModel", "SamPreTrainedModel"] diff --git a/docs/transformers/src/transformers/models/sam/modeling_tf_sam.py b/docs/transformers/src/transformers/models/sam/modeling_tf_sam.py new file mode 100644 index 0000000000000000000000000000000000000000..cbe06e05424a3e59accda6d14b91af447321499d --- /dev/null +++ b/docs/transformers/src/transformers/models/sam/modeling_tf_sam.py @@ -0,0 +1,1726 @@ +# coding=utf-8 +# Copyright 2023 The Meta AI Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +TensorFlow SAM model. This file was mostly generated by auto-translation from the PyTorch original. In the event of a +discrepancy, the original file should be regarded as the 'reference' version. +""" + +from __future__ import annotations + +import collections +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import tensorflow as tf + +from ...activations_tf import ACT2FN +from ...modeling_tf_outputs import TFBaseModelOutput +from ...modeling_tf_utils import TFModelInputType, TFPreTrainedModel, keras, shape_list, unpack_inputs +from ...tf_utils import flatten, functional_layernorm +from ...utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from .configuration_sam import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "SamConfig" +_CHECKPOINT_FOR_DOC = "facebook/sam-vit-huge" + + +@dataclass +class TFSamVisionEncoderOutput(ModelOutput): + """ + Base class for sam vision model's outputs that also contains image embeddings obtained by applying the projection + layer to the pooler_output. + + Args: + image_embeds (`tf.Tensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The image embeddings obtained by applying the projection layer to the pooler_output. + last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for + the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + image_embeds: tf.Tensor | None = None + last_hidden_state: Optional[tf.Tensor] = None + hidden_states: Tuple[tf.Tensor, ...] | None = None + attentions: Tuple[tf.Tensor, ...] | None = None + + +@dataclass +class TFSamImageSegmentationOutput(ModelOutput): + """ + Base class for Segment-Anything model's output + + Args: + iou_scores (`tf.Tensor` of shape `(batch_size, num_masks)`): + The iou scores of the predicted masks. + pred_masks (`tf.Tensor` of shape `(batch_size, num_masks, height, width)`): + The predicted low resolutions masks. Needs to be post-processed by the processor + vision_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for + the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the vision model at the output of each layer plus the optional initial embedding outputs. + vision_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + mask_decoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + iou_scores: Optional[tf.Tensor] = None + pred_masks: Optional[tf.Tensor] = None + vision_hidden_states: Tuple[tf.Tensor, ...] | None = None + vision_attentions: Tuple[tf.Tensor, ...] | None = None + mask_decoder_attentions: Tuple[tf.Tensor, ...] | None = None + + +class TFSamPatchEmbeddings(keras.layers.Layer): + """ + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.hidden_size + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + + self.projection = keras.layers.Conv2D( + hidden_size, kernel_size=patch_size, strides=patch_size, name="projection" + ) + + def call(self, pixel_values): + batch_size, num_channels, height, width = shape_list(pixel_values) + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + if height != self.image_size[0] or width != self.image_size[1]: + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})." + ) + embeddings = self.projection(tf.transpose(pixel_values, perm=[0, 2, 3, 1])) + return embeddings + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "projection", None) is not None: + with tf.name_scope(self.projection.name): + self.projection.build([None, None, None, self.num_channels]) + + +class TFSamMLPBlock(keras.layers.Layer): + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + self.lin1 = keras.layers.Dense(config.mlp_dim, name="lin1") + self.lin2 = keras.layers.Dense(config.hidden_size, name="lin2") + self.act = ACT2FN[config.hidden_act] + self.config = config + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + hidden_states = self.lin1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.lin2(hidden_states) + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "lin1", None) is not None: + with tf.name_scope(self.lin1.name): + self.lin1.build([None, None, self.config.hidden_size]) + if getattr(self, "lin2", None) is not None: + with tf.name_scope(self.lin2.name): + self.lin2.build([None, None, self.config.mlp_dim]) + + +class TFSamLayerNorm(keras.layers.Layer): + r"""LayerNorm that supports two data formats: channels_last (default) or channels_first. + The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height, + width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width). + """ + + def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last", **kwargs): + super().__init__(**kwargs) + self.eps = eps + self.data_format = data_format + self.normalized_shape = normalized_shape + if self.data_format not in ["channels_last", "channels_first"]: + raise NotImplementedError(f"Unsupported data format: {self.data_format}") + + def build(self, input_shape): + self.weight = self.add_weight(shape=self.normalized_shape, initializer="ones", name="weight") + self.bias = self.add_weight(shape=self.normalized_shape, initializer="zeros", name="bias") + super().build(input_shape) + + def call(self, x: tf.Tensor) -> tf.Tensor: + if self.data_format == "channels_last": + x = functional_layernorm(x, weight=self.weight, bias=self.bias, epsilon=self.eps, axis=-1) + elif self.data_format == "channels_first": + x = functional_layernorm(x, weight=self.weight, bias=self.bias, epsilon=self.eps, axis=1) + return x + + +class TFSamAttention(keras.layers.Layer): + """ + SAM's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and + values. + """ + + def __init__(self, config, downsample_rate=None, **kwargs): + super().__init__(**kwargs) + self.hidden_size = config.hidden_size + + downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate + + self.internal_dim = config.hidden_size // downsample_rate + self.num_attention_heads = config.num_attention_heads + if self.internal_dim % config.num_attention_heads != 0: + raise ValueError("num_attention_heads must divide hidden_size.") + + self.q_proj = keras.layers.Dense(self.internal_dim, name="q_proj") + self.k_proj = keras.layers.Dense(self.internal_dim, name="k_proj") + self.v_proj = keras.layers.Dense(self.internal_dim, name="v_proj") + self.out_proj = keras.layers.Dense(self.hidden_size, name="out_proj") + + def _separate_heads(self, hidden_states: tf.Tensor, num_attention_heads: int) -> tf.Tensor: + batch, point_batch_size, n_tokens, channel = shape_list(hidden_states) + c_per_head = channel // num_attention_heads + hidden_states = tf.reshape( + hidden_states, (batch * point_batch_size, n_tokens, num_attention_heads, c_per_head) + ) + return tf.transpose(hidden_states, perm=[0, 2, 1, 3]) + + def _recombine_heads(self, hidden_states: tf.Tensor, point_batch_size: int) -> tf.Tensor: + batch, n_heads, n_tokens, c_per_head = shape_list(hidden_states) + hidden_states = tf.transpose(hidden_states, perm=[0, 2, 1, 3]) + return tf.reshape( + hidden_states, + (batch // tf.reduce_max([1, point_batch_size]), point_batch_size, n_tokens, n_heads * c_per_head), + ) + + def call(self, query: tf.Tensor, key: tf.Tensor, value: tf.Tensor) -> tf.Tensor: + # Input projections + query = self.q_proj(query) + key = self.k_proj(key) + value = self.v_proj(value) + + point_batch_size = shape_list(query)[1] + # Separate into heads + query = self._separate_heads(query, self.num_attention_heads) + key = self._separate_heads(key, self.num_attention_heads) + value = self._separate_heads(value, self.num_attention_heads) + + # SamAttention + _, _, _, c_per_head = shape_list(query) + attn = tf.matmul( + query, tf.transpose(key, perm=[0, 1, 3, 2]) + ) # batch_size * point_batch_size x N_heads x N_tokens x N_tokens + attn = attn / tf.math.sqrt(float(c_per_head)) + attn = tf.nn.softmax(attn, axis=-1) + + # Get output + out = tf.matmul(attn, value) + out = self._recombine_heads(out, point_batch_size) + out = self.out_proj(out) + + return out + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "q_proj", None) is not None: + with tf.name_scope(self.q_proj.name): + self.q_proj.build([None, None, self.hidden_size]) + if getattr(self, "k_proj", None) is not None: + with tf.name_scope(self.k_proj.name): + self.k_proj.build([None, None, self.hidden_size]) + if getattr(self, "v_proj", None) is not None: + with tf.name_scope(self.v_proj.name): + self.v_proj.build([None, None, self.hidden_size]) + if getattr(self, "out_proj", None) is not None: + with tf.name_scope(self.out_proj.name): + self.out_proj.build([None, None, self.internal_dim]) + + +class TFSamTwoWayAttentionBlock(keras.layers.Layer): + def __init__(self, config, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False, **kwargs): + """ + A transformer block with four layers: + (1) self-attention of sparse inputs (2) cross attention of sparse inputs -> dense inputs (3) mlp block on + sparse inputs (4) cross attention of dense inputs -> sparse inputs + + Arguments: + config (`SamMaskDecoderConfig`): + The configuration file used to instantiate the block + attention_downsample_rate (*optionalk*, int, defaults to 2): + The downsample ratio of the block used to reduce the inner dim of the attention. + skip_first_layer_pe (*optional*, bool, defaults to `False`): + Whether or not to skip the addition of the query_point_embedding on the first layer. + """ + super().__init__(**kwargs) + + self.hidden_size = config.hidden_size + self.layer_norm_eps = config.layer_norm_eps + + self.self_attn = TFSamAttention(config, downsample_rate=1, name="self_attn") + self.layer_norm1 = keras.layers.LayerNormalization(epsilon=self.layer_norm_eps, name="layer_norm1") + + self.cross_attn_token_to_image = TFSamAttention( + config, downsample_rate=attention_downsample_rate, name="cross_attn_token_to_image" + ) + self.layer_norm2 = keras.layers.LayerNormalization(epsilon=self.layer_norm_eps, name="layer_norm2") + + self.mlp = TFSamMLPBlock(config, name="mlp") + self.layer_norm3 = keras.layers.LayerNormalization(epsilon=self.layer_norm_eps, name="layer_norm3") + + self.layer_norm4 = keras.layers.LayerNormalization(epsilon=self.layer_norm_eps, name="layer_norm4") + self.cross_attn_image_to_token = TFSamAttention( + config, downsample_rate=attention_downsample_rate, name="cross_attn_image_to_token" + ) + + self.skip_first_layer_pe = skip_first_layer_pe + + def call( + self, + queries: tf.Tensor, + keys: tf.Tensor, + query_point_embedding: tf.Tensor, + key_point_embedding: tf.Tensor, + output_attentions: bool = False, + ): + # Self attention block + if self.skip_first_layer_pe: + queries = self.self_attn(query=queries, key=queries, value=queries) + else: + query = queries + query_point_embedding + attn_out = self.self_attn(query=query, key=query, value=queries) + queries = queries + attn_out + queries = self.layer_norm1(queries) + + # Cross attention block, tokens attending to image embedding + query = queries + query_point_embedding + key = keys + key_point_embedding + + attn_out = self.cross_attn_token_to_image(query=query, key=key, value=keys) + queries = queries + attn_out + + queries = self.layer_norm2(queries) + + # MLP block + mlp_out = self.mlp(queries) + queries = queries + mlp_out + queries = self.layer_norm3(queries) + + # Cross attention block, image embedding attending to tokens + query = queries + query_point_embedding + key = keys + key_point_embedding + + attn_out = self.cross_attn_image_to_token(query=key, key=query, value=queries) + keys = keys + attn_out + + keys = self.layer_norm4(keys) + + outputs = (queries, keys) + + if output_attentions: + outputs = outputs + (attn_out,) + else: + outputs = outputs + (None,) + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "self_attn", None) is not None: + with tf.name_scope(self.self_attn.name): + self.self_attn.build(None) + if getattr(self, "layer_norm1", None) is not None: + with tf.name_scope(self.layer_norm1.name): + self.layer_norm1.build([None, None, None, self.hidden_size]) + if getattr(self, "cross_attn_token_to_image", None) is not None: + with tf.name_scope(self.cross_attn_token_to_image.name): + self.cross_attn_token_to_image.build(None) + if getattr(self, "layer_norm2", None) is not None: + with tf.name_scope(self.layer_norm2.name): + self.layer_norm2.build([None, None, None, self.hidden_size]) + if getattr(self, "mlp", None) is not None: + with tf.name_scope(self.mlp.name): + self.mlp.build(None) + if getattr(self, "layer_norm3", None) is not None: + with tf.name_scope(self.layer_norm3.name): + self.layer_norm3.build([None, None, None, self.hidden_size]) + if getattr(self, "layer_norm4", None) is not None: + with tf.name_scope(self.layer_norm4.name): + self.layer_norm4.build([None, None, None, self.hidden_size]) + if getattr(self, "cross_attn_image_to_token", None) is not None: + with tf.name_scope(self.cross_attn_image_to_token.name): + self.cross_attn_image_to_token.build(None) + + +class TFSamTwoWayTransformer(keras.layers.Layer): + def __init__(self, config: SamMaskDecoderConfig, **kwargs): + super().__init__(**kwargs) + self.config = config + + self.num_hidden_layers = config.num_hidden_layers + self.layers = [] + + for i in range(self.num_hidden_layers): + self.layers.append(TFSamTwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0), name=f"layers_._{i}")) + + self.final_attn_token_to_image = TFSamAttention(config, name="final_attn_token_to_image") + self.layer_norm_final_attn = keras.layers.LayerNormalization( + epsilon=config.layer_norm_eps, name="layer_norm_final_attn" + ) + + def call( + self, + point_embeddings: tf.Tensor, + image_embeddings: tf.Tensor, + image_positional_embeddings: tf.Tensor, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TFBaseModelOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + all_attentions = () + + if image_embeddings is None: + raise ValueError("You have to specify an image_embedding") + + image_embeddings = tf.transpose(flatten(image_embeddings, 2), perm=(0, 2, 1))[:, None] + image_positional_embeddings = tf.transpose(flatten(image_positional_embeddings, 2), (0, 2, 1))[:, None] + + # Prepare queries + queries = point_embeddings + keys = image_embeddings + + # Apply transformer blocks and final layernorm + for layer in self.layers: + queries, keys, attention_outputs = layer( + queries=queries, + keys=keys, + query_point_embedding=point_embeddings, + key_point_embedding=image_positional_embeddings, + output_attentions=output_attentions, + ) + + if output_attentions: + all_attentions = all_attentions + (attention_outputs,) + + # Apply the final attenion layer from the points to the image + query = queries + point_embeddings + key = keys + image_positional_embeddings + + attn_out = self.final_attn_token_to_image(query=query, key=key, value=keys) + + queries = queries + attn_out + queries = self.layer_norm_final_attn(queries) + return queries, keys, all_attentions + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "final_attn_token_to_image", None) is not None: + with tf.name_scope(self.final_attn_token_to_image.name): + self.final_attn_token_to_image.build(None) + if getattr(self, "layer_norm_final_attn", None) is not None: + with tf.name_scope(self.layer_norm_final_attn.name): + self.layer_norm_final_attn.build([None, None, None, self.config.hidden_size]) + for layer in self.layers: + with tf.name_scope(layer.name): + layer.build(None) + + +class TFSamFeedForward(keras.layers.Layer): + def __init__( + self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, sigmoid_output: bool = False, **kwargs + ): + super().__init__(**kwargs) + self.num_layers = num_layers + self.activation = keras.layers.ReLU() + self.proj_in = keras.layers.Dense(hidden_dim, input_shape=(input_dim,), name="proj_in") + self.proj_out = keras.layers.Dense(output_dim, input_shape=(hidden_dim,), name="proj_out") + self.layers = [ + keras.layers.Dense(hidden_dim, input_shape=(hidden_dim,), name=f"layers_._{i}") + for i in range(num_layers - 2) + ] + self.sigmoid_output = sigmoid_output + self.hidden_dim = hidden_dim + self.input_dim = input_dim + + def call(self, hidden_states): + hidden_states = self.proj_in(hidden_states) + hidden_states = self.activation(hidden_states) + for layer in self.layers: + hidden_states = self.activation(layer(hidden_states)) + + hidden_states = self.proj_out(hidden_states) + if self.sigmoid_output: + hidden_states = tf.sigmoid(hidden_states) + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "proj_in", None) is not None: + with tf.name_scope(self.proj_in.name): + self.proj_in.build([None, None, self.input_dim]) + if getattr(self, "proj_out", None) is not None: + with tf.name_scope(self.proj_out.name): + self.proj_out.build([None, None, self.hidden_dim]) + if getattr(self, "layers", None) is not None: + for layer in self.layers: + with tf.name_scope(layer.name): + layer.build([None, None, self.hidden_dim]) + + +class TFSamMaskDecoder(keras.layers.Layer): + def __init__(self, config: SamMaskDecoderConfig, **kwargs): + super().__init__(**kwargs) + + self.hidden_size = config.hidden_size + + self.num_multimask_outputs = config.num_multimask_outputs + self.num_mask_tokens = config.num_multimask_outputs + 1 + + self.transformer = TFSamTwoWayTransformer(config, name="transformer") + + self.upscale_conv1 = keras.layers.Conv2DTranspose( + self.hidden_size // 4, kernel_size=2, strides=2, name="upscale_conv1", data_format="channels_first" + ) + self.upscale_conv2 = keras.layers.Conv2DTranspose( + self.hidden_size // 8, kernel_size=2, strides=2, name="upscale_conv2", data_format="channels_first" + ) + self.upscale_layer_norm = TFSamLayerNorm( + self.hidden_size // 4, data_format="channels_first", name="upscale_layer_norm" + ) + self.activation = tf.nn.gelu + + mlps_list = [] + for i in range(self.num_mask_tokens): + mlps_list += [ + TFSamFeedForward( + self.hidden_size, + self.hidden_size, + self.hidden_size // 8, + 3, + name=f"output_hypernetworks_mlps_._{i}", + ) + ] + self.output_hypernetworks_mlps = mlps_list + + self.iou_prediction_head = TFSamFeedForward( + self.hidden_size, + config.iou_head_hidden_dim, + self.num_mask_tokens, + config.iou_head_depth, + name="iou_prediction_head", + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + self.iou_token = self.add_weight(shape=(1, self.hidden_size), name="iou_token.weight", trainable=True) + self.mask_tokens = self.add_weight( + shape=(self.num_mask_tokens, self.hidden_size), name="mask_tokens.weight", trainable=True + ) + + if getattr(self, "transformer", None) is not None: + with tf.name_scope(self.transformer.name): + self.transformer.build(None) + if getattr(self, "upscale_conv1", None) is not None: + with tf.name_scope(self.upscale_conv1.name): + self.upscale_conv1.build([None, self.hidden_size, None, None]) + if getattr(self, "upscale_conv2", None) is not None: + with tf.name_scope(self.upscale_conv2.name): + self.upscale_conv2.build([None, self.hidden_size // 4, None, None]) + if getattr(self, "upscale_layer_norm", None) is not None: + with tf.name_scope(self.upscale_layer_norm.name): + self.upscale_layer_norm.build(None) + if getattr(self, "iou_prediction_head", None) is not None: + with tf.name_scope(self.iou_prediction_head.name): + self.iou_prediction_head.build(None) + for mlp in self.output_hypernetworks_mlps: + with tf.name_scope(mlp.name): + mlp.build(None) + + def call( + self, + image_embeddings: tf.Tensor, + image_positional_embeddings: tf.Tensor, + sparse_prompt_embeddings: tf.Tensor, + dense_prompt_embeddings: tf.Tensor, + multimask_output: bool, + output_attentions: Optional[bool] = None, + ) -> Tuple[tf.Tensor, tf.Tensor]: + batch_size, num_channels, height, width = shape_list(image_embeddings) + point_batch_size = tf.math.maximum(1, tf.shape(sparse_prompt_embeddings)[1]) + + output_tokens = tf.concat([self.iou_token, self.mask_tokens], axis=0) # Should be (1, 32) + (4, 32) = (5, 32) + output_tokens = tf.tile( + output_tokens[None, None, :], [batch_size, point_batch_size, 1, 1] + ) # Should be (batch_size, point_size, 5, 32) + + # Matt: The original Torch code checked that the sum of sparse_prompt_embeddings equalled 0. However, this only + # happens when the sparse prompt embeddings are an empty tensor with shape[1] == 0. I replaced + # it with an explicit shape check to avoid data-dependent control flow which breaks XLA. + if shape_list(sparse_prompt_embeddings)[1] != 0: + tokens = tf.concat((output_tokens, sparse_prompt_embeddings), axis=2) + else: + tokens = output_tokens + point_embeddings = tf.cast(tokens, self.iou_token.dtype) + + image_embeddings = image_embeddings + dense_prompt_embeddings + image_embeddings = tf.repeat(image_embeddings, point_batch_size, axis=0) + image_positional_embeddings = tf.repeat(image_positional_embeddings, point_batch_size, axis=0) + + point_embedding, image_embeddings, attentions = self.transformer( + point_embeddings=point_embeddings, + image_embeddings=image_embeddings, + image_positional_embeddings=image_positional_embeddings, + output_attentions=output_attentions, + ) + iou_token_out = point_embedding[:, :, 0, :] + mask_tokens_out = point_embedding[:, :, 1 : (1 + self.num_mask_tokens), :] + + image_embeddings = tf.transpose(image_embeddings, perm=(0, 1, 3, 2)) + image_embeddings = tf.reshape(image_embeddings, [batch_size * point_batch_size, num_channels, height, width]) + + upscaled_embedding = self.upscale_conv1(image_embeddings) + upscaled_embedding = self.activation(self.upscale_layer_norm(upscaled_embedding)) + upscaled_embedding = self.activation(self.upscale_conv2(upscaled_embedding)) + + hyper_in_list = [] + for i in range(self.num_mask_tokens): + current_mlp = self.output_hypernetworks_mlps[i] + hyper_in_list += [current_mlp(mask_tokens_out[:, :, i, :])] + hyper_in = tf.stack(hyper_in_list, axis=2) + + _, num_channels, height, width = shape_list(upscaled_embedding) + upscaled_embedding = tf.reshape( + upscaled_embedding, [batch_size, point_batch_size, num_channels, height * width] + ) + masks = tf.reshape(hyper_in @ upscaled_embedding, [batch_size, point_batch_size, -1, height, width]) + + iou_pred = self.iou_prediction_head(iou_token_out) + + if multimask_output: + mask_slice = slice(1, None) + else: + mask_slice = slice(0, 1) + masks = masks[:, :, mask_slice, :, :] + iou_pred = iou_pred[:, :, mask_slice] + + outputs = (masks, iou_pred) + + if output_attentions: + outputs = outputs + (attentions,) + else: + outputs = outputs + (None,) + + return outputs + + +class TFSamPositionalEmbedding(keras.layers.Layer): + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + self.scale = config.hidden_size // 2 + self.config = config + + def build(self, input_shape): + # TODO Matt: What is going on here? Why is a non-trainable weight randomly initialized? + self.positional_embedding = self.add_weight( + name="positional_embedding", + shape=(2, self.config.num_pos_feats), + initializer=keras.initializers.RandomNormal(mean=0.0, stddev=self.scale), + trainable=False, + ) + super().build(input_shape) + + def call(self, input_coords, input_shape=None): + """Positionally encode points that are normalized to [0,1].""" + coordinates = tf.identity(input_coords) + + if input_shape is not None: + coordinates = tf.stack( + [ + tf.cast(coordinates[:, :, :, 0], tf.float32) / input_shape[1], + tf.cast(coordinates[:, :, :, 1], tf.float32) / input_shape[0], + ], + axis=-1, + ) + + # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coordinates = 2 * coordinates - 1 + coordinates = tf.cast(coordinates, self.positional_embedding.dtype) + coordinates = tf.matmul(coordinates, self.positional_embedding) + coordinates = 2 * np.pi * coordinates + # outputs d_1 x ... x d_n x channel shape + return tf.concat([tf.sin(coordinates), tf.cos(coordinates)], axis=-1) + + +class TFSamMaskEmbedding(keras.layers.Layer): + def __init__(self, config: SamPromptEncoderConfig, **kwargs): + super().__init__(**kwargs) + self.mask_input_channels = config.mask_input_channels // 4 + self.activation = ACT2FN[config.hidden_act] + self.conv1 = keras.layers.Conv2D(self.mask_input_channels, kernel_size=2, strides=2, name="conv1") + self.conv2 = keras.layers.Conv2D(config.mask_input_channels, kernel_size=2, strides=2, name="conv2") + self.conv3 = keras.layers.Conv2D(config.hidden_size, kernel_size=1, name="conv3") + self.layer_norm1 = TFSamLayerNorm(self.mask_input_channels, config.layer_norm_eps, name="layer_norm1") + self.layer_norm2 = TFSamLayerNorm(self.mask_input_channels * 4, config.layer_norm_eps, name="layer_norm2") + self.config = config + + def call(self, masks): + masks = tf.transpose(masks, perm=(0, 2, 3, 1)) # Convert to channels-last + hidden_states = self.conv1(masks) + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.activation(hidden_states) + + hidden_states = self.conv2(hidden_states) + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.activation(hidden_states) + dense_embeddings = self.conv3(hidden_states) + dense_embeddings = tf.transpose(dense_embeddings, perm=(0, 3, 1, 2)) # Convert back to channels-first + return dense_embeddings + + def build(self, input_shape=None): + # This class needs an explicit build method because it isn't called with the standard dummy inputs + if self.built: + return + self.built = True + with tf.name_scope("conv1"): + self.conv1.build([None, None, None, 1]) + with tf.name_scope("conv2"): + self.conv2.build([None, None, None, self.mask_input_channels]) + with tf.name_scope("conv3"): + self.conv3.build([None, None, None, self.mask_input_channels * 4]) + with tf.name_scope("layer_norm1"): + self.layer_norm1.build([None, None, None, self.mask_input_channels]) + with tf.name_scope("layer_norm2"): + self.layer_norm2.build([None, None, None, self.mask_input_channels * 4]) + + +class TFSamPromptEncoder(keras.layers.Layer): + def __init__(self, config: SamPromptEncoderConfig, shared_patch_embedding, **kwargs): + super().__init__(**kwargs) + self.shared_embedding = shared_patch_embedding + self.mask_embed = TFSamMaskEmbedding(config, name="mask_embed") + self.no_mask_embed = None + + self.image_embedding_size = (config.image_embedding_size, config.image_embedding_size) + self.input_image_size = config.image_size + + self.point_embed = [] + self.hidden_size = config.hidden_size + self.not_a_point_embed = None + self.config = config + + def build(self, input_shape=None): + self.no_mask_embed = self.add_weight( + name="no_mask_embed.weight", + shape=(1, self.hidden_size), + initializer=keras.initializers.RandomNormal(mean=0.0, stddev=0.02), + trainable=True, + ) + self.point_embed = [ + self.add_weight( + name=f"point_embed_._{i}.weight", + shape=(1, self.hidden_size), + initializer=keras.initializers.RandomNormal(mean=0.0, stddev=0.02), + trainable=True, + ) + for i in range(self.config.num_point_embeddings) + ] + self.not_a_point_embed = self.add_weight( + name="not_a_point_embed.weight", + shape=(1, self.hidden_size), + initializer=keras.initializers.RandomNormal(mean=0.0, stddev=0.02), + trainable=True, + ) + with tf.name_scope("mask_embed"): + # We must explicitly build the mask embed because it isn't touched by the standard dummy inputs + self.mask_embed.build( + (None, self.config.mask_input_channels, self.config.image_size, self.config.image_size) + ) + + if self.built: + return + self.built = True + if getattr(self, "mask_embed", None) is not None: + with tf.name_scope(self.mask_embed.name): + self.mask_embed.build(None) + + def _embed_points(self, points: tf.Tensor, labels: tf.Tensor, pad: bool) -> tf.Tensor: + """Embeds point prompts.""" + points = points + 0.5 # Shift to center of pixel + if pad: + target_point_shape = (shape_list(points)[0], shape_list(points)[1], 1, shape_list(points)[-1]) + target_labels_shape = (shape_list(points)[0], shape_list(points)[1], 1) + padding_point = tf.zeros(target_point_shape, dtype=points.dtype) + padding_label = -tf.ones(target_labels_shape, dtype=labels.dtype) + points = tf.concat([points, padding_point], axis=2) + labels = tf.concat([labels, padding_label], axis=2) + input_shape = (self.input_image_size, self.input_image_size) + point_embedding = self.shared_embedding(points, input_shape) + + point_embedding = tf.where(labels[..., None] == -1, self.not_a_point_embed[0], point_embedding) + + point_embedding = tf.where( + labels[..., None] != -10, + point_embedding, + tf.zeros_like(point_embedding), + ) + point_embedding = tf.where( + (labels == 0)[:, :, :, None], point_embedding + self.point_embed[0], point_embedding + ) + point_embedding = tf.where( + (labels == 1)[:, :, :, None], point_embedding + self.point_embed[1], point_embedding + ) + return point_embedding + + def _embed_boxes(self, boxes: tf.Tensor) -> tf.Tensor: + """Embeds box prompts.""" + boxes = boxes + 0.5 # Shift to center of pixel + batch_size, nb_boxes = shape_list(boxes)[:2] + coords = tf.reshape(boxes, (batch_size, nb_boxes, 2, 2)) + input_shape = (self.input_image_size, self.input_image_size) + corner_embedding = self.shared_embedding(coords, input_shape) + corner_embedding += tf.where( + tf.range(shape_list(corner_embedding)[2])[None, None, :, None] == 0, + self.point_embed[2][0], + self.point_embed[3][0], + ) + return corner_embedding + + def call( + self, + batch_size: Optional[int], + input_points: Optional[Tuple[tf.Tensor, tf.Tensor]], + input_labels: tf.Tensor | None, + input_boxes: tf.Tensor | None, + input_masks: tf.Tensor | None, + ) -> Tuple[tf.Tensor, tf.Tensor]: + """ + Embeds different types of prompts, returning both sparse and dense embeddings. + + Args: + points (`tf.Tensor`, *optional*): + point coordinates and labels to embed. + boxes (`tf.Tensor`, *optional*): + boxes to embed + masks (`tf.Tensor`, *optional*): + masks to embed + """ + sparse_embeddings = None + if input_points is not None: + batch_size, point_batch_size = shape_list(input_points)[:2] + if input_labels is None: + raise ValueError("If points are provided, labels must also be provided.") + point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None)) + sparse_embeddings = tf.zeros( + (batch_size, point_batch_size, 0, self.hidden_size), dtype=point_embeddings.dtype + ) + sparse_embeddings = tf.concat([sparse_embeddings, point_embeddings], axis=2) + if input_boxes is not None: + batch_size = shape_list(input_boxes)[0] + box_embeddings = self._embed_boxes(input_boxes) + if sparse_embeddings is None: + sparse_embeddings = box_embeddings + else: + sparse_embeddings = tf.concat([sparse_embeddings, box_embeddings], axis=2) + if input_masks is not None: + dense_embeddings = self.mask_embed(input_masks) + else: + dense_embeddings = self.no_mask_embed[0] + dense_embeddings = tf.reshape(dense_embeddings, (1, -1, 1, 1)) + dense_embeddings = tf.tile( + dense_embeddings, (batch_size, 1, self.image_embedding_size[0], self.image_embedding_size[1]) + ) + if sparse_embeddings is None: + sparse_embeddings = tf.zeros((batch_size, 0, 1, self.hidden_size), dtype=dense_embeddings.dtype) + + return sparse_embeddings, dense_embeddings + + +class TFSamVisionAttention(keras.layers.Layer): + """Multi-head Attention block with relative position embeddings.""" + + def __init__(self, config, window_size, **kwargs): + super().__init__(**kwargs) + input_size = ( + (config.image_size // config.patch_size, config.image_size // config.patch_size) + if window_size == 0 + else (window_size, window_size) + ) + self.input_size = input_size + + self.num_attention_heads = config.num_attention_heads + head_dim = config.hidden_size // config.num_attention_heads + self.head_dim = head_dim + self.scale = head_dim**-0.5 + self.dropout = config.attention_dropout + + self.qkv = keras.layers.Dense(config.hidden_size * 3, use_bias=config.qkv_bias, name="qkv") + self.proj = keras.layers.Dense(config.hidden_size, name="proj") + + self.use_rel_pos = config.use_rel_pos + if self.use_rel_pos: + if input_size is None: + raise ValueError("Input size must be provided if using relative positional encoding.") + self.config = config + + def build(self, input_shape=None): + if self.input_size is not None: + # initialize relative positional embeddings + self.rel_pos_h = self.add_weight( + shape=(2 * self.input_size[0] - 1, self.head_dim), initializer="zeros", name="rel_pos_h" + ) + self.rel_pos_w = self.add_weight( + shape=(2 * self.input_size[1] - 1, self.head_dim), initializer="zeros", name="rel_pos_w" + ) + + if self.built: + return + self.built = True + if getattr(self, "qkv", None) is not None: + with tf.name_scope(self.qkv.name): + self.qkv.build([None, None, self.config.hidden_size]) + if getattr(self, "proj", None) is not None: + with tf.name_scope(self.proj.name): + self.proj.build([None, None, self.config.hidden_size]) + + def get_rel_pos(self, q_size: int, k_size: int, rel_pos: tf.Tensor) -> tf.Tensor: + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + + Args: + q_size (int): + size of the query. + k_size (int): + size of key k. + rel_pos (`tf.Tensor`): + relative position embeddings (L, channel). + + Returns: + Extracted positional embeddings according to relative positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos if needed. + if rel_pos.shape[0] != max_rel_dist: + # Interpolate rel pos. + rel_pos_resized = tf.image.resize( + tf.reshape(rel_pos, (1, rel_pos.shape[0], -1)), + size=(max_rel_dist, rel_pos.shape[1]), + method="bilinear", + ) + rel_pos_resized = tf.reshape(rel_pos_resized, (-1, max_rel_dist)) + else: + rel_pos_resized = rel_pos + + # Scale the coords with short length if shapes for q and k are different. + q_coords = tf.expand_dims(tf.range(q_size, dtype=tf.float32), 1) * max(k_size / q_size, 1.0) + k_coords = tf.expand_dims(tf.range(k_size, dtype=tf.float32), 0) * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return tf.gather(rel_pos_resized, tf.cast(relative_coords, tf.int32)) + + def get_decomposed_rel_pos( + self, + query: tf.Tensor, + rel_pos_h: tf.Tensor, + rel_pos_w: tf.Tensor, + q_size: Tuple[int, int], + k_size: Tuple[int, int], + ) -> tf.Tensor: + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py + + Args: + query (`tf.Tensor`): + query q in the attention layer with shape (batch_size, query_height * query_width, channel). + rel_pos_h (`tf.Tensor`): + relative position embeddings (Lh, channel) for height axis. + rel_pos_w (`tf.Tensor`): + relative position embeddings (Lw, channel) for width axis. + q_size (tuple): + spatial sequence size of query q with (query_height, query_width). + k_size (tuple): + spatial sequence size of key k with (key_height, key_width). + + Returns: + decomposed_rel_pos (`torch.Tensor`): + decomposed relative position embeddings. + """ + query_height, query_width = q_size + key_height, key_width = k_size + relative_position_height = self.get_rel_pos(query_height, key_height, rel_pos_h) + relative_position_width = self.get_rel_pos(query_width, key_width, rel_pos_w) + + batch_size, _, dim = shape_list(query) + reshaped_query = tf.reshape(query, (batch_size, query_height, query_width, dim)) + rel_h = tf.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height) + rel_w = tf.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width) + + rel_h = tf.expand_dims(rel_h, axis=-1) + rel_w = tf.expand_dims(rel_w, axis=-2) + decomposed_rel_pos = rel_h + rel_w + + return decomposed_rel_pos + + def call(self, hidden_states: tf.Tensor, output_attentions=False, training=False) -> tf.Tensor: + batch_size, height, width, _ = shape_list(hidden_states) + # qkv with shape (3, batch_size, nHead, height * width, channel) + qkv = tf.reshape(self.qkv(hidden_states), (batch_size, height * width, 3, self.num_attention_heads, -1)) + qkv = tf.transpose(qkv, perm=(2, 0, 3, 1, 4)) + # q, k, v with shape (batch_size * nHead, height * width, channel) + query, key, value = tf.unstack( + tf.reshape(qkv, (3, batch_size * self.num_attention_heads, height * width, -1)), axis=0 + ) + attn_weights = tf.matmul(query * self.scale, key, transpose_b=True) + + if self.use_rel_pos: + decomposed_rel_pos = self.get_decomposed_rel_pos( + query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width) + ) + decomposed_rel_pos = tf.reshape(decomposed_rel_pos, shape_list(attn_weights)) + attn_weights = attn_weights + decomposed_rel_pos + + attn_weights = tf.nn.softmax(attn_weights, axis=-1) + + if training: + attn_probs = tf.nn.dropout(attn_weights, rate=self.dropout) + else: + attn_probs = attn_weights + + attn_output = tf.reshape(attn_probs @ value, (batch_size, self.num_attention_heads, height, width, -1)) + attn_output = tf.transpose(attn_output, perm=(0, 2, 3, 1, 4)) + attn_output = tf.reshape(attn_output, (batch_size, height, width, self.config.hidden_size)) + + attn_output = self.proj(attn_output) + + if output_attentions: + outputs = (attn_output, attn_weights) + else: + outputs = (attn_output, None) + + return outputs + + +class TFSamVisionLayer(keras.layers.Layer): + def __init__(self, config, window_size, **kwargs): + super().__init__(**kwargs) + self.layer_norm1 = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm1") + self.attn = TFSamVisionAttention(config, window_size, name="attn") + self.layer_norm2 = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm2") + self.mlp = TFSamMLPBlock(config, name="mlp") + self.window_size = window_size + self.config = config + + def window_partition(self, hidden_states: tf.Tensor, window_size: int) -> Tuple[tf.Tensor, Tuple[int, int]]: + batch_size, height, width, channel = shape_list(hidden_states) + + pad_h = (window_size - height % window_size) % window_size + pad_w = (window_size - width % window_size) % window_size + if pad_h > 0 or pad_w > 0: + hidden_states = tf.pad(hidden_states, [[0, 0], [0, pad_h], [0, pad_w], [0, 0]]) + pad_height, pad_width = height + pad_h, width + pad_w + + hidden_states = tf.reshape( + hidden_states, + [batch_size, pad_height // window_size, window_size, pad_width // window_size, window_size, channel], + ) + windows = tf.reshape( + tf.transpose(hidden_states, perm=[0, 1, 3, 2, 4, 5]), [-1, window_size, window_size, channel] + ) + return windows, (pad_height, pad_width) + + def window_unpartition( + self, windows: tf.Tensor, window_size: int, padding_shape: Tuple[int, int], original_shape: Tuple[int, int] + ) -> tf.Tensor: + pad_height, pad_width = padding_shape + height, width = original_shape + batch_size = shape_list(windows)[0] // (pad_height * pad_width // window_size // window_size) + hidden_states = tf.reshape( + windows, [batch_size, pad_height // window_size, pad_width // window_size, window_size, window_size, -1] + ) + hidden_states = tf.reshape( + tf.transpose(hidden_states, perm=[0, 1, 3, 2, 4, 5]), [batch_size, pad_height, pad_width, -1] + ) + + if pad_height > height or pad_width > width: + hidden_states = hidden_states[:, :height, :width, :] + return hidden_states + + def call( + self, + hidden_states: tf.Tensor, + output_attentions: Optional[bool] = False, + training: Optional[bool] = False, + ) -> Tuple[tf.Tensor]: + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + if self.window_size > 0: + height, width = hidden_states.shape[1], hidden_states.shape[2] + hidden_states, padding_shape = self.window_partition(hidden_states, self.window_size) + + hidden_states, attn_weights = self.attn( + hidden_states=hidden_states, + output_attentions=output_attentions, + training=training, + ) + if self.window_size > 0: + hidden_states = self.window_unpartition(hidden_states, self.window_size, padding_shape, (height, width)) + + hidden_states = residual + hidden_states + layernorm_output = self.layer_norm2(hidden_states) + hidden_states = hidden_states + self.mlp(layernorm_output) + + outputs = (hidden_states,) + if output_attentions: + outputs += (attn_weights,) + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "layer_norm1", None) is not None: + with tf.name_scope(self.layer_norm1.name): + self.layer_norm1.build([None, None, None, self.config.hidden_size]) + if getattr(self, "attn", None) is not None: + with tf.name_scope(self.attn.name): + self.attn.build(None) + if getattr(self, "layer_norm2", None) is not None: + with tf.name_scope(self.layer_norm2.name): + self.layer_norm2.build([None, None, None, self.config.hidden_size]) + if getattr(self, "mlp", None) is not None: + with tf.name_scope(self.mlp.name): + self.mlp.build(None) + + +class TFSamVisionNeck(keras.layers.Layer): + def __init__(self, config: SamVisionConfig, **kwargs): + super().__init__(**kwargs) + self.config = config + + self.conv1 = keras.layers.Conv2D( + config.output_channels, + kernel_size=1, + use_bias=False, + name="conv1", + ) + self.layer_norm1 = TFSamLayerNorm(config.output_channels, name="layer_norm1") + self.conv2 = keras.layers.Conv2D( + config.output_channels, + kernel_size=3, + padding="same", + use_bias=False, + name="conv2", + ) + self.layer_norm2 = TFSamLayerNorm(config.output_channels, name="layer_norm2") + + def call(self, hidden_states): + hidden_states = self.conv1(hidden_states) + hidden_states = self.layer_norm1(hidden_states) + + hidden_states = self.conv2(hidden_states) + hidden_states = self.layer_norm2(hidden_states) + hidden_states = tf.transpose(hidden_states, perm=[0, 3, 1, 2]) + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "conv1", None) is not None: + with tf.name_scope(self.conv1.name): + self.conv1.build([None, None, None, self.config.hidden_size]) + if getattr(self, "layer_norm1", None) is not None: + with tf.name_scope(self.layer_norm1.name): + self.layer_norm1.build(None) + if getattr(self, "conv2", None) is not None: + with tf.name_scope(self.conv2.name): + self.conv2.build([None, None, None, self.config.output_channels]) + if getattr(self, "layer_norm2", None) is not None: + with tf.name_scope(self.layer_norm2.name): + self.layer_norm2.build(None) + + +class TFSamVisionEncoder(keras.layers.Layer): + def __init__(self, config: SamVisionConfig, **kwargs): + super().__init__(**kwargs) + self.config = config + self.image_size = config.image_size + + self.patch_embed = TFSamPatchEmbeddings(config, name="patch_embed") + + self.pos_embed = None + + self.layers = [] + for i in range(config.num_hidden_layers): + layer = TFSamVisionLayer( + config, + window_size=config.window_size if i not in config.global_attn_indexes else 0, + name=f"layers_._{i}", + ) + self.layers.append(layer) + + self.neck = TFSamVisionNeck(config, name="neck") + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if self.config.use_abs_pos: + # Initialize absolute positional embedding with pretrain image size. + self.pos_embed = self.add_weight( + shape=[ + 1, + self.config.image_size // self.config.patch_size, + self.config.image_size // self.config.patch_size, + self.config.hidden_size, + ], + initializer="zeros", + trainable=True, + name="pos_embed", + ) + + if getattr(self, "patch_embed", None) is not None: + with tf.name_scope(self.patch_embed.name): + self.patch_embed.build(None) + if getattr(self, "neck", None) is not None: + with tf.name_scope(self.neck.name): + self.neck.build(None) + for layer in self.layers: + with tf.name_scope(layer.name): + layer.build(None) + + def get_input_embeddings(self): + return self.patch_embed + + def call( + self, + pixel_values: tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: Optional[bool] = False, + ) -> Union[Tuple, TFSamVisionEncoderOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + hidden_states = self.patch_embed(pixel_values) + if self.pos_embed is not None: + hidden_states = hidden_states + self.pos_embed + + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, layer_module in enumerate(self.layers): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_outputs = layer_module(hidden_states, output_attentions=output_attentions, training=training) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + hidden_states = self.neck(hidden_states) + + if not return_dict: + outputs = (hidden_states,) + if output_hidden_states: + outputs = outputs + (all_hidden_states,) + if output_attentions: + outputs = outputs + (all_self_attentions,) + return outputs + + return TFSamVisionEncoderOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class TFSamPreTrainedModel(TFPreTrainedModel): + config_class = SamConfig + base_model_prefix = "sam" + main_input_name = "pixel_values" + + +SAM_START_DOCSTRING = r""" + This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a TensorFlow [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) + subclass. Use it as a regular TensorFlow Model and refer to the TensorFlow documentation for all matter related to + general usage and behavior. + + Parameters: + config ([`SamConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +SAM_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`SamProcessor`]. See [`SamProcessor.__call__`] for + details. + input_points (`tf.Tensor` of shape `(batch_size, num_points, 2)`): + Input 2D spatial points, this is used by the prompt encoder to encode the prompt. Generally yields to much + better results. The points can be obtained by passing a list of list of list to the processor that will + create corresponding `tf` tensors of dimension 4. The first dimension is the image batch size, the second + dimension is the point batch size (i.e. how many segmentation masks do we want the model to predict per + input point), the third dimension is the number of points per segmentation mask (it is possible to pass + multiple points for a single mask), and the last dimension is the x (vertical) and y (horizontal) + coordinates of the point. If a different number of points is passed either for each image, or for each + mask, the processor will create "PAD" points that will correspond to the (0, 0) coordinate, and the + computation of the embedding will be skipped for these points using the labels. + input_labels (`tf.Tensor` of shape `(batch_size, point_batch_size, num_points)`): + Input labels for the points, this is used by the prompt encoder to encode the prompt. According to the + official implementation, there are 3 types of labels + + - `1`: the point is a point that contains the object of interest + - `0`: the point is a point that does not contain the object of interest + - `-1`: the point corresponds to the background + + We added the label: + + - `-10`: the point is a padding point, thus should be ignored by the prompt encoder + + The padding labels should be automatically done by the processor. + input_boxes (`tf.Tensor` of shape `(batch_size, num_boxes, 4)`): + Input boxes for the points, this is used by the prompt encoder to encode the prompt. Generally yields to + much better generated masks. The boxes can be obtained by passing a list of list of list to the processor, + that will generate a `tf` tensor, with each dimension corresponding respectively to the image batch size, + the number of boxes per image and the coordinates of the top left and botton right point of the box. In the + order (`x1`, `y1`, `x2`, `y2`): + + - `x1`: the x coordinate of the top left point of the input box + - `y1`: the y coordinate of the top left point of the input box + - `x2`: the x coordinate of the bottom right point of the input box + - `y2`: the y coordinate of the bottom right point of the input box + + input_masks (`tf.Tensor` of shape `(batch_size, image_size, image_size)`): + SAM model also accepts segmentation masks as input. The mask will be embedded by the prompt encoder to + generate a corresponding embedding, that will be fed later on to the mask decoder. These masks needs to be + manually fed by the user, and they need to be of shape (`batch_size`, `image_size`, `image_size`). + + image_embeddings (`tf.Tensor` of shape `(batch_size, output_channels, window_size, window_size)`): + Image embeddings, this is used by the mask decder to generate masks and iou scores. For more memory + efficient computation, users can first retrieve the image embeddings using the `get_image_embeddings` + method, and then feed them to the `call` method instead of feeding the `pixel_values`. + multimask_output (`bool`, *optional*): + In the original implementation and paper, the model always outputs 3 masks per image (or per point / per + bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the + "best" mask, by specifying `multimask_output=False`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +SAM_VISION_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`SamProcessor`]. See [`SamProcessor.__call__`] for + details. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + """The vision model from Sam without any head or projection on top.""", + SAM_START_DOCSTRING, +) +class TFSamVisionModel(TFSamPreTrainedModel): + config_class = SamVisionConfig + main_input_name = "pixel_values" + + def __init__(self, config: SamVisionConfig, **kwargs): + super().__init__(config, **kwargs) + self.vision_encoder = TFSamVisionEncoder(config, name="vision_encoder") + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "vision_encoder", None) is not None: + with tf.name_scope(self.vision_encoder.name): + self.vision_encoder.build(None) + + def get_input_embeddings(self): + return self.vision_encoder.patch_embed + + @unpack_inputs + @add_start_docstrings_to_model_forward(SAM_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TFSamVisionEncoderOutput, config_class=SamVisionConfig) + def call( + self, + pixel_values: TFModelInputType | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + training: bool = False, + **kwargs, + ) -> TFSamVisionEncoderOutput | Tuple[tf.Tensor]: + r""" + Returns: + + """ + return self.vision_encoder( + pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + +@add_start_docstrings( + "Segment Anything Model (SAM) for generating segmentation masks, given an input image and ", + " optional 2D location and bounding boxes.", + SAM_START_DOCSTRING, +) +class TFSamModel(TFSamPreTrainedModel): + _keys_to_ignore_on_load_missing = [r"prompt_encoder.shared_embedding.positional_embedding"] + + def __init__(self, config, **kwargs): + super().__init__(config, **kwargs) + self.shared_image_embedding = TFSamPositionalEmbedding(config.vision_config, name="shared_image_embedding") + + self.vision_encoder = TFSamVisionEncoder(config.vision_config, name="vision_encoder") + self.prompt_encoder = TFSamPromptEncoder( + config.prompt_encoder_config, self.shared_image_embedding, name="prompt_encoder" + ) + self.mask_decoder = TFSamMaskDecoder(config.mask_decoder_config, name="mask_decoder") + self.config = config + + def get_input_embeddings(self): + return self.vision_encoder.get_input_embeddings() + + def get_image_wide_positional_embeddings(self): + size = self.config.prompt_encoder_config.image_embedding_size + grid = tf.ones((size, size)) + y_embed = tf.math.cumsum(grid, axis=0) - 0.5 + x_embed = tf.math.cumsum(grid, axis=1) - 0.5 + y_embed = y_embed / size + x_embed = x_embed / size + + positional_embedding = self.shared_image_embedding(tf.stack([x_embed, y_embed], axis=-1)) + return tf.expand_dims(tf.transpose(positional_embedding, perm=[2, 0, 1]), axis=0) # channel x height x width + + def get_image_embeddings( + self, + pixel_values, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + r""" + Returns the image embeddings by passing the pixel values through the vision encoder. + + Args: + pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`): + Input pixel values + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.TFModelOutput`] instead of a plain tuple. + + """ + vision_output = self.vision_encoder( + pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + image_embeddings = vision_output[0] + return image_embeddings + + def get_prompt_embeddings( + self, + input_points: tf.Tensor | None = None, + input_labels: tf.Tensor | None = None, + input_boxes: tf.Tensor | None = None, + input_masks: tf.Tensor | None = None, + ): + r""" + Returns the prompt embeddings by passing the input points, labels, boxes and masks through the prompt encoder. + + Args: + input_points (`tf.Tensor` of shape `(batch_size, point_batch_size, num_points_per_image, 2)`): + Optional input points for the prompt encoder. The padding of the point is automatically done by the + processor. `point_batch_size` refers to the number of masks that we want the model to predict per + point. The model will output `point_batch_size` times 3 masks in total. + input_labels (`tf.Tensor` of shape `(batch_size, point_batch_size, num_points_per_image)`): + Optional input labels for the prompt encoder. The padding of the labels is automatically done by the + processor, or can be fed by the user. + input_boxes (`tf.Tensor` of shape `(batch_size, num_boxes_per_image, 4)`): + Optional input boxes for the prompt encoder. The padding of the boxes is automatically done by the + processor. users can also pass manually the input boxes. + input_masks (`tf.Tensor` of shape `(batch_size, image_size, image_size)`): + Optional input masks for the prompt encoder. + """ + prompt_output = self.prompt_encoder( + input_points=input_points, + input_labels=input_labels, + input_boxes=input_boxes, + input_masks=input_masks, + ) + return prompt_output + + @unpack_inputs + @add_start_docstrings_to_model_forward(SAM_INPUTS_DOCSTRING) + def call( + self, + pixel_values: TFModelInputType | None = None, + input_points: tf.Tensor | None = None, + input_labels: tf.Tensor | None = None, + input_boxes: tf.Tensor | None = None, + input_masks: tf.Tensor | None = None, + image_embeddings: tf.Tensor | None = None, + multimask_output: bool = True, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + training: bool = False, + **kwargs, + ) -> TFSamImageSegmentationOutput | Tuple[tf.Tensor]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None and image_embeddings is None: + raise ValueError("Either pixel_values or image_embeddings must be provided.") + + if pixel_values is not None and image_embeddings is not None: + raise ValueError("Only one of pixel_values and image_embeddings can be provided.") + + if input_points is not None and len(input_points.shape) != 4: + raise ValueError( + "The input_points must be a 4D tensor. Of shape `batch_size`, `point_batch_size`, `nb_points_per_image`, `2`.", + " got {}.".format(input_points.shape), + ) + if input_boxes is not None and len(input_boxes.shape) != 3: + raise ValueError( + "The input_points must be a 3D tensor. Of shape `batch_size`, `nb_boxes`, `4`.", + " got {}.".format(input_boxes.shape), + ) + if input_points is not None and input_boxes is not None: + point_batch_size = shape_list(input_points)[1] + box_batch_size = shape_list(input_boxes)[1] + if point_batch_size != box_batch_size: + raise ValueError( + "You should provide as many bounding boxes as input points per box. Got {} and {}.".format( + point_batch_size, box_batch_size + ) + ) + if pixel_values is not None: + # Ensures that later checks pass even with an all-None shape from the serving signature + pixel_values = tf.ensure_shape( + pixel_values, + [ + None, + self.config.vision_config.num_channels, + self.config.vision_config.image_size, + self.config.vision_config.image_size, + ], + ) + image_positional_embeddings = self.get_image_wide_positional_embeddings() + # repeat with batch size + batch_size = shape_list(pixel_values)[0] if pixel_values is not None else shape_list(image_embeddings)[0] + image_positional_embeddings = tf.repeat(image_positional_embeddings, batch_size, axis=0) + + vision_attentions = None + vision_hidden_states = None + + if pixel_values is not None: + vision_outputs = self.vision_encoder( + pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + training=training, + ) + image_embeddings = vision_outputs["last_hidden_state"] + + if output_hidden_states: + vision_hidden_states = vision_outputs["hidden_states"] + if output_attentions: + vision_attentions = vision_outputs["attentions"] + + if input_points is not None and input_labels is None: + input_labels = tf.ones_like(input_points[:, :, :, 0], dtype=tf.int32) + + if input_points is not None and image_embeddings.shape[0] != input_points.shape[0]: + raise ValueError( + "The batch size of the image embeddings and the input points must be the same. ", + "Got {} and {} respectively.".format(image_embeddings.shape[0], input_points.shape[0]), + " if you want to pass multiple points for the same image, make sure that you passed ", + " input_points of shape (batch_size, point_batch_size, num_points_per_image, 3) and ", + " input_labels of shape (batch_size, point_batch_size, num_points_per_image)", + ) + + sparse_embeddings, dense_embeddings = self.prompt_encoder( + batch_size=shape_list(image_embeddings)[0], + input_points=input_points, + input_labels=input_labels, + input_boxes=input_boxes, + input_masks=input_masks, + ) + + low_res_masks, iou_predictions, mask_decoder_attentions = self.mask_decoder( + image_embeddings=image_embeddings, + image_positional_embeddings=image_positional_embeddings, + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + output_attentions=output_attentions, + ) + + if not return_dict: + output = (iou_predictions, low_res_masks) + if output_hidden_states: + output = output + (vision_hidden_states,) + + if output_attentions: + output = output + (vision_attentions, mask_decoder_attentions) + return output + + return TFSamImageSegmentationOutput( + iou_scores=iou_predictions, + pred_masks=low_res_masks, + vision_hidden_states=vision_hidden_states, + vision_attentions=vision_attentions, + mask_decoder_attentions=mask_decoder_attentions, + ) + + def serving_output(self, output: TFSamImageSegmentationOutput) -> TFSamImageSegmentationOutput: + hs = tf.convert_to_tensor(output.vision_hidden_states) if self.config.output_hidden_states else None + attns = tf.convert_to_tensor(output.vision_attentions) if self.config.output_attentions else None + + return TFSamImageSegmentationOutput( + iou_scores=output.iou_scores, + pred_masks=output.pred_masks, + vision_hidden_states=hs if self.config.output_hidden_states else None, + vision_attentions=attns if self.config.output_attentions else None, + mask_decoder_attentions=output.mask_decoder_attentions if self.config.output_attentions else None, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "shared_image_embedding", None) is not None: + with tf.name_scope(self.shared_image_embedding.name): + self.shared_image_embedding.build(None) + if getattr(self, "vision_encoder", None) is not None: + with tf.name_scope(self.vision_encoder.name): + self.vision_encoder.build(None) + if getattr(self, "prompt_encoder", None) is not None: + with tf.name_scope(self.prompt_encoder.name): + self.prompt_encoder.build(None) + if getattr(self, "mask_decoder", None) is not None: + with tf.name_scope(self.mask_decoder.name): + self.mask_decoder.build(None) + + +__all__ = ["TFSamVisionModel", "TFSamModel", "TFSamPreTrainedModel"] diff --git a/docs/transformers/src/transformers/models/sam/processing_sam.py b/docs/transformers/src/transformers/models/sam/processing_sam.py new file mode 100644 index 0000000000000000000000000000000000000000..0b30261fa9130863c710de5a0050803fb081edd0 --- /dev/null +++ b/docs/transformers/src/transformers/models/sam/processing_sam.py @@ -0,0 +1,311 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Processor class for SAM. +""" + +from copy import deepcopy +from typing import List, Optional, Union + +import numpy as np + +from ...image_utils import ImageInput, VideoInput +from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin +from ...tokenization_utils_base import AudioInput, BatchEncoding, PreTokenizedInput, TextInput +from ...utils import is_tf_available, is_torch_available + + +if is_torch_available(): + import torch + +if is_tf_available(): + import tensorflow as tf + + +class SamImagesKwargs(ImagesKwargs): + segmentation_maps: Optional[ImageInput] + input_points: Optional[List[List[float]]] + input_labels: Optional[List[List[int]]] + input_boxes: Optional[List[List[List[float]]]] + point_pad_value: Optional[int] + + +class SamProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs: SamImagesKwargs + _defaults = { + "images_kwargs": { + "point_pad_value": -10, + } + } + + +class SamProcessor(ProcessorMixin): + r""" + Constructs a SAM processor which wraps a SAM image processor and an 2D points & Bounding boxes processor into a + single processor. + + [`SamProcessor`] offers all the functionalities of [`SamImageProcessor`]. See the docstring of + [`~SamImageProcessor.__call__`] for more information. + + Args: + image_processor (`SamImageProcessor`): + An instance of [`SamImageProcessor`]. The image processor is a required input. + """ + + attributes = ["image_processor"] + image_processor_class = "SamImageProcessor" + # For backward compatibility. See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details. + optional_call_args = [ + "segmentation_maps", + "input_points", + "input_labels", + "input_boxes", + ] + + def __init__(self, image_processor): + super().__init__(image_processor) + self.target_size = self.image_processor.size["longest_edge"] + + def __call__( + self, + images: Optional[ImageInput] = None, + # The following is to capture `segmentation_maps`, `input_points`, `input_labels` and `input_boxes` + # arguments that may be passed as a positional argument. + # See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details, + # or this conversation for more context: + # https://github.com/huggingface/transformers/pull/32544#discussion_r1720208116 + # This behavior is only needed for backward compatibility and will be removed in future versions. + *args, # to be deprecated + text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, + audio: Optional[AudioInput] = None, + video: Optional[VideoInput] = None, + **kwargs, + ) -> BatchEncoding: + """ + This method uses [`SamImageProcessor.__call__`] method to prepare image(s) for the model. It also prepares 2D + points and bounding boxes for the model if they are provided. + """ + output_kwargs = self._merge_kwargs( + SamProcessorKwargs, + tokenizer_init_kwargs={}, + **kwargs, + **self.prepare_and_validate_optional_call_args(*args), + ) + input_points = output_kwargs["images_kwargs"].pop("input_points", None) + input_labels = output_kwargs["images_kwargs"].pop("input_labels", None) + input_boxes = output_kwargs["images_kwargs"].pop("input_boxes", None) + point_pad_value = output_kwargs["images_kwargs"].pop("point_pad_value", None) + + encoding_image_processor = self.image_processor( + images, + **output_kwargs["images_kwargs"], + ) + + # pop arguments that are not used in the foward but used nevertheless + original_sizes = encoding_image_processor["original_sizes"] + + if hasattr(original_sizes, "numpy"): # Checks if Torch or TF tensor + original_sizes = original_sizes.numpy() + + input_points, input_labels, input_boxes = self._check_and_preprocess_points( + input_points=input_points, + input_labels=input_labels, + input_boxes=input_boxes, + ) + + encoding_image_processor = self._normalize_and_convert( + encoding_image_processor, + original_sizes, + input_points=input_points, + input_labels=input_labels, + input_boxes=input_boxes, + return_tensors=output_kwargs["common_kwargs"].get("return_tensors"), + point_pad_value=point_pad_value, + ) + + return encoding_image_processor + + def _normalize_and_convert( + self, + encoding_image_processor, + original_sizes, + input_points=None, + input_labels=None, + input_boxes=None, + return_tensors="pt", + point_pad_value=-10, + ): + if input_points is not None: + if len(original_sizes) != len(input_points): + input_points = [ + self._normalize_coordinates(self.target_size, point, original_sizes[0]) for point in input_points + ] + else: + input_points = [ + self._normalize_coordinates(self.target_size, point, original_size) + for point, original_size in zip(input_points, original_sizes) + ] + # check that all arrays have the same shape + if not all(point.shape == input_points[0].shape for point in input_points): + if input_labels is not None: + input_points, input_labels = self._pad_points_and_labels( + input_points, input_labels, point_pad_value + ) + + input_points = np.array(input_points) + + if input_labels is not None: + input_labels = np.array(input_labels) + + if input_boxes is not None: + if len(original_sizes) != len(input_boxes): + input_boxes = [ + self._normalize_coordinates(self.target_size, box, original_sizes[0], is_bounding_box=True) + for box in input_boxes + ] + else: + input_boxes = [ + self._normalize_coordinates(self.target_size, box, original_size, is_bounding_box=True) + for box, original_size in zip(input_boxes, original_sizes) + ] + input_boxes = np.array(input_boxes) + + if input_boxes is not None: + if return_tensors == "pt": + input_boxes = torch.from_numpy(input_boxes) + # boxes batch size of 1 by default + input_boxes = input_boxes.unsqueeze(1) if len(input_boxes.shape) != 3 else input_boxes + elif return_tensors == "tf": + input_boxes = tf.convert_to_tensor(input_boxes) + # boxes batch size of 1 by default + input_boxes = tf.expand_dims(input_boxes, 1) if len(input_boxes.shape) != 3 else input_boxes + encoding_image_processor.update({"input_boxes": input_boxes}) + if input_points is not None: + if return_tensors == "pt": + input_points = torch.from_numpy(input_points) + # point batch size of 1 by default + input_points = input_points.unsqueeze(1) if len(input_points.shape) != 4 else input_points + elif return_tensors == "tf": + input_points = tf.convert_to_tensor(input_points) + # point batch size of 1 by default + input_points = tf.expand_dims(input_points, 1) if len(input_points.shape) != 4 else input_points + encoding_image_processor.update({"input_points": input_points}) + if input_labels is not None: + if return_tensors == "pt": + input_labels = torch.from_numpy(input_labels) + # point batch size of 1 by default + input_labels = input_labels.unsqueeze(1) if len(input_labels.shape) != 3 else input_labels + elif return_tensors == "tf": + input_labels = tf.convert_to_tensor(input_labels) + # point batch size of 1 by default + input_labels = tf.expand_dims(input_labels, 1) if len(input_labels.shape) != 3 else input_labels + encoding_image_processor.update({"input_labels": input_labels}) + + return encoding_image_processor + + def _pad_points_and_labels(self, input_points, input_labels, point_pad_value): + r""" + The method pads the 2D points and labels to the maximum number of points in the batch. + """ + expected_nb_points = max([point.shape[0] for point in input_points]) + processed_input_points = [] + for i, point in enumerate(input_points): + if point.shape[0] != expected_nb_points: + point = np.concatenate( + [point, np.zeros((expected_nb_points - point.shape[0], 2)) + point_pad_value], axis=0 + ) + input_labels[i] = np.append(input_labels[i], [point_pad_value]) + processed_input_points.append(point) + input_points = processed_input_points + return input_points, input_labels + + def _normalize_coordinates( + self, target_size: int, coords: np.ndarray, original_size, is_bounding_box=False + ) -> np.ndarray: + """ + Expects a numpy array of length 2 in the final dimension. Requires the original image size in (H, W) format. + """ + old_h, old_w = original_size + new_h, new_w = self.image_processor._get_preprocess_shape(original_size, longest_edge=target_size) + coords = deepcopy(coords).astype(float) + + if is_bounding_box: + coords = coords.reshape(-1, 2, 2) + + coords[..., 0] = coords[..., 0] * (new_w / old_w) + coords[..., 1] = coords[..., 1] * (new_h / old_h) + + if is_bounding_box: + coords = coords.reshape(-1, 4) + + return coords + + def _check_and_preprocess_points( + self, + input_points=None, + input_labels=None, + input_boxes=None, + ): + r""" + Check and preprocesses the 2D points, labels and bounding boxes. It checks if the input is valid and if they + are, it converts the coordinates of the points and bounding boxes. If a user passes directly a `torch.Tensor`, + it is converted to a `numpy.ndarray` and then to a `list`. + """ + if input_points is not None: + if hasattr(input_points, "numpy"): # Checks for TF or Torch tensor + input_points = input_points.numpy().tolist() + + if not isinstance(input_points, list) or not isinstance(input_points[0], list): + raise ValueError("Input points must be a list of list of floating points.") + input_points = [np.array(input_point) for input_point in input_points] + else: + input_points = None + + if input_labels is not None: + if hasattr(input_labels, "numpy"): + input_labels = input_labels.numpy().tolist() + + if not isinstance(input_labels, list) or not isinstance(input_labels[0], list): + raise ValueError("Input labels must be a list of list integers.") + input_labels = [np.array(label) for label in input_labels] + else: + input_labels = None + + if input_boxes is not None: + if hasattr(input_boxes, "numpy"): + input_boxes = input_boxes.numpy().tolist() + + if ( + not isinstance(input_boxes, list) + or not isinstance(input_boxes[0], list) + or not isinstance(input_boxes[0][0], list) + ): + raise ValueError("Input boxes must be a list of list of list of floating points.") + input_boxes = [np.array(box).astype(np.float32) for box in input_boxes] + else: + input_boxes = None + + return input_points, input_labels, input_boxes + + @property + def model_input_names(self): + image_processor_input_names = self.image_processor.model_input_names + return list(dict.fromkeys(image_processor_input_names)) + + def post_process_masks(self, *args, **kwargs): + return self.image_processor.post_process_masks(*args, **kwargs) + + +__all__ = ["SamProcessor"] diff --git a/docs/transformers/src/transformers/models/seamless_m4t/__init__.py b/docs/transformers/src/transformers/models/seamless_m4t/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0d289de0474acfadeeee08019dd57f40f9a1a4b4 --- /dev/null +++ b/docs/transformers/src/transformers/models/seamless_m4t/__init__.py @@ -0,0 +1,31 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_seamless_m4t import * + from .feature_extraction_seamless_m4t import * + from .modeling_seamless_m4t import * + from .processing_seamless_m4t import * + from .tokenization_seamless_m4t import * + from .tokenization_seamless_m4t_fast import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/seamless_m4t/configuration_seamless_m4t.py b/docs/transformers/src/transformers/models/seamless_m4t/configuration_seamless_m4t.py new file mode 100644 index 0000000000000000000000000000000000000000..f406264b030873e7a0075687e874dc7ca8d0cd67 --- /dev/null +++ b/docs/transformers/src/transformers/models/seamless_m4t/configuration_seamless_m4t.py @@ -0,0 +1,416 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""SeamlessM4T model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class SeamlessM4TConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`~SeamlessM4TModel`]. It is used to instantiate an + SeamlessM4T model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the SeamlessM4T + ["facebook/hf-seamless-m4t-medium"](https://huggingface.co/"facebook/hf-seamless-m4t-medium") architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 256102): + Vocabulary size of the SeamlessM4T model. Defines the number of different tokens that can be represented by + the `inputs_ids` passed when calling [`~SeamlessM4TModel`], [`~SeamlessM4TForTextToSpeech`] or + [`~SeamlessM4TForTextToText`]. + t2u_vocab_size (`int`, *optional*, defaults to 10082): + Unit vocabulary size of the SeamlessM4T model. Defines the number of different unit tokens that can be + represented by the `inputs_ids` passed when calling the Text-To-Units sub-model of [`~SeamlessM4TModel`], + [`~SeamlessM4TForSpeechToSpeech`] or [`~SeamlessM4TForTextToSpeech`]. + + > Parameters shared across sub-models + + hidden_size (`int`, *optional*, defaults to 1024): + Dimensionality of the "intermediate" layers in the architecture. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the layer normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + max_position_embeddings (`int`, *optional*, defaults to 1024): + The maximum sequence length that this model text encoder and decoder might ever be used with. Typically set + this to something large just in case (e.g., 512 or 1024 or 2048). + is_encoder_decoder (`bool`, *optional*, defaults to `True`): + Whether the model is used as an encoder/decoder or not. + encoder_layerdrop (`float`, *optional*, defaults to 0.05): + The LayerDrop probability for the encoders. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) + for more details. + decoder_layerdrop (`float`, *optional*, defaults to 0.05): + The LayerDrop probability for the decoders. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) + for more details. + activation_function (`str` or `function`, *optional*, defaults to `"relu"`): + The non-linear activation function (function or string) in the decoder and feed-forward layers. If string, + `"gelu"`, `"relu"`, `"selu"`, `"swish"` and `"gelu_new"` are supported. + dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, decoder, and pooler. + attention_dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all attention layers. + activation_dropout (`float`, *optional*, defaults to 0.0): + The dropout probability for all activation layers in the model. + scale_embedding (`bool`, *optional*, defaults to `True`): + Scale embeddings by diving by sqrt(d_model). + + > Text encoder and text decoder specific parameters + + encoder_layers (`int`, *optional*, defaults to 24): + Number of hidden layers in the Transformer text encoder. + encoder_ffn_dim (`int`, *optional*, defaults to 8192): + Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer text encoder. + encoder_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer text encoder. + decoder_layers (`int`, *optional*, defaults to 24): + Number of hidden layers in the Transformer text decoder. + decoder_ffn_dim (`int`, *optional*, defaults to 8192): + Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer text decoder. + decoder_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer text decoder. + decoder_start_token_id (`int`, *optional*, defaults to 3): + If an encoder-decoder model starts decoding with a different token than _bos_, the id of that token. Only + applied in the text decoder. + max_new_tokens (`int`, *optional*, defaults to 256): + The maximum numbers of text tokens to generate, ignoring the number of tokens in the prompt. + pad_token_id (`int`, *optional*, defaults to 0): + The id of the _padding_ text token. Only applied to the text-decoder model. + bos_token_id (`int`, *optional*, defaults to 2): + The id of the _beginning-of-stream_ text token. Only applied to the text-decoder model. + eos_token_id (`int`, *optional*, defaults to 3): + The id of the _end-of-stream_ text token. Only applied to the text-decoder model. + + > Speech encoder specific parameters + + speech_encoder_layers (`int`, *optional*, defaults to 24): + Number of hidden layers in the Transformer speech encoder. + speech_encoder_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer speech encoder. + speech_encoder_intermediate_size (`int`, *optional*, defaults to 4096): + Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer speech encoder. + speech_encoder_hidden_act (`str` or `function`, *optional*, defaults to `"swish"`): + The non-linear activation function (function or string) in the speech encoder. If string, `"gelu"`, + `"relu"`, `"selu"`, `"swish"` and `"gelu_new"` are supported. + speech_encoder_dropout (`float`, *optional*, defaults to 0.0): + The dropout probability for all layers in the speech encoder. + add_adapter (`bool`, *optional*, defaults to `True`): + Add an adapter layer on top of the speech encoder. + speech_encoder_layerdrop (`float`, *optional*, defaults to 0.1): + The LayerDrop probability for the speech encoder. See the [LayerDrop paper](see + https://arxiv.org/abs/1909.11556) for more details. + feature_projection_input_dim (`int`, *optional*, defaults to 160): + Input dimension of the input feature projection of the speech encoder, i.e the dimension after processing + input audios with [`SeamlessM4TFeatureExtractor`]. + num_conv_pos_embeddings (`int`, *optional*, defaults to 128): + Number of convolutional positional embeddings. Defines the kernel size of 1D convolutional positional + embeddings layer of the speech encoder. + num_conv_pos_embedding_groups (`int`, *optional*, defaults to 16): + Number of groups of 1D convolutional positional embeddings layer of the speech encoder. + adaptor_kernel_size (`int`, *optional*, defaults to 8): + Kernel size of the convolutional layers in the adapter network. Only relevant if `add_adapter is True`. + adaptor_stride (`int`, *optional*, defaults to 8): + Stride of the convolutional layers in the adapter network. Only relevant if `add_adapter is True`. + adaptor_dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all layers in the speech adapter. + num_adapter_layers (`int`, *optional*, defaults to 1): + Number of convolutional layers that should be used in the adapter network. Only relevant if `add_adapter is + True`. + position_embeddings_type (`str`, *optional*, defaults to `"relative"`): + Can be specified to `relative` or `rotary` for relative or rotary position embeddings respectively. If left + `None` no relative position embedding is applied. Only applied to the speech encoder. + rotary_embedding_base (`int`, *optional*, defaults to 10000): + If `"rotary"` position embeddings are used, defines the size of the embedding base. Only applied to the + speech encoder. + max_source_positions (`int`, *optional*, defaults to 4096): + if `"relative"` position embeddings are used, defines the maximum source input positions. Only applied to + the speech encoder. + conv_depthwise_kernel_size (`int`, *optional*, defaults to 31): + Kernel size of convolutional depthwise 1D layer in Conformer blocks. Only applied to the speech encoder. + + > Text-To-Unit (t2u) model specific parameters + + t2u_bos_token_id (`int`, *optional*, defaults to 0): + The id of the _beginning-of-stream_ unit token. Only applied to the text-to-unit seq2seq model. + t2u_pad_token_id (`int`, *optional*, defaults to 1): + The id of the _padding_ unit token. Only applied to the text-to-unit seq2seq model. + t2u_eos_token_id (`int`, *optional*, defaults to 2): + The id of the _end-of-stream_ unit token. Only applied to the text-to-unit seq2seq model. + t2u_decoder_start_token_id (`int`, *optional*, defaults to 2): + If an encoder-decoder model starts decoding with a different token than _bos_, the id of that token. Only + applied to the text-to-unit seq2seq model. + t2u_max_new_tokens (`int`, *optional*, defaults to 1024): + The maximum numbers of unit tokens to generate, ignoring the number of tokens in the prompt. Only applied + to the text-to-unit seq2seq model. + t2u_encoder_layers (`int`, *optional*, defaults to 6): + Number of hidden layers in the Transformer text-to-unit encoder. + t2u_encoder_ffn_dim (`int`, *optional*, defaults to 8192): + Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer text-to-unit encoder. + t2u_encoder_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer text-to-unit encoder. + t2u_decoder_layers (`int`, *optional*, defaults to 6): + Number of hidden layers in the Transformer text-to-unit decoder. + t2u_decoder_ffn_dim (`int`, *optional*, defaults to 8192): + Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer text-to-unit decoder. + t2u_decoder_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer text-to-unit decoder. + t2u_max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model text-to-unit component might ever be used with. Typically set + this to something large just in case (e.g., 512 or 1024 or 2048). + + > Hifi-Gan Vocoder specific parameters + + sampling_rate (`int`, *optional*, defaults to 16000): + The sampling rate at which the output audio will be generated, expressed in hertz (Hz). + upsample_initial_channel (`int`, *optional*, defaults to 512): + The number of input channels into the hifi-gan upsampling network. Applies to the vocoder only. + upsample_rates (`Tuple[int]` or `List[int]`, *optional*, defaults to `[5, 4, 4, 2, 2]`): + A tuple of integers defining the stride of each 1D convolutional layer in the vocoder upsampling network. + The length of *upsample_rates* defines the number of convolutional layers and has to match the length of + *upsample_kernel_sizes*. Applies to the vocoder only. + upsample_kernel_sizes (`Tuple[int]` or `List[int]`, *optional*, defaults to `[11, 8, 8, 4, 4]`): + A tuple of integers defining the kernel size of each 1D convolutional layer in the vocoder upsampling + network. The length of *upsample_kernel_sizes* defines the number of convolutional layers and has to match + the length of *upsample_rates*. Applies to the vocoder only. + resblock_kernel_sizes (`Tuple[int]` or `List[int]`, *optional*, defaults to `[3, 7, 11]`): + A tuple of integers defining the kernel sizes of the vocoder 1D convolutional layers in the multi-receptive + field fusion (MRF) module. Applies to the vocoder only. + resblock_dilation_sizes (`Tuple[Tuple[int]]` or `List[List[int]]`, *optional*, defaults to `[[1, 3, 5], [1, 3, 5], [1, 3, 5]]`): + A nested tuple of integers defining the dilation rates of the vocoder dilated 1D convolutional layers in + the multi-receptive field fusion (MRF) module. Applies to the vocoder only. + leaky_relu_slope (`float`, *optional*, defaults to 0.1): + The angle of the negative slope used by the leaky ReLU activation in the vocoder. Applies to the vocoder + only. + unit_hifi_gan_vocab_size (`int`, *optional*, defaults to 10000): + Vocabulary size of the SeamlessM4T vocoder. Defines the number of different unit tokens that can be + represented by the `inputs_ids` passed when calling the vocoder of [`~SeamlessM4TModel`], + [`~SeamlessM4TForSpeechToSpeech`] or [`~SeamlessM4TForTextToSpeech`]. + unit_embed_dim (`int`, *optional*, defaults to 1280): + The projection dimension of the input ids given to the hifi-gan vocoder. Applies to the vocoder only. + lang_embed_dim (`int`, *optional*, defaults to 256): + The projection dimension of the target language given to the hifi-gan vocoder. Applies to the vocoder only. + spkr_embed_dim (`int`, *optional*, defaults to 256): + The projection dimension of the speaker id given to the hifi-gan vocoder. Applies to the vocoder only. + vocoder_num_langs (`int`, *optional*, defaults to 36): + Number of langs supported by the vocoder. Might be different from `t2u_num_langs`. + vocoder_num_spkrs (`int`, *optional*, defaults to 200): + Number of speakers supported by the vocoder. + variance_predictor_kernel_size (`int`, *optional*, defaults to 3): + Kernel size of the duration predictor. Applies to the vocoder only. + var_pred_dropout (`float`, *optional*, defaults to 0.5): + The dropout probability of the duration predictor. Applies to the vocoder only. + vocoder_offset (`int`, *optional*, defaults to 4): + Offset the unit token ids by this number to account for symbol tokens. Applies to the vocoder only. + + ```python + >>> from transformers import SeamlessM4TModel, SeamlessM4TConfig + + >>> # Initializing a SeamlessM4T "facebook/hf-seamless-m4t-medium" style configuration + >>> configuration = SeamlessM4TConfig() + + >>> # Initializing a model from the "facebook/hf-seamless-m4t-medium" style configuration + >>> model = SeamlessM4TModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "seamless_m4t" + + def __init__( + self, + vocab_size=256102, + t2u_vocab_size=10082, + # shared config + hidden_size=1024, + initializer_range=0.02, + layer_norm_eps=1e-5, + use_cache=True, + max_position_embeddings=1024, + is_encoder_decoder=True, + encoder_layerdrop=0.05, + decoder_layerdrop=0.05, + activation_function="relu", + dropout=0.1, + attention_dropout=0.1, + activation_dropout=0.0, + scale_embedding=True, + # text encoder|decoder + encoder_layers=24, + encoder_ffn_dim=8192, + encoder_attention_heads=16, + decoder_layers=24, + decoder_ffn_dim=8192, + decoder_attention_heads=16, + decoder_start_token_id=3, + max_new_tokens=256, + pad_token_id=0, + bos_token_id=2, + eos_token_id=3, + # speech_encoder + speech_encoder_layers=24, + speech_encoder_attention_heads=16, + speech_encoder_intermediate_size=4096, + speech_encoder_hidden_act="swish", + speech_encoder_dropout=0.0, + add_adapter=True, + speech_encoder_layerdrop=0.1, + feature_projection_input_dim=160, + num_conv_pos_embeddings=128, + num_conv_pos_embedding_groups=16, + adaptor_kernel_size=8, + adaptor_stride=8, + adaptor_dropout=0.1, + num_adapter_layers=1, + position_embeddings_type="relative", + rotary_embedding_base=10000, + max_source_positions=4096, + conv_depthwise_kernel_size=31, + # t2u config + t2u_bos_token_id=0, + t2u_pad_token_id=1, + t2u_eos_token_id=2, + t2u_decoder_start_token_id=2, + t2u_max_new_tokens=1024, + t2u_encoder_layers=6, + t2u_encoder_ffn_dim=8192, + t2u_encoder_attention_heads=16, + t2u_decoder_layers=6, + t2u_decoder_ffn_dim=8192, + t2u_decoder_attention_heads=16, + t2u_max_position_embeddings=2048, + # hifi-gan vocoder config + sampling_rate=16000, + upsample_initial_channel=512, + upsample_rates=[5, 4, 4, 2, 2], + upsample_kernel_sizes=[11, 8, 8, 4, 4], + resblock_kernel_sizes=[3, 7, 11], + resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], + leaky_relu_slope=0.1, + # specific to Code Hifi-Gan + unit_hifi_gan_vocab_size=10000, + unit_embed_dim=1280, + lang_embed_dim=256, + spkr_embed_dim=256, + vocoder_num_langs=36, + vocoder_num_spkrs=200, + variance_predictor_kernel_size=3, + var_pred_dropout=0.5, + vocoder_offset=4, + **kwargs, + ): + # overall_config + self.vocab_size = vocab_size + self.t2u_vocab_size = t2u_vocab_size + self.hidden_size = hidden_size + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.max_position_embeddings = max_position_embeddings + self.use_cache = use_cache + self.max_new_tokens = max_new_tokens + self.encoder_layerdrop = encoder_layerdrop + self.decoder_layerdrop = decoder_layerdrop + self.activation_function = activation_function + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.scale_embedding = scale_embedding + # for proper config init + self.num_attention_heads = decoder_attention_heads + self.num_hidden_layers = decoder_layers + + # text|unit encoder|decoder + self.encoder_layers = encoder_layers + self.encoder_ffn_dim = encoder_ffn_dim + self.encoder_attention_heads = encoder_attention_heads + self.decoder_layers = decoder_layers + self.decoder_ffn_dim = decoder_ffn_dim + self.decoder_attention_heads = decoder_attention_heads + + # speech_encoder + self.speech_encoder_layers = speech_encoder_layers + self.speech_encoder_hidden_act = speech_encoder_hidden_act + self.speech_encoder_dropout = speech_encoder_dropout + self.speech_encoder_attention_heads = speech_encoder_attention_heads + self.speech_encoder_layerdrop = speech_encoder_layerdrop + self.speech_encoder_intermediate_size = speech_encoder_intermediate_size + self.feature_projection_input_dim = feature_projection_input_dim + self.num_conv_pos_embeddings = num_conv_pos_embeddings + self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups + self.adaptor_kernel_size = adaptor_kernel_size + self.adaptor_stride = adaptor_stride + self.adaptor_dropout = adaptor_dropout + self.num_adapter_layers = num_adapter_layers + self.position_embeddings_type = position_embeddings_type + self.rotary_embedding_base = rotary_embedding_base + self.max_source_positions = max_source_positions + self.conv_depthwise_kernel_size = conv_depthwise_kernel_size + self.add_adapter = add_adapter + + # t2u config + self.t2u_bos_token_id = t2u_bos_token_id + self.t2u_pad_token_id = t2u_pad_token_id + self.t2u_eos_token_id = t2u_eos_token_id + self.t2u_decoder_start_token_id = t2u_decoder_start_token_id + self.t2u_max_new_tokens = t2u_max_new_tokens + self.t2u_encoder_layers = t2u_encoder_layers + self.t2u_encoder_ffn_dim = t2u_encoder_ffn_dim + self.t2u_encoder_attention_heads = t2u_encoder_attention_heads + self.t2u_decoder_layers = t2u_decoder_layers + self.t2u_decoder_ffn_dim = t2u_decoder_ffn_dim + self.t2u_decoder_attention_heads = t2u_decoder_attention_heads + self.t2u_max_position_embeddings = t2u_max_position_embeddings + + # hifi-gan vocoder config + # original parameters specific to Hifi-Gan + self.sampling_rate = sampling_rate + self.upsample_initial_channel = upsample_initial_channel + self.upsample_rates = upsample_rates + self.upsample_kernel_sizes = upsample_kernel_sizes + self.resblock_kernel_sizes = resblock_kernel_sizes + self.resblock_dilation_sizes = resblock_dilation_sizes + self.leaky_relu_slope = leaky_relu_slope + + # specific to Code Hifi-Gan + self.unit_hifi_gan_vocab_size = unit_hifi_gan_vocab_size + self.unit_embed_dim = unit_embed_dim + self.lang_embed_dim = lang_embed_dim + self.spkr_embed_dim = spkr_embed_dim + self.vocoder_num_langs = vocoder_num_langs + self.vocoder_num_spkrs = vocoder_num_spkrs + self.variance_predictor_kernel_size = variance_predictor_kernel_size + self.var_pred_dropout = var_pred_dropout + self.vocoder_offset = vocoder_offset + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + decoder_start_token_id=decoder_start_token_id, + is_encoder_decoder=is_encoder_decoder, + max_position_embeddings=max_position_embeddings, + **kwargs, + ) + + +__all__ = ["SeamlessM4TConfig"] diff --git a/docs/transformers/src/transformers/models/seamless_m4t/convert_fairseq2_to_hf.py b/docs/transformers/src/transformers/models/seamless_m4t/convert_fairseq2_to_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..7bef416ec3757f4eb172f0c57135cb2ba58d3540 --- /dev/null +++ b/docs/transformers/src/transformers/models/seamless_m4t/convert_fairseq2_to_hf.py @@ -0,0 +1,396 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Converting Meta SeamlessM4T checkpoints from seamless_communication to HF.""" + +import argparse +import os +from pathlib import Path + +import torch +from accelerate.utils.modeling import find_tied_parameters +from seamless_communication.models.inference.translator import Translator + +from transformers import ( + SeamlessM4TConfig, + SeamlessM4TFeatureExtractor, + SeamlessM4TModel, + SeamlessM4TProcessor, + SeamlessM4TTokenizer, +) +from transformers.utils import logging + + +UNIT_SUPPORTED_LANGUAGES = ["__arb__", "__ben__", "__cat__", "__ces__", "__cmn__", "__cym__", "__dan__", "__deu__", "__eng__", "__est__", "__fin__", "__fra__", "__hin__", "__ind__", "__ita__", "__jpn__", "__kan__", "__kor__", "__mlt__", "__nld__", "__pes__", "__pol__", "__por__", "__ron__", "__rus__", "__slk__", "__spa__", "__swe__", "__swh__", "__tam__", "__tel__", "__tgl__", "__tha__", "__tur__", "__ukr__", "__urd__", "__uzn__", "__vie__", ] # fmt: skip +VOCODER_SUPPORTED_LANGUAGES = ["__arb__", "__ben__", "__cat__", "__ces__", "__cmn__", "__cym__", "__dan__", "__deu__", "__eng__", "__est__", "__fin__", "__fra__", "__hin__", "__ind__", "__ita__", "__jpn__", "__kor__", "__mlt__", "__nld__", "__pes__", "__pol__", "__por__", "__ron__", "__rus__", "__slk__", "__spa__", "__swe__", "__swh__", "__tel__", "__tgl__", "__tha__", "__tur__", "__ukr__", "__urd__", "__uzn__", "__vie__",] # fmt: skip +MEDIUM_SUPPORTED_LANGUAGES = ["ace","ace_Latn","acm","acq","aeb","afr","ajp","aka","amh","apc","arb","ars","ary","arz","asm","ast","awa","ayr","azb","azj","bak","bam","ban","bel","bem","ben","bho","bjn","bjn_Latn","bod","bos","bug","bul","cat","ceb","ces","cjk","ckb","crh","cym","dan","deu","dik","dyu","dzo","ell","eng","epo","est","eus","ewe","fao","pes","fij","fin","fon","fra","fur","fuv","gla","gle","glg","grn","guj","hat","hau","heb","hin","hne","hrv","hun","hye","ibo","ilo","ind","isl","ita","jav","jpn","kab","kac","kam","kan","kas","kas_Deva","kat","knc","knc_Latn","kaz","kbp","kea","khm","kik","kin","kir","kmb","kon","kor","kmr","lao","lvs","lij","lim","lin","lit","lmo","ltg","ltz","lua","lug","luo","lus","mag","mai","mal","mar","min","mkd","plt","mlt","mni","khk","mos","mri","zsm","mya","nld","nno","nob","npi","nso","nus","nya","oci","gaz","ory","pag","pan","pap","pol","por","prs","pbt","quy","ron","run","rus","sag","san","sat","scn","shn","sin","slk","slv","smo","sna","snd","som","sot","spa","als","srd","srp","ssw","sun","swe","swh","szl","tam","tat","tel","tgk","tgl","tha","tir","taq","taq_Tfng","tpi","tsn","tso","tuk","tum","tur","twi","tzm","uig","ukr","umb","urd","uzn","vec","vie","war","wol","xho","ydd","yor","yue","cmn","cmn_Hant","zul",] # fmt: skip +LARGE_SUPPORTED_LANGUAGES = ["afr","amh","arb","ary","arz","asm","azj","bel","ben","bos","bul","cat","ceb","ces","ckb","cmn","cmn_Hant","cym","dan","deu","ell","eng","est","eus","fin","fra","fuv","gaz","gle","glg","guj","heb","hin","hrv","hun","hye","ibo","ind","isl","ita","jav","jpn","kan","kat","kaz","khk","khm","kir","kor","lao","lit","lug","luo","lvs","mai","mal","mar","mkd","mlt","mni","mya","nld","nno","nob","npi","nya","ory","pan","pbt","pes","pol","por","ron","rus","sat","slk","slv","sna","snd","som","spa","srp","swe","swh","tam","tel","tgk","tgl","tha","tur","ukr","urd","uzn","vie","yor","yue","zlm","zul",] # fmt: skip + + +def assert_param_count(model_1, model_2): + count_1 = sum(p[1].numel() for p in model_1.named_parameters() if "final_proj" not in p[0]) + count_2 = sum(p[1].numel() for p in model_2.named_parameters() if "final_proj" not in p[0]) + assert count_1 == count_2, f"{model_1.__class__}: {count_1} != {model_2.__class__}: {count_2}" + + +def param_count(model): + return sum(p[1].numel() for p in model.named_parameters() if "final_proj" not in p[0]) + + +def _grab_best_device(use_gpu=True): + if torch.cuda.device_count() > 0 and use_gpu: + device = "cuda" + else: + device = "cpu" + return torch.device(device) + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + +vocoder_convert_list = [ + ("ups", "hifi_gan.upsampler"), + ("conv_pre", "hifi_gan.conv_pre"), + ("resblocks", "hifi_gan.resblocks"), + ("conv_post", "hifi_gan.conv_post"), + ("lang", "language_embedding"), + ("spkr", "speaker_embedding"), + ("dict.", "unit_embedding."), + ("dur_predictor.conv1.0", "dur_predictor.conv1"), + ("dur_predictor.conv2.0", "dur_predictor.conv2"), +] + +# order is important +wav2vec_convert_list = [ + ("speech_encoder_frontend.model_dim_proj", "feature_projection.projection"), + ("speech_encoder_frontend.post_extract_layer_norm", "feature_projection.layer_norm"), + ("speech_encoder_frontend.pos_encoder.conv", "encoder.pos_conv_embed.conv"), + ("speech_encoder.inner.layers", "encoder.layers"), + ("speech_encoder.inner_layer_norm", "encoder.layer_norm"), + ("speech_encoder.adaptor_layers", "adapter.layers"), + ("inner_proj", "intermediate_dense"), + ("self_attn.output_proj", "self_attn.linear_out"), + ("output_proj", "output_dense"), + ("self_attn.k_proj", "self_attn.linear_k"), + ("self_attn.v_proj", "self_attn.linear_v"), + ("self_attn.q_proj", "self_attn.linear_q"), + ("self_attn.sdpa.u_bias", "self_attn.pos_bias_u"), + ("self_attn.sdpa.v_bias", "self_attn.pos_bias_v"), + ("self_attn.sdpa.r_proj", "self_attn.linear_pos"), + ("conv.pointwise_conv1", "conv_module.pointwise_conv1"), + ("conv.pointwise_conv2", "conv_module.pointwise_conv2"), + ("conv.depthwise_conv", "conv_module.depthwise_conv"), + ("conv.batch_norm", "conv_module.batch_norm"), + ("conv_layer_norm", "conv_module.layer_norm"), + ("speech_encoder.proj1", "intermediate_ffn.intermediate_dense"), + ("speech_encoder.proj2", "intermediate_ffn.output_dense"), + ("speech_encoder.layer_norm", "inner_layer_norm"), +] + +t2u_convert_list = [ + ("t2u_model.final_proj", "lm_head"), + ("t2u_model.", "model."), + ("encoder_decoder_attn_layer_norm", "cross_attention_layer_norm"), + ("encoder_decoder_attn", "cross_attention"), + ("linear_k", "k_proj"), + ("linear_v", "v_proj"), + ("linear_q", "q_proj"), + ("ffn.inner_proj", "ffn.fc1"), + ("ffn.output_proj", "ffn.fc2"), + ("output_proj", "out_proj"), + ("decoder_frontend.embed", "decoder.embed_tokens"), +] + +text_convert_list = [ + ("text_encoder.", ""), + ("text_decoder.", ""), + ("text_encoder_frontend.embed", "embed_tokens"), + ("text_decoder_frontend.embed", "embed_tokens"), + ("encoder_decoder_attn_layer_norm", "cross_attention_layer_norm"), + ("encoder_decoder_attn", "cross_attention"), + ("linear_k", "k_proj"), + ("linear_v", "v_proj"), + ("linear_q", "q_proj"), + ("ffn.inner_proj", "ffn.fc1"), + ("ffn.output_proj", "ffn.fc2"), + ("output_proj", "out_proj"), + ("final_proj", "lm_head"), +] + +CUR_PATH = os.path.dirname(os.path.abspath(__file__)) +default_cache_dir = os.path.join(os.path.expanduser("~"), ".cache") +CACHE_DIR = os.path.join(os.getenv("XDG_CACHE_HOME", default_cache_dir), "huggingface", "hub") + + +def _load_hf_config(model_type="medium"): + if model_type == "medium": + kwargs = { + "vocab_size": 256206, + "t2u_vocab_size": 10082, + "hidden_size": 1024, + "max_position_embeddings": 4096, + "encoder_layers": 12, + "decoder_layers": 12, + "encoder_ffn_dim": 4096, + "decoder_ffn_dim": 4096, + "t2u_encoder_layers": 4, + "t2u_decoder_layers": 4, + "speech_encoder_layers": 12, + } + return SeamlessM4TConfig(**kwargs) + else: + return SeamlessM4TConfig() + + +def _convert_model( + original_model, + hf_model, + convert_list, + device, + unwanted_prefix="model.", + filter_state_dict="speech", + exclude_state_dict=None, +): + state_dict = original_model.state_dict() + + # filter func + if isinstance(filter_state_dict, str): + + def filter_func(x): + return filter_state_dict in x[0] + + else: + + def filter_func(item): + if exclude_state_dict is not None and exclude_state_dict in item[0]: + return False + for filter_el in filter_state_dict: + if filter_el in item[0]: + return True + + return False + + state_dict = dict(filter(filter_func, state_dict.items())) + + for k, v in list(state_dict.items()): + new_k = k[len(unwanted_prefix) :] + for old_layer_name, new_layer_name in convert_list: + if old_layer_name in new_k: + new_k = new_k.replace(old_layer_name, new_layer_name) + + # must do it by hand + if ".layer_norm" in new_k and new_k.split(".layer_norm")[0][-1].isnumeric(): + new_k = new_k.replace("layer_norm", "final_layer_norm") + + state_dict[new_k] = state_dict.pop(k) + + extra_keys = set(state_dict.keys()) - set(hf_model.state_dict().keys()) + extra_keys = set(extra_keys) + missing_keys = set(hf_model.state_dict().keys()) - set(state_dict.keys()) + missing_keys = set({k for k in missing_keys if "final_logits_bias" not in k}) + if len(extra_keys) != 0: + raise ValueError(f"extra keys found: {extra_keys}") + if len(missing_keys) != 0: + raise ValueError(f"missing keys: {missing_keys}") + hf_model.load_state_dict(state_dict, strict=False) + n_params = param_count(hf_model) + + logger.info(f"model loaded: {round(n_params / 1e6, 1)}M params") + + hf_model.eval() + hf_model.to(device) + del state_dict + + return hf_model + + +def load_model(save_dir, model_type, repo_id): + """ + Meta SeamlessM4T is made of 8 main components: + - speech_encoder (#1) and speech_encoder_frontend (#2) + - t2u_model (#3) + - text_encoder (#4) and text_encoder_frontend (#5) + - text_decoder (#6) [and text_decoder_frontend (#5) = equals to text_encoder_frontend] + - final_proj (#7) + - vocoder (#8) + """ + device = _grab_best_device() + if model_type == "medium": + name = "seamlessM4T_medium" + else: + name = "seamlessM4T_large" + + original_model = Translator(name, "vocoder_36langs", device, torch.float32) + + ######### TOKENIZER + + langs = MEDIUM_SUPPORTED_LANGUAGES if model_type == "medium" else LARGE_SUPPORTED_LANGUAGES + langs = [f"__{lang}__" for lang in langs] + vocab_file = os.path.join(os.path.expanduser("~"), "tokenizer", model_type, "tokenizer.model") + + save_dir = os.path.join(save_dir, name) + Path(save_dir).mkdir(exist_ok=True) + + tokenizer = SeamlessM4TTokenizer(vocab_file, additional_special_tokens=langs) + + sanity_check_lang_id = tokenizer.convert_tokens_to_ids("__fra__") + + tokenizer.save_pretrained(save_dir) + tokenizer = SeamlessM4TTokenizer.from_pretrained(save_dir) + + if sanity_check_lang_id != tokenizer.convert_tokens_to_ids("__fra__"): + raise ValueError( + f"Error in tokenizer saving/loading - __fra__ lang id is not coherent: {sanity_check_lang_id} vs {tokenizer.convert_tokens_to_ids('__fra__')}" + ) + + ####### get language to ids dict + text_decoder_lang_code_to_id = {lang.replace("__", ""): tokenizer.convert_tokens_to_ids(lang) for lang in langs} + # offset: vocoder unit vocab size + 5 (for EOS/PAD/BOS/UNK/MSK) + len(supported_languages) + t2u_lang_code_to_id = { + code.replace("__", ""): i + 10005 + len(UNIT_SUPPORTED_LANGUAGES) + for i, code in enumerate(UNIT_SUPPORTED_LANGUAGES) + } + vocoder_lang_code_to_id = {code.replace("__", ""): i for i, code in enumerate(VOCODER_SUPPORTED_LANGUAGES)} + + ######### FE + + fe = SeamlessM4TFeatureExtractor(language_code=langs) + + fe.save_pretrained(save_dir) + fe = SeamlessM4TFeatureExtractor.from_pretrained(save_dir) + + processor = SeamlessM4TProcessor(feature_extractor=fe, tokenizer=tokenizer) + processor.save_pretrained(save_dir) + processor.push_to_hub(repo_id=repo_id, create_pr=True) + + processor = SeamlessM4TProcessor.from_pretrained(save_dir) + + ######## Model + + # init model + hf_config = _load_hf_config(model_type) + hf_model = SeamlessM4TModel(hf_config) + + hf_model.generation_config.__setattr__("text_decoder_lang_to_code_id", text_decoder_lang_code_to_id) + hf_model.generation_config.__setattr__("t2u_lang_code_to_id", t2u_lang_code_to_id) + hf_model.generation_config.__setattr__("vocoder_lang_code_to_id", vocoder_lang_code_to_id) + + # -1. take care of vocoder + # similarly to speech T5 must apply and remove weight norm + hf_model.vocoder.apply_weight_norm() + hf_model.vocoder = _convert_model( + original_model, + hf_model.vocoder, + vocoder_convert_list, + device, + unwanted_prefix="vocoder.code_generator.", + filter_state_dict="vocoder", + ) + hf_model.vocoder.remove_weight_norm() + + # 1. take care of speech encoder + wav2vec = hf_model.speech_encoder + hf_model.speech_encoder = _convert_model( + original_model, wav2vec, wav2vec_convert_list, device, unwanted_prefix="model.", filter_state_dict="speech" + ) + + # 2. take care of t2u + + hf_model.t2u_model = _convert_model( + original_model, + hf_model.t2u_model, + t2u_convert_list, + device, + unwanted_prefix="model.", + filter_state_dict="t2u_model", + ) + + # 3. take care of text encoder + hf_model.text_encoder = _convert_model( + original_model, + hf_model.text_encoder, + text_convert_list, + device, + unwanted_prefix="model.", + filter_state_dict=["model.text_encoder"], + exclude_state_dict="t2u_model", + ) + + # 4. take care of text decoder + hf_model.text_decoder = _convert_model( + original_model, + hf_model.text_decoder, + text_convert_list, + device, + unwanted_prefix="model.", + filter_state_dict=["model.text_decoder"], + exclude_state_dict="t2u_model", + ) + + # 5. take care of final proj + hf_model.lm_head = _convert_model( + original_model, + hf_model.lm_head, + [("final_proj.", "")], + device, + unwanted_prefix="model.", + filter_state_dict=["model.final_proj"], + exclude_state_dict="t2u_model", + ) + + # sanity check + print(find_tied_parameters(hf_model)) + + count_1 = param_count(hf_model) + count_2 = param_count(original_model) + + print(f"HF MODEL:{count_1}, ORIGINAL_MODEL: {count_2}, diff:{count_1 - count_2}") + print(f"HF MODEL excluding embeddings:{hf_model.num_parameters(exclude_embeddings=True)}") + + del original_model + + hf_model.generation_config._from_model_config = False + hf_model.save_pretrained(save_dir) + hf_model.push_to_hub(repo_id=repo_id, create_pr=True) + hf_model = SeamlessM4TModel.from_pretrained(save_dir) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + + parser.add_argument( + "--model_type", + default="medium", + type=str, + help="Model type.", + ) + + parser.add_argument( + "--save_dir", + default="/home/ubuntu/weights", + type=str, + help="Path to the output PyTorch model.", + ) + + parser.add_argument( + "--repo_id", + default="facebook/hf-seamless-m4t-medium", + type=str, + help="Repo ID.", + ) + + args = parser.parse_args() + + load_model(args.save_dir, args.model_type, args.repo_id) diff --git a/docs/transformers/src/transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py b/docs/transformers/src/transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py new file mode 100644 index 0000000000000000000000000000000000000000..b17dcf792e187bf0d52b2beb5764cad0043d2723 --- /dev/null +++ b/docs/transformers/src/transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py @@ -0,0 +1,309 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Feature extractor class for SeamlessM4T +""" + +from typing import List, Optional, Union + +import numpy as np + +from ...utils import is_torch_available + + +if is_torch_available(): + import torch + +from ...audio_utils import mel_filter_bank, spectrogram, window_function +from ...feature_extraction_sequence_utils import SequenceFeatureExtractor +from ...feature_extraction_utils import BatchFeature +from ...utils import PaddingStrategy, TensorType, logging + + +logger = logging.get_logger(__name__) + + +class SeamlessM4TFeatureExtractor(SequenceFeatureExtractor): + r""" + Constructs a SeamlessM4T feature extractor. + + This feature extractor inherits from [`SequenceFeatureExtractor`] which contains most of the main methods. Users + should refer to this superclass for more information regarding those methods. + + This class extracts mel-filter bank features from raw speech. + + Args: + feature_size (`int`, *optional*, defaults to 80): + The feature dimension of the extracted features. + sampling_rate (`int`, *optional*, defaults to 16000): + The sampling rate at which the audio files should be digitalized expressed in hertz (Hz). + num_mel_bins (`int`, *optional*, defaults to 80): + Number of Mel-frequency bins. + padding_value (`float`, *optional*, defaults to 0.0): + The value that is used to fill the padding vectors. + stride (`int`, *optional*, defaults to 2): + Stride used to reshape audios from shape (batch_size,num_frames,num_mel_bins) to + (batch_size,num_frames//stride,num_mel_bins*stride). + """ + + model_input_names = ["input_features", "attention_mask"] + + def __init__( + self, + feature_size=80, + sampling_rate=16000, + num_mel_bins=80, + padding_value=0.0, + stride=2, + **kwargs, + ): + self.num_mel_bins = num_mel_bins + self.return_attention_mask = True + self.stride = stride + + mel_filters = mel_filter_bank( + num_frequency_bins=257, + num_mel_filters=self.num_mel_bins, + min_frequency=20, + max_frequency=sampling_rate // 2, + sampling_rate=sampling_rate, + norm=None, + mel_scale="kaldi", + triangularize_in_mel_space=True, + ) + + self.mel_filters = mel_filters + self.window = window_function(400, "povey", periodic=False) + + super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs) + + @staticmethod + # Copied from transformers.models.wav2vec2.feature_extraction_wav2vec2.Wav2Vec2FeatureExtractor.zero_mean_unit_var_norm + def zero_mean_unit_var_norm( + input_values: List[np.ndarray], attention_mask: List[np.ndarray], padding_value: float = 0.0 + ) -> List[np.ndarray]: + """ + Every array in the list is normalized to have zero mean and unit variance + """ + if attention_mask is not None: + attention_mask = np.array(attention_mask, np.int32) + normed_input_values = [] + + for vector, length in zip(input_values, attention_mask.sum(-1)): + normed_slice = (vector - vector[:length].mean()) / np.sqrt(vector[:length].var() + 1e-7) + if length < normed_slice.shape[0]: + normed_slice[length:] = padding_value + + normed_input_values.append(normed_slice) + else: + normed_input_values = [(x - x.mean()) / np.sqrt(x.var() + 1e-7) for x in input_values] + + return normed_input_values + + def _extract_fbank_features( + self, + waveform: np.ndarray, + ) -> np.ndarray: + """ + Get mel-filter bank features using TorchAudio. Note that TorchAudio requires 16-bit signed integers as inputs + and hence the waveform should not be normalized before feature extraction. + """ + # by default, it extracts the left channel if stereo + if len(waveform.shape) == 2: + waveform = waveform[0] + + waveform = np.squeeze(waveform) * (2**15) # Kaldi compliance: 16-bit signed integers + features = spectrogram( + waveform, + self.window, + frame_length=400, + hop_length=160, + fft_length=512, + power=2.0, + center=False, + preemphasis=0.97, + mel_filters=self.mel_filters, + log_mel="log", + mel_floor=1.192092955078125e-07, + remove_dc_offset=True, + ).T + return features + + def __call__( + self, + raw_speech: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]], + padding: Union[bool, str, PaddingStrategy] = True, + pad_to_multiple_of: Optional[int] = 2, + max_length: Optional[int] = None, + truncation: bool = False, + return_tensors: Optional[Union[str, TensorType]] = None, + sampling_rate: Optional[int] = None, + return_attention_mask: Optional[bool] = None, + do_normalize_per_mel_bins: Optional[bool] = True, + **kwargs, + ) -> BatchFeature: + """ + Main method to featurize and prepare for the model one or several sequence(s). + + Args: + raw_speech (`np.ndarray`, `torch.Tensor`, `List[float]`, `List[np.ndarray]`, `List[torch.Tensor]`, + `List[List[float]]`, `List[List[List[float]]]`): + The sequence or batch of sequences to be padded. Each sequence can be a numpy array, + a torch tensor, a list of float values, a list of numpy arrays, a list of torch tensors, + a list of list of float values or a list of a list of list of float values. + If `raw_speech` is a one-dimensional `np.ndarray`, `torch.Tensor` or a `List[float]`, `raw_speech` is + considered a single-channel, single-sample sound. In all other cases, the first dimension of + `raw_speech`, whether from an `np.ndarray`, a `torch.Tensor` or a `List[...]`, + corresponds to the number of samples in the batch, and the number of channels + (i.e. mono or stereo character) is derived from the other dimensions + (1D -> single-channel waveform batches; 2D-> stereo-channel waveform batches). + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding + index) among: + + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + pad_to_multiple_of (`int`, *optional*, defaults to 2): + If set will pad the sequence to a multiple of the provided value. + + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability + `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. + max_length (`int`, *optional*): + Maximum length of the returned list and optionally padding length (see above). + truncation (`bool`): + Activates truncation to cut input sequences longer than *max_length* to *max_length*. + return_attention_mask (`bool`, *optional*): + Whether to return the attention mask. If left to the default, will return the attention mask according + to the specific feature_extractor's default. + + [What are attention masks?](../glossary#attention-mask) + + + + For SeamlessM4T models, `attention_mask` should always be passed for batched inference, to avoid subtle + bugs. + + + + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors instead of list of python integers. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return Numpy `np.ndarray` objects. + sampling_rate (`int`, *optional*): + The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass + `sampling_rate` at the forward call to prevent silent errors. + do_normalize_per_mel_bins (`bool`, *optional*, defaults to `True`): + Whether or not to zero-mean unit-variance normalize the input per mel-channel. + kwargs (*optional*): + Remaining dictionary of keyword arguments that will be passed to the tokenizer or the feature + extractor. + """ + if sampling_rate is not None: + if sampling_rate != self.sampling_rate: + raise ValueError( + f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of" + f" {self.sampling_rate}. Please make sure that the provided `raw_speech` input was sampled with" + f" {self.sampling_rate} and not {sampling_rate}." + ) + else: + logger.warning( + f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. " + "Failing to do so can result in silent errors that might be hard to debug." + ) + + return_attention_mask = ( + return_attention_mask if return_attention_mask is not None else self.return_attention_mask + ) + + is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1 + if is_batched_numpy and len(raw_speech.shape) > 3: + raise ValueError(f"Only mono-channel or stereo-channel audio is supported for input to {self}") + + acceptable_types = ( + (torch.Tensor, np.ndarray, tuple, list) if is_torch_available() else (np.ndarray, tuple, list) + ) + is_batched = is_batched_numpy or ( + isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], acceptable_types)) + ) + + if is_batched: + raw_speech = [np.asarray(speech, dtype=np.float32) for speech in raw_speech] + elif not is_batched and not isinstance(raw_speech, np.ndarray): + raw_speech = np.asarray(raw_speech, dtype=np.float32) + elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(np.float64): + raw_speech = raw_speech.astype(np.float32) + + # always return batch + if not is_batched: + raw_speech = [raw_speech] + + # extract fbank features + features = [self._extract_fbank_features(waveform) for waveform in raw_speech] + + if do_normalize_per_mel_bins: + # torch defaults to ddof=1, and numpy defaults to ddof=0 + features = [ + (x - np.expand_dims(x.mean(0), 0)) / np.sqrt(np.expand_dims(x.var(0, ddof=1), 0) + 1e-7) + for x in features + ] + + # convert into correct format for padding + encoded_inputs = BatchFeature({"input_features": features}) + + padded_inputs = self.pad( + encoded_inputs, + padding=padding, + max_length=max_length, + truncation=truncation, + pad_to_multiple_of=pad_to_multiple_of, + return_attention_mask=True, + return_tensors="np", + ) + + # SeamlessM4T needs to process extracted features + input_features = padded_inputs.get("input_features") + attention_mask = padded_inputs.pop("attention_mask") + + batch_size, num_frames, num_channels = input_features.shape + + remainder = num_frames % self.stride + if remainder != 0: + input_features = input_features[:, : num_frames - remainder, :] + attention_mask = attention_mask[:, : num_frames - remainder] + + input_features = np.reshape( + input_features, (batch_size, num_frames // self.stride, num_channels * self.stride) + ) + + indices = np.arange(0, num_frames - remainder) + attention_mask = attention_mask[:, indices % self.stride == 1] + + padded_inputs["input_features"] = input_features + if return_attention_mask: + padded_inputs["attention_mask"] = attention_mask + + if return_tensors is not None: + padded_inputs = padded_inputs.convert_to_tensors(return_tensors) + + return padded_inputs + + +__all__ = ["SeamlessM4TFeatureExtractor"] diff --git a/docs/transformers/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py b/docs/transformers/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py new file mode 100644 index 0000000000000000000000000000000000000000..fa7c75844f756a785c67f0989be81efec7906e93 --- /dev/null +++ b/docs/transformers/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py @@ -0,0 +1,4308 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch SeamlessM4T model.""" + +import copy +import math +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import Tensor, nn +from torch.nn import CrossEntropyLoss + +from ...activations import ACT2FN +from ...generation import GenerationMixin +from ...integrations.deepspeed import is_deepspeed_zero3_enabled +from ...integrations.fsdp import is_fsdp_managed_module +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, + Wav2Vec2BaseModelOutput, +) +from ...modeling_utils import PreTrainedModel +from ...utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, +) +from .configuration_seamless_m4t import SeamlessM4TConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "facebook/hf-seamless-m4t-medium" +_CONFIG_FOR_DOC = "SeamlessM4TConfig" + + +@dataclass +class SeamlessM4TGenerationOutput(ModelOutput): + """ + Class defining the generated outputs from [`SeamlessM4TModel`], [`SeamlessM4TForTextToText`], + [`SeamlessM4TForTextToSpeech`], [`SeamlessM4TForSpeechToSpeech`] and [`SeamlessM4TForTextToSpeech`]. + + Args: + waveform (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + The final audio waveform predicted by the model. + waveform_lengths (`torch.IntTensor` of shape `(batch_size,)`, *optional*): + The length in samples of each element in the `waveform` batch. + sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + The generated translated sequences. This is the output of the text-to-text or the speech-to-text models. + The second dimension (sequence_length) is either equal to `max_length` or shorter if all batches finished + early due to the `eos_token_id`. + unit_sequences (`torch.LongTensor` of shape `(batch_size, unit_sequence_length)`, *optional*): + The generated translated unit sequences. This is the output of the text-to-units model. The second + dimension (unit_sequence_length) is either equal to `t2u_max_length` or shorter if all batches finished + early due to the `t2u_eos_token_id`. + """ + + waveform: Optional[torch.FloatTensor] = None + waveform_lengths: Optional[torch.IntTensor] = None + sequences: Optional[Tuple[torch.FloatTensor]] = None + unit_sequences: Optional[Tuple[torch.FloatTensor]] = None + + +SEAMLESS_M4T_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use + it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`~SeamlessM4TConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +SEAMLESS_M4T_INPUTS_DOCSTRING_FIRST_PART = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`SeamlessM4TTokenizer`] or [`SeamlessM4TProcessor`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_banks)`): + Input audio features. This should be returnes by the [`SeamlessM4TFeatureExtractor`] class or the + [`SeamlessM4TProcessor`] class. See [`SeamlessM4TFeatureExtractor.__call__`] for details. + """ + +SEAMLESS_M4T_INPUTS_DOCSTRING_TEXT_PART = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`SeamlessM4TTokenizer`] or [`SeamlessM4TProcessor`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + """ + +SEAMLESS_M4T_INPUTS_DOCSTRING_SPEECH_PART = r""" + Args: + input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_banks)`): + Input audio features. This should be returnes by the [`SeamlessM4TFeatureExtractor`] class or the + [`SeamlessM4TProcessor`] class. See [`SeamlessM4TFeatureExtractor.__call__`] for details. + """ + +SEAMLESS_M4T_INPUTS_DOCSTRING_LAST_PART = r""" + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + Bart uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values` + is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). + + For translation and summarization training, `decoder_input_ids` should be provided. If no + `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right + for denoising pre-training following the paper. + decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + + If you want to change padding behavior, you should read [`modeling_bart._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): + Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of + hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape`(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded + representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be + input (see `past_key_values`). This is useful if you want more control over how to convert + `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. + + If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value + of `inputs_embeds`. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +M4T_MODEL_INPUTS_DOCSTRING = SEAMLESS_M4T_INPUTS_DOCSTRING_FIRST_PART + SEAMLESS_M4T_INPUTS_DOCSTRING_LAST_PART + +M4T_TEXT_INPUTS_DOCSTRING = SEAMLESS_M4T_INPUTS_DOCSTRING_TEXT_PART + SEAMLESS_M4T_INPUTS_DOCSTRING_LAST_PART + +M4T_SPEECH_INPUTS_DOCSTRING = SEAMLESS_M4T_INPUTS_DOCSTRING_SPEECH_PART + SEAMLESS_M4T_INPUTS_DOCSTRING_LAST_PART + + +############ UTILS ################ + + +# Copied from transformers.models.roberta.modeling_roberta.create_position_ids_from_input_ids +def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0): + """ + Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols + are ignored. This is modified from fairseq's `utils.make_positions`. + + Args: + x: torch.Tensor x: + + Returns: torch.Tensor + """ + # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. + mask = input_ids.ne(padding_idx).int() + incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask + return incremental_indices.long() + padding_idx + + +# Copied from transformers.models.bart.modeling_bart.shift_tokens_right +def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): + """ + Shift input ids one token to the right. + """ + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() + shifted_input_ids[:, 0] = decoder_start_token_id + + if pad_token_id is None: + raise ValueError("self.model.config.pad_token_id has to be defined.") + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +def _compute_new_attention_mask(hidden_states: torch.Tensor, seq_lens: torch.Tensor): + """ + Computes an attention mask of the form `(batch, seq_len)` with an attention for each element in the batch that + stops at the corresponding element in `seq_lens`. + + Args: + hidden_states (`torch.FloatTensor` of shape `(batch, seq_len, *)`): + The sequences to mask, where `*` is any number of sequence-specific dimensions including none. + seq_lens (`torch.Tensor` of shape `(batch)`: + Each element represents the length of the sequence at the same index in `hidden_states` + + Returns: + `torch.FloatTensor`: The float attention mask of shape `(batch, seq_len)` + """ + batch_size, mask_seq_len = hidden_states.shape[:2] + + indices = torch.arange(mask_seq_len, device=seq_lens.device).expand(batch_size, -1) + + bool_mask = indices >= seq_lens.unsqueeze(1).expand(-1, mask_seq_len) + + mask = hidden_states.new_ones((batch_size, mask_seq_len)) + + mask = mask.masked_fill(bool_mask, 0) + + return mask + + +def format_speech_generation_kwargs(kwargs): + """ + Format kwargs for SeamlessM4T models that generate speech, attribute kwargs to either the text generation or the + speech generation models. + + Args: + kwargs (`dict`)`: + Keyword arguments are of two types: + + - Without a prefix, they will be entered as `**kwargs` for the `generate` method of each sub-model, + except for `decoder_input_ids` which will only be passed through the text components. + - With a *text_* or *speech_* prefix, they will be input for the `generate` method of the + text model and speech model respectively. It has the priority over the keywords without a prefix. + + This means you can, for example, specify a generation strategy for one generation but not for the + other. + """ + # attribute kwargs to models + kwargs_text = {} + kwargs_speech = {} + for key, value in kwargs.items(): + if key.startswith("text_"): + key = key[len("text_") :] + kwargs_text[key] = value + elif key.startswith("speech_"): + key = key[len("speech_") :] + kwargs_speech[key] = value + elif key == "generation_config": + kwargs_text[key] = value + else: + # If the key is already in a specific config, then it's been set with a + # submodules specific value and we don't override + if key not in kwargs_text: + kwargs_text[key] = value + if key not in kwargs_speech: + kwargs_speech[key] = value + return kwargs_text, kwargs_speech + + +############ SPEECH ENCODER related code ################ + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PositionalConvEmbedding with Wav2Vec2->SeamlessM4TConformer, feat_extract_activation->speech_encoder_hidden_act +class SeamlessM4TConformerPositionalConvEmbedding(nn.Module): + def __init__(self, config): + super().__init__() + self.conv = nn.Conv1d( + config.hidden_size, + config.hidden_size, + kernel_size=config.num_conv_pos_embeddings, + padding=config.num_conv_pos_embeddings // 2, + groups=config.num_conv_pos_embedding_groups, + ) + + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + + if is_deepspeed_zero3_enabled(): + import deepspeed + + with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0): + self.conv = weight_norm(self.conv, name="weight", dim=2) + if hasattr(self.conv, "parametrizations"): + weight_g = self.conv.parametrizations.weight.original0 + weight_v = self.conv.parametrizations.weight.original1 + else: + weight_g = self.conv.weight_g + weight_v = self.conv.weight_v + deepspeed.zero.register_external_parameter(self, weight_v) + deepspeed.zero.register_external_parameter(self, weight_g) + else: + self.conv = weight_norm(self.conv, name="weight", dim=2) + + self.padding = SeamlessM4TConformerSamePadLayer(config.num_conv_pos_embeddings) + self.activation = ACT2FN[config.speech_encoder_hidden_act] + + def forward(self, hidden_states): + hidden_states = hidden_states.transpose(1, 2) + + hidden_states = self.conv(hidden_states) + hidden_states = self.padding(hidden_states) + hidden_states = self.activation(hidden_states) + + hidden_states = hidden_states.transpose(1, 2) + return hidden_states + + +# Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerRotaryPositionalEmbedding with Wav2Vec2->SeamlessM4T, num_attention_heads->speech_encoder_attention_heads +class SeamlessM4TConformerRotaryPositionalEmbedding(nn.Module): + """Rotary positional embedding + Reference : https://blog.eleuther.ai/rotary-embeddings/ Paper: https://arxiv.org/pdf/2104.09864.pdf + """ + + def __init__(self, config): + super().__init__() + dim = config.hidden_size // config.speech_encoder_attention_heads + base = config.rotary_embedding_base + + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim)) + self.register_buffer("inv_freq", inv_freq) + self.cached_sequence_length = None + self.cached_rotary_positional_embedding = None + + def forward(self, hidden_states): + sequence_length = hidden_states.shape[1] + + if sequence_length == self.cached_sequence_length and self.cached_rotary_positional_embedding is not None: + return self.cached_rotary_positional_embedding + + self.cached_sequence_length = sequence_length + # Embeddings are computed in the dtype of the inv_freq constant + time_stamps = torch.arange(sequence_length).type_as(self.inv_freq) + freqs = torch.einsum("i,j->ij", time_stamps, self.inv_freq) + embeddings = torch.cat((freqs, freqs), dim=-1) + + cos_embeddings = embeddings.cos()[:, None, None, :] + sin_embeddings = embeddings.sin()[:, None, None, :] + # Computed embeddings are cast to the dtype of the hidden state inputs + self.cached_rotary_positional_embedding = torch.stack([cos_embeddings, sin_embeddings]).type_as(hidden_states) + return self.cached_rotary_positional_embedding + + +# Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerRelPositionalEmbedding with Wav2Vec2->SeamlessM4T +class SeamlessM4TConformerRelPositionalEmbedding(nn.Module): + """Relative positional encoding module.""" + + def __init__(self, config): + super().__init__() + self.max_len = config.max_source_positions + self.d_model = config.hidden_size + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, self.max_len)) + + def extend_pe(self, x): + # Reset the positional encodings + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(1) >= x.size(1) * 2 - 1: + if self.pe.dtype != x.dtype or self.pe.device != x.device: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + # Suppose `i` is the position of query vector and `j` is the + # position of key vector. We use positive relative positions when keys + # are to the left (i>j) and negative relative positions otherwise (iSeamlessM4T +class SeamlessM4TConformerSamePadLayer(nn.Module): + def __init__(self, num_conv_pos_embeddings): + super().__init__() + self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0 + + def forward(self, hidden_states): + if self.num_pad_remove > 0: + hidden_states = hidden_states[:, :, : -self.num_pad_remove] + return hidden_states + + +class SeamlessM4TConformerFeatureProjection(nn.Module): + def __init__(self, config): + super().__init__() + self.layer_norm = nn.LayerNorm(config.feature_projection_input_dim, eps=config.layer_norm_eps) + self.projection = nn.Linear(config.feature_projection_input_dim, config.hidden_size) + self.dropout = nn.Dropout(config.speech_encoder_dropout) + + def forward(self, hidden_states): + # non-projected hidden states are needed for quantization + norm_hidden_states = self.layer_norm(hidden_states) + hidden_states = self.projection(norm_hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +class SeamlessM4TConformerFeedForward(nn.Module): + def __init__(self, config, act_fn=None, dropout=None): + super().__init__() + dropout = dropout if dropout is not None else config.speech_encoder_dropout + act_fn = act_fn if act_fn is not None else config.speech_encoder_hidden_act + + self.intermediate_dropout = nn.Dropout(dropout) + self.intermediate_dense = nn.Linear(config.hidden_size, config.speech_encoder_intermediate_size) + self.intermediate_act_fn = ACT2FN[act_fn] if isinstance(act_fn, str) else act_fn + + self.output_dense = nn.Linear(config.speech_encoder_intermediate_size, config.hidden_size) + self.output_dropout = nn.Dropout(dropout) + + def forward(self, hidden_states): + hidden_states = self.intermediate_dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + hidden_states = self.intermediate_dropout(hidden_states) + + hidden_states = self.output_dense(hidden_states) + hidden_states = self.output_dropout(hidden_states) + return hidden_states + + +class SeamlessM4TConformerConvolutionModule(nn.Module): + """Convolution block used in the conformer block""" + + def __init__(self, config): + super().__init__() + if (config.conv_depthwise_kernel_size - 1) % 2 == 1: + raise ValueError("`config.conv_depthwise_kernel_size` should be a odd number for 'SAME' padding") + self.layer_norm = nn.LayerNorm(config.hidden_size) + self.pointwise_conv1 = nn.Conv1d( + config.hidden_size, + 2 * config.hidden_size, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ) + self.glu = nn.GLU(dim=1) + self.depthwise_conv = nn.Conv1d( + config.hidden_size, + config.hidden_size, + config.conv_depthwise_kernel_size, + stride=1, + padding="same", + groups=config.hidden_size, + bias=False, + ) + self.batch_norm = nn.BatchNorm1d(config.hidden_size) + self.activation = ACT2FN[config.speech_encoder_hidden_act] + self.pointwise_conv2 = nn.Conv1d( + config.hidden_size, + config.hidden_size, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ) + self.dropout = nn.Dropout(config.speech_encoder_dropout) + + def forward(self, hidden_states, attention_mask=None): + hidden_states = self.layer_norm(hidden_states) + + # Ensure that we do not leak padded positions in depthwise convolution. + # Put 0 where necessary + if attention_mask is not None: + hidden_states = hidden_states.masked_fill(~attention_mask.bool().unsqueeze(-1), 0.0) + + # exchange the temporal dimension and the feature dimension + hidden_states = hidden_states.transpose(1, 2) + + # GLU mechanism + # => (batch, 2*channel, dim) + hidden_states = self.pointwise_conv1(hidden_states) + # => (batch, channel, dim) + hidden_states = self.glu(hidden_states) + + # 1D Depthwise Conv + hidden_states = self.depthwise_conv(hidden_states) + hidden_states = self.batch_norm(hidden_states) + hidden_states = self.activation(hidden_states) + + hidden_states = self.pointwise_conv2(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = hidden_states.transpose(1, 2) + return hidden_states + + +class SeamlessM4TConformerSelfAttention(nn.Module): + """Construct a SeamlessM4TConformerSelfAttention object. + Can be enhanced with rotary or relative position embeddings. + """ + + def __init__(self, config, use_position_embeddings=True): + super().__init__() + + self.head_size = config.hidden_size // config.speech_encoder_attention_heads + self.num_heads = config.speech_encoder_attention_heads + self.position_embeddings_type = config.position_embeddings_type if use_position_embeddings else None + + self.linear_q = nn.Linear(config.hidden_size, config.hidden_size) + self.linear_k = nn.Linear(config.hidden_size, config.hidden_size) + self.linear_v = nn.Linear(config.hidden_size, config.hidden_size) + self.linear_out = nn.Linear(config.hidden_size, config.hidden_size) + + self.dropout = nn.Dropout(p=config.speech_encoder_dropout) + + if self.position_embeddings_type == "relative": + # linear transformation for positional encoding + self.linear_pos = nn.Linear(config.hidden_size, config.hidden_size, bias=False) + # these two learnable bias are used in matrix c and matrix d + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + self.pos_bias_u = nn.Parameter(torch.zeros(self.num_heads, self.head_size)) + self.pos_bias_v = nn.Parameter(torch.zeros(self.num_heads, self.head_size)) + + # Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerSelfAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + relative_position_embeddings: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # self-attention mechanism + batch_size, sequence_length, hidden_size = hidden_states.size() + + # make sure query/key states can be != value states + query_key_states = hidden_states + value_states = hidden_states + + if self.position_embeddings_type == "rotary": + if relative_position_embeddings is None: + raise ValueError( + "`relative_position_embeddings` has to be defined when `self.position_embeddings_type == 'rotary'" + ) + query_key_states = self._apply_rotary_embedding(query_key_states, relative_position_embeddings) + + # project query_key_states and value_states + query = self.linear_q(query_key_states).view(batch_size, -1, self.num_heads, self.head_size) + key = self.linear_k(query_key_states).view(batch_size, -1, self.num_heads, self.head_size) + value = self.linear_v(value_states).view(batch_size, -1, self.num_heads, self.head_size) + + # => (batch, head, time1, d_k) + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + if self.position_embeddings_type == "relative": + if relative_position_embeddings is None: + raise ValueError( + "`relative_position_embeddings` has to be defined when `self.position_embeddings_type ==" + " 'relative'" + ) + # apply relative_position_embeddings to qk scores + # as proposed in Transformer_XL: https://arxiv.org/abs/1901.02860 + scores = self._apply_relative_embeddings( + query=query, key=key, relative_position_embeddings=relative_position_embeddings + ) + else: + scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_size) + + # apply attention_mask if necessary + if attention_mask is not None: + scores = scores + attention_mask + + # => (batch, head, time1, time2) + probs = torch.softmax(scores, dim=-1) + probs = self.dropout(probs) + + # => (batch, head, time1, d_k) + hidden_states = torch.matmul(probs, value) + + # => (batch, time1, hidden_size) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_size) + hidden_states = self.linear_out(hidden_states) + + return hidden_states, probs + + # Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerSelfAttention._apply_rotary_embedding + def _apply_rotary_embedding(self, hidden_states, relative_position_embeddings): + batch_size, sequence_length, hidden_size = hidden_states.size() + hidden_states = hidden_states.view(batch_size, sequence_length, self.num_heads, self.head_size) + + cos = relative_position_embeddings[0, :sequence_length, ...] + sin = relative_position_embeddings[1, :sequence_length, ...] + + # rotate hidden_states with rotary embeddings + hidden_states = hidden_states.transpose(0, 1) + rotated_states_begin = hidden_states[..., : self.head_size // 2] + rotated_states_end = hidden_states[..., self.head_size // 2 :] + rotated_states = torch.cat((-rotated_states_end, rotated_states_begin), dim=rotated_states_begin.ndim - 1) + hidden_states = (hidden_states * cos) + (rotated_states * sin) + hidden_states = hidden_states.transpose(0, 1) + + hidden_states = hidden_states.view(batch_size, sequence_length, self.num_heads * self.head_size) + + return hidden_states + + # Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerSelfAttention._apply_relative_embeddings + def _apply_relative_embeddings(self, query, key, relative_position_embeddings): + # 1. project positional embeddings + # => (batch, head, 2*time1-1, d_k) + proj_relative_position_embeddings = self.linear_pos(relative_position_embeddings) + proj_relative_position_embeddings = proj_relative_position_embeddings.view( + relative_position_embeddings.size(0), -1, self.num_heads, self.head_size + ) + proj_relative_position_embeddings = proj_relative_position_embeddings.transpose(1, 2) + proj_relative_position_embeddings = proj_relative_position_embeddings.transpose(2, 3) + + # 2. Add bias to query + # => (batch, head, time1, d_k) + query = query.transpose(1, 2) + q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2) + q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2) + + # 3. attention score: first compute matrix a and matrix c + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + # => (batch, head, time1, time2) + scores_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1)) + + # 4. then compute matrix b and matrix d + # => (batch, head, time1, 2*time1-1) + scores_bd = torch.matmul(q_with_bias_v, proj_relative_position_embeddings) + + # 5. shift matrix b and matrix d + zero_pad = torch.zeros((*scores_bd.size()[:3], 1), device=scores_bd.device, dtype=scores_bd.dtype) + scores_bd_padded = torch.cat([zero_pad, scores_bd], dim=-1) + scores_bd_padded_shape = scores_bd.size()[:2] + (scores_bd.shape[3] + 1, scores_bd.shape[2]) + scores_bd_padded = scores_bd_padded.view(*scores_bd_padded_shape) + scores_bd = scores_bd_padded[:, :, 1:].view_as(scores_bd) + scores_bd = scores_bd[:, :, :, : scores_bd.size(-1) // 2 + 1] + + # 6. sum matrices + # => (batch, head, time1, time2) + scores = (scores_ac + scores_bd) / math.sqrt(self.head_size) + + return scores + + +class SeamlessM4TConformerEncoderLayer(nn.Module): + """Conformer block based on https://arxiv.org/abs/2005.08100.""" + + # Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerEncoderLayer.__init__ with Wav2Vec2->SeamlessM4T, attention_dropout->speech_encoder_dropout, torch.nn->nn + def __init__(self, config): + super().__init__() + embed_dim = config.hidden_size + dropout = config.speech_encoder_dropout + + # Feed-forward 1 + self.ffn1_layer_norm = nn.LayerNorm(embed_dim) + self.ffn1 = SeamlessM4TConformerFeedForward(config) + + # Self-Attention + self.self_attn_layer_norm = nn.LayerNorm(embed_dim) + self.self_attn_dropout = nn.Dropout(dropout) + self.self_attn = SeamlessM4TConformerSelfAttention(config) + + # Conformer Convolution + self.conv_module = SeamlessM4TConformerConvolutionModule(config) + + # Feed-forward 2 + self.ffn2_layer_norm = nn.LayerNorm(embed_dim) + self.ffn2 = SeamlessM4TConformerFeedForward(config) + self.final_layer_norm = nn.LayerNorm(embed_dim) + + def forward( + self, + hidden_states, + attention_mask: Optional[torch.Tensor] = None, + relative_position_embeddings: Optional[torch.Tensor] = None, + output_attentions: bool = False, + conv_attention_mask: Optional[torch.Tensor] = None, + ): + hidden_states = hidden_states + + # 1. Feed-Forward 1 layer + residual = hidden_states + hidden_states = self.ffn1_layer_norm(hidden_states) + hidden_states = self.ffn1(hidden_states) + hidden_states = hidden_states * 0.5 + residual + residual = hidden_states + + # 2. Self-Attention layer + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states, attn_weigts = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + relative_position_embeddings=relative_position_embeddings, + output_attentions=output_attentions, + ) + hidden_states = self.self_attn_dropout(hidden_states) + hidden_states = hidden_states + residual + + # 3. Convolutional Layer + residual = hidden_states + hidden_states = self.conv_module(hidden_states, attention_mask=conv_attention_mask) + hidden_states = residual + hidden_states + + # 4. Feed-Forward 2 Layer + residual = hidden_states + hidden_states = self.ffn2_layer_norm(hidden_states) + hidden_states = self.ffn2(hidden_states) + hidden_states = hidden_states * 0.5 + residual + hidden_states = self.final_layer_norm(hidden_states) + + return hidden_states, attn_weigts + + +class SeamlessM4TConformerEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + + if config.position_embeddings_type == "relative": + self.embed_positions = SeamlessM4TConformerRelPositionalEmbedding(config) + elif config.position_embeddings_type == "rotary": + self.embed_positions = SeamlessM4TConformerRotaryPositionalEmbedding(config) + else: + self.embed_positions = None + + self.dropout = nn.Dropout(config.speech_encoder_dropout) + self.layers = nn.ModuleList( + [SeamlessM4TConformerEncoderLayer(config) for _ in range(config.speech_encoder_layers)] + ) + + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + attention_mask=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + conv_attention_mask = attention_mask + if attention_mask is not None: + # make sure padded tokens output 0 + hidden_states = hidden_states.masked_fill(~attention_mask.bool().unsqueeze(-1), 0.0) + # extend attention_mask + attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) + attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min + attention_mask = attention_mask.expand( + attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] + ) + + hidden_states = self.dropout(hidden_states) + + if self.embed_positions is not None: + relative_position_embeddings = self.embed_positions(hidden_states) + else: + relative_position_embeddings = None + + synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) + + for i, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = torch.rand([]) + + skip_the_layer = ( + True if self.training and (dropout_probability < self.config.speech_encoder_layerdrop) else False + ) + if not skip_the_layer or synced_gpus: + # under fsdp or deepspeed zero3 all gpus must run in sync + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + relative_position_embeddings, + output_attentions, + conv_attention_mask, + ) + else: + layer_outputs = layer( + hidden_states, + attention_mask=attention_mask, + relative_position_embeddings=relative_position_embeddings, + output_attentions=output_attentions, + conv_attention_mask=conv_attention_mask, + ) + hidden_states = layer_outputs[0] + + if skip_the_layer: + layer_outputs = (None, None) + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + hidden_states = self.layer_norm(hidden_states) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class SeamlessM4TConformerAdapterLayer(nn.Module): + def __init__(self, config): + super().__init__() + embed_dim = config.hidden_size + dropout = config.adaptor_dropout + + self.kernel_size = config.adaptor_kernel_size + self.stride = config.adaptor_stride + + # 1. residual convolution + self.residual_layer_norm = nn.LayerNorm(embed_dim) + self.residual_conv = nn.Conv1d( + embed_dim, + 2 * embed_dim, + self.kernel_size, + stride=self.stride, + padding=self.stride // 2, + ) + self.activation = nn.GLU(dim=1) + + # Self-Attention + self.self_attn_layer_norm = nn.LayerNorm(embed_dim) + self.self_attn_conv = nn.Conv1d( + embed_dim, + 2 * embed_dim, + self.kernel_size, + stride=self.stride, + padding=self.stride // 2, + ) + self.self_attn = SeamlessM4TConformerSelfAttention(config, use_position_embeddings=False) + self.self_attn_dropout = nn.Dropout(dropout) + + # Feed-forward + self.ffn_layer_norm = nn.LayerNorm(embed_dim) + self.ffn = SeamlessM4TConformerFeedForward(config, act_fn="relu", dropout=dropout) + + def _compute_sub_sample_lengths_from_attention_mask(self, attention_mask): + pad = self.kernel_size // 2 + seq_lens = attention_mask.size(1) - (1 - attention_mask.int()).sum(1) + + seq_lens = ((seq_lens + 2 * pad - self.kernel_size) / self.stride) + 1 + + return seq_lens.floor() + + def forward( + self, + hidden_states, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ): + residual = self.residual_layer_norm(hidden_states) + + # Apply pooling to the residual to match the sequence length of the + # multi-head attention output. + # (batch, seq_len, feature_dim) -> (batch, feature_dim, seq_len) + residual = residual.transpose(1, 2) + residual = self.residual_conv(residual) + residual = self.activation(residual) + # (batch, feature_dim, seq_len) -> (batch, seq_len, feature_dim) + residual = residual.transpose(1, 2) + + hidden_states = self.self_attn_layer_norm(hidden_states) + # Apply pooling before feeding to the multihead-attention layer. + # (batch, seq_len, feature_dim) -> (batch, feature_dim, seq_len) + hidden_states = hidden_states.transpose(1, 2) + hidden_states = self.self_attn_conv(hidden_states) + hidden_states = self.activation(hidden_states) + # (batch, feature_dim, seq_len) -> (batch, seq_len, feature_dim) + hidden_states = hidden_states.transpose(1, 2) + + if attention_mask is not None: + sub_sampled_lengths = self._compute_sub_sample_lengths_from_attention_mask(attention_mask).to( + hidden_states.device + ) + attention_mask = _compute_new_attention_mask(hidden_states=hidden_states, seq_lens=sub_sampled_lengths) + attention_mask = _prepare_4d_attention_mask( + attention_mask, + hidden_states.dtype, + ) + + # The rest of the computation is identical to a vanilla Transformer + # encoder layer. + hidden_states, attn_weigths = self.self_attn( + hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_states = self.self_attn_dropout(hidden_states) + hidden_states = hidden_states + residual + + residual = hidden_states + + hidden_states = self.ffn_layer_norm(hidden_states) + hidden_states = self.ffn(hidden_states) + residual + + return hidden_states + + +class SeamlessM4TConformerAdapter(nn.Module): + def __init__(self, config): + super().__init__() + + self.layers = nn.ModuleList(SeamlessM4TConformerAdapterLayer(config) for _ in range(config.num_adapter_layers)) + + def forward(self, hidden_states, attention_mask): + # down project hidden_states if necessary + + for layer in self.layers: + hidden_states = layer(hidden_states, attention_mask) + + return hidden_states + + +############ TEXT / UNITS related code ################ + + +# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100ScaledWordEmbedding with M2M100->SeamlessM4T +class SeamlessM4TScaledWordEmbedding(nn.Embedding): + """ + This module overrides nn.Embeddings' forward by multiplying with embeddings scale. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0): + super().__init__(num_embeddings, embedding_dim, padding_idx) + self.embed_scale = embed_scale + + def forward(self, input_ids: torch.Tensor): + return super().forward(input_ids) * self.embed_scale + + +# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding +class SeamlessM4TSinusoidalPositionalEmbedding(nn.Module): + """This module produces sinusoidal positional embeddings of any length.""" + + def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None): + super().__init__() + self.offset = 2 + self.embedding_dim = embedding_dim + self.padding_idx = padding_idx + self.make_weights(num_positions + self.offset, embedding_dim, padding_idx) + + def make_weights(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None): + emb_weights = self.get_embedding(num_embeddings, embedding_dim, padding_idx) + if hasattr(self, "weights"): + # in forward put the weights on the correct dtype and device of the param + emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device) + + self.register_buffer("weights", emb_weights, persistent=False) + + @staticmethod + def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None): + """ + Build sinusoidal embeddings. + + This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of + "Attention Is All You Need". + """ + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb) + emb = torch.arange(num_embeddings, dtype=torch.int64).float().unsqueeze(1) * emb.unsqueeze(0) + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1) + if embedding_dim % 2 == 1: + # zero pad + emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) + if padding_idx is not None: + emb[padding_idx, :] = 0 + + return emb.to(torch.get_default_dtype()) + + @torch.no_grad() + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + past_key_values_length: int = 0, + ): + if input_ids is not None: + bsz, seq_len = input_ids.size() + # Create the position ids from the input token ids. Any padded tokens remain padded. + position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length).to( + input_ids.device + ) + else: + bsz, seq_len = inputs_embeds.size()[:-1] + position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds, past_key_values_length) + + # expand embeddings if needed + max_pos = self.padding_idx + 1 + seq_len + past_key_values_length + if max_pos > self.weights.size(0): + self.make_weights(max_pos + self.offset, self.embedding_dim, self.padding_idx) + + return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, self.weights.shape[-1]).detach() + + def create_position_ids_from_inputs_embeds(self, inputs_embeds, past_key_values_length): + """ + We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids. + + Args: + inputs_embeds: torch.Tensor + + Returns: torch.Tensor + """ + input_shape = inputs_embeds.size()[:-1] + sequence_length = input_shape[1] + + position_ids = torch.arange( + self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device + ) + return position_ids.unsqueeze(0).expand(input_shape).contiguous() + past_key_values_length + + +class SeamlessM4TAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + # Copied from transformers.models.bart.modeling_bart.BartAttention.__init__ with Bart->SeamlessM4T + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + is_causal: bool = False, + config: Optional[SeamlessM4TConfig] = None, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + self.config = config + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + self.is_causal = is_causal + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if encoder_hidden_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = encoder_hidden_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + # `past_key_value[0].shape[2] == encoder_hidden_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `encoder_hidden_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == encoder_hidden_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(encoder_hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(encoder_hidden_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.reshape(*proj_shape) + value_states = value_states.reshape(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value + + +# Copied from transformers.models.nllb_moe.modeling_nllb_moe.NllbMoeDenseActDense with NllbMoe->SeamlessM4T,DenseActDense->FeedForwardNetwork, d_model->hidden_size +class SeamlessM4TFeedForwardNetwork(nn.Module): + def __init__(self, config: SeamlessM4TConfig, ffn_dim: int): + super().__init__() + self.fc1 = nn.Linear(config.hidden_size, ffn_dim) + self.fc2 = nn.Linear(ffn_dim, config.hidden_size) + self.dropout = nn.Dropout(config.activation_dropout) + self.act = ACT2FN[config.activation_function] + + def forward(self, hidden_states): + hidden_states = self.fc1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.dropout(hidden_states) + if ( + isinstance(self.fc2.weight, torch.Tensor) + and hidden_states.dtype != self.fc2.weight.dtype + and (self.fc2.weight.dtype != torch.int8 and self.fc2.weight.dtype != torch.uint8) + ): + hidden_states = hidden_states.to(self.fc2.weight.dtype) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class SeamlessM4TEncoderLayer(nn.Module): + def __init__(self, config: SeamlessM4TConfig, encoder_ffn_dim=None, encoder_attention_heads=None): + super().__init__() + encoder_ffn_dim = config.encoder_ffn_dim if encoder_ffn_dim is None else encoder_ffn_dim + encoder_attention_heads = ( + config.encoder_attention_heads if encoder_attention_heads is None else encoder_attention_heads + ) + + self.embed_dim = config.hidden_size + self.self_attn = SeamlessM4TAttention( + embed_dim=self.embed_dim, + num_heads=encoder_attention_heads, + dropout=config.attention_dropout, + ) + self.attn_dropout = nn.Dropout(config.dropout) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + + self.ffn = SeamlessM4TFeedForwardNetwork(config, ffn_dim=encoder_ffn_dim) + + self.ffn_layer_norm = nn.LayerNorm(config.hidden_size) + self.ffn_dropout = nn.Dropout(config.activation_dropout) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + output_attentions: bool = False, + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): + input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): + attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very + large negative values. + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states, attn_weights, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_states = self.attn_dropout(hidden_states) + hidden_states = residual + hidden_states + + residual = hidden_states + + hidden_states = self.ffn_layer_norm(hidden_states) + + hidden_states = self.ffn(hidden_states) + hidden_states = self.ffn_dropout(hidden_states) + + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class SeamlessM4TDecoderLayer(nn.Module): + def __init__(self, config: SeamlessM4TConfig, decoder_ffn_dim=None, decoder_attention_heads=None): + super().__init__() + decoder_ffn_dim = config.decoder_ffn_dim if decoder_ffn_dim is None else decoder_ffn_dim + decoder_attention_heads = ( + config.decoder_attention_heads if decoder_attention_heads is None else decoder_attention_heads + ) + + self.embed_dim = config.hidden_size + self.self_attn = SeamlessM4TAttention( + embed_dim=self.embed_dim, + num_heads=decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + ) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.attn_dropout = nn.Dropout(config.dropout) + + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.cross_attention = SeamlessM4TAttention( + self.embed_dim, decoder_attention_heads, config.attention_dropout, is_decoder=True + ) + self.cross_attention_layer_norm = nn.LayerNorm(self.embed_dim) + + self.ffn = SeamlessM4TFeedForwardNetwork(config, ffn_dim=decoder_ffn_dim) + + self.ffn_layer_norm = nn.LayerNorm(config.hidden_size) + self.ffn_dropout = nn.Dropout(config.activation_dropout) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = True, + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): + input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): + attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very + large negative values. + encoder_hidden_states (`torch.FloatTensor`): + cross attention input to the layer of shape `(batch, seq_len, embed_dim)` + encoder_attention_mask (`torch.FloatTensor`): + encoder attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by + very large negative values. + past_key_value (`Tuple(torch.FloatTensor)`): + cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + # add present self-attn cache to positions 1,2 of present_key_value tuple + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=self_attn_past_key_value, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_states = self.attn_dropout(hidden_states) + hidden_states = residual + hidden_states + + # Cross-Attention Block + cross_attn_present_key_value = None + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + hidden_states = self.cross_attention_layer_norm(hidden_states) + + # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + + hidden_states, cross_attn_weights, cross_attn_present_key_value = self.cross_attention( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + past_key_value=cross_attn_past_key_value, + attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + ) + hidden_states = self.attn_dropout(hidden_states) + hidden_states = residual + hidden_states + + # add cross-attn to positions 3,4 of present_key_value tuple + present_key_value += cross_attn_present_key_value + + # Fully Connected + residual = hidden_states + + hidden_states = self.ffn_layer_norm(hidden_states) + + hidden_states = self.ffn(hidden_states) + hidden_states = self.ffn_dropout(hidden_states) + + hidden_states = residual + hidden_states + + outputs = (hidden_states, present_key_value) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + return outputs + + +############ SUB-MODELS related code ################ + + +class SeamlessM4TPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = SeamlessM4TConfig + base_model_prefix = "seamless_m4t" + supports_gradient_checkpointing = True + _no_split_modules = ["SeamlessM4TEncoderLayer", "SeamlessM4TDecoderLayer", "SeamlessM4TConformerEncoderLayer"] + + def _init_weights(self, module): + """Initialize the weights""" + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, SeamlessM4TConformerSelfAttention): + if hasattr(module, "pos_bias_u"): + nn.init.xavier_uniform_(module.pos_bias_u) + if hasattr(module, "pos_bias_v"): + nn.init.xavier_uniform_(module.pos_bias_v) + elif isinstance(module, SeamlessM4TConformerPositionalConvEmbedding): + nn.init.normal_( + module.conv.weight, + mean=0, + std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)), + ) + nn.init.constant_(module.conv.bias, 0) + elif isinstance(module, SeamlessM4TConformerFeatureProjection): + k = math.sqrt(1 / module.projection.in_features) + nn.init.uniform_(module.projection.weight, a=-k, b=k) + nn.init.uniform_(module.projection.bias, a=-k, b=k) + elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, nn.Conv1d): + nn.init.kaiming_normal_(module.weight) + if module.bias is not None: + k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) + nn.init.uniform_(module.bias, a=-k, b=k) + + def _compute_sub_sample_lengths_from_attention_mask(self, attention_mask): + kernel_size, stride = self.config.adaptor_kernel_size, self.config.adaptor_stride + pad = kernel_size // 2 + seq_lens = attention_mask.size(1) - (1 - attention_mask.int()).sum(1) + + seq_lens = ((seq_lens + 2 * pad - kernel_size) / stride) + 1 + + return seq_lens.floor() + + def compute_last_hidden_states_per_sample( + self, + hidden_states: Tuple[Tuple[torch.Tensor]], + beam_indices: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Computes the last hidden states. + + Parameters: + hidden_states (`Tuple[Tuple[torch.Tensor]]`): + The generated hidden states. Tuple (one element for each generated token) of tuples (one element for + each layer of the decoder) of torch.FloatTensor of shape (batch_size*num_beams*num_return_sequences, + generated_length, hidden_size). + beam_indices (`torch.LongTensor`, *optional*): + Beam indices of generated token id at each generation step. `torch.LongTensor` of shape + `(batch_size*num_return_sequences, sequence_length)`. Only required if a `num_beams>1` at + generate-time. + + Return: + `torch.Tensor`: A `torch.Tensor` of shape `(batch_size*num_return_sequences, sequence_length, hidden_size)` + containing + the last hidden states. + ```""" + # 1. First, let's compute last_hidden_states from hidden_states. + # For each generation step, takes the hidden state from the last layer. + # shape: (batch_size*vocab_size*num_return_sequences, # generation_steps, hidden_dim) + last_hidden_states = torch.concat([hidden_states[-1] for hidden_states in hidden_states], dim=1) + + # 2. In absence of `beam_indices`, we can assume that we come from e.g. greedy search, which is equivalent + # to a beam search approach were the first (and only) beam is always selected + # in that case, return directly last_hidden_states + if beam_indices is None: + return last_hidden_states + + # 3. cut beam_indices to longest beam length + beam_indices_mask = beam_indices < 0 + max_beam_length = (1 - beam_indices_mask.long()).sum(-1).max() + beam_indices = beam_indices.clone()[:, :max_beam_length] + beam_indices_mask = beam_indices_mask[:, :max_beam_length] + + # 4. Set indices of beams that finished early to 0; such indices will be masked correctly afterwards anyways + beam_indices[beam_indices_mask] = 0 + + # 5. expand beam_indices to last_hidden_states dim + beam_indices = beam_indices.unsqueeze(-1) + beam_indices = beam_indices.expand(-1, -1, last_hidden_states.shape[-1]) + + # 6. select the right candidate for each beam + # in other words, new_last_hidden_states[i,j,k] = last_hidden_states[beam_indices[i,j,k], j, k] for all i, j, k + last_hidden_states = torch.gather(last_hidden_states, 0, beam_indices) + + return last_hidden_states + + +@add_start_docstrings( + """Transformer speech encoder consisting of *config.speech_encoder_layers* conformer self attention layers. + Each layer is a [`SeamlessM4TConformerEncoderLayer`].""", + SEAMLESS_M4T_START_DOCSTRING, +) +class SeamlessM4TSpeechEncoder(SeamlessM4TPreTrainedModel): + main_input_name = "input_features" + + def __init__(self, config: SeamlessM4TConfig): + super().__init__(config) + + self.feature_projection = SeamlessM4TConformerFeatureProjection(config) + self.encoder = SeamlessM4TConformerEncoder(config) + self.intermediate_ffn = SeamlessM4TConformerFeedForward(config, act_fn="relu", dropout=0.0) + self.adapter = SeamlessM4TConformerAdapter(config) if config.add_adapter else None + self.inner_layer_norm = nn.LayerNorm(config.hidden_size) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_features: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple, Wav2Vec2BaseModelOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_features is None: + raise ValueError( + """Both `input_features` and `inputs_embeds` are `None` in `SeamlessM4TSpeechEncoder.forward`. + Make sure one of them is not `None`.""" + ) + + hidden_states = self.feature_projection(input_features) + + encoder_outputs = self.encoder( + hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = encoder_outputs[0] + + expanded_hidden_states = self.intermediate_ffn(hidden_states) + hidden_states = hidden_states + 0.5 * expanded_hidden_states + + if self.adapter is not None: + hidden_states = self.adapter(hidden_states, attention_mask=attention_mask) + + hidden_states = self.inner_layer_norm(hidden_states) + + if not return_dict: + return (hidden_states,) + encoder_outputs[1:] + + return Wav2Vec2BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +# inspired from MBart and NllbMoe +@add_start_docstrings( + "Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a [`SeamlessM4TEncoderLayer`].", + SEAMLESS_M4T_START_DOCSTRING, + """ + embed_tokens (`nn.Embedding`, *optional*): + Input embedding + is_t2u_encoder (`bool`, *optional*, defaults to `False`): + indicates if it belongs to the text-to-units model, in which case it won't have input embeddings + """, +) +class SeamlessM4TEncoder(SeamlessM4TPreTrainedModel): + def __init__( + self, + config: SeamlessM4TConfig, + embed_tokens: Optional[nn.Embedding] = None, + is_t2u_encoder: bool = False, + ): + super().__init__(config) + + self.dropout = config.dropout + self.layerdrop = config.encoder_layerdrop + self.padding_idx = config.pad_token_id + embed_dim = config.hidden_size + + self.is_t2u_encoder = is_t2u_encoder + self.max_source_positions = config.max_position_embeddings + + if not self.is_t2u_encoder: + embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + + self.embed_tokens = SeamlessM4TScaledWordEmbedding( + config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale + ) + + if embed_tokens is not None: + self.embed_tokens.weight = embed_tokens.weight + + self.embed_positions = SeamlessM4TSinusoidalPositionalEmbedding( + self.max_source_positions, + embed_dim, + self.padding_idx, + ) + + layers = [] + for _ in range(config.encoder_layers): + layers.append( + SeamlessM4TEncoderLayer( + config, + encoder_attention_heads=config.encoder_attention_heads, + encoder_ffn_dim=config.encoder_ffn_dim, + ) + ) + + self.layers = nn.ModuleList(layers) + + self.layer_norm = nn.LayerNorm(config.hidden_size) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and self.is_t2u_encoder: + raise ValueError( + "You cannot pass input_ids to the encoder of the text_to_units model. Pass inputs_embeds instead." + ) + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input = input_ids + input_shape = input.shape + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input = inputs_embeds[:, :, -1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if not self.is_t2u_encoder: + embed_pos = self.embed_positions(input) + + hidden_states = inputs_embeds + embed_pos.to(inputs_embeds.device) + else: + hidden_states = inputs_embeds + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + # expand attention_mask + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + to_drop = False + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: # skip the layer + to_drop = True + + if to_drop: + layer_outputs = (None, None) + else: + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.forward, + hidden_states, + attention_mask, + output_attentions, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + hidden_states = self.layer_norm(hidden_states) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +@add_start_docstrings( + "Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`SeamlessM4TDecoderLayer`].", + SEAMLESS_M4T_START_DOCSTRING, + """ + embed_tokens (`nn.Embedding`, *optional*): + Input embedding + """, +) +class SeamlessM4TDecoder(SeamlessM4TPreTrainedModel): + def __init__( + self, + config: SeamlessM4TConfig, + embed_tokens: Optional[nn.Embedding] = None, + ): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.decoder_layerdrop + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.max_target_positions = config.max_position_embeddings + embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0 + + if embed_tokens is not None: + # if embed_tokens defined, use its shape instead + self.embed_tokens = SeamlessM4TScaledWordEmbedding( + embed_tokens.num_embeddings, embed_tokens.embedding_dim, self.padding_idx, embed_scale=embed_scale + ) + self.embed_tokens.weight = embed_tokens.weight + else: + self.embed_tokens = SeamlessM4TScaledWordEmbedding( + self.vocab_size, config.hidden_size, self.padding_idx, embed_scale=embed_scale + ) + + self.embed_positions = SeamlessM4TSinusoidalPositionalEmbedding( + self.max_target_positions, + config.hidden_size, + padding_idx=self.padding_idx, + ) + + layers = [] + for _ in range(config.decoder_layers): + layers.append( + SeamlessM4TDecoderLayer( + config, + decoder_attention_heads=config.decoder_attention_heads, + decoder_ffn_dim=config.decoder_ffn_dim, + ) + ) + self.layers = nn.ModuleList(layers) + self.layer_norm = nn.LayerNorm(config.hidden_size) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): + Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values + selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input = input_ids + input_shape = input.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + input = inputs_embeds[:, :, -1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + # embed positions + positions = self.embed_positions(input, past_key_values_length=past_key_values_length) + + hidden_states = inputs_embeds + positions.to(inputs_embeds.device) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + next_decoder_cache = () if use_cache else None + + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: + continue + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + None, + output_attentions, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[1],) + + if output_attentions: + all_self_attns += (layer_outputs[2],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[3],) + + hidden_states = self.layer_norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + +@add_start_docstrings( + "Transformer bare text-to-unit encoder-decoder. The encoder is a [`SeamlessM4TEncoder`] without embeddings and the decoder is a [`SeamlessM4TDecoder`].", + SEAMLESS_M4T_START_DOCSTRING, + """ + embed_tokens_decoder (`nn.Embedding`, *optional*): input embedding of the decoder. + """, +) +class SeamlessM4TTextToUnitModel(SeamlessM4TPreTrainedModel): + def __init__( + self, + config: SeamlessM4TConfig, + embed_tokens_decoder: Optional[nn.Embedding] = None, + ): + super().__init__(config) + + self.encoder = SeamlessM4TEncoder(config, is_t2u_encoder=True) + self.decoder = SeamlessM4TDecoder(config, embed_tokens_decoder) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + "Transformer text-to-unit encoder-decoder with a language model head. The base encoder-decoder model is a [`SeamlessM4TTextToUnit`].", + SEAMLESS_M4T_START_DOCSTRING, + """ + embed_tokens_decoder (`nn.Embedding`, *optional*): input embedding of the decoder. + """, +) +class SeamlessM4TTextToUnitForConditionalGeneration(SeamlessM4TPreTrainedModel, GenerationMixin): + _keys_to_ignore_on_load_missing = [ + "vocoder", + "speech_encoder", + "text_encoder", + "text_decoder", + ] + _tied_weights_keys = ["decoder.embed_tokens.weight", "lm_head.weight"] + + def __init__( + self, + config: SeamlessM4TConfig, + embed_tokens_decoder: Optional[nn.Embedding] = None, + ): + # update config - used principaly for bos_token_id etc. + config = copy.deepcopy(config) + for param, val in config.to_dict().items(): + if param.startswith("t2u_"): + config.__setattr__(param[4:], val) + super().__init__(config) + + self.model = SeamlessM4TTextToUnitModel(config, embed_tokens_decoder) + + self.lm_head = nn.Linear(config.hidden_size, config.t2u_vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_encoder(self): + return self.model.encoder + + def get_decoder(self): + return self.model.decoder + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_input_embeddings(self): + return self.model.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.model.decoder.embed_tokens = value + + @add_start_docstrings_to_model_forward(M4T_TEXT_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Seq2SeqLMOutput, Tuple[torch.FloatTensor]]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None: + if use_cache: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.t2u_pad_token_id, self.config.t2u_decoder_start_token_id + ) + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + encoder_outputs=encoder_outputs, + decoder_attention_mask=decoder_attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + lm_logits = self.lm_head(outputs[0]) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + labels = labels.to(lm_logits.device) + masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (lm_logits,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return Seq2SeqLMOutput( + loss=masked_lm_loss, + logits=lm_logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return shift_tokens_right(labels, self.config.t2u_pad_token_id, self.config.t2u_decoder_start_token_id) + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + # cached cross_attention states don't have to be reordered -> they are always the same + reordered_past += ( + tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], + ) + return reordered_past + + def _tie_weights(self) -> None: + if getattr(self.config, "tie_word_embeddings", True): + output_embeddings = self.get_output_embeddings() + if output_embeddings is not None: + self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings()) + + +############ VOCODER related code ################ + + +HIFIGAN_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`SeamlessM4TConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +# Copied from transformers.models.speecht5.modeling_speecht5.HifiGanResidualBlock +class HifiGanResidualBlock(nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), leaky_relu_slope=0.1): + super().__init__() + self.leaky_relu_slope = leaky_relu_slope + + self.convs1 = nn.ModuleList( + [ + nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + dilation=dilation[i], + padding=self.get_padding(kernel_size, dilation[i]), + ) + for i in range(len(dilation)) + ] + ) + self.convs2 = nn.ModuleList( + [ + nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + dilation=1, + padding=self.get_padding(kernel_size, 1), + ) + for _ in range(len(dilation)) + ] + ) + + def get_padding(self, kernel_size, dilation=1): + return (kernel_size * dilation - dilation) // 2 + + def apply_weight_norm(self): + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + + for layer in self.convs1: + weight_norm(layer) + for layer in self.convs2: + weight_norm(layer) + + def remove_weight_norm(self): + for layer in self.convs1: + nn.utils.remove_weight_norm(layer) + for layer in self.convs2: + nn.utils.remove_weight_norm(layer) + + def forward(self, hidden_states): + for conv1, conv2 in zip(self.convs1, self.convs2): + residual = hidden_states + hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope) + hidden_states = conv1(hidden_states) + hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope) + hidden_states = conv2(hidden_states) + hidden_states = hidden_states + residual + return hidden_states + + +class SeamlessM4TVariancePredictor(nn.Module): + def __init__(self, config): + super().__init__() + + embed_dim = config.unit_embed_dim + kernel_size = config.variance_predictor_kernel_size + var_pred_dropout = config.var_pred_dropout + + self.conv1 = nn.Conv1d( + embed_dim, + embed_dim, + kernel_size=kernel_size, + padding=(kernel_size - 1) // 2, + ) + self.activation_fuction = nn.ReLU() + self.ln1 = nn.LayerNorm(embed_dim) + self.dropout_module = nn.Dropout(p=var_pred_dropout) + self.conv2 = nn.Conv1d( + embed_dim, + embed_dim, + kernel_size=kernel_size, + padding=1, + ) + self.ln2 = nn.LayerNorm(embed_dim) + self.proj = nn.Linear(embed_dim, 1) + + def forward(self, hidden_states: Tensor) -> Tensor: + # Input: B x T x C; Output: B x T + hidden_states = self.conv1(hidden_states.transpose(1, 2)) + hidden_states = self.activation_fuction(hidden_states).transpose(1, 2) + hidden_states = self.dropout_module(self.ln1(hidden_states)) + hidden_states = self.conv2(hidden_states.transpose(1, 2)) + hidden_states = self.activation_fuction(hidden_states).transpose(1, 2) + hidden_states = self.dropout_module(self.ln2(hidden_states)) + return self.proj(hidden_states).squeeze(dim=2) + + +class SeamlessM4THifiGan(nn.Module): + def __init__(self, config: SeamlessM4TConfig): + super().__init__() + model_in_dim = config.unit_embed_dim + config.lang_embed_dim + config.spkr_embed_dim + self.leaky_relu_slope = config.leaky_relu_slope + self.num_kernels = len(config.resblock_kernel_sizes) + self.num_upsamples = len(config.upsample_rates) + self.conv_pre = nn.Conv1d( + model_in_dim, + config.upsample_initial_channel, + kernel_size=7, + stride=1, + padding=3, + ) + + self.upsampler = nn.ModuleList() + for i, (upsample_rate, kernel_size) in enumerate(zip(config.upsample_rates, config.upsample_kernel_sizes)): + self.upsampler.append( + nn.ConvTranspose1d( + config.upsample_initial_channel // (2**i), + config.upsample_initial_channel // (2 ** (i + 1)), + kernel_size=kernel_size, + stride=upsample_rate, + padding=(kernel_size - upsample_rate) // 2, + ) + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.upsampler)): + channels = config.upsample_initial_channel // (2 ** (i + 1)) + for kernel_size, dilation in zip(config.resblock_kernel_sizes, config.resblock_dilation_sizes): + self.resblocks.append(HifiGanResidualBlock(channels, kernel_size, dilation, config.leaky_relu_slope)) + + self.conv_post = nn.Conv1d(channels, 1, kernel_size=7, stride=1, padding=3) + + def forward(self, input_embeds: torch.FloatTensor) -> torch.FloatTensor: + r""" + Converts a log-mel spectrogram into a speech waveform. Passing a batch of log-mel spectrograms returns a batch + of speech waveforms. Passing a single, un-batched log-mel spectrogram returns a single, un-batched speech + waveform. + + Args: + spectrogram (`torch.FloatTensor`): + Tensor containing the log-mel spectrograms. Can be batched and of shape `(batch_size, sequence_length, + model_in_dim)`, or un-batched and of shape `(sequence_length, model_in_dim)`. Note that `model_in_dim` + is the sum of `config.unit_embed_dim`, `config.lang_embed_dim` and `config.spkr_embed_dim`. + + Returns: + `torch.FloatTensor`: Tensor containing the speech waveform. If the input spectrogram is batched, will be of + shape `(batch_size, num_frames,)`. If un-batched, will be of shape `(num_frames,)`. + """ + + hidden_states = self.conv_pre(input_embeds) + for i in range(self.num_upsamples): + hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope) + hidden_states = self.upsampler[i](hidden_states) + + res_state = self.resblocks[i * self.num_kernels](hidden_states) + for j in range(1, self.num_kernels): + res_state += self.resblocks[i * self.num_kernels + j](hidden_states) + hidden_states = res_state / self.num_kernels + + hidden_states = nn.functional.leaky_relu(hidden_states) + hidden_states = self.conv_post(hidden_states) + hidden_states = torch.tanh(hidden_states) + + # remove seq-len dim since this collapses to 1 + waveform = hidden_states.squeeze(1) + + return waveform + + +@add_start_docstrings( + """Code HiFi-GAN vocoder as described in this [repository](https://github.com/facebookresearch/speech-resynthesis).""", + HIFIGAN_START_DOCSTRING, +) +class SeamlessM4TCodeHifiGan(PreTrainedModel): + config_class = SeamlessM4TConfig + main_input_name = "input_embeds" + _no_split_modules = [] + + def __init__(self, config): + super().__init__(config) + + self.pad_token_id = config.t2u_pad_token_id + self.dur_predictor = SeamlessM4TVariancePredictor(config) + + self.unit_embedding = nn.Embedding(config.unit_hifi_gan_vocab_size, config.unit_embed_dim) + self.speaker_embedding = nn.Embedding(config.vocoder_num_spkrs, config.spkr_embed_dim) + self.language_embedding = nn.Embedding(config.vocoder_num_langs, config.lang_embed_dim) + + self.hifi_gan = SeamlessM4THifiGan(config) + + # Initialize weights and apply final processing + self.post_init() + + def _get_dur_output_lengths(self, input_ids, dur_out): + """ + Computes the output length after the duration layer. + """ + unit_lengths = (input_ids != self.pad_token_id).sum(1) + + # take care of edge cases where no padding or too many padding + unit_lengths = torch.clamp(unit_lengths, 0, dur_out.shape[1] - 1) + + cumulative_dur_out = torch.cumsum(dur_out, dim=1) + unit_lengths = cumulative_dur_out.gather(dim=1, index=unit_lengths.unsqueeze(1)).squeeze() + + return unit_lengths + + def _get_output_hifigan_lengths(self, input_lengths: Union[torch.LongTensor, int]): + """ + Computes the output length of the hifigan convolutional layers + """ + + def _conv_out_length(input_length, kernel_size, stride, pad, dilation=1): + # 1D convolutional layer output length formula taken + # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html + return ( + torch.div(input_length + 2 * pad - dilation * (kernel_size - 1) - 1, stride, rounding_mode="floor") + 1 + ) + + def _transpose_conv_out_length(input_length, kernel_size, stride, pad, dilation=1): + return (input_length - 1) * stride - 2 * pad + dilation * (kernel_size - 1) + 1 + + # conv_pre + input_lengths = _conv_out_length(input_lengths, 7, 1, 3) + + # upsampler + for i, (upsample_rate, kernel_size) in enumerate( + zip(self.config.upsample_rates, self.config.upsample_kernel_sizes) + ): + input_lengths = _transpose_conv_out_length( + input_lengths, kernel_size, upsample_rate, (kernel_size - upsample_rate) // 2 + ) + + # resblock + for i in range(len(self.config.upsample_rates)): + for kernel_size, dilation in zip(self.config.resblock_kernel_sizes, self.config.resblock_dilation_sizes): + for dil in dilation: + input_lengths = _conv_out_length( + input_lengths, kernel_size, 1, (kernel_size - 1) * dil // 2, dilation=dil + ) + + for dil in dilation: + input_lengths = _conv_out_length(input_lengths, kernel_size, 1, (kernel_size - 1) // 2, dilation=1) + + # conv_post + input_lengths = _conv_out_length(input_lengths, 7, 1, 3) + + return input_lengths + + def forward( + self, input_ids: torch.LongTensor, spkr_id: torch.Tensor, lang_id: torch.Tensor + ) -> Tuple[torch.Tensor]: + """ + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`SeamlessM4TTextToUnitForConditionalGeneration`]. [What are input + IDs?](../glossary#input-ids) + spkr_id (`int`, *optional*): + The id of the speaker used for speech synthesis. Must be lower than `config.vocoder_num_spkrs`. + tgt_lang (`str`, *optional*): + The language id to use as target language for translation. + """ + hidden_states = self.unit_embedding(input_ids).transpose(1, 2) + spkr = self.speaker_embedding(spkr_id).transpose(1, 2) + lang = self.language_embedding(lang_id).transpose(1, 2) + + log_dur_pred = self.dur_predictor(hidden_states.transpose(1, 2)) + dur_out = torch.clamp(torch.round((torch.expm1(log_dur_pred))).long(), min=1) + # B x C x T + if hidden_states.size(0) == 1: + hidden_states = torch.repeat_interleave(hidden_states, dur_out.view(-1), dim=2) + else: + # if batched sample, need to interleave per sample, and pad -> loss of parallelism + if hidden_states.shape[0] > 1 and self.training: + logger.warning( + """`self.training=True` and you use batching. You lose parallelism during the hifigan + forward pass because the samples are interleaved.""" + ) + hidden_states = [ + torch.repeat_interleave(hidden_state, duration, dim=-1).transpose(0, 1) + for (hidden_state, duration) in zip(hidden_states, dur_out) + ] + + hidden_states = nn.utils.rnn.pad_sequence(hidden_states, batch_first=True).transpose(1, 2) + + spkr = spkr.repeat(1, 1, hidden_states.shape[-1]) + lang = lang.repeat(1, 1, hidden_states.shape[-1]) + hidden_states = torch.cat([lang, hidden_states, spkr], dim=1) + + hidden_states = self.hifi_gan(hidden_states) + + unit_lengths = self._get_dur_output_lengths(input_ids, dur_out) + lengths = self._get_output_hifigan_lengths(unit_lengths) + + return hidden_states, lengths + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, (nn.Linear, nn.Conv1d, nn.ConvTranspose1d)): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def apply_weight_norm(self): + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + + weight_norm(self.hifi_gan.conv_pre) + for layer in self.hifi_gan.upsampler: + weight_norm(layer) + for layer in self.hifi_gan.resblocks: + layer.apply_weight_norm() + weight_norm(self.hifi_gan.conv_post) + + def remove_weight_norm(self): + nn.utils.remove_weight_norm(self.hifi_gan.conv_pre) + for layer in self.hifi_gan.upsampler: + nn.utils.remove_weight_norm(layer) + for layer in self.hifi_gan.resblocks: + layer.remove_weight_norm() + nn.utils.remove_weight_norm(self.hifi_gan.conv_post) + + +############ WHOLE MODEL related code ################ + + +@add_start_docstrings( + "The text-to-text SeamlessM4T Model transformer which can be used for T2TT.", + SEAMLESS_M4T_START_DOCSTRING, +) +class SeamlessM4TForTextToText(SeamlessM4TPreTrainedModel, GenerationMixin): + _keys_to_ignore_on_load_missing = ["speech_encoder", "t2u_model", "vocoder"] + main_input_name = "input_ids" + + _tied_weights_keys = [ + "lm_head.weight", + "text_encoder.embed_tokens.weight", + "text_decoder.embed_tokens.weight", + ] + + def __init__(self, config: SeamlessM4TConfig): + super().__init__(config) + + self.shared = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) + + self.text_encoder = SeamlessM4TEncoder(config, self.shared) + self.text_decoder = SeamlessM4TDecoder(config, self.shared) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_encoder(self): + return self.text_encoder + + def get_decoder(self): + return self.text_decoder + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_input_embeddings(self): + return self.text_decoder.embed_tokens + + def set_input_embeddings(self, value): + self.text_encoder.embed_tokens = value + self.text_decoder.embed_tokens = value + self.shared = value + + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.text_encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.text_decoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.lm_head, self.shared) + + @add_start_docstrings_to_model_forward(M4T_TEXT_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Seq2SeqLMOutput, Tuple[torch.FloatTensor]]: + if labels is not None: + if use_cache: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if encoder_outputs is None: + encoder_outputs = self.text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + encoder_attention_mask = attention_mask + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.text_decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + lm_logits = self.lm_head(decoder_outputs[0]) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + labels = labels.to(lm_logits.device) + masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + outputs = decoder_outputs + encoder_outputs + output = (lm_logits,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return Seq2SeqLMOutput( + loss=masked_lm_loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + def generate( + self, + input_ids=None, + tgt_lang=None, + generation_config=None, + logits_processor=None, + stopping_criteria=None, + prefix_allowed_tokens_fn=None, + synced_gpus=False, + **kwargs, + ): + """ + Generates sequences of token ids. + + + + Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the + model's default generation configuration. You can override any `generation_config` by passing the corresponding + parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`. + + For an overview of generation strategies and code examples, check out the [following + guide](./generation_strategies). + + + + Parameters: + input_ids (`torch.Tensor` of varying shape depending on the modality, *optional*): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`SeamlessM4TTokenizer`] or [`SeamlessM4TProcessor`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + tgt_lang (`str`, *optional*): + The language to use as target language for translation. + generation_config (`~generation.GenerationConfig`, *optional*): + The generation configuration to be used as base parametrization for the generation call. `**kwargs` + passed to generate matching the attributes of `generation_config` will override them. If + `generation_config` is not provided, the default will be used, which had the following loading + priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model + configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s + default values, whose documentation should be checked to parameterize generation. + logits_processor (`LogitsProcessorList`, *optional*): + Custom logits processors that complement the default logits processors built from arguments and + generation config. If a logit processor is passed that is already created with the arguments or a + generation config an error is thrown. This feature is intended for advanced users. + stopping_criteria (`StoppingCriteriaList`, *optional*): + Custom stopping criteria that complement the default stopping criteria built from arguments and a + generation config. If a stopping criteria is passed that is already created with the arguments or a + generation config an error is thrown. This feature is intended for advanced users. + prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*): + If provided, this function constraints the beam search to allowed tokens only at each step. If not + provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and + `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned + on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful + for constrained generation conditioned on the prefix, as described in [Autoregressive Entity + Retrieval](https://arxiv.org/abs/2010.00904). + synced_gpus (`bool`, *optional*, defaults to `False`): + Whether to continue running the while loop until max_length (needed to avoid deadlocking with + `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3). + kwargs (`Dict[str, Any]`, *optional*): + Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be + forwarded to the `forward` function of the model. + + Return: + [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` + or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. The possible + [`~utils.ModelOutput`] types are: + - [`~generation.GenerateEncoderDecoderOutput`], + - [`~generation.GenerateBeamEncoderDecoderOutput`] + """ + # prepare text_decoder_input_ids + text_decoder_input_ids = kwargs.pop("decoder_input_ids", None) + # overwrite text_decoder_input_ids if tgt_lang is passed. The latter gets priority over decoder_input_ids. + if tgt_lang is not None: + batch_size = len(input_ids) if input_ids is not None else len(kwargs.get("inputs_embeds")) + + if hasattr(self.generation_config, "text_decoder_lang_to_code_id"): + # also accept __xxx__ + tgt_lang = tgt_lang.replace("__", "") + if tgt_lang not in self.generation_config.text_decoder_lang_to_code_id: + raise ValueError( + f"""`tgt_lang={tgt_lang}` is not supported by this model. Please specify a `tgt_lang` in + {", ".join(self.generation_config.text_decoder_lang_to_code_id.keys())}""" + ) + # tgt_lang gets priority over decoder input ids + text_tgt_lang_id = self.generation_config.text_decoder_lang_to_code_id.get(tgt_lang) + text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size, device=self.device) + else: + raise ValueError( + """This model generation config doesn't have a `text_decoder_lang_to_code_id` key which maps + the target language to the right token id. Make sure to load the right generation config.""" + ) + else: + # only a warning, otherwise errors appear in the tests + logger.warning( + """You must either specify a `tgt_lang` or pass a correct `text_decoder_input_ids` to get + a correct generation, otherwise the generation will probably make no sense.""" + ) + + return super().generate( + input_ids, + generation_config, + logits_processor, + stopping_criteria, + prefix_allowed_tokens_fn, + synced_gpus, + decoder_input_ids=text_decoder_input_ids, + **kwargs, + ) + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + # cached cross_attention states don't have to be reordered -> they are always the same + reordered_past += ( + tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], + ) + return reordered_past + + +@add_start_docstrings( + "The speech-to-text SeamlessM4T Model transformer which can be used for S2TT.", + SEAMLESS_M4T_START_DOCSTRING, +) +class SeamlessM4TForSpeechToText(SeamlessM4TPreTrainedModel, GenerationMixin): + _keys_to_ignore_on_load_missing = ["text_decoder", "t2u_model", "vocoder"] + main_input_name = "input_features" + + _tied_weights_keys = [ + "lm_head.weight", + "text_decoder.embed_tokens.weight", + ] + + def __init__(self, config: SeamlessM4TConfig): + super().__init__(config) + + self.shared = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) + self.speech_encoder = SeamlessM4TSpeechEncoder(config) + self.text_decoder = SeamlessM4TDecoder(config, self.shared) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_encoder(self): + return self.speech_encoder + + def get_decoder(self): + return self.text_decoder + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_input_embeddings(self): + return self.text_decoder.embed_tokens + + def set_input_embeddings(self, value): + self.text_decoder.embed_tokens = value + + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.text_decoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.lm_head, self.shared) + + @add_start_docstrings_to_model_forward(M4T_SPEECH_INPUTS_DOCSTRING) + def forward( + self, + input_features: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Seq2SeqLMOutput, Tuple[torch.FloatTensor]]: + if labels is not None: + if use_cache: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if encoder_outputs is None: + encoder_outputs = self.speech_encoder( + input_features=input_features, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + encoder_attention_mask = attention_mask + if attention_mask is not None: + sub_sampled_lengths = self._compute_sub_sample_lengths_from_attention_mask(attention_mask).to( + encoder_outputs[0].device + ) + encoder_attention_mask = _compute_new_attention_mask( + hidden_states=encoder_outputs[0], seq_lens=sub_sampled_lengths + ) + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.text_decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + lm_logits = self.lm_head(decoder_outputs[0]) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + labels = labels.to(lm_logits.device) + masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + outputs = decoder_outputs + encoder_outputs + output = (lm_logits,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return Seq2SeqLMOutput( + loss=masked_lm_loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + def generate( + self, + input_features=None, + tgt_lang=None, + generation_config=None, + logits_processor=None, + stopping_criteria=None, + prefix_allowed_tokens_fn=None, + synced_gpus=False, + **kwargs, + ): + """ + Generates sequences of token ids. + + + + Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the + model's default generation configuration. You can override any `generation_config` by passing the corresponding + parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`. + + For an overview of generation strategies and code examples, check out the [following + guide](./generation_strategies). + + + + Parameters: + input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_banks)`): + Input audio features. This should be returnes by the [`SeamlessM4TFeatureExtractor`] class or the + [`SeamlessM4TProcessor`] class. See [`SeamlessM4TFeatureExtractor.__call__`] for details. + + tgt_lang (`str`, *optional*): + The language to use as target language for translation. + generation_config (`~generation.GenerationConfig`, *optional*): + The generation configuration to be used as base parametrization for the generation call. `**kwargs` + passed to generate matching the attributes of `generation_config` will override them. If + `generation_config` is not provided, the default will be used, which had the following loading + priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model + configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s + default values, whose documentation should be checked to parameterize generation. + logits_processor (`LogitsProcessorList`, *optional*): + Custom logits processors that complement the default logits processors built from arguments and + generation config. If a logit processor is passed that is already created with the arguments or a + generation config an error is thrown. This feature is intended for advanced users. + stopping_criteria (`StoppingCriteriaList`, *optional*): + Custom stopping criteria that complement the default stopping criteria built from arguments and a + generation config. If a stopping criteria is passed that is already created with the arguments or a + generation config an error is thrown. This feature is intended for advanced users. + prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*): + If provided, this function constraints the beam search to allowed tokens only at each step. If not + provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and + `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned + on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful + for constrained generation conditioned on the prefix, as described in [Autoregressive Entity + Retrieval](https://arxiv.org/abs/2010.00904). + synced_gpus (`bool`, *optional*, defaults to `False`): + Whether to continue running the while loop until max_length (needed to avoid deadlocking with + `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3). + kwargs (`Dict[str, Any]`, *optional*): + Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be + forwarded to the `forward` function of the model. + + Return: + [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` + or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. The possible + [`~utils.ModelOutput`] types are: + - [`~generation.GenerateEncoderDecoderOutput`], + - [`~generation.GenerateBeamEncoderDecoderOutput`] + """ + text_decoder_input_ids = kwargs.pop("decoder_input_ids", None) + # overwrite text_decoder_input_ids if tgt_lang is passed. The latter gets priority over decoder_input_ids. + input_features = input_features if input_features is not None else kwargs.pop("inputs") + if tgt_lang is not None: + inputs = kwargs.get("input_embeds") if input_features is None else input_features + inputs = ( + inputs + if inputs is not None + else kwargs.get("encoder_outputs", {"last_hidden_state": None})["last_hidden_state"] + ) + batch_size = len(inputs) + + if hasattr(self.generation_config, "text_decoder_lang_to_code_id"): + # also accept __xxx__ + tgt_lang = tgt_lang.replace("__", "") + if tgt_lang not in self.generation_config.text_decoder_lang_to_code_id: + raise ValueError( + f"""`tgt_lang={tgt_lang}` is not supported by this model. Please specify a `tgt_lang` in + {", ".join(self.generation_config.text_decoder_lang_to_code_id.keys())}""" + ) + # tgt_lang gets priority over decoder input ids + text_tgt_lang_id = self.generation_config.text_decoder_lang_to_code_id.get(tgt_lang) + text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size, device=self.device) + else: + raise ValueError( + """This model generation config doesn't have a `text_decoder_lang_to_code_id` key which maps + the target language to the right token id. Make sure to load the right generation config.""" + ) + else: + # only a warning, otherwise errors appear in the tests + logger.warning( + """You must either specify a `tgt_lang` or pass a correct `text_decoder_input_ids` to get + a correct generation, otherwise the generation will probably make no sense.""" + ) + return super().generate( + input_features, + generation_config, + logits_processor, + stopping_criteria, + prefix_allowed_tokens_fn, + synced_gpus, + decoder_input_ids=text_decoder_input_ids, + **kwargs, + ) + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + # cached cross_attention states don't have to be reordered -> they are always the same + reordered_past += ( + tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], + ) + return reordered_past + + +@add_start_docstrings( + "The text-to-speech SeamlessM4T Model transformer which can be used for T2ST.", + SEAMLESS_M4T_START_DOCSTRING, +) +class SeamlessM4TForTextToSpeech(SeamlessM4TPreTrainedModel, GenerationMixin): + _keys_to_ignore_on_load_missing = ["speech_encoder"] + main_input_name = "input_ids" + + _tied_weights_keys = [ + "lm_head.weight", + "text_encoder.embed_tokens.weight", + "text_decoder.embed_tokens.weight", + ] + + def __init__(self, config: SeamlessM4TConfig): + super().__init__(config) + + self.shared = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) + + self.text_encoder = SeamlessM4TEncoder(config, self.shared) + self.text_decoder = SeamlessM4TDecoder(config, self.shared) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + self.t2u_model = SeamlessM4TTextToUnitForConditionalGeneration(config) + self.vocoder = SeamlessM4TCodeHifiGan(config) + + def get_encoder(self): + return self.text_encoder + + def get_decoder(self): + return self.text_decoder + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_input_embeddings(self): + return self.text_decoder.embed_tokens + + def set_input_embeddings(self, value): + self.text_encoder.embed_tokens = value + self.text_decoder.embed_tokens = value + self.shared = value + + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.text_encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.text_decoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.lm_head, self.shared) + + @add_start_docstrings_to_model_forward(M4T_TEXT_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Seq2SeqLMOutput, Tuple[torch.FloatTensor]]: + if labels is not None: + if use_cache: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if encoder_outputs is None: + # if encoder_outputs is not None, it's probably used within a .generate method so no need to warn + logger.warning( + "This is the same forward method as `SeamlessM4TForTextToText`." + "It doesn't use the text-to-unit model `SeamlessM4TTextToUnitForConditionalGeneration`." + "If you want to generate speech, use the `.generate` method." + ) + encoder_outputs = self.text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + encoder_attention_mask = attention_mask + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.text_decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + lm_logits = self.lm_head(decoder_outputs[0]) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + labels = labels.to(lm_logits.device) + masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + outputs = decoder_outputs + encoder_outputs + output = (lm_logits,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return Seq2SeqLMOutput( + loss=masked_lm_loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + @torch.no_grad() + def generate( + self, + input_ids: Optional[torch.Tensor] = None, + return_intermediate_token_ids: Optional[bool] = None, + tgt_lang: Optional[str] = None, + spkr_id: Optional[int] = 0, + **kwargs, + ) -> Union[torch.Tensor, SeamlessM4TGenerationOutput]: + """ + Generates translated audio waveforms. + + + + This method successively calls the `.generate` function of two different sub-models. You can specify keyword + arguments at two different levels: general arguments that will be passed to both models, or prefixed arguments + that will be passed to one of them. + + For example, calling `.generate(input_ids, num_beams=4, speech_do_sample=True)` will successively perform + beam-search decoding on the text model, and multinomial beam-search sampling on the speech model. + + For an overview of generation strategies and code examples, check out the [following + guide](./generation_strategies). + + + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`SeamlessM4TTokenizer`] or [`SeamlessM4TProcessor`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + return_intermediate_token_ids (`bool`, *optional*): + If `True`, also returns the intermediate generated text and unit tokens. Set to `True` if you also want + to get translated text alongside the audio. + tgt_lang (`str`, *optional*): + The language to use as target language for translation. + spkr_id (`int`, *optional*, defaults to 0): + The id of the speaker used for speech synthesis. Must be lower than `config.vocoder_num_spkrs`. + kwargs (*optional*): + Remaining dictionary of keyword arguments that will be passed to [`GenerationMixin.generate`]. Keyword + arguments are of two types: + + - Without a prefix, they will be entered as `**kwargs` for the `generate` method of each sub-model, + except for `decoder_input_ids` which will only be passed through the text components. + - With a *text_* or *speech_* prefix, they will be input for the `generate` method of the + text model and speech model respectively. It has the priority over the keywords without a prefix. + + This means you can, for example, specify a generation strategy for one generation but not for the + other. + + + Returns: + `Union[SeamlessM4TGenerationOutput, Tuple[Tensor]]`: + - If `return_intermediate_token_ids`, returns [`SeamlessM4TGenerationOutput`]. + - If not `return_intermediate_token_ids`, returns a tuple composed of waveforms of shape `(batch_size, + sequence_length)`and and `waveform_lengths` which gives the length of each sample. + """ + batch_size = len(input_ids) if input_ids is not None else len(kwargs.get("inputs_embeds")) + + if tgt_lang is None: + raise ValueError("You must specify a `tgt_lang` to generate translated speech.") + else: + # also accept __xxx__ + tgt_lang = tgt_lang.replace("__", "") + for key in ["text_decoder_lang_to_code_id", "t2u_lang_code_to_id", "vocoder_lang_code_to_id"]: + lang_code_to_id = getattr(self.generation_config, key, None) + if lang_code_to_id is None: + raise ValueError( + f"""This model generation config doesn't have a `{key}` key which maps the target language + to the right token id. Make sure to load the right generation config.""" + ) + elif tgt_lang not in lang_code_to_id: + raise ValueError( + f"""`tgt_lang={tgt_lang}` is not supported by this model. + Please specify a `tgt_lang` in {",".join(lang_code_to_id.keys())}. Note that SeamlessM4T supports + more languages for text translation than for speech synthesis.""" + ) + + kwargs_text, kwargs_speech = format_speech_generation_kwargs(kwargs) + kwargs_text["output_hidden_states"] = True + kwargs_text["return_dict_in_generate"] = True + kwargs_text["output_scores"] = True + + text_decoder_input_ids = kwargs_text.get("decoder_input_ids") + + # overwrite text_decoder_input_ids if tgt_lang is passed. The latter gets priority over decoder_input_ids. + text_tgt_lang_id = self.generation_config.text_decoder_lang_to_code_id.get(tgt_lang) + text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size, device=self.device) + + kwargs_text["decoder_input_ids"] = text_decoder_input_ids + + # first generation + text_generation_output = super().generate(input_ids, **kwargs_text) + sequences = text_generation_output.sequences + + # prepare second generation + num_return_sequences = len(sequences) // batch_size + attention_mask = kwargs_speech.get("attention_mask", kwargs_text.get("attention_mask", None)) + + encoder_hidden_states = text_generation_output.encoder_hidden_states[-1] + + # take care of num_return_sequences + # take most probable hidden states per batch of return_sequences + # (batch_size*num_return_sequences, ...) -> (batch_size,...) + if num_return_sequences > 1: + idx_most_probable_sequences_per_batch = text_generation_output.sequences_scores.view(batch_size, -1) + idx_most_probable_sequences_per_batch = idx_most_probable_sequences_per_batch.argmax(-1) + idx_most_probable_sequences_per_batch = ( + idx_most_probable_sequences_per_batch + + torch.arange(batch_size, device=self.device) * num_return_sequences + ) + sequences = sequences[idx_most_probable_sequences_per_batch] + + # get decoder last hidden state - must do a pass through the text decoder + t2u_input_embeds = self.text_decoder( + input_ids=sequences, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=attention_mask, + ).last_hidden_state + + pad_token_id = self.generation_config.pad_token_id + + # Compute new attention mask + seq_lens = (sequences != pad_token_id).int().sum(1) + t2u_model_attention_mask = _compute_new_attention_mask(t2u_input_embeds, seq_lens) + kwargs_speech["attention_mask"] = t2u_model_attention_mask + + # Compute t2u decoder_input_ids + t2u_decoder_input_ids = kwargs_speech.get("decoder_input_ids") + t2u_tgt_lang_id = self.generation_config.t2u_lang_code_to_id.get(tgt_lang) + t2u_decoder_input_ids = torch.tensor( + [[self.config.t2u_eos_token_id, t2u_tgt_lang_id]] * batch_size, device=self.device + ) + kwargs_speech["decoder_input_ids"] = t2u_decoder_input_ids + # second generation + unit_ids = self.t2u_model.generate(inputs_embeds=t2u_input_embeds, **kwargs_speech) + output_unit_ids = unit_ids.detach().clone() + + # get rid of t2u_decoder_input_ids + unit_ids = unit_ids[:, kwargs_speech["decoder_input_ids"].shape[1] :] + # replace eos per pad + unit_ids[unit_ids == self.config.t2u_eos_token_id] = self.config.t2u_pad_token_id + # offset of control symbols + unit_ids = torch.where( + unit_ids == self.config.t2u_pad_token_id, unit_ids, unit_ids - self.config.vocoder_offset + ) + + vocoder_tgt_lang_id = self.generation_config.vocoder_lang_code_to_id.get(tgt_lang) + vocoder_tgt_lang_id = torch.tensor([[vocoder_tgt_lang_id]] * len(unit_ids), device=self.device) + + spkr_id = torch.tensor([[spkr_id]] * len(unit_ids), device=self.device) + + waveform, waveform_lengths = self.vocoder(input_ids=unit_ids, spkr_id=spkr_id, lang_id=vocoder_tgt_lang_id) + + if return_intermediate_token_ids: + return SeamlessM4TGenerationOutput( + waveform=waveform, + waveform_lengths=waveform_lengths, + sequences=sequences, + unit_sequences=output_unit_ids, + ) + + return waveform, waveform_lengths + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + # cached cross_attention states don't have to be reordered -> they are always the same + reordered_past += ( + tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], + ) + return reordered_past + + +@add_start_docstrings( + "The speech-to-speech SeamlessM4T Model transformer which can be used for S2ST.", + SEAMLESS_M4T_START_DOCSTRING, +) +class SeamlessM4TForSpeechToSpeech(SeamlessM4TPreTrainedModel, GenerationMixin): + _keys_to_ignore_on_load_missing = ["text_encoder"] + main_input_name = "input_features" + + _tied_weights_keys = [ + "lm_head.weight", + "text_decoder.embed_tokens.weight", + ] + + def __init__(self, config): + super().__init__(config) + + self.shared = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) + self.speech_encoder = SeamlessM4TSpeechEncoder(config) + self.text_decoder = SeamlessM4TDecoder(config, self.shared) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + self.t2u_model = SeamlessM4TTextToUnitForConditionalGeneration(config) + self.vocoder = SeamlessM4TCodeHifiGan(config) + + def get_encoder(self): + return self.speech_encoder + + def get_decoder(self): + return self.text_decoder + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_input_embeddings(self): + return self.text_decoder.embed_tokens + + def set_input_embeddings(self, value): + self.text_decoder.embed_tokens = value + + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.text_decoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.lm_head, self.shared) + + @add_start_docstrings_to_model_forward(M4T_SPEECH_INPUTS_DOCSTRING) + def forward( + self, + input_features: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Seq2SeqLMOutput, Tuple[torch.FloatTensor]]: + if labels is not None: + if use_cache: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if encoder_outputs is None: + # if encoder_outputs is not None, it's probably used within a .generate method so no need to warn + logger.warning( + "This is the same forward method as `SeamlessM4TForSpeechToText`. It doesn't use `self.t2u_model`." + "If you want to generate speech, use the `generate` method." + ) + + encoder_outputs = self.speech_encoder( + input_features=input_features, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + encoder_attention_mask = attention_mask + if attention_mask is not None: + sub_sampled_lengths = self._compute_sub_sample_lengths_from_attention_mask(attention_mask).to( + encoder_outputs[0].device + ) + encoder_attention_mask = _compute_new_attention_mask( + hidden_states=encoder_outputs[0], seq_lens=sub_sampled_lengths + ) + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.text_decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + lm_logits = self.lm_head(decoder_outputs[0]) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + labels = labels.to(lm_logits.device) + masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + outputs = decoder_outputs + encoder_outputs + output = (lm_logits,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return Seq2SeqLMOutput( + loss=masked_lm_loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + @torch.no_grad() + def generate( + self, + input_features: Optional[torch.Tensor] = None, + return_intermediate_token_ids: Optional[bool] = None, + tgt_lang: Optional[str] = None, + spkr_id: Optional[int] = 0, + **kwargs, + ) -> Union[torch.Tensor, SeamlessM4TGenerationOutput]: + """ + Generates translated audio waveforms. + + + + This method successively calls the `.generate` function of two different sub-models. You can specify keyword + arguments at two different levels: general arguments that will be passed to both models, or prefixed arguments + that will be passed to one of them. + + For example, calling `.generate(input_features, num_beams=4, speech_do_sample=True)` will successively perform + beam-search decoding on the text model, and multinomial beam-search sampling on the speech model. + + For an overview of generation strategies and code examples, check out the [following + guide](./generation_strategies). + + + + Args: + input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_banks)`): + Input audio features. This should be returnes by the [`SeamlessM4TFeatureExtractor`] class or the + [`SeamlessM4TProcessor`] class. See [`SeamlessM4TFeatureExtractor.__call__`] for details. + return_intermediate_token_ids (`bool`, *optional*): + If `True`, also returns the intermediate generated text and unit tokens. Set to `True` if you also want + to get translated text alongside the audio. + tgt_lang (`str`, *optional*): + The language to use as target language for translation. + spkr_id (`int`, *optional*, defaults to 0): + The id of the speaker used for speech synthesis. Must be lower than `config.vocoder_num_spkrs`. + + kwargs (*optional*): + Remaining dictionary of keyword arguments that will be passed to [`GenerationMixin.generate`]. Keyword + arguments are of two types: + + - Without a prefix, they will be entered as `**kwargs` for the `generate` method of each sub-model, + except for `decoder_input_ids` which will only be passed through the text components. + - With a *text_* or *speech_* prefix, they will be input for the `generate` method of the + text model and speech model respectively. It has the priority over the keywords without a prefix. + + This means you can, for example, specify a generation strategy for one generation but not for the + other. + + + Returns: + `Union[SeamlessM4TGenerationOutput, Tuple[Tensor]]`: + - If `return_intermediate_token_ids`, returns [`SeamlessM4TGenerationOutput`]. + - If not `return_intermediate_token_ids`, returns a tuple composed of waveforms of shape `(batch_size, + sequence_length)`and and `waveform_lengths` which gives the length of each sample. + """ + batch_size = len(input_features) if input_features is not None else len(kwargs.get("inputs_embeds")) + + if tgt_lang is None: + raise ValueError("You must specify a `tgt_lang` to generate translated speech.") + else: + # also accept __xxx__ + tgt_lang = tgt_lang.replace("__", "") + for key in ["text_decoder_lang_to_code_id", "t2u_lang_code_to_id", "vocoder_lang_code_to_id"]: + lang_code_to_id = getattr(self.generation_config, key, None) + if lang_code_to_id is None: + raise ValueError( + f"""This model generation config doesn't have a `{key}` key which maps the target language + to the right token id. Make sure to load the right generation config.""" + ) + elif tgt_lang not in lang_code_to_id: + raise ValueError( + f"""`tgt_lang={tgt_lang}` is not supported by this model. + Please specify a `tgt_lang` in {",".join(lang_code_to_id.keys())}. Note that SeamlessM4T supports + more languages for text translation than for speech synthesis.""" + ) + + kwargs_text, kwargs_speech = format_speech_generation_kwargs(kwargs) + kwargs_text["output_hidden_states"] = True + kwargs_text["return_dict_in_generate"] = True + kwargs_text["output_scores"] = True + + text_decoder_input_ids = kwargs_text.get("decoder_input_ids") + # overwrite text_decoder_input_ids if tgt_lang is passed. The latter gets priority over decoder_input_ids. + text_tgt_lang_id = self.generation_config.text_decoder_lang_to_code_id.get(tgt_lang) + text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size, device=self.device) + + kwargs_text["decoder_input_ids"] = text_decoder_input_ids + + # first generation + text_generation_output = super().generate(input_features, **kwargs_text) + sequences = text_generation_output.sequences + + # prepare second generation + num_return_sequences = len(sequences) // batch_size + attention_mask = kwargs_speech.get("attention_mask", kwargs_text.get("attention_mask", None)) + + # get last_hidden_state from encoder + encoder_hidden_states = self.speech_encoder(input_features=input_features, attention_mask=attention_mask)[0] + + # input modality = speech so new attention mask for the decoder + if attention_mask is not None: + sub_sampled_lengths = self._compute_sub_sample_lengths_from_attention_mask(attention_mask).to( + encoder_hidden_states.device + ) + attention_mask = _compute_new_attention_mask( + hidden_states=encoder_hidden_states, seq_lens=sub_sampled_lengths + ) + + # take care of num_return_sequences + # take most probable hidden states per batch of return_sequences + # (batch_size*num_return_sequences, ...) -> (batch_size,...) + if num_return_sequences > 1: + idx_most_probable_sequences_per_batch = text_generation_output.sequences_scores.view(batch_size, -1) + idx_most_probable_sequences_per_batch = idx_most_probable_sequences_per_batch.argmax(-1) + idx_most_probable_sequences_per_batch = ( + idx_most_probable_sequences_per_batch + + torch.arange(batch_size, device=self.device) * num_return_sequences + ) + sequences = sequences[idx_most_probable_sequences_per_batch] + + # get decoder last hidden state - must do a pass through the text decoder + t2u_input_embeds = self.text_decoder( + input_ids=sequences, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=attention_mask, + ).last_hidden_state + + pad_token_id = self.generation_config.pad_token_id + + # Compute new attention mask + seq_lens = (sequences != pad_token_id).int().sum(1) + t2u_model_attention_mask = _compute_new_attention_mask(t2u_input_embeds, seq_lens) + kwargs_speech["attention_mask"] = t2u_model_attention_mask + + # Compute t2u decoder_input_ids + t2u_decoder_input_ids = kwargs_speech.get("decoder_input_ids") + t2u_tgt_lang_id = self.generation_config.t2u_lang_code_to_id.get(tgt_lang) + t2u_decoder_input_ids = torch.tensor( + [[self.config.t2u_eos_token_id, t2u_tgt_lang_id]] * batch_size, device=self.device + ) + kwargs_speech["decoder_input_ids"] = t2u_decoder_input_ids + + # second generation + unit_ids = self.t2u_model.generate(inputs_embeds=t2u_input_embeds, **kwargs_speech) + output_unit_ids = unit_ids.detach().clone() + + # get rid of t2u_decoder_input_ids + unit_ids = unit_ids[:, kwargs_speech["decoder_input_ids"].shape[1] :] + # replace eos per pad + unit_ids[unit_ids == self.config.t2u_eos_token_id] = self.config.t2u_pad_token_id + # offset of control symbols + unit_ids = torch.where( + unit_ids == self.config.t2u_pad_token_id, unit_ids, unit_ids - self.config.vocoder_offset + ) + + vocoder_tgt_lang_id = self.generation_config.vocoder_lang_code_to_id.get(tgt_lang) + vocoder_tgt_lang_id = torch.tensor([[vocoder_tgt_lang_id]] * len(unit_ids), device=self.device) + + spkr_id = torch.tensor([[spkr_id]] * len(unit_ids), device=self.device) + + waveform, waveform_lengths = self.vocoder(input_ids=unit_ids, spkr_id=spkr_id, lang_id=vocoder_tgt_lang_id) + + if return_intermediate_token_ids: + return SeamlessM4TGenerationOutput( + waveform=waveform, + waveform_lengths=waveform_lengths, + sequences=sequences, + unit_sequences=output_unit_ids, + ) + + return waveform, waveform_lengths + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + # cached cross_attention states don't have to be reordered -> they are always the same + reordered_past += ( + tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], + ) + return reordered_past + + +@add_start_docstrings( + "The original SeamlessM4T Model transformer which can be used for every tasks available (S2ST, S2TT, T2TT, T2ST).", + SEAMLESS_M4T_START_DOCSTRING, + """ + current_modality (`str`, *optional*, defaults to `"text"`): + Default modality. Used to initialize the model. + """, +) +class SeamlessM4TModel(SeamlessM4TPreTrainedModel, GenerationMixin): + _tied_weights_keys = [ + "lm_head.weight", + "text_encoder.embed_tokens.weight", + "text_decoder.embed_tokens.weight", + ] + + def __init__(self, config, current_modality="text"): + super().__init__(config) + + self.shared = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) + + self.text_encoder = SeamlessM4TEncoder(config, self.shared) + self.speech_encoder = SeamlessM4TSpeechEncoder(config) + self.text_decoder = SeamlessM4TDecoder(config, self.shared) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + self.current_modality = current_modality + if current_modality == "speech": + self.main_input_name = "input_features" + + # these models already call post_init in their initialization + self.t2u_model = SeamlessM4TTextToUnitForConditionalGeneration(config) + self.vocoder = SeamlessM4TCodeHifiGan(config) + + def set_modality(self, modality="text"): + if modality == "text": + self.main_input_name = "input_ids" + self.current_modality = "text" + elif modality == "speech": + self.main_input_name = "input_features" + self.current_modality = "speech" + else: + raise ValueError(f"`modality={modality}` is not a valid modality. It must be `text` or `speech`.") + + def get_encoder(self): + if self.current_modality == "text": + return self.text_encoder + else: + return self.speech_encoder + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_input_embeddings(self): + return self.text_decoder.embed_tokens + + def set_input_embeddings(self, value): + self.text_encoder.embed_tokens = value + self.text_decoder.embed_tokens = value + self.shared = value + + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.text_encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.text_decoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.lm_head, self.shared) + + @add_start_docstrings_to_model_forward(M4T_MODEL_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + input_features: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Seq2SeqLMOutput, Tuple[torch.FloatTensor]]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None: + if use_cache: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + if input_ids is None and input_features is None and inputs_embeds is None and encoder_outputs is None: + raise ValueError( + "`input_ids`,`input_features`, `inputs_embeds` and `encoder_outputs` are all empty. Make sure at least one of them is not." + ) + elif input_features is not None: + if input_ids is not None: + logger.warning( + "`input_ids` is not `None` but `input_features` has been given." + "`input_features` will be used in priority through the `speech_encoder`. " + "Make sure that `input_features` and `input_ids` are mutually exclusive." + ) + + if inputs_embeds is not None: + logger.warning( + "`inputs_embeds` is not `None` but `input_features` has been given." + "`input_features` will be used in priority through `speech_encoder`. " + "`inputs_embeds` will be ignored." + ) + + # if encoder_outputs is not None, it's probably used within a .generate method so no need to warn + logger.warning( + "This calls the same method `forward` as `SeamlessM4TForTextToText` and `SeamlessM4TForSpeechToText`" + "depending on the input modality. If you want to generate speech, use the `generate` method." + ) + + self.set_modality("speech") + + encoder_outputs = self.speech_encoder( + input_features=input_features, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + elif input_ids is not None or inputs_embeds is not None: + # if encoder_outputs is not None, it's probably used within a .generate method so no need to warn + logger.warning( + "This calls the same method `forward` as `SeamlessM4TForTextToText` and `SeamlessM4TForSpeechToText`" + "depending on the input modality. If you want to generate speech, use the `generate` method." + ) + self.set_modality("text") + encoder_outputs = self.text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + encoder_attention_mask = attention_mask + # input modality = speech so new attention mask + if self.current_modality == "speech" and attention_mask is not None: + sub_sampled_lengths = self._compute_sub_sample_lengths_from_attention_mask(attention_mask).to( + encoder_outputs[0].device + ) + encoder_attention_mask = _compute_new_attention_mask( + hidden_states=encoder_outputs[0], seq_lens=sub_sampled_lengths + ) + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.text_decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + lm_logits = self.lm_head(decoder_outputs[0]) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + labels = labels.to(lm_logits.device) + masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + outputs = decoder_outputs + encoder_outputs + output = (lm_logits,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return Seq2SeqLMOutput( + loss=masked_lm_loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + @torch.no_grad() + def generate( + self, + input_ids: Optional[torch.Tensor] = None, + input_features: Optional[torch.Tensor] = None, + return_intermediate_token_ids: Optional[bool] = None, + tgt_lang: Optional[str] = None, + spkr_id: Optional[int] = 0, + generate_speech: Optional[bool] = True, + **kwargs, + ) -> Union[torch.Tensor, SeamlessM4TGenerationOutput]: + """ + Generates translated token ids and/or translated audio waveforms. + + + + This method successively calls the `.generate` function of two different sub-models. You can specify keyword + arguments at two different levels: general arguments that will be passed to both models, or prefixed arguments + that will be passed to one of them. + + For example, calling `.generate(input_ids=input_ids, num_beams=4, speech_do_sample=True)` will successively + perform beam-search decoding on the text model, and multinomial beam-search sampling on the speech model. + + For an overview of generation strategies and code examples, check out the [following + guide](./generation_strategies). + + + + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`SeamlessM4TTokenizer`] or [`SeamlessM4TProcessor`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_banks)`, *optional*): + Input audio features. This should be returnes by the [`SeamlessM4TFeatureExtractor`] class or the + [`SeamlessM4TProcessor`] class. See [`SeamlessM4TFeatureExtractor.__call__`] for details. + return_intermediate_token_ids (`bool`, *optional*): + If `True`, also returns the intermediate generated text and unit tokens. Set to `True` if you also want + to get translated text alongside the audio. Note that if `generate_speech=True`, this parameter will be + ignored. + tgt_lang (`str`, *optional*): + The language to use as target language for translation. + spkr_id (`int`, *optional*, defaults to 0): + The id of the speaker used for speech synthesis. Must be lower than `config.vocoder_num_spkrs`. + generate_speech (`bool`, *optional*, defaults to `True`): + If `False`, will only returns the text tokens and won't generate speech. + + kwargs (*optional*): + Remaining dictionary of keyword arguments that will be passed to [`GenerationMixin.generate`]. Keyword + arguments are of two types: + + - Without a prefix, they will be entered as `**kwargs` for the `generate` method of each sub-model, + except for `decoder_input_ids` which will only be passed through the text components. + - With a *text_* or *speech_* prefix, they will be input for the `generate` method of the + text model and speech model respectively. It has the priority over the keywords without a prefix. + + This means you can, for example, specify a generation strategy for one generation but not for the + other. + + Returns: + `Union[SeamlessM4TGenerationOutput, Tuple[Tensor], ModelOutput]`: + - If `generate_speech` and `return_intermediate_token_ids`, returns [`SeamlessM4TGenerationOutput`]. + - If `generate_speech` and not `return_intermediate_token_ids`, returns a tuple composed of waveforms of + shape `(batch_size, sequence_length)`and and `waveform_lengths` which gives the length of each sample. + - If `generate_speech=False`, it will returns `ModelOutput`. + """ + if input_ids is None and input_features is None and kwargs.get("inputs_embeds", None) is None: + raise ValueError( + "`input_ids`,`input_features` and `inputs_embeds` are all empty. Make sure at least one of them is not." + ) + + if generate_speech and tgt_lang is None: + raise ValueError("You must specify a `tgt_lang` to generate translated speech.") + + if tgt_lang is not None: + # also accept __xxx__ + tgt_lang = tgt_lang.replace("__", "") + for key in ["text_decoder_lang_to_code_id", "t2u_lang_code_to_id", "vocoder_lang_code_to_id"]: + lang_code_to_id = getattr(self.generation_config, key, None) + if lang_code_to_id is None: + raise ValueError( + f"""This model generation config doesn't have a `{key}` key which maps the target language + to the right token id. Make sure to load the right generation config.""" + ) + elif tgt_lang not in lang_code_to_id: + raise ValueError( + f"""`tgt_lang={tgt_lang}` is not supported by this model. + Please specify a `tgt_lang` in {",".join(lang_code_to_id.keys())}. Note that SeamlessM4T supports + more languages for text translation than for speech synthesis.""" + ) + + batch_size = ( + len(input_features) + if input_features is not None + else (len(input_ids) if input_ids is not None else len(kwargs.get("inputs_embeds"))) + ) + + kwargs_text, kwargs_speech = format_speech_generation_kwargs(kwargs) + kwargs_text["output_hidden_states"] = True + kwargs_text["return_dict_in_generate"] = True + kwargs_text["output_scores"] = True + + text_decoder_input_ids = kwargs_text.get("decoder_input_ids") + # overwrite text_decoder_input_ids if tgt_lang is passed. The latter gets priority over decoder_input_ids. + if tgt_lang is not None: + # tgt_lang gets priority over decoder input ids + text_tgt_lang_id = self.generation_config.text_decoder_lang_to_code_id.get(tgt_lang) + text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size, device=self.device) + + kwargs_text["decoder_input_ids"] = text_decoder_input_ids + + # first generation + if input_features is not None: + self.set_modality("speech") + if input_ids is not None: + logger.warning( + "`input_features` and `input_ids` are both non empty. `input_features` will be used in priority " + "through the speech encoder. Make sure `input_features=None` if you want to use the text encoder." + ) + text_generation_output = super().generate(input_features=input_features, **kwargs_text) + else: + self.set_modality("text") + text_generation_output = super().generate(input_ids=input_ids, input_features=None, **kwargs_text) + sequences = text_generation_output.sequences + + if not generate_speech: + return text_generation_output + + # prepare second generation + num_return_sequences = len(sequences) // batch_size + attention_mask = kwargs_speech.get("attention_mask", kwargs_text.get("attention_mask", None)) + + # get encoder last hidden states + if self.current_modality == "speech": + # get last_hidden_state from encoder - must do a pass through the speech encoder + encoder_hidden_states = self.speech_encoder( + input_features=input_features, attention_mask=attention_mask + ).last_hidden_state + + # input modality = speech so new attention mask for the decoder + if attention_mask is not None: + sub_sampled_lengths = self._compute_sub_sample_lengths_from_attention_mask(attention_mask).to( + encoder_hidden_states.device + ) + attention_mask = _compute_new_attention_mask( + hidden_states=encoder_hidden_states, seq_lens=sub_sampled_lengths + ) + else: + encoder_hidden_states = text_generation_output.encoder_hidden_states[-1] + + # take care of num_return_sequences + # take most probable hidden states per batch of return_sequences + # (batch_size*num_return_sequences, ...) -> (batch_size,...) + if num_return_sequences > 1: + idx_most_probable_sequences_per_batch = text_generation_output.sequences_scores.view(batch_size, -1) + idx_most_probable_sequences_per_batch = idx_most_probable_sequences_per_batch.argmax(-1) + idx_most_probable_sequences_per_batch = ( + idx_most_probable_sequences_per_batch + + torch.arange(batch_size, device=self.device) * num_return_sequences + ) + sequences = sequences[idx_most_probable_sequences_per_batch] + + # get decoder last hidden state - must do a pass through the text decoder + t2u_input_embeds = self.text_decoder( + input_ids=sequences, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=attention_mask, + ).last_hidden_state + + pad_token_id = self.generation_config.pad_token_id + + # Compute new attention mask + seq_lens = (sequences != pad_token_id).int().sum(1) + t2u_model_attention_mask = _compute_new_attention_mask(t2u_input_embeds, seq_lens) + kwargs_speech["attention_mask"] = t2u_model_attention_mask + + # Compute t2u decoder_input_ids + t2u_decoder_input_ids = kwargs_speech.get("decoder_input_ids") + t2u_tgt_lang_id = self.generation_config.t2u_lang_code_to_id.get(tgt_lang) + t2u_decoder_input_ids = torch.tensor( + [[self.config.t2u_eos_token_id, t2u_tgt_lang_id]] * batch_size, device=self.device + ) + kwargs_speech["decoder_input_ids"] = t2u_decoder_input_ids + + # second generation + unit_ids = self.t2u_model.generate(inputs_embeds=t2u_input_embeds, **kwargs_speech) + output_unit_ids = unit_ids.detach().clone() + + # get rid of t2u_decoder_input_ids + unit_ids = unit_ids[:, kwargs_speech["decoder_input_ids"].shape[1] :] + # replace eos per pad + unit_ids[unit_ids == self.config.t2u_eos_token_id] = self.config.t2u_pad_token_id + # offset of control symbols + unit_ids = torch.where( + unit_ids == self.config.t2u_pad_token_id, unit_ids, unit_ids - self.config.vocoder_offset + ) + + vocoder_tgt_lang_id = self.generation_config.vocoder_lang_code_to_id.get(tgt_lang) + vocoder_tgt_lang_id = torch.tensor([[vocoder_tgt_lang_id]] * len(unit_ids), device=self.device) + + spkr_id = torch.tensor([[spkr_id]] * len(unit_ids), device=self.device) + + waveform, waveform_lengths = self.vocoder(input_ids=unit_ids, spkr_id=spkr_id, lang_id=vocoder_tgt_lang_id) + + if return_intermediate_token_ids: + return SeamlessM4TGenerationOutput( + waveform=waveform, + waveform_lengths=waveform_lengths, + sequences=sequences, + unit_sequences=output_unit_ids, + ) + + return waveform, waveform_lengths + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + # cached cross_attention states don't have to be reordered -> they are always the same + reordered_past += ( + tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], + ) + return reordered_past + + +__all__ = [ + "SeamlessM4TForTextToSpeech", + "SeamlessM4TForSpeechToSpeech", + "SeamlessM4TForTextToText", + "SeamlessM4TForSpeechToText", + "SeamlessM4TModel", + "SeamlessM4TPreTrainedModel", + "SeamlessM4TCodeHifiGan", + "SeamlessM4THifiGan", + "SeamlessM4TTextToUnitForConditionalGeneration", + "SeamlessM4TTextToUnitModel", +] diff --git a/docs/transformers/src/transformers/models/seamless_m4t/processing_seamless_m4t.py b/docs/transformers/src/transformers/models/seamless_m4t/processing_seamless_m4t.py new file mode 100644 index 0000000000000000000000000000000000000000..90d96b61c99c1f27f4d9d81189bee43c8be5974a --- /dev/null +++ b/docs/transformers/src/transformers/models/seamless_m4t/processing_seamless_m4t.py @@ -0,0 +1,120 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Audio/Text processor class for SeamlessM4T +""" + +from ...processing_utils import ProcessorMixin + + +class SeamlessM4TProcessor(ProcessorMixin): + r""" + Constructs a SeamlessM4T processor which wraps a SeamlessM4T feature extractor and a SeamlessM4T tokenizer into a + single processor. + + [`SeamlessM4TProcessor`] offers all the functionalities of [`SeamlessM4TFeatureExtractor`] and + [`SeamlessM4TTokenizerFast`]. See the [`~SeamlessM4TProcessor.__call__`] and [`~SeamlessM4TProcessor.decode`] for + more information. + + Args: + feature_extractor ([`SeamlessM4TFeatureExtractor`]): + The audio processor is a required input. + tokenizer ([`SeamlessM4TTokenizerFast`]): + The tokenizer is a required input. + """ + + feature_extractor_class = "SeamlessM4TFeatureExtractor" + tokenizer_class = ("SeamlessM4TTokenizer", "SeamlessM4TTokenizerFast") + + def __init__(self, feature_extractor, tokenizer): + super().__init__(feature_extractor, tokenizer) + + def __call__(self, text=None, audios=None, src_lang=None, tgt_lang=None, **kwargs): + """ + Main method to prepare for the model one or several sequences(s) and audio(s). This method forwards the `text` + and `kwargs` arguments to SeamlessM4TTokenizerFast's [`~SeamlessM4TTokenizerFast.__call__`] if `text` is not + `None` to encode the text. To prepare the audio(s), this method forwards the `audios` and `kwrags` arguments to + SeamlessM4TFeatureExtractor's [`~SeamlessM4TFeatureExtractor.__call__`] if `audios` is not `None`. Please refer + to the docstring of the above two methods for more information. + + Args: + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + audios (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`): + The audio or batch of audios to be prepared. Each audio can be NumPy array or PyTorch tensor. In case + of a NumPy array/PyTorch tensor, each audio should be of shape (C, T), where C is a number of channels, + and T the sample length of the audio. + src_lang (`str`, *optional*): + The language code of the input texts/audios. If not specified, the last `src_lang` specified will be + used. + tgt_lang (`str`, *optional*): + The code of the target language. If not specified, the last `tgt_lang` specified will be used. + kwargs (*optional*): + Remaining dictionary of keyword arguments that will be passed to the feature extractor and/or the + tokenizer. + Returns: + [`BatchEncoding`]: A [`BatchEncoding`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **input_features** -- Audio input features to be fed to a model. Returned when `audios` is not `None`. + """ + sampling_rate = kwargs.pop("sampling_rate", None) + + if text is None and audios is None: + raise ValueError("You have to specify either text or audios. Both cannot be none.") + elif text is not None and audios is not None: + raise ValueError( + "Text and audios are mututally exclusive when passed to `SeamlessM4T`. Specify one or another." + ) + elif text is not None: + if tgt_lang is not None: + self.tokenizer.tgt_lang = tgt_lang + if src_lang is not None: + self.tokenizer.src_lang = src_lang + encoding = self.tokenizer(text, **kwargs) + + return encoding + + else: + encoding = self.feature_extractor(audios, sampling_rate=sampling_rate, **kwargs) + return encoding + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to SeamlessM4TTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. + Please refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to SeamlessM4TTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + feature_extractor_input_names = self.feature_extractor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + feature_extractor_input_names)) + + +__all__ = ["SeamlessM4TProcessor"] diff --git a/docs/transformers/src/transformers/models/seamless_m4t/tokenization_seamless_m4t.py b/docs/transformers/src/transformers/models/seamless_m4t/tokenization_seamless_m4t.py new file mode 100644 index 0000000000000000000000000000000000000000..0d73e6a099a5c6ae5c8923715f124c0264fe2c11 --- /dev/null +++ b/docs/transformers/src/transformers/models/seamless_m4t/tokenization_seamless_m4t.py @@ -0,0 +1,567 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for SeamlessM4T.""" + +import os +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple, Union + +import sentencepiece as spm + +from ...convert_slow_tokenizer import import_protobuf +from ...tokenization_utils import ( + BatchEncoding, + PreTokenizedInput, + PreTrainedTokenizer, + TextInput, +) +from ...tokenization_utils_base import AddedToken +from ...utils import PaddingStrategy, logging +from ...utils.import_utils import requires + + +logger = logging.get_logger(__name__) + + +SPIECE_UNDERLINE = "▁" + + +VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model"} + + +@requires(backends=("sentencepiece",)) +class SeamlessM4TTokenizer(PreTrainedTokenizer): + """ + Construct a SeamlessM4T tokenizer. + + Adapted from [`RobertaTokenizer`] and [`XLNetTokenizer`]. Based on + [SentencePiece](https://github.com/google/sentencepiece). + + The tokenization method is ` ` for source language documents, and ` ` for target language documents. + + Examples: + + ```python + >>> from transformers import SeamlessM4TTokenizer + + >>> tokenizer = SeamlessM4TTokenizer.from_pretrained( + ... "facebook/hf-seamless-m4t-medium", src_lang="eng", tgt_lang="fra" + ... ) + >>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria" + >>> expected_translation_french = "Le chef de l'ONU affirme qu'il n'y a pas de solution militaire en Syrie." + >>> inputs = tokenizer(example_english_phrase, text_target=expected_translation_french, return_tensors="pt") + ``` + + Args: + vocab_file (`str`): + Path to the vocabulary file. + bos_token (`str`, *optional*, defaults to `""`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + + + + When building a sequence using special tokens, this is not the token that is used for the beginning of + sequence. The token used is the `cls_token`. + + + + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + sep_token (`str`, *optional*, defaults to `""`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + cls_token (`str`, *optional*, defaults to `""`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + tokenizer_file (`str`, *optional*): + The path to a tokenizer file to use instead of the vocab file. + src_lang (`str`, *optional*, defaults to `"eng"`): + The language to use as source language for translation. + tgt_lang (`str`, *optional*, defaults to `"fra"`): + The language to use as target language for translation. + sp_model_kwargs (`Dict[str, Any]`, *optional*): + Additional keyword arguments to pass to the model initialization. + additional_special_tokens (tuple or list of `str` or `tokenizers.AddedToken`, *optional*): + A tuple or a list of additional special tokens. Can be used to specify the list of languages that will be + supported by the tokenizer. + add_prefix_space (`bool`, *optional*, defaults to `True`): + Whether or not to add an initial space to the input. This allows to treat the leading word just as any + other word. + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + + prefix_tokens: List[int] = [] + suffix_tokens: List[int] = [] + + def __init__( + self, + vocab_file, + bos_token="", + eos_token="", + sep_token="", + cls_token="", + unk_token="", + pad_token="", + tokenizer_file=None, + src_lang="eng", + tgt_lang="fra", + sp_model_kwargs: Optional[Dict[str, Any]] = None, + additional_special_tokens=None, + add_prefix_space=True, + **kwargs, + ): + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + # Add this unused argument to keep some important Copied from statements + self.legacy = False + self.vocab_file = vocab_file + + self.sp_model = self.get_spm_processor(kwargs.pop("from_slow", False)) + + # Vocab | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 + # -------- | ------- | ------- | ------ | ------- | ---- | ---- | ---- | ---- | ---- | ---- + # spm | '' | '' | '' | 'an' | 'en' | '_d' | 'er' | 'in' | '_s' | '_a' + # fairseq | '' | '' | '' | '' | 'an' | 'en' | '▁d' | 'er' | 'in' | '▁s' + + # Mimic fairseq token-to-id alignment for the first 4 token + self._added_tokens_decoder = { + 0: AddedToken(pad_token, special=True) if isinstance(pad_token, str) else pad_token, + 1: AddedToken(unk_token, special=True) if isinstance(unk_token, str) else unk_token, + 2: AddedToken(bos_token, special=True) if isinstance(bos_token, str) else bos_token, + 3: AddedToken(eos_token, special=True) if isinstance(eos_token, str) else eos_token, + } + + # The first "real" token "an" has position 4 in the original fairseq vocab and position 3 in the spm vocab + self.fairseq_offset = 1 + + self.sp_model_size = len(self.sp_model) + + self._src_lang = f"__{src_lang}__" if "__" not in src_lang else src_lang + self._tgt_lang = f"__{tgt_lang}__" if "__" not in tgt_lang else tgt_lang + self.add_prefix_space = add_prefix_space + + super().__init__( + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + sep_token=sep_token, + cls_token=cls_token, + pad_token=pad_token, + tokenizer_file=tokenizer_file, + src_lang=src_lang, + tgt_lang=tgt_lang, + additional_special_tokens=additional_special_tokens, + sp_model_kwargs=self.sp_model_kwargs, + add_prefix_space=add_prefix_space, + **kwargs, + ) + + self.set_src_lang_special_tokens(self._src_lang) + self.set_tgt_lang_special_tokens(self._tgt_lang) + + # Copied from transformers.models.nllb.tokenization_nllb.NllbTokenizer.__getstate__ + def __getstate__(self): + state = self.__dict__.copy() + state["sp_model"] = None + state["sp_model_proto"] = self.sp_model.serialized_model_proto() + return state + + # Copied from transformers.models.nllb.tokenization_nllb.NllbTokenizer.__setstate__ + def __setstate__(self, d): + self.__dict__ = d + + # for backward compatibility + if not hasattr(self, "sp_model_kwargs"): + self.sp_model_kwargs = {} + + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.LoadFromSerializedProto(self.sp_model_proto) + + @property + def vocab_size(self): + return len(self.sp_model) + + def __call__( + self, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + text_pair: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, + text_target: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + text_pair_target: Optional[ + Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] + ] = None, + padding: Union[bool, str, PaddingStrategy] = True, + pad_to_multiple_of: Optional[int] = 2, + src_lang: Optional[str] = None, + tgt_lang: Optional[str] = None, + **kwargs, + ): + """ + Args: + text (`str`, `List[str]`, `List[List[str]]`, *optional*): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + text_pair (`str`, `List[str]`, `List[List[str]]`, *optional*): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + text_target (`str`, `List[str]`, `List[List[str]]`, *optional*): + The sequence or batch of sequences to be encoded as target texts. Each sequence can be a string or a + list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized), + you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + text_pair_target (`str`, `List[str]`, `List[List[str]]`, *optional*): + The sequence or batch of sequences to be encoded as target texts. Each sequence can be a string or a + list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized), + you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding + index) among: + + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + pad_to_multiple_of (`int`, *optional*): + If set will pad the sequence to a multiple of the provided value. + + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability + `>= 7.5` (Volta). + src_lang (`str`, *optional*): + A string representing the source language. If not specified, the last `src_lang` specified (either + during initialization or when calling this tokenizer) will be used. + tgt_lang (`str`, *optional*): + A string representing the target language. If not specified, the last `tgt_lang` specified (either + during initialization or when calling this tokenizer) will be used. + kwargs (*optional*): + Remaining dictionary of keyword arguments that will be passed to [`PreTrainedTokenizer.__call__`]. + """ + if src_lang is not None: + self.src_lang = src_lang + if tgt_lang is not None: + self.tgt_lang = tgt_lang + + output = super().__call__( + text=text, + text_pair=text_pair, + text_target=text_target, + text_pair_target=text_pair_target, + padding=padding, + pad_to_multiple_of=pad_to_multiple_of, + **kwargs, + ) + + return BatchEncoding(output, tensor_type=kwargs.get("return_tensors")) + + @property + # Copied from transformers.models.nllb.tokenization_nllb.NllbTokenizer.src_lang + def src_lang(self) -> str: + return self._src_lang + + @src_lang.setter + def src_lang(self, new_src_lang: str) -> None: + if "__" not in new_src_lang: + self._src_lang = f"__{new_src_lang}__" + else: + self._src_lang = new_src_lang + self.set_src_lang_special_tokens(self._src_lang) + + @property + def tgt_lang(self) -> str: + return self._tgt_lang + + @tgt_lang.setter + def tgt_lang(self, new_tgt_lang: str) -> None: + if "__" not in new_tgt_lang: + self._tgt_lang = f"__{new_tgt_lang}__" + else: + self._tgt_lang = new_tgt_lang + self.set_tgt_lang_special_tokens(self._tgt_lang) + + # Copied from transformers.models.nllb.tokenization_nllb.NllbTokenizer.get_special_tokens_mask + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + prefix_ones = [1] * len(self.prefix_tokens) + suffix_ones = [1] * len(self.suffix_tokens) + if token_ids_1 is None: + return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones + return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones + + # Copied from transformers.models.nllb.tokenization_nllb.NllbTokenizer.build_inputs_with_special_tokens + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. An NLLB sequence has the following format, where `X` represents the sequence: + + - `input_ids` (for encoder) `X [eos, src_lang_code]` + - `decoder_input_ids`: (for decoder) `X [eos, tgt_lang_code]` + + BOS is never used. Pairs of sequences are not the expected use case, but they will be handled without a + separator. + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + if token_ids_1 is None: + return self.prefix_tokens + token_ids_0 + self.suffix_tokens + # We don't expect to process pairs, but leave the pair logic for API consistency + return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens + + # Copied from transformers.models.nllb.tokenization_nllb.NllbTokenizer.create_token_type_ids_from_sequences + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. nllb does not + make use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of zeros. + + """ + + sep = [self.sep_token_id] + cls = [self.cls_token_id] + + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0] + + def _build_translation_inputs( + self, raw_inputs, return_tensors: str, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs + ): + """Used by translation pipeline, to prepare inputs for the generate function""" + if src_lang is None or tgt_lang is None: + raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model.") + self.src_lang = src_lang + inputs = self(raw_inputs, add_special_tokens=True, return_tensors=return_tensors, **extra_kwargs) + if "__" not in tgt_lang: + tgt_lang = f"__{tgt_lang}__" + tgt_lang_id = self.convert_tokens_to_ids(tgt_lang) + inputs["forced_bos_token_id"] = tgt_lang_id + return inputs + + def get_vocab(self): + vocab = { + self.convert_ids_to_tokens(i): i for i in range(self.fairseq_offset, self.vocab_size + self.fairseq_offset) + } + vocab.update(self.added_tokens_encoder) + return vocab + + @property + def unk_token_length(self): + return len(self.sp_model.encode(str(self.unk_token))) + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_spm_processor + def get_spm_processor(self, from_slow=False): + tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs) + if self.legacy or from_slow: # no dependency on protobuf + tokenizer.Load(self.vocab_file) + return tokenizer + + with open(self.vocab_file, "rb") as f: + sp_model = f.read() + model_pb2 = import_protobuf(f"The new behaviour of {self.__class__.__name__} (with `self.legacy = False`)") + model = model_pb2.ModelProto.FromString(sp_model) + normalizer_spec = model_pb2.NormalizerSpec() + normalizer_spec.add_dummy_prefix = False + model.normalizer_spec.MergeFrom(normalizer_spec) + sp_model = model.SerializeToString() + tokenizer.LoadFromSerializedProto(sp_model) + return tokenizer + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.tokenize + def tokenize(self, text: "TextInput", **kwargs) -> List[str]: + """ + Converts a string to a list of tokens. If `self.legacy` is set to `False`, a prefix token is added unless the + first token is special. + """ + if self.legacy or len(text) == 0: + return super().tokenize(text, **kwargs) + + text = text.replace(SPIECE_UNDERLINE, " ") + if self.add_prefix_space: + text = SPIECE_UNDERLINE + text + + tokens = super().tokenize(text, **kwargs) + + if len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens: + tokens = tokens[1:] + return tokens + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._tokenize + def _tokenize(self, text, **kwargs): + """ + Returns a tokenized string. + + We de-activated the `add_dummy_prefix` option, thus the sentencepiece internals will always strip any + SPIECE_UNDERLINE. For example: `self.sp_model.encode(f"{SPIECE_UNDERLINE}Hey", out_type = str)` will give + `['H', 'e', 'y']` instead of `['▁He', 'y']`. Thus we always encode `f"{unk_token}text"` and strip the + `unk_token`. Here is an example with `unk_token = ""` and `unk_token_length = 4`. + `self.tokenizer.sp_model.encode(" Hey", out_type = str)[4:]`. + """ + if self.legacy or not text.startswith((SPIECE_UNDERLINE, " ")): + return self.sp_model.encode(text, out_type=str) + + # 1. Encode string + prefix ex: " Hey" + tokens = self.sp_model.encode(self.unk_token + text, out_type=str) + # 2. Remove self.unk_token from ['<','unk','>', '▁Hey'] + return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + spm_id = self.sp_model.PieceToId(token) + + # Need to return unknown token if the SP model returned 0 + return spm_id + self.fairseq_offset if spm_id else self.unk_token_id + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.sp_model.IdToPiece(index - self.fairseq_offset) + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (strings for sub-words) in a single string.""" + # since we manually add the prefix space, we have to remove it when decoding + if tokens[0].startswith(SPIECE_UNDERLINE) and self.add_prefix_space: + tokens[0] = tokens[0][1:] + + out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip() + return out_string + + # Copied from transformers.models.nllb.tokenization_nllb.NllbTokenizer.save_vocabulary + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): + copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + + return (out_vocab_file,) + + # Copied from transformers.models.nllb.tokenization_nllb.NllbTokenizer.prepare_seq2seq_batch with eng_Latn->eng, fra_Latn->fra + def prepare_seq2seq_batch( + self, + src_texts: List[str], + src_lang: str = "eng", + tgt_texts: Optional[List[str]] = None, + tgt_lang: str = "fra", + **kwargs, + ) -> BatchEncoding: + self.src_lang = src_lang + self.tgt_lang = tgt_lang + return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs) + + # Copied from transformers.models.nllb.tokenization_nllb.NllbTokenizer._switch_to_input_mode + def _switch_to_input_mode(self): + return self.set_src_lang_special_tokens(self.src_lang) + + # Copied from transformers.models.nllb.tokenization_nllb.NllbTokenizer._switch_to_target_mode + def _switch_to_target_mode(self): + return self.set_tgt_lang_special_tokens(self.tgt_lang) + + def set_src_lang_special_tokens(self, src_lang) -> None: + """Reset the special tokens to the source lang setting. + Prefix=[src_lang_code], suffix = [eos] + """ + self.cur_lang_code = self.convert_tokens_to_ids(src_lang) + self.init_kwargs["src_lang"] = src_lang + + if self.cur_lang_code == self.unk_token_id: + logger.warning_once( + f"`src_lang={src_lang}` has not be found in the vocabulary. Behaviour will probably be unexpected because the language token id will be replaced by the unknown token id." + ) + + self.prefix_tokens = [self.cur_lang_code] + self.suffix_tokens = [self.eos_token_id] + + # https://github.com/facebookresearch/fairseq2/blob/c53f18e6be6b8b46b722f2249b8397b7eccd7ad3/src/fairseq2/models/nllb/tokenizer.py#L112-L116 + def set_tgt_lang_special_tokens(self, lang: str) -> None: + """Reset the special tokens to the target lang setting. + Prefix=[eos, tgt_lang_code] and suffix=[eos]. + """ + self.cur_lang_code = self.convert_tokens_to_ids(lang) + self.init_kwargs["tgt_lang"] = lang + + if self.cur_lang_code == self.unk_token_id: + logger.warning_once( + f"`tgt_lang={lang}` has not be found in the vocabulary. Behaviour will probably be unexpected because the language token id will be replaced by the unknown token id." + ) + + self.prefix_tokens = [self.eos_token_id, self.cur_lang_code] + self.suffix_tokens = [self.eos_token_id] + + +__all__ = ["SeamlessM4TTokenizer"] diff --git a/docs/transformers/src/transformers/models/seamless_m4t/tokenization_seamless_m4t_fast.py b/docs/transformers/src/transformers/models/seamless_m4t/tokenization_seamless_m4t_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..c1142c10719fb1f6b934eb70b2b641da3140ccef --- /dev/null +++ b/docs/transformers/src/transformers/models/seamless_m4t/tokenization_seamless_m4t_fast.py @@ -0,0 +1,450 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Tokenization class for SeamlessM4T.""" + +import os +from shutil import copyfile +from typing import List, Optional, Tuple, Union + +from tokenizers import processors + +from ...tokenization_utils import ( + BatchEncoding, + PreTokenizedInput, + TextInput, +) +from ...tokenization_utils_fast import PreTrainedTokenizerFast +from ...utils import PaddingStrategy, is_sentencepiece_available, logging + + +if is_sentencepiece_available(): + from .tokenization_seamless_m4t import SeamlessM4TTokenizer +else: + SeamlessM4TTokenizer = None + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model", "tokenizer_file": "tokenizer.json"} + + +class SeamlessM4TTokenizerFast(PreTrainedTokenizerFast): + """ + Construct a "fast" SeamlessM4T tokenizer (backed by HuggingFace's *tokenizers* library). Based on + [BPE](https://huggingface.co/docs/tokenizers/python/latest/components.html?highlight=BPE#models). + + This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should + refer to this superclass for more information regarding those methods. + + The tokenization method is ` ` for source language documents, and ` ` for target language documents. + + Examples: + + ```python + >>> from transformers import SeamlessM4TTokenizerFast + + >>> tokenizer = SeamlessM4TTokenizerFast.from_pretrained( + ... "facebook/hf-seamless-m4t-medium", src_lang="eng", tgt_lang="fra" + ... ) + >>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria" + >>> expected_translation_french = "Le chef de l'ONU affirme qu'il n'y a pas de solution militaire en Syrie." + >>> inputs = tokenizer(example_english_phrase, text_target=expected_translation_french, return_tensors="pt") + ``` + + Args: + vocab_file (`str`, *optional*): + Path to the vocabulary file. + tokenizer_file (`str`, *optional*): + The path to a tokenizer file to use instead of the vocab file. + bos_token (`str`, *optional*, defaults to `""`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + + + + When building a sequence using special tokens, this is not the token that is used for the beginning of + sequence. The token used is the `cls_token`. + + + + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + sep_token (`str`, *optional*, defaults to `""`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + cls_token (`str`, *optional*, defaults to `""`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + src_lang (`str`, *optional*, defaults to `"eng"`): + The language to use as source language for translation. + tgt_lang (`str`, *optional*, defaults to `"fra"`): + The language to use as target language for translation. + additional_special_tokens (tuple or list of `str` or `tokenizers.AddedToken`, *optional*): + A tuple or a list of additional special tokens. + """ + + vocab_files_names = VOCAB_FILES_NAMES + slow_tokenizer_class = SeamlessM4TTokenizer + model_input_names = ["input_ids", "attention_mask"] + + prefix_tokens: List[int] = [] + suffix_tokens: List[int] = [] + + def __init__( + self, + vocab_file=None, + tokenizer_file=None, + bos_token="", + eos_token="", + sep_token="", + cls_token="", + unk_token="", + pad_token="", + src_lang="eng", + tgt_lang="fra", + additional_special_tokens=None, + **kwargs, + ): + super().__init__( + vocab_file=vocab_file, + tokenizer_file=tokenizer_file, + bos_token=bos_token, + eos_token=eos_token, + sep_token=sep_token, + cls_token=cls_token, + unk_token=unk_token, + pad_token=pad_token, + src_lang=src_lang, + tgt_lang=tgt_lang, + additional_special_tokens=additional_special_tokens, + **kwargs, + ) + + self.vocab_file = vocab_file + self._src_lang = f"__{src_lang}__" if "__" not in src_lang else src_lang + self._tgt_lang = f"__{tgt_lang}__" if "__" not in tgt_lang else tgt_lang + self.set_src_lang_special_tokens(self._src_lang) + self.set_tgt_lang_special_tokens(self._tgt_lang) + + @property + def can_save_slow_tokenizer(self) -> bool: + return os.path.isfile(self.vocab_file) if self.vocab_file else False + + @property + # Copied from transformers.models.nllb.tokenization_nllb.NllbTokenizer.src_lang + def src_lang(self) -> str: + return self._src_lang + + @src_lang.setter + def src_lang(self, new_src_lang: str) -> None: + if "__" not in new_src_lang: + self._src_lang = f"__{new_src_lang}__" + else: + self._src_lang = new_src_lang + self.set_src_lang_special_tokens(self._src_lang) + + @property + def tgt_lang(self) -> str: + return self._tgt_lang + + @tgt_lang.setter + def tgt_lang(self, new_tgt_lang: str) -> None: + if "__" not in new_tgt_lang: + self._tgt_lang = f"__{new_tgt_lang}__" + else: + self._tgt_lang = new_tgt_lang + self.set_tgt_lang_special_tokens(self._tgt_lang) + + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. The special tokens depend on calling set_lang. + + An SeamlessM4T sequence has the following format, where `X` represents the sequence: + + - `input_ids` (for encoder) `[src_lang_code] X [eos]` + - `decoder_input_ids`: (for decoder) `[eos, tgt_lang_code] X [eos]` + + BOS is never used. Pairs of sequences are not the expected use case, but they will be handled without a + separator. + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: list of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + if token_ids_1 is None: + return self.prefix_tokens + token_ids_0 + self.suffix_tokens + # We don't expect to process pairs, but leave the pair logic for API consistency + return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens + + # Copied from transformers.models.nllb.tokenization_nllb_fast.NllbTokenizerFast.create_token_type_ids_from_sequences + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. nllb does not + make use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of zeros. + + """ + + sep = [self.sep_token_id] + cls = [self.cls_token_id] + + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0] + + def _build_translation_inputs( + self, raw_inputs, return_tensors: str, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs + ): + """Used by translation pipeline, to prepare inputs for the generate function""" + if src_lang is None or tgt_lang is None: + raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model") + self.src_lang = src_lang + inputs = self(raw_inputs, add_special_tokens=True, return_tensors=return_tensors, **extra_kwargs) + if "__" not in tgt_lang: + tgt_lang = f"__{tgt_lang}__" + tgt_lang_id = self.convert_tokens_to_ids(tgt_lang) + inputs["forced_bos_token_id"] = tgt_lang_id + return inputs + + # Copied from transformers.models.nllb.tokenization_nllb_fast.NllbTokenizerFast.prepare_seq2seq_batch with "fra_Latn"->"fra", "eng_Latn"->"eng" + def prepare_seq2seq_batch( + self, + src_texts: List[str], + src_lang: str = "eng", + tgt_texts: Optional[List[str]] = None, + tgt_lang: str = "fra", + **kwargs, + ) -> BatchEncoding: + self.src_lang = src_lang + self.tgt_lang = tgt_lang + return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs) + + # Copied from transformers.models.nllb.tokenization_nllb_fast.NllbTokenizerFast._switch_to_input_mode + def _switch_to_input_mode(self): + return self.set_src_lang_special_tokens(self.src_lang) + + # Copied from transformers.models.nllb.tokenization_nllb_fast.NllbTokenizerFast._switch_to_target_mode + def _switch_to_target_mode(self): + return self.set_tgt_lang_special_tokens(self.tgt_lang) + + def set_src_lang_special_tokens(self, src_lang) -> None: + """Reset the special tokens to the source lang setting. + Prefix=[src_lang_code], suffix = [eos] + """ + self.cur_lang_code = self.convert_tokens_to_ids(src_lang) + + if self.cur_lang_code == self.unk_token_id: + logger.warning_once( + f"`tgt_lang={src_lang}` has not be found in the `vocabulary`. Behaviour will probably be unexpected because the language token id will be replaced by the unknown token id." + ) + + self.init_kwargs["src_lang"] = src_lang + + self.prefix_tokens = [self.cur_lang_code] + self.suffix_tokens = [self.eos_token_id] + + prefix_tokens_str = self.convert_ids_to_tokens(self.prefix_tokens) + suffix_tokens_str = self.convert_ids_to_tokens(self.suffix_tokens) + + self._tokenizer.post_processor = processors.TemplateProcessing( + single=prefix_tokens_str + ["$A"] + suffix_tokens_str, + pair=prefix_tokens_str + ["$A", "$B"] + suffix_tokens_str, + special_tokens=list(zip(prefix_tokens_str + suffix_tokens_str, self.prefix_tokens + self.suffix_tokens)), + ) + + def set_tgt_lang_special_tokens(self, lang: str) -> None: + """Reset the special tokens to the target lang setting. + Prefix=[eos, tgt_lang_code] and suffix=[eos]. + """ + self.cur_lang_code = self.convert_tokens_to_ids(lang) + + if self.cur_lang_code == self.unk_token_id: + logger.warning_once( + f"`tgt_lang={lang}` has not be found in the `vocabulary`. Behaviour will probably be unexpected because the language token id will be replaced by the unknown token id." + ) + + self.init_kwargs["tgt_lang"] = lang + + self.prefix_tokens = [self.eos_token_id, self.cur_lang_code] + self.suffix_tokens = [self.eos_token_id] + + prefix_tokens_str = self.convert_ids_to_tokens(self.prefix_tokens) + suffix_tokens_str = self.convert_ids_to_tokens(self.suffix_tokens) + + self._tokenizer.post_processor = processors.TemplateProcessing( + single=prefix_tokens_str + ["$A"] + suffix_tokens_str, + pair=prefix_tokens_str + ["$A", "$B"] + suffix_tokens_str, + special_tokens=list(zip(prefix_tokens_str + suffix_tokens_str, self.prefix_tokens + self.suffix_tokens)), + ) + + # Copied from transformers.models.nllb.tokenization_nllb_fast.NllbTokenizerFast.save_vocabulary + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + if not self.can_save_slow_tokenizer: + raise ValueError( + "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow " + "tokenizer." + ) + + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory.") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): + copyfile(self.vocab_file, out_vocab_file) + + return (out_vocab_file,) + + @classmethod + def _from_pretrained( + cls, + resolved_vocab_files, + pretrained_model_name_or_path, + init_configuration, + *init_inputs, + token=None, + cache_dir=None, + local_files_only=False, + _commit_hash=None, + _is_local=False, + **kwargs, + ): + tokenizer = super()._from_pretrained( + resolved_vocab_files, + pretrained_model_name_or_path, + init_configuration, + *init_inputs, + token=token, + cache_dir=cache_dir, + local_files_only=local_files_only, + _commit_hash=_commit_hash, + _is_local=_is_local, + **kwargs, + ) + + # ensure also set after from pretrained + tokenizer.set_src_lang_special_tokens(tokenizer._src_lang) + tokenizer.set_tgt_lang_special_tokens(tokenizer._tgt_lang) + + return tokenizer + + def __call__( + self, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + text_pair: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, + text_target: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + text_pair_target: Optional[ + Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] + ] = None, + padding: Union[bool, str, PaddingStrategy] = True, + pad_to_multiple_of: Optional[int] = 2, + src_lang: Optional[str] = None, + tgt_lang: Optional[str] = None, + **kwargs, + ): + """ + Args: + text (`str`, `List[str]`, `List[List[str]]`, *optional*): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + text_pair (`str`, `List[str]`, `List[List[str]]`, *optional*): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + text_target (`str`, `List[str]`, `List[List[str]]`, *optional*): + The sequence or batch of sequences to be encoded as target texts. Each sequence can be a string or a + list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized), + you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + text_pair_target (`str`, `List[str]`, `List[List[str]]`, *optional*): + The sequence or batch of sequences to be encoded as target texts. Each sequence can be a string or a + list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized), + you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding + index) among: + + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + pad_to_multiple_of (`int`, *optional*): + If set will pad the sequence to a multiple of the provided value. + + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability + `>= 7.5` (Volta). + src_lang (`str`, *optional*): + A string representing the source language. If not specified, the last `src_lang` specified (either + during initialization or when calling this tokenizer) will be used. + tgt_lang (`str`, *optional*): + A string representing the target language. If not specified, the last `tgt_lang` specified (either + during initialization or when calling this tokenizer) will be used. + kwargs (*optional*): + Remaining dictionary of keyword arguments that will be passed to [`PreTrainedTokenizerFast.__call__`]. + """ + if src_lang is not None: + self.src_lang = src_lang + if tgt_lang is not None: + self.tgt_lang = tgt_lang + + output = super().__call__( + text=text, + text_pair=text_pair, + text_target=text_target, + text_pair_target=text_pair_target, + padding=padding, + pad_to_multiple_of=pad_to_multiple_of, + **kwargs, + ) + + return output + + +__all__ = ["SeamlessM4TTokenizerFast"] diff --git a/docs/transformers/src/transformers/models/seamless_m4t_v2/__init__.py b/docs/transformers/src/transformers/models/seamless_m4t_v2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..10f256f8575595d6b3409c030e1d6c3b33c75aa9 --- /dev/null +++ b/docs/transformers/src/transformers/models/seamless_m4t_v2/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_seamless_m4t_v2 import * + from .modeling_seamless_m4t_v2 import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/seamless_m4t_v2/configuration_seamless_m4t_v2.py b/docs/transformers/src/transformers/models/seamless_m4t_v2/configuration_seamless_m4t_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..d29eaf4540100365a82b1a23290932ee9f428ffc --- /dev/null +++ b/docs/transformers/src/transformers/models/seamless_m4t_v2/configuration_seamless_m4t_v2.py @@ -0,0 +1,425 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""SeamlessM4Tv2 model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class SeamlessM4Tv2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`~SeamlessM4Tv2Model`]. It is used to instantiate + an SeamlessM4Tv2 model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the SeamlessM4Tv2 + [""](https://huggingface.co/"") architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 256102): + Vocabulary size of the text modality of the SeamlessM4Tv2 model. Defines the number of different tokens + that can be represented by the `inputs_ids` passed when calling [`~SeamlessM4Tv2Model`], + [`~SeamlessM4Tv2ForTextToSpeech`] or [`~SeamlessM4Tv2ForTextToText`]. + t2u_vocab_size (`int`, *optional*, defaults to 10082): + Unit vocabulary size of the SeamlessM4Tv2 model. Defines the number of different "unit tokens" that can be + represented by the `inputs_ids` passed when calling the Text-To-Units sub-model of [`~SeamlessM4Tv2Model`], + [`~SeamlessM4Tv2ForSpeechToSpeech`] or [`~SeamlessM4Tv2ForTextToSpeech`]. + char_vocab_size (`int`, *optional*, defaults to 10943): + Character vocabulary size of the SeamlessM4Tv2 model. Defines the number of different character tokens that + can be represented by the `char_inputs_ids` passed when calling the Text-To-Units sub-model of + [`~SeamlessM4Tv2Model`], [`~SeamlessM4Tv2ForSpeechToSpeech`] or [`~SeamlessM4Tv2ForTextToSpeech`]. + + > Parameters shared across sub-models + + hidden_size (`int`, *optional*, defaults to 1024): + Dimensionality of the "intermediate" layers in the architecture. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the layer normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + max_position_embeddings (`int`, *optional*, defaults to 4096): + The maximum sequence length that this model text encoder and decoder might ever be used with. Typically set + this to something large just in case (e.g., 512 or 1024 or 2048). + is_encoder_decoder (`bool`, *optional*, defaults to `True`): + Whether the model is used as an encoder/decoder or not. + encoder_layerdrop (`float`, *optional*, defaults to 0.05): + The LayerDrop probability for the encoders. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) + for more details. + decoder_layerdrop (`float`, *optional*, defaults to 0.05): + The LayerDrop probability for the decoders. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) + for more details. + activation_function (`str` or `function`, *optional*, defaults to `"relu"`): + The non-linear activation function (function or string) in the decoder and feed-forward layers. If string, + `"gelu"`, `"relu"`, `"selu"`, `"swish"` and `"gelu_new"` are supported. + dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, decoder, and pooler. + attention_dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all attention layers. + activation_dropout (`float`, *optional*, defaults to 0.0): + The dropout probability for all activation layers in the model. + scale_embedding (`bool`, *optional*, defaults to `True`): + Scale embeddings by diving by sqrt(d_model). + + > Text encoder and text decoder specific parameters + + encoder_layers (`int`, *optional*, defaults to 24): + Number of hidden layers in the Transformer text encoder. + encoder_ffn_dim (`int`, *optional*, defaults to 8192): + Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer text encoder. + encoder_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer text encoder. + decoder_layers (`int`, *optional*, defaults to 24): + Number of hidden layers in the Transformer text decoder. + decoder_ffn_dim (`int`, *optional*, defaults to 8192): + Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer text decoder. + decoder_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer text decoder. + decoder_start_token_id (`int`, *optional*, defaults to 3): + If an encoder-decoder model starts decoding with a different token than _bos_, the id of that token. Only + applied in the text decoder. + max_new_tokens (`int`, *optional*, defaults to 256): + The maximum numbers of text tokens to generate, ignoring the number of tokens in the prompt. + pad_token_id (`int`, *optional*, defaults to 0): + The id of the _padding_ text token. Only applied to the text-decoder model. + bos_token_id (`int`, *optional*, defaults to 2): + The id of the _beginning-of-stream_ text token. Only applied to the text-decoder model. + eos_token_id (`int`, *optional*, defaults to 3): + The id of the _end-of-stream_ text token. Only applied to the text-decoder model. + + > Speech encoder specific parameters + + speech_encoder_layers (`int`, *optional*, defaults to 24): + Number of hidden layers in the Transformer speech encoder. + speech_encoder_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer speech encoder. + speech_encoder_intermediate_size (`int`, *optional*, defaults to 4096): + Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer speech encoder. + speech_encoder_hidden_act (`str` or `function`, *optional*, defaults to `"swish"`): + The non-linear activation function (function or string) in the speech encoder. If string, `"gelu"`, + `"relu"`, `"selu"`, `"swish"` and `"gelu_new"` are supported. + speech_encoder_dropout (`float`, *optional*, defaults to 0.0): + The dropout probability for all layers in the speech encoder. + add_adapter (`bool`, *optional*, defaults to `True`): + Add an adapter layer on top of the speech encoder. + speech_encoder_layerdrop (`float`, *optional*, defaults to 0.1): + The LayerDrop probability for the speech encoder. See the [LayerDrop paper](see + https://arxiv.org/abs/1909.11556) for more details. + feature_projection_input_dim (`int`, *optional*, defaults to 160): + Input dimension of the input feature projection of the speech encoder, i.e the dimension after processing + input audios with [`SeamlessM4TFeatureExtractor`]. + adaptor_kernel_size (`int`, *optional*, defaults to 8): + Kernel size of the convolutional layers in the adapter network. Only relevant if `add_adapter is True`. + adaptor_stride (`int`, *optional*, defaults to 8): + Stride of the convolutional layers in the adapter network. Only relevant if `add_adapter is True`. + adaptor_dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all layers in the speech adapter. + num_adapter_layers (`int`, *optional*, defaults to 1): + Number of convolutional layers that should be used in the adapter network. Only relevant if `add_adapter is + True`. + position_embeddings_type (`str`, *optional*, defaults to `"relative_key"`): + Can be specified to `relative_key`. If left to `None`, no relative position embedding is applied. Only + applied to the speech encoder. For more information on `"relative_key"`, please refer to [Self-Attention + with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155). + conv_depthwise_kernel_size (`int`, *optional*, defaults to 31): + Kernel size of convolutional depthwise 1D layer in Conformer blocks. Only applied to the speech encoder. + left_max_position_embeddings (`int`, *optional*, defaults to 64): + The left clipping value for relative positions. + right_max_position_embeddings (`int`, *optional*, defaults to 8): + The right clipping value for relative positions. + speech_encoder_chunk_size (`int`, *optional*, defaults to 20000): The size of each attention chunk. + speech_encoder_left_chunk_num (`int`, *optional*, defaults to 128): + Number of chunks on the left up to which lookahead is allowed. + + > Text-To-Unit (t2u) model specific parameters + + t2u_bos_token_id (`int`, *optional*, defaults to 0): + The id of the _beginning-of-stream_ unit token. Only applied to the text-to-unit seq2seq model. + t2u_pad_token_id (`int`, *optional*, defaults to 1): + The id of the _padding_ unit token. Only applied to the text-to-unit seq2seq model. + t2u_eos_token_id (`int`, *optional*, defaults to 2): + The id of the _end-of-stream_ unit token. Only applied to the text-to-unit seq2seq model. + t2u_encoder_layers (`int`, *optional*, defaults to 6): + Number of hidden layers in the Transformer text-to-unit encoder. + t2u_encoder_ffn_dim (`int`, *optional*, defaults to 8192): + Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer text-to-unit encoder. + t2u_encoder_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer text-to-unit encoder. + t2u_decoder_layers (`int`, *optional*, defaults to 6): + Number of hidden layers in the Transformer text-to-unit decoder. + t2u_decoder_ffn_dim (`int`, *optional*, defaults to 8192): + Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer text-to-unit decoder. + t2u_decoder_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer text-to-unit decoder. + t2u_max_position_embeddings (`int`, *optional*, defaults to 4096): + The maximum sequence length that this model text-to-unit component might ever be used with. Typically set + this to something large just in case (e.g., 512 or 1024 or 2048). + t2u_variance_predictor_embed_dim (`int`, *optional*, defaults to 1024): + The projection dimension of the text-to-unit's duration predictor. + t2u_variance_predictor_hidden_dim (`int`, *optional*, defaults to 256): + Internal dimension of the text-to-unit's duration predictor. + t2u_variance_predictor_kernel_size (`int`, *optional*, defaults to 3): + Kernel size of the convolutional layers of the text-to-unit's duration predictor. + t2u_variance_pred_dropout (`float`, *optional*, defaults to 0.5): + The dropout probability of the text-to-unit's duration predictor. + + > Hifi-Gan Vocoder specific parameters + + sampling_rate (`int`, *optional*, defaults to 16000): + The sampling rate at which the output audio will be generated, expressed in hertz (Hz). + upsample_initial_channel (`int`, *optional*, defaults to 512): + The number of input channels into the hifi-gan upsampling network. Applies to the vocoder only. + upsample_rates (`Tuple[int]` or `List[int]`, *optional*, defaults to `[5, 4, 4, 2, 2]`): + A tuple of integers defining the stride of each 1D convolutional layer in the vocoder upsampling network. + The length of *upsample_rates* defines the number of convolutional layers and has to match the length of + *upsample_kernel_sizes*. Applies to the vocoder only. + upsample_kernel_sizes (`Tuple[int]` or `List[int]`, *optional*, defaults to `[11, 8, 8, 4, 4]`): + A tuple of integers defining the kernel size of each 1D convolutional layer in the vocoder upsampling + network. The length of *upsample_kernel_sizes* defines the number of convolutional layers and has to match + the length of *upsample_rates*. Applies to the vocoder only. + resblock_kernel_sizes (`Tuple[int]` or `List[int]`, *optional*, defaults to `[3, 7, 11]`): + A tuple of integers defining the kernel sizes of the vocoder 1D convolutional layers in the multi-receptive + field fusion (MRF) module. Applies to the vocoder only. + resblock_dilation_sizes (`Tuple[Tuple[int]]` or `List[List[int]]`, *optional*, defaults to `[[1, 3, 5], [1, 3, 5], [1, 3, 5]]`): + A nested tuple of integers defining the dilation rates of the vocoder dilated 1D convolutional layers in + the multi-receptive field fusion (MRF) module. Applies to the vocoder only. + leaky_relu_slope (`float`, *optional*, defaults to 0.1): + The angle of the negative slope used by the leaky ReLU activation in the vocoder. Applies to the vocoder + only. + unit_hifi_gan_vocab_size (`int`, *optional*, defaults to 10000): + Vocabulary size of the SeamlessM4Tv2 vocoder. Defines the number of different unit tokens that can be + represented by the `inputs_ids` passed when calling the vocoder of [`~SeamlessM4Tv2Model`], + [`~SeamlessM4Tv2ForSpeechToSpeech`] or [`~SeamlessM4Tv2ForTextToSpeech`]. + unit_embed_dim (`int`, *optional*, defaults to 1280): + The projection dimension of the input ids given to the hifi-gan vocoder. Applies to the vocoder only. + lang_embed_dim (`int`, *optional*, defaults to 256): + The projection dimension of the target language given to the hifi-gan vocoder. Applies to the vocoder only. + spkr_embed_dim (`int`, *optional*, defaults to 256): + The projection dimension of the speaker id given to the hifi-gan vocoder. Applies to the vocoder only. + vocoder_num_langs (`int`, *optional*, defaults to 36): + Number of langs supported by the vocoder. Might be different from `t2u_num_langs`. + vocoder_num_spkrs (`int`, *optional*, defaults to 200): + Number of speakers supported by the vocoder. + variance_predictor_kernel_size (`int`, *optional*, defaults to 3): + Kernel size of the duration predictor. Applies to the vocoder only. + var_pred_dropout (`float`, *optional*, defaults to 0.5): + The dropout probability of the duration predictor. Applies to the vocoder only. + vocoder_offset (`int`, *optional*, defaults to 4): + Offset the unit token ids by this number to account for symbol tokens. Applies to the vocoder only. + + ```python + >>> from transformers import SeamlessM4Tv2Model, SeamlessM4Tv2Config + + >>> # Initializing a SeamlessM4Tv2 "" style configuration + >>> configuration = SeamlessM4Tv2Config() + + >>> # Initializing a model from the "" style configuration + >>> model = SeamlessM4Tv2Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "seamless_m4t_v2" + + def __init__( + self, + vocab_size=256102, + t2u_vocab_size=10082, + char_vocab_size=10943, + # shared config + hidden_size=1024, + initializer_range=0.02, + layer_norm_eps=1e-5, + use_cache=True, + max_position_embeddings=4096, + is_encoder_decoder=True, + encoder_layerdrop=0.05, + decoder_layerdrop=0.05, + activation_function="relu", + dropout=0.1, + attention_dropout=0.1, + activation_dropout=0.0, + scale_embedding=True, + # text encoder|decoder + encoder_layers=24, + encoder_ffn_dim=8192, + encoder_attention_heads=16, + decoder_layers=24, + decoder_ffn_dim=8192, + decoder_attention_heads=16, + decoder_start_token_id=3, + max_new_tokens=256, + pad_token_id=0, + bos_token_id=2, + eos_token_id=3, + # speech_encoder + speech_encoder_layers=24, + speech_encoder_attention_heads=16, + speech_encoder_intermediate_size=4096, + speech_encoder_hidden_act="swish", + speech_encoder_dropout=0.0, + add_adapter=True, + speech_encoder_layerdrop=0.1, + feature_projection_input_dim=160, + adaptor_kernel_size=8, + adaptor_stride=8, + adaptor_dropout=0.1, + num_adapter_layers=1, + position_embeddings_type="relative_key", + conv_depthwise_kernel_size=31, + left_max_position_embeddings=64, + right_max_position_embeddings=8, + speech_encoder_chunk_size=20000, + speech_encoder_left_chunk_num=128, + # t2u config + t2u_bos_token_id=0, + t2u_pad_token_id=1, + t2u_eos_token_id=2, + t2u_encoder_layers=6, + t2u_encoder_ffn_dim=8192, + t2u_encoder_attention_heads=16, + t2u_decoder_layers=6, + t2u_decoder_ffn_dim=8192, + t2u_decoder_attention_heads=16, + t2u_max_position_embeddings=4096, + t2u_variance_predictor_embed_dim=1024, + t2u_variance_predictor_hidden_dim=256, + t2u_variance_predictor_kernel_size=3, + t2u_variance_pred_dropout=0.5, + # hifi-gan vocoder config + sampling_rate=16000, + upsample_initial_channel=512, + upsample_rates=[5, 4, 4, 2, 2], + upsample_kernel_sizes=[11, 8, 8, 4, 4], + resblock_kernel_sizes=[3, 7, 11], + resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], + leaky_relu_slope=0.1, + # specific to Code Hifi-Gan + unit_hifi_gan_vocab_size=10000, + unit_embed_dim=1280, + lang_embed_dim=256, + spkr_embed_dim=256, + vocoder_num_langs=36, + vocoder_num_spkrs=200, + variance_predictor_kernel_size=3, + var_pred_dropout=0.5, + vocoder_offset=4, + **kwargs, + ): + # overall_config + self.vocab_size = vocab_size + self.t2u_vocab_size = t2u_vocab_size + self.char_vocab_size = char_vocab_size + self.hidden_size = hidden_size + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.max_position_embeddings = max_position_embeddings + self.use_cache = use_cache + self.max_new_tokens = max_new_tokens + self.encoder_layerdrop = encoder_layerdrop + self.decoder_layerdrop = decoder_layerdrop + self.activation_function = activation_function + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.scale_embedding = scale_embedding + # for proper config init + self.num_attention_heads = decoder_attention_heads + self.num_hidden_layers = decoder_layers + + # text|unit encoder|decoder + self.encoder_layers = encoder_layers + self.encoder_ffn_dim = encoder_ffn_dim + self.encoder_attention_heads = encoder_attention_heads + self.decoder_layers = decoder_layers + self.decoder_ffn_dim = decoder_ffn_dim + self.decoder_attention_heads = decoder_attention_heads + + # speech_encoder + self.speech_encoder_layers = speech_encoder_layers + self.speech_encoder_hidden_act = speech_encoder_hidden_act + self.speech_encoder_dropout = speech_encoder_dropout + self.speech_encoder_attention_heads = speech_encoder_attention_heads + self.speech_encoder_layerdrop = speech_encoder_layerdrop + self.speech_encoder_intermediate_size = speech_encoder_intermediate_size + self.feature_projection_input_dim = feature_projection_input_dim + self.adaptor_kernel_size = adaptor_kernel_size + self.adaptor_stride = adaptor_stride + self.adaptor_dropout = adaptor_dropout + self.num_adapter_layers = num_adapter_layers + self.position_embeddings_type = position_embeddings_type + self.conv_depthwise_kernel_size = conv_depthwise_kernel_size + self.add_adapter = add_adapter + self.left_max_position_embeddings = left_max_position_embeddings + self.right_max_position_embeddings = right_max_position_embeddings + self.speech_encoder_chunk_size = speech_encoder_chunk_size + self.speech_encoder_left_chunk_num = speech_encoder_left_chunk_num + + # t2u config + self.t2u_bos_token_id = t2u_bos_token_id + self.t2u_pad_token_id = t2u_pad_token_id + self.t2u_eos_token_id = t2u_eos_token_id + self.t2u_encoder_layers = t2u_encoder_layers + self.t2u_encoder_ffn_dim = t2u_encoder_ffn_dim + self.t2u_encoder_attention_heads = t2u_encoder_attention_heads + self.t2u_decoder_layers = t2u_decoder_layers + self.t2u_decoder_ffn_dim = t2u_decoder_ffn_dim + self.t2u_decoder_attention_heads = t2u_decoder_attention_heads + self.t2u_max_position_embeddings = t2u_max_position_embeddings + self.t2u_variance_predictor_embed_dim = t2u_variance_predictor_embed_dim # TODO: add to docstrings + self.t2u_variance_predictor_hidden_dim = t2u_variance_predictor_hidden_dim # TODO: add to docstrings + self.t2u_variance_predictor_kernel_size = t2u_variance_predictor_kernel_size # TODO: add to docstrings + self.t2u_variance_pred_dropout = t2u_variance_pred_dropout # TODO: add to docstrings + + # hifi-gan vocoder config + # original parameters specific to Hifi-Gan + self.sampling_rate = sampling_rate + self.upsample_initial_channel = upsample_initial_channel + self.upsample_rates = upsample_rates + self.upsample_kernel_sizes = upsample_kernel_sizes + self.resblock_kernel_sizes = resblock_kernel_sizes + self.resblock_dilation_sizes = resblock_dilation_sizes + self.leaky_relu_slope = leaky_relu_slope + + # specific to Code Hifi-Gan + self.unit_hifi_gan_vocab_size = unit_hifi_gan_vocab_size + self.unit_embed_dim = unit_embed_dim + self.lang_embed_dim = lang_embed_dim + self.spkr_embed_dim = spkr_embed_dim + self.vocoder_num_langs = vocoder_num_langs + self.vocoder_num_spkrs = vocoder_num_spkrs + self.variance_predictor_kernel_size = variance_predictor_kernel_size + self.var_pred_dropout = var_pred_dropout + self.vocoder_offset = vocoder_offset + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + decoder_start_token_id=decoder_start_token_id, + is_encoder_decoder=is_encoder_decoder, + max_position_embeddings=max_position_embeddings, + **kwargs, + ) + + +__all__ = ["SeamlessM4Tv2Config"] diff --git a/docs/transformers/src/transformers/models/seamless_m4t_v2/convert_fairseq2_to_hf.py b/docs/transformers/src/transformers/models/seamless_m4t_v2/convert_fairseq2_to_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..c75b7c8139d3eab6e05d17d21ee88f6f7ab84747 --- /dev/null +++ b/docs/transformers/src/transformers/models/seamless_m4t_v2/convert_fairseq2_to_hf.py @@ -0,0 +1,404 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Converting Meta SeamlessM4Tv2 checkpoints from seamless_communication to HF.""" + +import argparse +import os +from pathlib import Path + +import torch +from accelerate.utils.modeling import find_tied_parameters +from seamless_communication.inference import Translator + +from transformers import ( + SeamlessM4TFeatureExtractor, + SeamlessM4TProcessor, + SeamlessM4TTokenizer, + SeamlessM4Tv2Config, + SeamlessM4Tv2Model, +) +from transformers.utils import logging + + +# fmt: off +UNIT_SUPPORTED_LANGUAGES = ["__arb__", "__ben__", "__cat__", "__ces__", "__cmn__", "__cym__", "__dan__", "__deu__", "__eng__", "__est__", "__fin__", "__fra__", "__hin__", "__ind__", "__ita__", "__jpn__", "__kan__", "__kor__", "__mlt__", "__nld__", "__pes__", "__pol__", "__por__", "__ron__", "__rus__", "__slk__", "__spa__", "__swe__", "__swh__", "__tam__", "__tel__", "__tgl__", "__tha__", "__tur__", "__ukr__", "__urd__", "__uzn__", "__vie__", ] +# fmt: on + +# fmt: off +VOCODER_SUPPORTED_LANGUAGES = ["__arb__", "__ben__", "__cat__", "__ces__", "__cmn__", "__cym__", "__dan__", "__deu__", "__eng__", "__est__", "__fin__", "__fra__", "__hin__", "__ind__", "__ita__", "__jpn__", "__kor__", "__mlt__", "__nld__", "__pes__", "__pol__", "__por__", "__ron__", "__rus__", "__slk__", "__spa__", "__swe__", "__swh__", "__tel__", "__tgl__", "__tha__", "__tur__", "__ukr__", "__urd__", "__uzn__", "__vie__",] +# fmt: on + +# fmt: off +LARGE_SUPPORTED_LANGUAGES = ["afr","amh","arb","ary","arz","asm","azj","bel","ben","bos","bul","cat","ceb","ces","ckb","cmn","cmn_Hant","cym","dan","deu","ell","eng","est","eus","fin","fra","fuv","gaz","gle","glg","guj","heb","hin","hrv","hun","hye","ibo","ind","isl","ita","jav","jpn","kan","kat","kaz","khk","khm","kir","kor","lao","lit","lug","luo","lvs","mai","mal","mar","mkd","mlt","mni","mya","nld","nno","nob","npi","nya","ory","pan","pbt","pes","pol","por","ron","rus","sat","slk","slv","sna","snd","som","spa","srp","swe","swh","tam","tel","tgk","tgl","tha","tur","ukr","urd","uzn","vie","yor","yue","zlm","zul",] +# fmt: on + + +def assert_param_count(model_1, model_2): + count_1 = sum(p[1].numel() for p in model_1.named_parameters() if "final_proj" not in p[0]) + count_2 = sum(p[1].numel() for p in model_2.named_parameters() if "final_proj" not in p[0]) + assert count_1 == count_2, f"{model_1.__class__}: {count_1} != {model_2.__class__}: {count_2}" + + +def param_count(model): + return sum(p[1].numel() for p in model.named_parameters() if "final_proj" not in p[0]) + + +def _grab_best_device(use_gpu=True): + if torch.cuda.device_count() > 0 and use_gpu: + device = "cuda" + else: + device = "cpu" + return torch.device(device) + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + +vocoder_convert_list = [ + ("ups", "hifi_gan.upsampler"), + ("conv_pre", "hifi_gan.conv_pre"), + ("resblocks", "hifi_gan.resblocks"), + ("conv_post", "hifi_gan.conv_post"), + ("lang", "language_embedding"), + ("spkr", "speaker_embedding"), + ("dict.", "unit_embedding."), + ("dur_predictor.conv1.0", "dur_predictor.conv1"), + ("dur_predictor.conv2.0", "dur_predictor.conv2"), +] + +# order is important +wav2vec_convert_list = [ + ("speech_encoder_frontend.model_dim_proj", "feature_projection.projection"), + ("speech_encoder_frontend.post_extract_layer_norm", "feature_projection.layer_norm"), + ("speech_encoder_frontend.pos_encoder.conv", "encoder.pos_conv_embed.conv"), + ("speech_encoder.inner.layers", "encoder.layers"), + ("speech_encoder.inner_layer_norm", "encoder.layer_norm"), + ("speech_encoder.adaptor_layers", "adapter.layers"), + ("inner_proj", "intermediate_dense"), + ("self_attn.output_proj", "self_attn.linear_out"), + ("output_proj", "output_dense"), + ("self_attn.k_proj", "self_attn.linear_k"), + ("self_attn.v_proj", "self_attn.linear_v"), + ("self_attn.q_proj", "self_attn.linear_q"), + ("self_attn.sdpa.u_bias", "self_attn.pos_bias_u"), + ("self_attn.sdpa.v_bias", "self_attn.pos_bias_v"), + ("self_attn.sdpa.rel_k_embed", "self_attn.distance_embedding"), + ("self_attn.sdpa.r_proj", "self_attn.linear_pos"), + ("conv.pointwise_conv1", "conv_module.pointwise_conv1"), + ("conv.pointwise_conv2", "conv_module.pointwise_conv2"), + ("conv.depthwise_conv", "conv_module.depthwise_conv"), + ("conv.batch_norm", "conv_module.batch_norm"), + ("conv.layer_norm", "conv_module.depthwise_layer_norm"), + ("conv_layer_norm", "conv_module.layer_norm"), + ("speech_encoder.proj1", "intermediate_ffn.intermediate_dense"), + ("speech_encoder.proj2", "intermediate_ffn.output_dense"), + ("speech_encoder.layer_norm", "inner_layer_norm"), +] + +t2u_convert_list = [ + ("t2u_model.final_proj", "lm_head"), + ("t2u_model.", "model."), + ("encoder_decoder_attn_layer_norm", "cross_attention_layer_norm"), + ("encoder_decoder_attn", "cross_attention"), + ("linear_k", "k_proj"), + ("linear_v", "v_proj"), + ("linear_q", "q_proj"), + ("ffn.inner_proj", "ffn.fc1"), + ("ffn.output_proj", "ffn.fc2"), + ("output_proj", "out_proj"), + ("decoder_frontend.embed_char", "decoder.embed_char"), + ("decoder_frontend.pos_emb_alpha_char", "decoder.pos_emb_alpha_char"), + ("decoder_frontend.embed", "decoder.embed_tokens"), + ("decoder_frontend.pos_emb_alpha", "decoder.pos_emb_alpha"), + ("conv1d.conv", "conv"), + ("conv1d_layer_norm", "conv_layer_norm"), + ("decoder_frontend.variance_adaptor", "decoder"), + ("duration_predictor.conv1.0", "duration_predictor.conv1"), + ("duration_predictor.conv2.0", "duration_predictor.conv2"), +] + +text_convert_list = [ + ("text_encoder.", ""), + ("text_decoder.", ""), + ("text_encoder_frontend.embed", "embed_tokens"), + ("text_decoder_frontend.embed", "embed_tokens"), + ("encoder_decoder_attn_layer_norm", "cross_attention_layer_norm"), + ("encoder_decoder_attn", "cross_attention"), + ("linear_k", "k_proj"), + ("linear_v", "v_proj"), + ("linear_q", "q_proj"), + ("ffn.inner_proj", "ffn.fc1"), + ("ffn.output_proj", "ffn.fc2"), + ("output_proj", "out_proj"), + ("final_proj", "lm_head"), +] + +CUR_PATH = os.path.dirname(os.path.abspath(__file__)) +default_cache_dir = os.path.join(os.path.expanduser("~"), ".cache") +CACHE_DIR = os.path.join(os.getenv("XDG_CACHE_HOME", default_cache_dir), "huggingface", "hub") + + +def _load_hf_config(): + return SeamlessM4Tv2Config() + + +def _convert_model( + original_model, + hf_model, + convert_list, + device, + unwanted_prefix="model.", + filter_state_dict="speech", + exclude_state_dict=None, +): + state_dict = original_model.state_dict() + + # filter func + if isinstance(filter_state_dict, str): + + def filter_func(x): + return filter_state_dict in x[0] + + else: + + def filter_func(item): + if exclude_state_dict is not None and exclude_state_dict in item[0]: + return False + for filter_el in filter_state_dict: + if filter_el in item[0]: + return True + + return False + + state_dict = dict(filter(filter_func, state_dict.items())) + + for k, v in list(state_dict.items()): + new_k = k[len(unwanted_prefix) :] + for old_layer_name, new_layer_name in convert_list: + if old_layer_name in new_k: + new_k = new_k.replace(old_layer_name, new_layer_name) + + # must do it by hand + if ".layer_norm" in new_k and new_k.split(".layer_norm")[0][-1].isnumeric(): + new_k = new_k.replace("layer_norm", "final_layer_norm") + + state_dict[new_k] = state_dict.pop(k) + + extra_keys = set(state_dict.keys()) - set(hf_model.state_dict().keys()) + extra_keys = set(extra_keys) + missing_keys = set(hf_model.state_dict().keys()) - set(state_dict.keys()) + missing_keys = set({k for k in missing_keys if "final_logits_bias" not in k}) + if len(extra_keys) != 0: + raise ValueError(f"extra keys found: {extra_keys}") + if len(missing_keys) != 0: + raise ValueError(f"missing keys: {missing_keys}") + hf_model.load_state_dict(state_dict, strict=False) + n_params = param_count(hf_model) + + logger.info(f"model loaded: {round(n_params / 1e6, 1)}M params") + + hf_model.eval() + hf_model.to(device) + del state_dict + + return hf_model + + +def load_model(save_dir, model_type, repo_id): + """ + Meta SeamlessM4Tv2 is made of 8 main components: + - speech_encoder (#1) and speech_encoder_frontend (#2) + - t2u_model (#3) + - text_encoder (#4) and text_encoder_frontend (#5) + - text_decoder (#6) [and text_decoder_frontend (#5) = equals to text_encoder_frontend] + - final_proj (#7) + - vocoder (#8) + """ + device = _grab_best_device() + name = "seamlessM4T_v2_large" + + original_model = Translator(name, "vocoder_v2", device, dtype=torch.float32) + + ######### TOKENIZER + + langs = LARGE_SUPPORTED_LANGUAGES + langs = [f"__{lang}__" for lang in langs] + vocab_file = os.path.join(os.path.expanduser("~"), "tokenizer", model_type, "tokenizer.model") + + save_dir = os.path.join(save_dir, name) + Path(save_dir).mkdir(exist_ok=True) + + tokenizer = SeamlessM4TTokenizer(vocab_file, additional_special_tokens=langs) + + sanity_check_lang_id = tokenizer.convert_tokens_to_ids("__fra__") + + tokenizer.save_pretrained(save_dir) + tokenizer = SeamlessM4TTokenizer.from_pretrained(save_dir) + + if sanity_check_lang_id != tokenizer.convert_tokens_to_ids("__fra__"): + raise ValueError( + f"Error in tokenizer saving/loading - __fra__ lang id is not coherent: {sanity_check_lang_id} vs {tokenizer.convert_tokens_to_ids('__fra__')}" + ) + + ####### get language to ids dict + text_decoder_lang_code_to_id = {lang.replace("__", ""): tokenizer.convert_tokens_to_ids(lang) for lang in langs} + # offset: vocoder unit vocab size + 5 (for EOS/PAD/BOS/UNK/MSK) + len(supported_languages) + t2u_lang_code_to_id = { + code.replace("__", ""): i + 10005 + len(UNIT_SUPPORTED_LANGUAGES) + for i, code in enumerate(UNIT_SUPPORTED_LANGUAGES) + } + vocoder_lang_code_to_id = {code.replace("__", ""): i for i, code in enumerate(VOCODER_SUPPORTED_LANGUAGES)} + + ######### FE + + fe = SeamlessM4TFeatureExtractor(language_code=langs) + + fe.save_pretrained(save_dir) + fe = SeamlessM4TFeatureExtractor.from_pretrained(save_dir) + + processor = SeamlessM4TProcessor(feature_extractor=fe, tokenizer=tokenizer) + processor.save_pretrained(save_dir) + processor.push_to_hub(repo_id=repo_id, create_pr=True) + + processor = SeamlessM4TProcessor.from_pretrained(save_dir) + + ######## Model + + # init config + hf_config = _load_hf_config() + + ######## get id_to_text and char_to_id from original model tokenizers + id_to_text = {i: original_model.text_tokenizer.model.index_to_token(i) for i in range(hf_config.vocab_size)} + char_to_id = { + original_model.model.t2u_model.decoder_frontend.char_tokenizer.model.index_to_token(i): i for i in range(10904) + } + + # init model + hf_model = SeamlessM4Tv2Model(hf_config) + + hf_model.generation_config.__setattr__("text_decoder_lang_to_code_id", text_decoder_lang_code_to_id) + hf_model.generation_config.__setattr__("t2u_lang_code_to_id", t2u_lang_code_to_id) + hf_model.generation_config.__setattr__("vocoder_lang_code_to_id", vocoder_lang_code_to_id) + hf_model.generation_config.__setattr__("id_to_text", id_to_text) + hf_model.generation_config.__setattr__("char_to_id", char_to_id) + + # -1. take care of vocoder + # similarly to speech T5 must apply and remove weight norm + hf_model.vocoder.apply_weight_norm() + hf_model.vocoder = _convert_model( + original_model, + hf_model.vocoder, + vocoder_convert_list, + device, + unwanted_prefix="vocoder.code_generator.", + filter_state_dict="vocoder", + ) + hf_model.vocoder.remove_weight_norm() + + # 1. take care of speech encoder + wav2vec = hf_model.speech_encoder + hf_model.speech_encoder = _convert_model( + original_model, wav2vec, wav2vec_convert_list, device, unwanted_prefix="model.", filter_state_dict="speech" + ) + + # 2. take care of t2u + + hf_model.t2u_model = _convert_model( + original_model, + hf_model.t2u_model, + t2u_convert_list, + device, + unwanted_prefix="model.", + filter_state_dict="t2u_model", + ) + + # 3. take care of text encoder + hf_model.text_encoder = _convert_model( + original_model, + hf_model.text_encoder, + text_convert_list, + device, + unwanted_prefix="model.", + filter_state_dict=["model.text_encoder"], + exclude_state_dict="t2u_model", + ) + + # 4. take care of text decoder + hf_model.text_decoder = _convert_model( + original_model, + hf_model.text_decoder, + text_convert_list, + device, + unwanted_prefix="model.", + filter_state_dict=["model.text_decoder"], + exclude_state_dict="t2u_model", + ) + + # 5. take care of final proj + hf_model.lm_head = _convert_model( + original_model, + hf_model.lm_head, + [("final_proj.", "")], + device, + unwanted_prefix="model.", + filter_state_dict=["model.final_proj"], + exclude_state_dict="t2u_model", + ) + + # sanity check + print(find_tied_parameters(hf_model)) + + count_1 = param_count(hf_model) + count_2 = param_count(original_model) + + print(f"HF MODEL:{count_1}, ORIGINAL_MODEL: {count_2}, diff:{count_1 - count_2}") + print(f"HF MODEL excluding embeddings:{hf_model.num_parameters(exclude_embeddings=True)}") + + del original_model + + hf_model.generation_config._from_model_config = False + hf_model.save_pretrained(save_dir) + hf_model.push_to_hub(repo_id=repo_id, create_pr=True) + hf_model = SeamlessM4Tv2Model.from_pretrained(save_dir) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + + parser.add_argument( + "--model_type", + default="large", + type=str, + help="Model type.", + ) + + parser.add_argument( + "--save_dir", + default="/home/ubuntu/weights_v2", + type=str, + help="Path to the output PyTorch model.", + ) + + parser.add_argument( + "--repo_id", + default="facebook/seamless-m4t-v2-large", + type=str, + help="Repo ID.", + ) + + args = parser.parse_args() + + load_model(args.save_dir, args.model_type, args.repo_id) diff --git a/docs/transformers/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py b/docs/transformers/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..998008bbf56ba98ffd19b82b9ab669f98590f87f --- /dev/null +++ b/docs/transformers/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py @@ -0,0 +1,4722 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch SeamlessM4Tv2 model.""" + +import copy +import math +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import Tensor, nn +from torch.nn import CrossEntropyLoss + +from ...activations import ACT2FN +from ...generation import GenerationMixin +from ...integrations.deepspeed import is_deepspeed_zero3_enabled +from ...integrations.fsdp import is_fsdp_managed_module +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, + Wav2Vec2BaseModelOutput, +) +from ...modeling_utils import PreTrainedModel +from ...utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, +) +from .configuration_seamless_m4t_v2 import SeamlessM4Tv2Config + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "" +_CONFIG_FOR_DOC = "SeamlessM4Tv2Config" + + +@dataclass +# Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TGenerationOutput with SeamlessM4T->SeamlessM4Tv2 +class SeamlessM4Tv2GenerationOutput(ModelOutput): + """ + Class defining the generated outputs from [`SeamlessM4Tv2Model`], [`SeamlessM4Tv2ForTextToText`], + [`SeamlessM4Tv2ForTextToSpeech`], [`SeamlessM4Tv2ForSpeechToSpeech`] and [`SeamlessM4Tv2ForTextToSpeech`]. + + Args: + waveform (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + The final audio waveform predicted by the model. + waveform_lengths (`torch.IntTensor` of shape `(batch_size,)`, *optional*): + The length in samples of each element in the `waveform` batch. + sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + The generated translated sequences. This is the output of the text-to-text or the speech-to-text models. + The second dimension (sequence_length) is either equal to `max_length` or shorter if all batches finished + early due to the `eos_token_id`. + unit_sequences (`torch.LongTensor` of shape `(batch_size, unit_sequence_length)`, *optional*): + The generated translated unit sequences. This is the output of the text-to-units model. The second + dimension (unit_sequence_length) is either equal to `t2u_max_length` or shorter if all batches finished + early due to the `t2u_eos_token_id`. + """ + + waveform: Optional[torch.FloatTensor] = None + waveform_lengths: Optional[torch.IntTensor] = None + sequences: Optional[Tuple[torch.FloatTensor]] = None + unit_sequences: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class SeamlessM4Tv2TextToUnitDecoderOutput(ModelOutput): + """ + Class defining the outputs from [`SeamlessM4Tv2TextToUnitDecoder`]. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + padding_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Indicates which inputs are to be ignored due to padding, where elements are either 1 for *not masked* or 0 + for *masked* + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + padding_mask: Optional[torch.Tensor] = None + + +@dataclass +class SeamlessM4Tv2TextToUnitOutput(ModelOutput): + """ + Class defining the outputs from [`SeamlessM4Tv2TextToUnitForConditionalGeneration`] and + [`SeamlessM4Tv2TextToUnitModel`]. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the decoder of the model. + + If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, + hidden_size)` is output. + padding_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Indicates which inputs are to be ignored due to padding, where elements are either 1 for *not masked* or 0 + for *masked* + decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the decoder at the output of each layer plus the optional initial embedding outputs. + decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the encoder at the output of each layer plus the optional initial embedding outputs. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + padding_mask: Optional[torch.Tensor] = None + decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + loss: Optional[torch.FloatTensor] = None + + +SEAMLESS_M4T_V2_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use + it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`~SeamlessM4Tv2Config`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +SEAMLESS_M4T_V2_MULTIMODAL_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`SeamlessM4TTokenizer`] or [`SeamlessM4TProcessor`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_banks)`): + Input audio features. This should be returnes by the [`SeamlessM4TFeatureExtractor`] class or the + [`SeamlessM4TProcessor`] class. See [`SeamlessM4TFeatureExtractor.__call__`] for details. + """ + +M4T_TEXT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`SeamlessM4TTokenizer`] or [`SeamlessM4TProcessor`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + """ + +M4T_SPEECH_INPUTS_DOCSTRING = r""" + Args: + input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_banks)`): + Input audio features. This should be returnes by the [`SeamlessM4TFeatureExtractor`] class or the + [`SeamlessM4TProcessor`] class. See [`SeamlessM4TFeatureExtractor.__call__`] for details. + """ + +SEAMLESS_M4T_V2_END_INPUTS_DOCSTRING = r""" + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + Bart uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values` + is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). + + For translation and summarization training, `decoder_input_ids` should be provided. If no + `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right + for denoising pre-training following the paper. + decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + + If you want to change padding behavior, you should read [`modeling_bart._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): + Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of + hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape`(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded + representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be + input (see `past_key_values`). This is useful if you want more control over how to convert + `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. + + If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value + of `inputs_embeds`. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +M4T_MODEL_INPUTS_DOCSTRING = SEAMLESS_M4T_V2_MULTIMODAL_INPUTS_DOCSTRING + SEAMLESS_M4T_V2_END_INPUTS_DOCSTRING + +M4T_TEXT_INPUTS_DOCSTRING = M4T_TEXT_INPUTS_DOCSTRING + SEAMLESS_M4T_V2_END_INPUTS_DOCSTRING + +M4T_SPEECH_INPUTS_DOCSTRING = M4T_SPEECH_INPUTS_DOCSTRING + SEAMLESS_M4T_V2_END_INPUTS_DOCSTRING + +M4T_TEXT_TO_UNITS_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`SeamlessM4TTokenizer`] or [`SeamlessM4TProcessor`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + char_input_ids (`torch.LongTensor` of shape `(batch_size, char_sequence_length)`): + Character indices. The correspondence between characters and indices can be found in `char_to_id`, a + dictionary in the generation configuration. + char_count_per_id (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Number of characters per input id. + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): + Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of + hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. + inputs_embeds (`torch.FloatTensor` of shape`(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +############ UTILS ################ + + +# Copied from transformers.models.roberta.modeling_roberta.create_position_ids_from_input_ids +def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0): + """ + Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols + are ignored. This is modified from fairseq's `utils.make_positions`. + + Args: + x: torch.Tensor x: + + Returns: torch.Tensor + """ + # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. + mask = input_ids.ne(padding_idx).int() + incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask + return incremental_indices.long() + padding_idx + + +# Copied from transformers.models.bart.modeling_bart.shift_tokens_right +def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): + """ + Shift input ids one token to the right. + """ + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() + shifted_input_ids[:, 0] = decoder_start_token_id + + if pad_token_id is None: + raise ValueError("self.model.config.pad_token_id has to be defined.") + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +def _compute_new_attention_mask(hidden_states: torch.Tensor, seq_lens: torch.Tensor): + """ + Computes an attention mask of the form `(batch, seq_len)` with an attention for each element in the batch that + stops at the corresponding element in `seq_lens`. + + Args: + hidden_states (`torch.FloatTensor` of shape `(batch, seq_len, *)`): + The sequences to mask, where `*` is any number of sequence-specific dimensions including none. + seq_lens (`torch.Tensor` of shape `(batch)`: + Each element represents the length of the sequence at the same index in `hidden_states` + + Returns: + `torch.FloatTensor`: The float attention mask of shape `(batch, seq_len)` + """ + batch_size, mask_seq_len = hidden_states.shape[:2] + + indices = torch.arange(mask_seq_len, device=seq_lens.device).expand(batch_size, -1) + + bool_mask = indices >= seq_lens.unsqueeze(1).expand(-1, mask_seq_len) + + mask = hidden_states.new_ones((batch_size, mask_seq_len)) + + mask = mask.masked_fill(bool_mask, 0) + + return mask + + +# Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.format_speech_generation_kwargs with SeamlessM4T->SeamlessM4Tv2 +def format_speech_generation_kwargs(kwargs): + """ + Format kwargs for SeamlessM4Tv2 models that generate speech, attribute kwargs to either the text generation or the + speech generation models. + + Args: + kwargs (`dict`)`: + Keyword arguments are of two types: + + - Without a prefix, they will be entered as `**kwargs` for the `generate` method of each sub-model, + except for `decoder_input_ids` which will only be passed through the text components. + - With a *text_* or *speech_* prefix, they will be input for the `generate` method of the + text model and speech model respectively. It has the priority over the keywords without a prefix. + + This means you can, for example, specify a generation strategy for one generation but not for the + other. + """ + # attribute kwargs to models + kwargs_text = {} + kwargs_speech = {} + for key, value in kwargs.items(): + if key.startswith("text_"): + key = key[len("text_") :] + kwargs_text[key] = value + elif key.startswith("speech_"): + key = key[len("speech_") :] + kwargs_speech[key] = value + elif key == "generation_config": + kwargs_text[key] = value + else: + # If the key is already in a specific config, then it's been set with a + # submodules specific value and we don't override + if key not in kwargs_text: + kwargs_text[key] = value + if key not in kwargs_speech: + kwargs_speech[key] = value + return kwargs_text, kwargs_speech + + +############ SPEECH ENCODER related code ################ + + +class SeamlessM4Tv2ConformerFeatureProjection(nn.Module): + # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TConformerFeatureProjection.__init__ + def __init__(self, config): + super().__init__() + self.layer_norm = nn.LayerNorm(config.feature_projection_input_dim, eps=config.layer_norm_eps) + self.projection = nn.Linear(config.feature_projection_input_dim, config.hidden_size) + self.dropout = nn.Dropout(config.speech_encoder_dropout) + + def forward(self, hidden_states): + # non-projected hidden states are needed for quantization + norm_hidden_states = self.layer_norm(hidden_states.to(self.layer_norm.weight.dtype)) + hidden_states = self.projection(norm_hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +# Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TConformerFeedForward with SeamlessM4T->SeamlessM4Tv2 +class SeamlessM4Tv2ConformerFeedForward(nn.Module): + def __init__(self, config, act_fn=None, dropout=None): + super().__init__() + dropout = dropout if dropout is not None else config.speech_encoder_dropout + act_fn = act_fn if act_fn is not None else config.speech_encoder_hidden_act + + self.intermediate_dropout = nn.Dropout(dropout) + self.intermediate_dense = nn.Linear(config.hidden_size, config.speech_encoder_intermediate_size) + self.intermediate_act_fn = ACT2FN[act_fn] if isinstance(act_fn, str) else act_fn + + self.output_dense = nn.Linear(config.speech_encoder_intermediate_size, config.hidden_size) + self.output_dropout = nn.Dropout(dropout) + + def forward(self, hidden_states): + hidden_states = self.intermediate_dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + hidden_states = self.intermediate_dropout(hidden_states) + + hidden_states = self.output_dense(hidden_states) + hidden_states = self.output_dropout(hidden_states) + return hidden_states + + +class SeamlessM4Tv2ConformerConvolutionModule(nn.Module): + """Convolution block used in the conformer block. Uses a causal depthwise convolution similar to that + described in Section 2.1 of `https://doi.org/10.48550/arxiv.1609.03499""" + + def __init__(self, config): + super().__init__() + if (config.conv_depthwise_kernel_size - 1) % 2 == 1: + raise ValueError("`config.conv_depthwise_kernel_size` should be a odd number for 'SAME' padding") + self.layer_norm = nn.LayerNorm(config.hidden_size) + self.pointwise_conv1 = nn.Conv1d( + config.hidden_size, + 2 * config.hidden_size, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ) + self.glu = nn.GLU(dim=1) + self.depthwise_conv = nn.Conv1d( + config.hidden_size, + config.hidden_size, + config.conv_depthwise_kernel_size, + stride=1, + padding=0, + groups=config.hidden_size, + bias=False, + ) + self.depthwise_layer_norm = nn.LayerNorm(config.hidden_size) + self.activation = ACT2FN[config.speech_encoder_hidden_act] + self.pointwise_conv2 = nn.Conv1d( + config.hidden_size, + config.hidden_size, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ) + self.dropout = nn.Dropout(config.speech_encoder_dropout) + + def forward(self, hidden_states, attention_mask=None): + hidden_states = self.layer_norm(hidden_states) + + # Ensure that we do not leak padded positions in depthwise convolution. + # Put 0 where necessary + if attention_mask is not None: + hidden_states = hidden_states.masked_fill(~attention_mask.bool().unsqueeze(-1), 0.0) + + # exchange the temporal dimension and the feature dimension + hidden_states = hidden_states.transpose(1, 2) + + # GLU mechanism + # => (batch, 2*channel, dim) + hidden_states = self.pointwise_conv1(hidden_states) + # => (batch, channel, dim) + hidden_states = self.glu(hidden_states) + + # Pad the sequence entirely on the left because of causal convolution. + hidden_states = torch.nn.functional.pad(hidden_states, (self.depthwise_conv.kernel_size[0] - 1, 0)) + + # 1D Depthwise Conv + hidden_states = self.depthwise_conv(hidden_states) + hidden_states = self.depthwise_layer_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + hidden_states = self.activation(hidden_states) + + hidden_states = self.pointwise_conv2(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = hidden_states.transpose(1, 2) + return hidden_states + + +class SeamlessM4Tv2ConformerSelfAttention(nn.Module): + """Construct a SeamlessM4Tv2ConformerSelfAttention object. + Can be enhanced with relative position embeddings. + """ + + def __init__(self, config, use_position_embeddings=True): + super().__init__() + + self.head_size = config.hidden_size // config.speech_encoder_attention_heads + self.num_heads = config.speech_encoder_attention_heads + self.position_embeddings_type = config.position_embeddings_type if use_position_embeddings else None + + self.linear_q = nn.Linear(config.hidden_size, config.hidden_size) + self.linear_k = nn.Linear(config.hidden_size, config.hidden_size) + self.linear_v = nn.Linear(config.hidden_size, config.hidden_size) + self.linear_out = nn.Linear(config.hidden_size, config.hidden_size) + + self.dropout = nn.Dropout(p=config.speech_encoder_dropout) + + if self.position_embeddings_type == "relative_key": + self.left_max_position_embeddings = config.left_max_position_embeddings + self.right_max_position_embeddings = config.right_max_position_embeddings + num_positions = self.left_max_position_embeddings + self.right_max_position_embeddings + 1 + self.distance_embedding = nn.Embedding(num_positions, self.head_size) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # self-attention mechanism + batch_size, sequence_length, hidden_size = hidden_states.size() + + # make sure query/key states can be != value states + query_key_states = hidden_states + value_states = hidden_states + + # project query_key_states and value_states + query = self.linear_q(query_key_states).view(batch_size, -1, self.num_heads, self.head_size) + key = self.linear_k(query_key_states).view(batch_size, -1, self.num_heads, self.head_size) + value = self.linear_v(value_states).view(batch_size, -1, self.num_heads, self.head_size) + + # => (batch, head, time1, d_k) + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + attn_weights = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_size) + + if self.position_embeddings_type == "relative_key": + query_length, key_length = query.shape[2], key.shape[2] + + position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + distance = position_ids_r - position_ids_l + distance = torch.clamp(distance, -self.left_max_position_embeddings, self.right_max_position_embeddings) + + positional_embedding = self.distance_embedding(distance + self.left_max_position_embeddings) + positional_embedding = positional_embedding.to(dtype=query.dtype) # fp16 compatibility + + relative_position_attn_weights = torch.einsum("bhld,lrd->bhlr", query, positional_embedding) + attn_weights = attn_weights + (relative_position_attn_weights / math.sqrt(self.head_size)) + + # apply attention_mask if necessary + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + # => (batch, head, time1, time2) + attn_weights = torch.softmax(attn_weights, dim=-1) + attn_weights = self.dropout(attn_weights) + + # => (batch, head, time1, d_k) + attn_output = torch.matmul(attn_weights, value) + + # => (batch, time1, hidden_size) + attn_output = attn_output.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_size) + attn_output = self.linear_out(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights + + +class SeamlessM4Tv2ConformerEncoderLayer(nn.Module): + """Conformer block based on https://arxiv.org/abs/2005.08100.""" + + # Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerEncoderLayer.__init__ with Wav2Vec2->SeamlessM4Tv2, attention_dropout->speech_encoder_dropout, torch.nn->nn + def __init__(self, config): + super().__init__() + embed_dim = config.hidden_size + dropout = config.speech_encoder_dropout + + # Feed-forward 1 + self.ffn1_layer_norm = nn.LayerNorm(embed_dim) + self.ffn1 = SeamlessM4Tv2ConformerFeedForward(config) + + # Self-Attention + self.self_attn_layer_norm = nn.LayerNorm(embed_dim) + self.self_attn_dropout = nn.Dropout(dropout) + self.self_attn = SeamlessM4Tv2ConformerSelfAttention(config) + + # Conformer Convolution + self.conv_module = SeamlessM4Tv2ConformerConvolutionModule(config) + + # Feed-forward 2 + self.ffn2_layer_norm = nn.LayerNorm(embed_dim) + self.ffn2 = SeamlessM4Tv2ConformerFeedForward(config) + self.final_layer_norm = nn.LayerNorm(embed_dim) + + def forward( + self, + hidden_states, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + conv_attention_mask: Optional[torch.Tensor] = None, + ): + hidden_states = hidden_states + + # 1. Feed-Forward 1 layer + residual = hidden_states + hidden_states = self.ffn1_layer_norm(hidden_states) + hidden_states = self.ffn1(hidden_states) + hidden_states = hidden_states * 0.5 + residual + residual = hidden_states + + # 2. Self-Attention layer + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_states = self.self_attn_dropout(hidden_states) + hidden_states = hidden_states + residual + + # 3. Convolutional Layer + residual = hidden_states + hidden_states = self.conv_module(hidden_states, attention_mask=conv_attention_mask) + hidden_states = residual + hidden_states + + # 4. Feed-Forward 2 Layer + residual = hidden_states + hidden_states = self.ffn2_layer_norm(hidden_states) + hidden_states = self.ffn2(hidden_states) + hidden_states = hidden_states * 0.5 + residual + hidden_states = self.final_layer_norm(hidden_states) + + return hidden_states, attn_weights + + +class SeamlessM4Tv2ConformerEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + + self.dropout = nn.Dropout(config.speech_encoder_dropout) + self.layers = nn.ModuleList( + [SeamlessM4Tv2ConformerEncoderLayer(config) for _ in range(config.speech_encoder_layers)] + ) + + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.gradient_checkpointing = False + + def _apply_chunk_attention(self, attention_mask, hidden_states): + """ + Creates a chunk attention mask. It creates a mask to prevent attention across chunks, ensuring that each + position attends only to positions within its own chunk. If a left chunk overlap is specified + (`speech_encoder_chunk_size` in the configuration), the attention mask is adjusted accordingly to allow each + position to also attends the `speech_encoder_chunk_size - 1` previous chunks. + """ + sequence_len = hidden_states.shape[1] + + chunk_indices = torch.arange(sequence_len, device=hidden_states.device) + chunk_indices = torch.div(chunk_indices, self.config.speech_encoder_chunk_size).long() + + start_indices = torch.full_like(chunk_indices, 0) + if self.config.speech_encoder_left_chunk_num >= 0: + start_indices = (chunk_indices - self.config.speech_encoder_left_chunk_num).clamp_(min=0) + start_indices = start_indices * self.config.speech_encoder_chunk_size + start_indices = start_indices + start_indices = start_indices.unsqueeze(1).expand(-1, sequence_len) + + end_indices = ((chunk_indices + 1) * self.config.speech_encoder_chunk_size).clamp_(max=sequence_len) + + end_indices = end_indices.unsqueeze(1).expand(-1, sequence_len) + + indices = torch.arange(sequence_len, device=hidden_states.device).unsqueeze(0).expand(sequence_len, -1) + + chunk_mask = (indices < start_indices) | (indices >= end_indices) + chunk_mask = chunk_mask.unsqueeze(0).unsqueeze(0) + + attention_mask = chunk_mask if attention_mask is None else (attention_mask.bool() | chunk_mask) + attention_mask = attention_mask.to(dtype=hidden_states.dtype) + return attention_mask + + def forward( + self, + hidden_states, + attention_mask=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + conv_attention_mask = attention_mask + if attention_mask is not None: + # make sure padded tokens output 0 + hidden_states = hidden_states.masked_fill(~attention_mask.bool().unsqueeze(-1), 0.0) + # extend attention_mask + attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) + attention_mask = attention_mask.expand( + attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] + ) + + if self.config.speech_encoder_chunk_size is not None: + attention_mask = self._apply_chunk_attention(attention_mask, hidden_states) + + if attention_mask is not None: + attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min + + hidden_states = self.dropout(hidden_states) + + synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) + + for i, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = torch.rand([]) + + skip_the_layer = ( + True if self.training and (dropout_probability < self.config.speech_encoder_layerdrop) else False + ) + if not skip_the_layer or synced_gpus: + # under fsdp or deepspeed zero3 all gpus must run in sync + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + output_attentions, + conv_attention_mask, + ) + else: + layer_outputs = layer( + hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + conv_attention_mask=conv_attention_mask, + ) + hidden_states = layer_outputs[0] + + if skip_the_layer: + layer_outputs = (None, None) + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + hidden_states = self.layer_norm(hidden_states) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +# Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TConformerAdapterLayer with SeamlessM4T->SeamlessM4Tv2 +class SeamlessM4Tv2ConformerAdapterLayer(nn.Module): + def __init__(self, config): + super().__init__() + embed_dim = config.hidden_size + dropout = config.adaptor_dropout + + self.kernel_size = config.adaptor_kernel_size + self.stride = config.adaptor_stride + + # 1. residual convolution + self.residual_layer_norm = nn.LayerNorm(embed_dim) + self.residual_conv = nn.Conv1d( + embed_dim, + 2 * embed_dim, + self.kernel_size, + stride=self.stride, + padding=self.stride // 2, + ) + self.activation = nn.GLU(dim=1) + + # Self-Attention + self.self_attn_layer_norm = nn.LayerNorm(embed_dim) + self.self_attn_conv = nn.Conv1d( + embed_dim, + 2 * embed_dim, + self.kernel_size, + stride=self.stride, + padding=self.stride // 2, + ) + self.self_attn = SeamlessM4Tv2ConformerSelfAttention(config, use_position_embeddings=False) + self.self_attn_dropout = nn.Dropout(dropout) + + # Feed-forward + self.ffn_layer_norm = nn.LayerNorm(embed_dim) + self.ffn = SeamlessM4Tv2ConformerFeedForward(config, act_fn="relu", dropout=dropout) + + def _compute_sub_sample_lengths_from_attention_mask(self, attention_mask): + pad = self.kernel_size // 2 + seq_lens = attention_mask.size(1) - (1 - attention_mask.int()).sum(1) + + seq_lens = ((seq_lens + 2 * pad - self.kernel_size) / self.stride) + 1 + + return seq_lens.floor() + + def forward( + self, + hidden_states, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ): + residual = self.residual_layer_norm(hidden_states) + + # Apply pooling to the residual to match the sequence length of the + # multi-head attention output. + # (batch, seq_len, feature_dim) -> (batch, feature_dim, seq_len) + residual = residual.transpose(1, 2) + residual = self.residual_conv(residual) + residual = self.activation(residual) + # (batch, feature_dim, seq_len) -> (batch, seq_len, feature_dim) + residual = residual.transpose(1, 2) + + hidden_states = self.self_attn_layer_norm(hidden_states) + # Apply pooling before feeding to the multihead-attention layer. + # (batch, seq_len, feature_dim) -> (batch, feature_dim, seq_len) + hidden_states = hidden_states.transpose(1, 2) + hidden_states = self.self_attn_conv(hidden_states) + hidden_states = self.activation(hidden_states) + # (batch, feature_dim, seq_len) -> (batch, seq_len, feature_dim) + hidden_states = hidden_states.transpose(1, 2) + + if attention_mask is not None: + sub_sampled_lengths = self._compute_sub_sample_lengths_from_attention_mask(attention_mask).to( + hidden_states.device + ) + attention_mask = _compute_new_attention_mask(hidden_states=hidden_states, seq_lens=sub_sampled_lengths) + attention_mask = _prepare_4d_attention_mask( + attention_mask, + hidden_states.dtype, + ) + + # The rest of the computation is identical to a vanilla Transformer + # encoder layer. + hidden_states, attn_weigths = self.self_attn( + hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_states = self.self_attn_dropout(hidden_states) + hidden_states = hidden_states + residual + + residual = hidden_states + + hidden_states = self.ffn_layer_norm(hidden_states) + hidden_states = self.ffn(hidden_states) + residual + + return hidden_states + + +# Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TConformerAdapter with SeamlessM4T->SeamlessM4Tv2 +class SeamlessM4Tv2ConformerAdapter(nn.Module): + def __init__(self, config): + super().__init__() + + self.layers = nn.ModuleList( + SeamlessM4Tv2ConformerAdapterLayer(config) for _ in range(config.num_adapter_layers) + ) + + def forward(self, hidden_states, attention_mask): + # down project hidden_states if necessary + + for layer in self.layers: + hidden_states = layer(hidden_states, attention_mask) + + return hidden_states + + +############ TEXT / UNITS related code ################ + + +# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100ScaledWordEmbedding with M2M100->SeamlessM4Tv2 +class SeamlessM4Tv2ScaledWordEmbedding(nn.Embedding): + """ + This module overrides nn.Embeddings' forward by multiplying with embeddings scale. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0): + super().__init__(num_embeddings, embedding_dim, padding_idx) + self.embed_scale = embed_scale + + def forward(self, input_ids: torch.Tensor): + return super().forward(input_ids) * self.embed_scale + + +# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding +class SeamlessM4Tv2SinusoidalPositionalEmbedding(nn.Module): + """This module produces sinusoidal positional embeddings of any length.""" + + def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None): + super().__init__() + self.offset = 2 + self.embedding_dim = embedding_dim + self.padding_idx = padding_idx + self.make_weights(num_positions + self.offset, embedding_dim, padding_idx) + + def make_weights(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None): + emb_weights = self.get_embedding(num_embeddings, embedding_dim, padding_idx) + if hasattr(self, "weights"): + # in forward put the weights on the correct dtype and device of the param + emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device) + + self.register_buffer("weights", emb_weights, persistent=False) + + @staticmethod + def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None): + """ + Build sinusoidal embeddings. + + This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of + "Attention Is All You Need". + """ + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb) + emb = torch.arange(num_embeddings, dtype=torch.int64).float().unsqueeze(1) * emb.unsqueeze(0) + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1) + if embedding_dim % 2 == 1: + # zero pad + emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) + if padding_idx is not None: + emb[padding_idx, :] = 0 + + return emb.to(torch.get_default_dtype()) + + @torch.no_grad() + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + past_key_values_length: int = 0, + ): + if input_ids is not None: + bsz, seq_len = input_ids.size() + # Create the position ids from the input token ids. Any padded tokens remain padded. + position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length).to( + input_ids.device + ) + else: + bsz, seq_len = inputs_embeds.size()[:-1] + position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds, past_key_values_length) + + # expand embeddings if needed + max_pos = self.padding_idx + 1 + seq_len + past_key_values_length + if max_pos > self.weights.size(0): + self.make_weights(max_pos + self.offset, self.embedding_dim, self.padding_idx) + + return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, self.weights.shape[-1]).detach() + + def create_position_ids_from_inputs_embeds(self, inputs_embeds, past_key_values_length): + """ + We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids. + + Args: + inputs_embeds: torch.Tensor + + Returns: torch.Tensor + """ + input_shape = inputs_embeds.size()[:-1] + sequence_length = input_shape[1] + + position_ids = torch.arange( + self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device + ) + return position_ids.unsqueeze(0).expand(input_shape).contiguous() + past_key_values_length + + +class SeamlessM4Tv2Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + # Copied from transformers.models.bart.modeling_bart.BartAttention.__init__ with Bart->SeamlessM4Tv2 + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + is_causal: bool = False, + config: Optional[SeamlessM4Tv2Config] = None, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + self.config = config + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + self.is_causal = is_causal + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def _shape(self, projection: torch.Tensor) -> torch.Tensor: + new_projection_shape = projection.size()[:-1] + (self.num_heads, self.head_dim) + # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D) + new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3) + return new_projection + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + is_cross_attention = encoder_hidden_states is not None + batch_size, seq_length = hidden_states.shape[:2] + + # use encoder_hidden_states if cross attention + current_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + # checking that the `sequence_length` of the `past_key_value` is the same as the he provided + # `encoder_hidden_states` to support prefix tuning + if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + else: + key_states = self._shape(self.k_proj(current_states)) + value_states = self._shape(self.v_proj(current_states)) + if past_key_value is not None and not is_cross_attention: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + query_states = self._shape(self.q_proj(hidden_states) * self.scaling) + attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + if attention_mask is not None: + attention_scores = attention_scores + attention_mask + + # (batch_size, n_heads, seq_length, key_length) + attn_weights = nn.functional.softmax(attention_scores.float(), dim=-1).type_as(attention_scores) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + # attn_output = torch.bmm(attn_probs, value_states) ? + context_states = torch.matmul(attn_weights, value_states) + # attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) ? + context_states = context_states.permute(0, 2, 1, 3).contiguous().view(batch_size, seq_length, -1) + attn_output = self.out_proj(context_states) + + if output_attentions: + return attn_output, attn_weights, past_key_value + else: + return attn_output, None, past_key_value + + +# Copied from transformers.models.nllb_moe.modeling_nllb_moe.NllbMoeDenseActDense with NllbMoe->SeamlessM4Tv2,DenseActDense->FeedForwardNetwork, d_model->hidden_size +class SeamlessM4Tv2FeedForwardNetwork(nn.Module): + def __init__(self, config: SeamlessM4Tv2Config, ffn_dim: int): + super().__init__() + self.fc1 = nn.Linear(config.hidden_size, ffn_dim) + self.fc2 = nn.Linear(ffn_dim, config.hidden_size) + self.dropout = nn.Dropout(config.activation_dropout) + self.act = ACT2FN[config.activation_function] + + def forward(self, hidden_states): + hidden_states = self.fc1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.dropout(hidden_states) + if ( + isinstance(self.fc2.weight, torch.Tensor) + and hidden_states.dtype != self.fc2.weight.dtype + and (self.fc2.weight.dtype != torch.int8 and self.fc2.weight.dtype != torch.uint8) + ): + hidden_states = hidden_states.to(self.fc2.weight.dtype) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +# Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TEncoderLayer with SeamlessM4T->SeamlessM4Tv2 +class SeamlessM4Tv2EncoderLayer(nn.Module): + def __init__(self, config: SeamlessM4Tv2Config, encoder_ffn_dim=None, encoder_attention_heads=None): + super().__init__() + encoder_ffn_dim = config.encoder_ffn_dim if encoder_ffn_dim is None else encoder_ffn_dim + encoder_attention_heads = ( + config.encoder_attention_heads if encoder_attention_heads is None else encoder_attention_heads + ) + + self.embed_dim = config.hidden_size + self.self_attn = SeamlessM4Tv2Attention( + embed_dim=self.embed_dim, + num_heads=encoder_attention_heads, + dropout=config.attention_dropout, + ) + self.attn_dropout = nn.Dropout(config.dropout) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + + self.ffn = SeamlessM4Tv2FeedForwardNetwork(config, ffn_dim=encoder_ffn_dim) + + self.ffn_layer_norm = nn.LayerNorm(config.hidden_size) + self.ffn_dropout = nn.Dropout(config.activation_dropout) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + output_attentions: bool = False, + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): + input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): + attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very + large negative values. + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states, attn_weights, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_states = self.attn_dropout(hidden_states) + hidden_states = residual + hidden_states + + residual = hidden_states + + hidden_states = self.ffn_layer_norm(hidden_states) + + hidden_states = self.ffn(hidden_states) + hidden_states = self.ffn_dropout(hidden_states) + + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +# Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TDecoderLayer with SeamlessM4T->SeamlessM4Tv2 +class SeamlessM4Tv2DecoderLayer(nn.Module): + def __init__(self, config: SeamlessM4Tv2Config, decoder_ffn_dim=None, decoder_attention_heads=None): + super().__init__() + decoder_ffn_dim = config.decoder_ffn_dim if decoder_ffn_dim is None else decoder_ffn_dim + decoder_attention_heads = ( + config.decoder_attention_heads if decoder_attention_heads is None else decoder_attention_heads + ) + + self.embed_dim = config.hidden_size + self.self_attn = SeamlessM4Tv2Attention( + embed_dim=self.embed_dim, + num_heads=decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + ) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.attn_dropout = nn.Dropout(config.dropout) + + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.cross_attention = SeamlessM4Tv2Attention( + self.embed_dim, decoder_attention_heads, config.attention_dropout, is_decoder=True + ) + self.cross_attention_layer_norm = nn.LayerNorm(self.embed_dim) + + self.ffn = SeamlessM4Tv2FeedForwardNetwork(config, ffn_dim=decoder_ffn_dim) + + self.ffn_layer_norm = nn.LayerNorm(config.hidden_size) + self.ffn_dropout = nn.Dropout(config.activation_dropout) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = True, + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): + input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): + attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very + large negative values. + encoder_hidden_states (`torch.FloatTensor`): + cross attention input to the layer of shape `(batch, seq_len, embed_dim)` + encoder_attention_mask (`torch.FloatTensor`): + encoder attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by + very large negative values. + past_key_value (`Tuple(torch.FloatTensor)`): + cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + # add present self-attn cache to positions 1,2 of present_key_value tuple + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=self_attn_past_key_value, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_states = self.attn_dropout(hidden_states) + hidden_states = residual + hidden_states + + # Cross-Attention Block + cross_attn_present_key_value = None + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + hidden_states = self.cross_attention_layer_norm(hidden_states) + + # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + + hidden_states, cross_attn_weights, cross_attn_present_key_value = self.cross_attention( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + past_key_value=cross_attn_past_key_value, + attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + ) + hidden_states = self.attn_dropout(hidden_states) + hidden_states = residual + hidden_states + + # add cross-attn to positions 3,4 of present_key_value tuple + present_key_value += cross_attn_present_key_value + + # Fully Connected + residual = hidden_states + + hidden_states = self.ffn_layer_norm(hidden_states) + + hidden_states = self.ffn(hidden_states) + hidden_states = self.ffn_dropout(hidden_states) + + hidden_states = residual + hidden_states + + outputs = (hidden_states, present_key_value) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + return outputs + + +class SeamlessM4Tv2TextToUnitDecoderLayer(nn.Module): + def __init__(self, config: SeamlessM4Tv2Config, decoder_ffn_dim=None, decoder_attention_heads=None): + super().__init__() + decoder_ffn_dim = config.decoder_ffn_dim if decoder_ffn_dim is None else decoder_ffn_dim + decoder_attention_heads = ( + config.decoder_attention_heads if decoder_attention_heads is None else decoder_attention_heads + ) + self.dropout = config.dropout + self.embed_dim = config.hidden_size + + self.self_attn = SeamlessM4Tv2Attention( + embed_dim=self.embed_dim, + num_heads=decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + ) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + + self.conv1 = nn.Conv1d(self.embed_dim, self.embed_dim, kernel_size=7, stride=1, padding="same") + self.activation_fn = ACT2FN[config.activation_function] + self.conv2 = nn.Conv1d(self.embed_dim, self.embed_dim, kernel_size=7, stride=1, padding="same") + + self.conv_layer_norm = nn.LayerNorm(config.hidden_size) + self.conv_dropout = nn.Dropout(self.dropout) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): + input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): + attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very + large negative values. + padding_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Indicates which inputs are to be ignored due to padding, where elements are either 1 for *not masked* + or 0 for *masked* + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Conv + residual = hidden_states + + # Apply padding mask to avoid leaking padded positions in the convolution layer + if padding_mask is not None: + hidden_states = hidden_states.masked_fill(~padding_mask.bool().unsqueeze(-1), 0.0) + hidden_states = self.conv1(hidden_states.transpose(1, 2)).transpose(1, 2) + + if padding_mask is not None: + hidden_states = hidden_states.masked_fill(~padding_mask.bool().unsqueeze(-1), 0.0) + + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.conv2(hidden_states.transpose(1, 2)).transpose(1, 2) + + hidden_states = self.conv_dropout(hidden_states) + hidden_states = residual + hidden_states + hidden_states = self.conv_layer_norm(hidden_states) + + outputs = (hidden_states, present_key_value) + + if output_attentions: + outputs += self_attn_weights + + return outputs + + +############ SUB-MODELS related code ################ + + +class SeamlessM4Tv2PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = SeamlessM4Tv2Config + base_model_prefix = "seamless_m4t_v2" + supports_gradient_checkpointing = True + _no_split_modules = [ + "SeamlessM4Tv2EncoderLayer", + "SeamlessM4Tv2DecoderLayer", + "SeamlessM4Tv2ConformerEncoderLayer", + "SeamlessM4Tv2TextToUnitDecoderLayer", + ] + + def _init_weights(self, module): + """Initialize the weights""" + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, SeamlessM4Tv2ConformerSelfAttention): + if hasattr(module, "pos_bias_u"): + nn.init.xavier_uniform_(module.pos_bias_u) + if hasattr(module, "pos_bias_v"): + nn.init.xavier_uniform_(module.pos_bias_v) + elif isinstance(module, SeamlessM4Tv2ConformerFeatureProjection): + k = math.sqrt(1 / module.projection.in_features) + nn.init.uniform_(module.projection.weight, a=-k, b=k) + nn.init.uniform_(module.projection.bias, a=-k, b=k) + elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, (nn.Conv1d, nn.ConvTranspose1d)): + nn.init.kaiming_normal_(module.weight) + if module.bias is not None: + k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) + nn.init.uniform_(module.bias, a=-k, b=k) + + # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TPreTrainedModel._compute_sub_sample_lengths_from_attention_mask + def _compute_sub_sample_lengths_from_attention_mask(self, attention_mask): + kernel_size, stride = self.config.adaptor_kernel_size, self.config.adaptor_stride + pad = kernel_size // 2 + seq_lens = attention_mask.size(1) - (1 - attention_mask.int()).sum(1) + + seq_lens = ((seq_lens + 2 * pad - kernel_size) / stride) + 1 + + return seq_lens.floor() + + def _indices_to_subwords(self, input_ids): + """ + Returns the corresponding text string for each input id. + """ + if not hasattr(self.generation_config, "id_to_text"): + raise ValueError( + """This model generation config doesn't have a `id_to_text` key which maps + token ids to subwords. Make sure to load the right generation config.""" + ) + batch_size, sequence_len = input_ids.shape + + subwords_batch = [] + for batch_id in range(batch_size): + subwords = [] + for i in range(sequence_len): + subword = self.generation_config.id_to_text.get(str(input_ids[batch_id, i].item())) + subwords.append(str(subword)) + subwords_batch.append(subwords) + return subwords_batch + + def _count_character_length_in_subword( + self, + input_ids, + subwords_batch, + merge_space_with_prev_subword=False, + pad_token_id=0, + unk_token_id=1, + space="▁", + ): + """ + Counts the number of characters per text string associated with the input token id. + + Args: + input_ids (`torch.Tensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + subwords_batch (`List[List[str]]` of shape `(batch_size, sequence_length)`): + Corresponding text string for each input id. + merge_space_with_prev_subword (`bool`, *optional*, defaults to `False`): + Indicates if the space character is merged with the previous subword. If `False`, it will be merged + with the next subword. + pad_token_id (`int`, *optional*, defaults to 0): + The id of the _padding_ text token. If it is encountered when calculating the length of a subword + sample, the lengths of subsequent subwords will be set to 0. + unk_token_id (`int`, *optional*, defaults to 1): + The id of the _unknown_ text token. Associated to a subword of length 1. + space (`str`, *optional*, defaults to `"▁"`): + The space character. + """ + batch_size, _ = input_ids.shape + + char_count_per_id = input_ids.new_zeros(input_ids.size()) + + subword_lens = input_ids.ne(pad_token_id).sum(1) + + for batch_id in range(batch_size): + # We slice out the tensor till the padding index. + subword_indices = input_ids[batch_id, : subword_lens[batch_id]] + subwords = subwords_batch[batch_id][: subword_lens[batch_id]] + + is_next_start_with_space = [ + len(subwords[i + 1]) > 1 and subwords[i + 1][0] == space if i < len(subwords) - 1 else False + for i in range(len(subwords)) + ] + is_punc = [ + len(subwords[i]) == 1 + and not subwords[i].isalpha() + and not subwords[i].isnumeric() + and subwords[i] != space + for i in range(len(subwords)) + ] + for i, (subword_idx, subword) in enumerate(zip(subword_indices, subwords)): + if subword_idx == pad_token_id: + break + + if subword_idx == unk_token_id: + # We set char_len to 1 for an unk token. + char_len = 1 + + if merge_space_with_prev_subword and is_next_start_with_space[i]: + char_len += 1 + else: + # By default, spaces are merged with the next subword. + # char_len includes the space. + char_len = len(subword) + + if merge_space_with_prev_subword: + # Add the space for the next subword. + if is_next_start_with_space[i]: + char_len += 1 + # Subtract the space for the current subword. + if i > 0 and is_next_start_with_space[i - 1]: + char_len -= 1 + else: + # Merge space with punctuation mark by default. + if is_punc[i] and is_next_start_with_space[i]: + char_len += 1 + # Subtract the space for the subword succeeding the punctuation mark. + elif i > 0 and is_punc[i - 1] and is_next_start_with_space[i - 1]: + char_len -= 1 + + char_count_per_id[batch_id, i] = char_len + + return char_count_per_id + + def _get_char_input_ids(self, input_ids, subwords_batch, char_count_per_id, pad_token_id=0, unk_token_id=1): + """ + Returns the corresponding character input id for each character of `subwords_batch`. + + Args: + input_ids (`torch.Tensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + subwords_batch (`List[List[str]]` of shape `(batch_size, sequence_length)`): + Corresponding text string for each input id. + char_count_per_id (`torch.Tensor` of shape `(batch_size, sequence_length)`): + Number of characters per input id. + pad_token_id (`int`, *optional*, defaults to 0): + The id of the _padding_ text token. If it is encountered when calculating the length of a subword + sample, the lengths of subsequent subwords will be set to 0. + unk_token_id (`int`, *optional*, defaults to 1): + The id of the _unknown_ text token. Associated to a subword of length 1. + Returns: + `torch.Tensor`: Tensor of shape `(batch_size, char_sequence_length)` containing the id of each character. + """ + if not hasattr(self.generation_config, "char_to_id"): + raise ValueError( + """This model generation config doesn't have a `char_to_id` key which maps + characters to character ids. Make sure to load the right generation config.""" + ) + + batch_size = input_ids.shape[0] + max_len = int(char_count_per_id.sum(1).max().item()) + + char_seqs = input_ids.new_zeros((batch_size, max_len)).fill_(pad_token_id) + + subword_lens = input_ids.ne(pad_token_id).sum(1) + + for batch_id in range(batch_size): + total = 0 + subword_indices = input_ids[batch_id, : subword_lens[batch_id]] + subwords = subwords_batch[batch_id][: subword_lens[batch_id]] + for subword_idx, subword in zip(subword_indices, subwords): + if subword_idx == unk_token_id: + char_ids = [unk_token_id] + else: + # Get char token indices corresponding to the subwords. + char_ids = [self.generation_config.char_to_id.get(ch, unk_token_id) for ch in list(subword)] + char_seq_len = len(char_ids) + char_seqs[batch_id, total : total + char_seq_len] = torch.tensor(char_ids).to(char_seqs) + total += char_seq_len + return char_seqs + + def _hard_upsample(self, hidden_states, durations): + """ + Repeats the time dimension of each sample in the batch based on the corresponding duration. + + Args: + hidden_states (`torch.Tensor` of shape `(batch_size, sequence_length, *)`, *optional*): + The sequence to repeat, where `*` is any number of sequence-specific dimensions including none. + durations (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Indicates how many times to repeat time segments. + """ + if hidden_states.size(0) == 1: + hidden_states = torch.repeat_interleave(hidden_states, durations.view(-1), dim=1) + else: + # if batched sample, need to interleave per sample, and pad -> loss of parallelism + if hidden_states.shape[0] > 1 and self.training: + logger.warning_once( + """`self.training=True` and you use batching. You lose parallelism during the hifigan + forward pass because the samples are interleaved.""" + ) + hidden_states = [ + torch.repeat_interleave(hidden_state, duration, dim=0) + for (hidden_state, duration) in zip(hidden_states, durations) + ] + + hidden_states = nn.utils.rnn.pad_sequence(hidden_states, batch_first=True) + + return hidden_states + + +@add_start_docstrings( + """Transformer speech encoder consisting of *config.speech_encoder_layers* conformer self attention layers. + Each layer is a [`SeamlessM4Tv2ConformerEncoderLayer`].""", + SEAMLESS_M4T_V2_START_DOCSTRING, +) +# Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TSpeechEncoder with SeamlessM4T->SeamlessM4Tv2 +class SeamlessM4Tv2SpeechEncoder(SeamlessM4Tv2PreTrainedModel): + main_input_name = "input_features" + + def __init__(self, config: SeamlessM4Tv2Config): + super().__init__(config) + + self.feature_projection = SeamlessM4Tv2ConformerFeatureProjection(config) + self.encoder = SeamlessM4Tv2ConformerEncoder(config) + self.intermediate_ffn = SeamlessM4Tv2ConformerFeedForward(config, act_fn="relu", dropout=0.0) + self.adapter = SeamlessM4Tv2ConformerAdapter(config) if config.add_adapter else None + self.inner_layer_norm = nn.LayerNorm(config.hidden_size) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_features: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple, Wav2Vec2BaseModelOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_features is None: + raise ValueError( + """Both `input_features` and `inputs_embeds` are `None` in `SeamlessM4Tv2SpeechEncoder.forward`. + Make sure one of them is not `None`.""" + ) + + hidden_states = self.feature_projection(input_features) + + encoder_outputs = self.encoder( + hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = encoder_outputs[0] + + expanded_hidden_states = self.intermediate_ffn(hidden_states) + hidden_states = hidden_states + 0.5 * expanded_hidden_states + + if self.adapter is not None: + hidden_states = self.adapter(hidden_states, attention_mask=attention_mask) + + hidden_states = self.inner_layer_norm(hidden_states) + + if not return_dict: + return (hidden_states,) + encoder_outputs[1:] + + return Wav2Vec2BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +# inspired from MBart and NllbMoe +@add_start_docstrings( + "Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a [`SeamlessM4Tv2EncoderLayer`].", + SEAMLESS_M4T_V2_START_DOCSTRING, + """ + embed_tokens (`nn.Embedding`, *optional*): + Input embedding + is_t2u_encoder (`bool`, *optional*, defaults to `False`): + indicates if it belongs to the text-to-units model, in which case it won't have input embeddings + """, +) +# Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TEncoder with SeamlessM4T->SeamlessM4Tv2 +class SeamlessM4Tv2Encoder(SeamlessM4Tv2PreTrainedModel): + def __init__( + self, + config: SeamlessM4Tv2Config, + embed_tokens: Optional[nn.Embedding] = None, + is_t2u_encoder: bool = False, + ): + super().__init__(config) + + self.dropout = config.dropout + self.layerdrop = config.encoder_layerdrop + self.padding_idx = config.pad_token_id + embed_dim = config.hidden_size + + self.is_t2u_encoder = is_t2u_encoder + self.max_source_positions = config.max_position_embeddings + + if not self.is_t2u_encoder: + embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + + self.embed_tokens = SeamlessM4Tv2ScaledWordEmbedding( + config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale + ) + + if embed_tokens is not None: + self.embed_tokens.weight = embed_tokens.weight + + self.embed_positions = SeamlessM4Tv2SinusoidalPositionalEmbedding( + self.max_source_positions, + embed_dim, + self.padding_idx, + ) + + layers = [] + for _ in range(config.encoder_layers): + layers.append( + SeamlessM4Tv2EncoderLayer( + config, + encoder_attention_heads=config.encoder_attention_heads, + encoder_ffn_dim=config.encoder_ffn_dim, + ) + ) + + self.layers = nn.ModuleList(layers) + + self.layer_norm = nn.LayerNorm(config.hidden_size) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and self.is_t2u_encoder: + raise ValueError( + "You cannot pass input_ids to the encoder of the text_to_units model. Pass inputs_embeds instead." + ) + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input = input_ids + input_shape = input.shape + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input = inputs_embeds[:, :, -1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if not self.is_t2u_encoder: + embed_pos = self.embed_positions(input) + + hidden_states = inputs_embeds + embed_pos.to(inputs_embeds.device) + else: + hidden_states = inputs_embeds + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + # expand attention_mask + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + to_drop = False + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: # skip the layer + to_drop = True + + if to_drop: + layer_outputs = (None, None) + else: + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.forward, + hidden_states, + attention_mask, + output_attentions, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + hidden_states = self.layer_norm(hidden_states) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +@add_start_docstrings( + "Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`SeamlessM4Tv2DecoderLayer`].", + SEAMLESS_M4T_V2_START_DOCSTRING, + """ + embed_tokens (`nn.Embedding`, *optional*): + Input embedding + """, +) +# Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TDecoder with SeamlessM4T->SeamlessM4Tv2 +class SeamlessM4Tv2Decoder(SeamlessM4Tv2PreTrainedModel): + def __init__( + self, + config: SeamlessM4Tv2Config, + embed_tokens: Optional[nn.Embedding] = None, + ): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.decoder_layerdrop + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.max_target_positions = config.max_position_embeddings + embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0 + + if embed_tokens is not None: + # if embed_tokens defined, use its shape instead + self.embed_tokens = SeamlessM4Tv2ScaledWordEmbedding( + embed_tokens.num_embeddings, embed_tokens.embedding_dim, self.padding_idx, embed_scale=embed_scale + ) + self.embed_tokens.weight = embed_tokens.weight + else: + self.embed_tokens = SeamlessM4Tv2ScaledWordEmbedding( + self.vocab_size, config.hidden_size, self.padding_idx, embed_scale=embed_scale + ) + + self.embed_positions = SeamlessM4Tv2SinusoidalPositionalEmbedding( + self.max_target_positions, + config.hidden_size, + padding_idx=self.padding_idx, + ) + + layers = [] + for _ in range(config.decoder_layers): + layers.append( + SeamlessM4Tv2DecoderLayer( + config, + decoder_attention_heads=config.decoder_attention_heads, + decoder_ffn_dim=config.decoder_ffn_dim, + ) + ) + self.layers = nn.ModuleList(layers) + self.layer_norm = nn.LayerNorm(config.hidden_size) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): + Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values + selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input = input_ids + input_shape = input.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + input = inputs_embeds[:, :, -1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + # embed positions + positions = self.embed_positions(input, past_key_values_length=past_key_values_length) + + hidden_states = inputs_embeds + positions.to(inputs_embeds.device) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + next_decoder_cache = () if use_cache else None + + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: + continue + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + None, + output_attentions, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[1],) + + if output_attentions: + all_self_attns += (layer_outputs[2],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[3],) + + hidden_states = self.layer_norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + +@add_start_docstrings( + "Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`SeamlessM4Tv2DecoderLayer`].", + SEAMLESS_M4T_V2_START_DOCSTRING, + """ + embed_tokens (`nn.Embedding`, *optional*): + Input embedding + """, +) +class SeamlessM4Tv2TextToUnitDecoder(SeamlessM4Tv2PreTrainedModel): + def __init__( + self, + config: SeamlessM4Tv2Config, + embed_tokens: Optional[nn.Embedding] = None, + ): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.decoder_layerdrop + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.max_target_positions = config.max_position_embeddings + self.embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0 + + if embed_tokens is not None: + # if embed_tokens defined, use its shape instead + self.embed_tokens = nn.Embedding(embed_tokens.num_embeddings, embed_tokens.embedding_dim, self.padding_idx) + self.embed_tokens.weight = embed_tokens.weight + else: + self.embed_tokens = nn.Embedding(self.vocab_size, config.hidden_size, self.padding_idx) + + self.embed_char = nn.Embedding(config.char_vocab_size, config.hidden_size) + self.embed_char_positions = SeamlessM4Tv2SinusoidalPositionalEmbedding( + self.max_target_positions, + config.hidden_size, + padding_idx=self.padding_idx, + ) + + self.pos_emb_alpha_char = nn.Parameter(torch.ones(1)) + self.pos_emb_alpha = nn.Parameter(torch.ones(1)) + self.duration_predictor = SeamlessM4Tv2VariancePredictor( + config.variance_predictor_embed_dim, + config.variance_predictor_hidden_dim, + config.variance_predictor_kernel_size, + config.variance_pred_dropout, + ) + + self.embed_positions = SeamlessM4Tv2SinusoidalPositionalEmbedding( + self.max_target_positions, + config.hidden_size, + padding_idx=self.padding_idx, + ) + + layers = [] + for _ in range(config.decoder_layers): + layers.append( + SeamlessM4Tv2TextToUnitDecoderLayer( + config, + decoder_attention_heads=config.decoder_attention_heads, + decoder_ffn_dim=config.decoder_ffn_dim, + ) + ) + self.layers = nn.ModuleList(layers) + self.layer_norm = nn.LayerNorm(config.hidden_size) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + char_input_ids: Optional[torch.LongTensor] = None, + char_count_per_id: Optional[torch.LongTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SeamlessM4Tv2TextToUnitDecoderOutput]: + r""" + Args: + char_input_ids (`torch.LongTensor` of shape `(batch_size, char_sequence_length)`): + Character indices. The correspondence between characters and indices can be found in `char_to_id`, a + dictionary in the generation configuration. + char_count_per_id (`torch.Tensor` of shape `(batch_size, encoder_sequence_length)`): + Number of characters per text input id. + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # create padding mask for character lengths + char_padding_mask = _compute_new_attention_mask(char_input_ids, char_count_per_id.sum(1)) + + # upsample hidden states according to characters sequence lengths + char_hidden_states = self._hard_upsample(encoder_hidden_states, char_count_per_id) + # embed char positions + char_positions = self.pos_emb_alpha_char * self.embed_char_positions(inputs_embeds=char_hidden_states) + # update char hidden states with positions and char embeddings + char_hidden_states = self.embed_char(char_input_ids) * self.embed_scale + char_positions + char_hidden_states + + # predict duration + log_dur_pred = self.duration_predictor(char_hidden_states, padding_mask=char_padding_mask) + dur_out = torch.clamp(torch.round((torch.expm1(log_dur_pred))).long(), min=1) + dur_out = dur_out.masked_fill(~char_padding_mask.bool(), 0.0) + + # upsample char hidden states according to predicted duration + char_hidden_states = self._hard_upsample(char_hidden_states, dur_out) + + positions = self.pos_emb_alpha * self.embed_positions(inputs_embeds=char_hidden_states) + hidden_states = char_hidden_states + positions + + padding_mask = _compute_new_attention_mask(hidden_states, dur_out.sum(1)) + attention_mask = _prepare_4d_attention_mask(padding_mask, hidden_states.dtype) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: + continue + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + padding_mask, + output_attentions, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + padding_mask=padding_mask, + output_attentions=output_attentions, + ) + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[2],) + + hidden_states = self.layer_norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attns, padding_mask] if v is not None) + return SeamlessM4Tv2TextToUnitDecoderOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attns, + padding_mask=padding_mask, + ) + + +@add_start_docstrings( + "Transformer bare text-to-unit encoder-decoder. The encoder is a [`SeamlessM4Tv2Encoder`] without embeddings and the decoder is a [`SeamlessM4Tv2TextToUnitDecoder`].", + SEAMLESS_M4T_V2_START_DOCSTRING, + """ + embed_tokens_decoder (`nn.Embedding`, *optional*): input embedding of the decoder. + """, +) +class SeamlessM4Tv2TextToUnitModel(SeamlessM4Tv2PreTrainedModel): + # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TTextToUnitModel.__init__ with SeamlessM4T->SeamlessM4Tv2, Decoder->TextToUnitDecoder + def __init__( + self, + config: SeamlessM4Tv2Config, + embed_tokens_decoder: Optional[nn.Embedding] = None, + ): + super().__init__(config) + + self.encoder = SeamlessM4Tv2Encoder(config, is_t2u_encoder=True) + self.decoder = SeamlessM4Tv2TextToUnitDecoder(config, embed_tokens_decoder) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + char_input_ids: Optional[torch.LongTensor] = None, + char_count_per_id: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + # decoder outputs consists of (dec_features, dec_hidden, dec_attn, padding_mask) + decoder_outputs = self.decoder( + char_input_ids=char_input_ids, + char_count_per_id=char_count_per_id, + encoder_hidden_states=encoder_outputs[0], + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return SeamlessM4Tv2TextToUnitOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + padding_mask=decoder_outputs.padding_mask, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + "Transformer text-to-unit encoder-decoder with a language model head. The base encoder-decoder model is a [`SeamlessM4Tv2TextToUnitModel`].", + SEAMLESS_M4T_V2_START_DOCSTRING, + """ + embed_tokens_decoder (`nn.Embedding`, *optional*): input embedding of the decoder. + """, +) +class SeamlessM4Tv2TextToUnitForConditionalGeneration(SeamlessM4Tv2PreTrainedModel, GenerationMixin): + _keys_to_ignore_on_load_missing = [ + "vocoder", + "speech_encoder", + "text_encoder", + "text_decoder", + ] + _tied_weights_keys = ["decoder.embed_tokens.weight", "lm_head.weight"] + + # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TTextToUnitForConditionalGeneration.__init__ with SeamlessM4T->SeamlessM4Tv2 + def __init__( + self, + config: SeamlessM4Tv2Config, + embed_tokens_decoder: Optional[nn.Embedding] = None, + ): + # update config - used principaly for bos_token_id etc. + config = copy.deepcopy(config) + for param, val in config.to_dict().items(): + if param.startswith("t2u_"): + config.__setattr__(param[4:], val) + super().__init__(config) + + self.model = SeamlessM4Tv2TextToUnitModel(config, embed_tokens_decoder) + + self.lm_head = nn.Linear(config.hidden_size, config.t2u_vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TTextToUnitForConditionalGeneration.get_encoder + def get_encoder(self): + return self.model.encoder + + # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TTextToUnitForConditionalGeneration.get_decoder + def get_decoder(self): + return self.model.decoder + + # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TTextToUnitForConditionalGeneration.get_output_embeddings + def get_output_embeddings(self): + return self.lm_head + + # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TTextToUnitForConditionalGeneration.set_output_embeddings + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TTextToUnitForConditionalGeneration.get_input_embeddings + def get_input_embeddings(self): + return self.model.decoder.embed_tokens + + # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TTextToUnitForConditionalGeneration.set_input_embeddings + def set_input_embeddings(self, value): + self.model.decoder.embed_tokens = value + + @add_start_docstrings_to_model_forward(M4T_TEXT_TO_UNITS_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + char_input_ids: Optional[torch.LongTensor] = None, + char_count_per_id: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Seq2SeqLMOutput, Tuple[torch.FloatTensor]]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + char_input_ids=char_input_ids, + char_count_per_id=char_count_per_id, + attention_mask=attention_mask, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + lm_logits = self.lm_head(outputs[0]) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + labels = labels.to(lm_logits.device) + masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (lm_logits,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return SeamlessM4Tv2TextToUnitOutput( + last_hidden_state=lm_logits, + padding_mask=outputs.padding_mask, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + loss=masked_lm_loss, + ) + + # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TTextToUnitForConditionalGeneration._tie_weights + def _tie_weights(self) -> None: + if getattr(self.config, "tie_word_embeddings", True): + output_embeddings = self.get_output_embeddings() + if output_embeddings is not None: + self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings()) + + +############ VOCODER related code ################ + + +HIFIGAN_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`SeamlessM4Tv2Config`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +# Copied from transformers.models.speecht5.modeling_speecht5.HifiGanResidualBlock +class HifiGanResidualBlock(nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), leaky_relu_slope=0.1): + super().__init__() + self.leaky_relu_slope = leaky_relu_slope + + self.convs1 = nn.ModuleList( + [ + nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + dilation=dilation[i], + padding=self.get_padding(kernel_size, dilation[i]), + ) + for i in range(len(dilation)) + ] + ) + self.convs2 = nn.ModuleList( + [ + nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + dilation=1, + padding=self.get_padding(kernel_size, 1), + ) + for _ in range(len(dilation)) + ] + ) + + def get_padding(self, kernel_size, dilation=1): + return (kernel_size * dilation - dilation) // 2 + + def apply_weight_norm(self): + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + + for layer in self.convs1: + weight_norm(layer) + for layer in self.convs2: + weight_norm(layer) + + def remove_weight_norm(self): + for layer in self.convs1: + nn.utils.remove_weight_norm(layer) + for layer in self.convs2: + nn.utils.remove_weight_norm(layer) + + def forward(self, hidden_states): + for conv1, conv2 in zip(self.convs1, self.convs2): + residual = hidden_states + hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope) + hidden_states = conv1(hidden_states) + hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope) + hidden_states = conv2(hidden_states) + hidden_states = hidden_states + residual + return hidden_states + + +class SeamlessM4Tv2VariancePredictor(nn.Module): + def __init__(self, embed_dim, hidden_dim, kernel_size, var_pred_dropout): + super().__init__() + + self.conv1 = nn.Conv1d( + embed_dim, + hidden_dim, + kernel_size=kernel_size, + padding="same", + ) + self.activation_fuction = nn.ReLU() + self.ln1 = nn.LayerNorm(hidden_dim) + self.dropout_module = nn.Dropout(p=var_pred_dropout) + self.conv2 = nn.Conv1d( + hidden_dim, + hidden_dim, + kernel_size=kernel_size, + padding="same", + ) + self.ln2 = nn.LayerNorm(hidden_dim) + self.proj = nn.Linear(hidden_dim, 1) + + def forward(self, hidden_states: Tensor, padding_mask: Optional[Tensor] = None) -> Tensor: + # Input: B x T x C; Output: B x T + if padding_mask is not None: + hidden_states = hidden_states.masked_fill(~padding_mask.bool().unsqueeze(-1), 0.0) + hidden_states = self.conv1(hidden_states.transpose(1, 2)) + hidden_states = self.activation_fuction(hidden_states).transpose(1, 2) + hidden_states = self.dropout_module(self.ln1(hidden_states)) + if padding_mask is not None: + hidden_states = hidden_states.masked_fill(~padding_mask.bool().unsqueeze(-1), 0.0) + hidden_states = self.conv2(hidden_states.transpose(1, 2)) + hidden_states = self.activation_fuction(hidden_states).transpose(1, 2) + hidden_states = self.dropout_module(self.ln2(hidden_states)) + return self.proj(hidden_states).squeeze(dim=2) + + +# Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4THifiGan with SeamlessM4T->SeamlessM4Tv2 +class SeamlessM4Tv2HifiGan(nn.Module): + def __init__(self, config: SeamlessM4Tv2Config): + super().__init__() + model_in_dim = config.unit_embed_dim + config.lang_embed_dim + config.spkr_embed_dim + self.leaky_relu_slope = config.leaky_relu_slope + self.num_kernels = len(config.resblock_kernel_sizes) + self.num_upsamples = len(config.upsample_rates) + self.conv_pre = nn.Conv1d( + model_in_dim, + config.upsample_initial_channel, + kernel_size=7, + stride=1, + padding=3, + ) + + self.upsampler = nn.ModuleList() + for i, (upsample_rate, kernel_size) in enumerate(zip(config.upsample_rates, config.upsample_kernel_sizes)): + self.upsampler.append( + nn.ConvTranspose1d( + config.upsample_initial_channel // (2**i), + config.upsample_initial_channel // (2 ** (i + 1)), + kernel_size=kernel_size, + stride=upsample_rate, + padding=(kernel_size - upsample_rate) // 2, + ) + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.upsampler)): + channels = config.upsample_initial_channel // (2 ** (i + 1)) + for kernel_size, dilation in zip(config.resblock_kernel_sizes, config.resblock_dilation_sizes): + self.resblocks.append(HifiGanResidualBlock(channels, kernel_size, dilation, config.leaky_relu_slope)) + + self.conv_post = nn.Conv1d(channels, 1, kernel_size=7, stride=1, padding=3) + + def forward(self, input_embeds: torch.FloatTensor) -> torch.FloatTensor: + r""" + Converts a log-mel spectrogram into a speech waveform. Passing a batch of log-mel spectrograms returns a batch + of speech waveforms. Passing a single, un-batched log-mel spectrogram returns a single, un-batched speech + waveform. + + Args: + spectrogram (`torch.FloatTensor`): + Tensor containing the log-mel spectrograms. Can be batched and of shape `(batch_size, sequence_length, + model_in_dim)`, or un-batched and of shape `(sequence_length, model_in_dim)`. Note that `model_in_dim` + is the sum of `config.unit_embed_dim`, `config.lang_embed_dim` and `config.spkr_embed_dim`. + + Returns: + `torch.FloatTensor`: Tensor containing the speech waveform. If the input spectrogram is batched, will be of + shape `(batch_size, num_frames,)`. If un-batched, will be of shape `(num_frames,)`. + """ + + hidden_states = self.conv_pre(input_embeds) + for i in range(self.num_upsamples): + hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope) + hidden_states = self.upsampler[i](hidden_states) + + res_state = self.resblocks[i * self.num_kernels](hidden_states) + for j in range(1, self.num_kernels): + res_state += self.resblocks[i * self.num_kernels + j](hidden_states) + hidden_states = res_state / self.num_kernels + + hidden_states = nn.functional.leaky_relu(hidden_states) + hidden_states = self.conv_post(hidden_states) + hidden_states = torch.tanh(hidden_states) + + # remove seq-len dim since this collapses to 1 + waveform = hidden_states.squeeze(1) + + return waveform + + +@add_start_docstrings( + """Code HiFi-GAN vocoder as described in this [repository](https://github.com/facebookresearch/speech-resynthesis).""", + HIFIGAN_START_DOCSTRING, +) +class SeamlessM4Tv2CodeHifiGan(PreTrainedModel): + config_class = SeamlessM4Tv2Config + main_input_name = "input_embeds" + _no_split_modules = [] + + def __init__(self, config): + super().__init__(config) + + self.pad_token_id = config.t2u_pad_token_id + embed_dim = config.unit_embed_dim + kernel_size = config.variance_predictor_kernel_size + var_pred_dropout = config.var_pred_dropout + self.dur_predictor = SeamlessM4Tv2VariancePredictor(embed_dim, embed_dim, kernel_size, var_pred_dropout) + + self.unit_embedding = nn.Embedding(config.unit_hifi_gan_vocab_size, config.unit_embed_dim) + self.speaker_embedding = nn.Embedding(config.vocoder_num_spkrs, config.spkr_embed_dim) + self.language_embedding = nn.Embedding(config.vocoder_num_langs, config.lang_embed_dim) + + self.hifi_gan = SeamlessM4Tv2HifiGan(config) + + # Initialize weights and apply final processing + self.post_init() + + # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TCodeHifiGan._get_dur_output_lengths + def _get_dur_output_lengths(self, input_ids, dur_out): + """ + Computes the output length after the duration layer. + """ + unit_lengths = (input_ids != self.pad_token_id).sum(1) + + # take care of edge cases where no padding or too many padding + unit_lengths = torch.clamp(unit_lengths, 0, dur_out.shape[1] - 1) + + cumulative_dur_out = torch.cumsum(dur_out, dim=1) + unit_lengths = cumulative_dur_out.gather(dim=1, index=unit_lengths.unsqueeze(1)).squeeze() + + return unit_lengths + + # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TCodeHifiGan._get_output_hifigan_lengths + def _get_output_hifigan_lengths(self, input_lengths: Union[torch.LongTensor, int]): + """ + Computes the output length of the hifigan convolutional layers + """ + + def _conv_out_length(input_length, kernel_size, stride, pad, dilation=1): + # 1D convolutional layer output length formula taken + # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html + return ( + torch.div(input_length + 2 * pad - dilation * (kernel_size - 1) - 1, stride, rounding_mode="floor") + 1 + ) + + def _transpose_conv_out_length(input_length, kernel_size, stride, pad, dilation=1): + return (input_length - 1) * stride - 2 * pad + dilation * (kernel_size - 1) + 1 + + # conv_pre + input_lengths = _conv_out_length(input_lengths, 7, 1, 3) + + # upsampler + for i, (upsample_rate, kernel_size) in enumerate( + zip(self.config.upsample_rates, self.config.upsample_kernel_sizes) + ): + input_lengths = _transpose_conv_out_length( + input_lengths, kernel_size, upsample_rate, (kernel_size - upsample_rate) // 2 + ) + + # resblock + for i in range(len(self.config.upsample_rates)): + for kernel_size, dilation in zip(self.config.resblock_kernel_sizes, self.config.resblock_dilation_sizes): + for dil in dilation: + input_lengths = _conv_out_length( + input_lengths, kernel_size, 1, (kernel_size - 1) * dil // 2, dilation=dil + ) + + for dil in dilation: + input_lengths = _conv_out_length(input_lengths, kernel_size, 1, (kernel_size - 1) // 2, dilation=1) + + # conv_post + input_lengths = _conv_out_length(input_lengths, 7, 1, 3) + + return input_lengths + + # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TCodeHifiGan.forward with SeamlessM4T->SeamlessM4Tv2, spkr_id->speaker_id + def forward( + self, input_ids: torch.LongTensor, speaker_id: torch.Tensor, lang_id: torch.Tensor + ) -> Tuple[torch.Tensor]: + """ + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`SeamlessM4Tv2TextToUnitForConditionalGeneration`]. [What are input + IDs?](../glossary#input-ids) + speaker_id (`int`, *optional*): + The id of the speaker used for speech synthesis. Must be lower than `config.vocoder_num_spkrs`. + tgt_lang (`str`, *optional*): + The language id to use as target language for translation. + """ + hidden_states = self.unit_embedding(input_ids).transpose(1, 2) + spkr = self.speaker_embedding(speaker_id).transpose(1, 2) + lang = self.language_embedding(lang_id).transpose(1, 2) + + log_dur_pred = self.dur_predictor(hidden_states.transpose(1, 2)) + dur_out = torch.clamp(torch.round((torch.expm1(log_dur_pred))).long(), min=1) + # B x C x T + if hidden_states.size(0) == 1: + hidden_states = torch.repeat_interleave(hidden_states, dur_out.view(-1), dim=2) + else: + # if batched sample, need to interleave per sample, and pad -> loss of parallelism + if hidden_states.shape[0] > 1 and self.training: + logger.warning( + """`self.training=True` and you use batching. You lose parallelism during the hifigan + forward pass because the samples are interleaved.""" + ) + hidden_states = [ + torch.repeat_interleave(hidden_state, duration, dim=-1).transpose(0, 1) + for (hidden_state, duration) in zip(hidden_states, dur_out) + ] + + hidden_states = nn.utils.rnn.pad_sequence(hidden_states, batch_first=True).transpose(1, 2) + + spkr = spkr.repeat(1, 1, hidden_states.shape[-1]) + lang = lang.repeat(1, 1, hidden_states.shape[-1]) + hidden_states = torch.cat([lang, hidden_states, spkr], dim=1) + + hidden_states = self.hifi_gan(hidden_states) + + unit_lengths = self._get_dur_output_lengths(input_ids, dur_out) + lengths = self._get_output_hifigan_lengths(unit_lengths) + + return hidden_states, lengths + + # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TCodeHifiGan._init_weights + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, (nn.Linear, nn.Conv1d, nn.ConvTranspose1d)): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TCodeHifiGan.apply_weight_norm + def apply_weight_norm(self): + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + + weight_norm(self.hifi_gan.conv_pre) + for layer in self.hifi_gan.upsampler: + weight_norm(layer) + for layer in self.hifi_gan.resblocks: + layer.apply_weight_norm() + weight_norm(self.hifi_gan.conv_post) + + # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TCodeHifiGan.remove_weight_norm + def remove_weight_norm(self): + nn.utils.remove_weight_norm(self.hifi_gan.conv_pre) + for layer in self.hifi_gan.upsampler: + nn.utils.remove_weight_norm(layer) + for layer in self.hifi_gan.resblocks: + layer.remove_weight_norm() + nn.utils.remove_weight_norm(self.hifi_gan.conv_post) + + +############ WHOLE MODEL related code ################ + + +@add_start_docstrings( + "The text-to-text SeamlessM4Tv2 Model transformer which can be used for T2TT.", + SEAMLESS_M4T_V2_START_DOCSTRING, +) +# Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TForTextToText with SeamlessM4T->SeamlessM4Tv2,SeamlessM4Tv2Tokenizer->SeamlessM4TTokenizer, SeamlessM4Tv2Processor->SeamlessM4TProcessor +class SeamlessM4Tv2ForTextToText(SeamlessM4Tv2PreTrainedModel, GenerationMixin): + _keys_to_ignore_on_load_missing = ["speech_encoder", "t2u_model", "vocoder"] + main_input_name = "input_ids" + + _tied_weights_keys = [ + "lm_head.weight", + "text_encoder.embed_tokens.weight", + "text_decoder.embed_tokens.weight", + ] + + def __init__(self, config: SeamlessM4Tv2Config): + super().__init__(config) + + self.shared = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) + + self.text_encoder = SeamlessM4Tv2Encoder(config, self.shared) + self.text_decoder = SeamlessM4Tv2Decoder(config, self.shared) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_encoder(self): + return self.text_encoder + + def get_decoder(self): + return self.text_decoder + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_input_embeddings(self): + return self.text_decoder.embed_tokens + + def set_input_embeddings(self, value): + self.text_encoder.embed_tokens = value + self.text_decoder.embed_tokens = value + self.shared = value + + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.text_encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.text_decoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.lm_head, self.shared) + + @add_start_docstrings_to_model_forward(M4T_TEXT_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Seq2SeqLMOutput, Tuple[torch.FloatTensor]]: + if labels is not None: + if use_cache: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if encoder_outputs is None: + encoder_outputs = self.text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + encoder_attention_mask = attention_mask + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.text_decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + lm_logits = self.lm_head(decoder_outputs[0]) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + labels = labels.to(lm_logits.device) + masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + outputs = decoder_outputs + encoder_outputs + output = (lm_logits,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return Seq2SeqLMOutput( + loss=masked_lm_loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + def generate( + self, + input_ids=None, + tgt_lang=None, + generation_config=None, + logits_processor=None, + stopping_criteria=None, + prefix_allowed_tokens_fn=None, + synced_gpus=False, + **kwargs, + ): + """ + Generates sequences of token ids. + + + + Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the + model's default generation configuration. You can override any `generation_config` by passing the corresponding + parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`. + + For an overview of generation strategies and code examples, check out the [following + guide](./generation_strategies). + + + + Parameters: + input_ids (`torch.Tensor` of varying shape depending on the modality, *optional*): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`SeamlessM4TTokenizer`] or [`SeamlessM4TProcessor`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + tgt_lang (`str`, *optional*): + The language to use as target language for translation. + generation_config (`~generation.GenerationConfig`, *optional*): + The generation configuration to be used as base parametrization for the generation call. `**kwargs` + passed to generate matching the attributes of `generation_config` will override them. If + `generation_config` is not provided, the default will be used, which had the following loading + priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model + configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s + default values, whose documentation should be checked to parameterize generation. + logits_processor (`LogitsProcessorList`, *optional*): + Custom logits processors that complement the default logits processors built from arguments and + generation config. If a logit processor is passed that is already created with the arguments or a + generation config an error is thrown. This feature is intended for advanced users. + stopping_criteria (`StoppingCriteriaList`, *optional*): + Custom stopping criteria that complement the default stopping criteria built from arguments and a + generation config. If a stopping criteria is passed that is already created with the arguments or a + generation config an error is thrown. This feature is intended for advanced users. + prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*): + If provided, this function constraints the beam search to allowed tokens only at each step. If not + provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and + `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned + on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful + for constrained generation conditioned on the prefix, as described in [Autoregressive Entity + Retrieval](https://arxiv.org/abs/2010.00904). + synced_gpus (`bool`, *optional*, defaults to `False`): + Whether to continue running the while loop until max_length (needed to avoid deadlocking with + `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3). + kwargs (`Dict[str, Any]`, *optional*): + Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be + forwarded to the `forward` function of the model. + + Return: + [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` + or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. The possible + [`~utils.ModelOutput`] types are: + - [`~generation.GenerateEncoderDecoderOutput`], + - [`~generation.GenerateBeamEncoderDecoderOutput`] + """ + # prepare text_decoder_input_ids + text_decoder_input_ids = kwargs.pop("decoder_input_ids", None) + # overwrite text_decoder_input_ids if tgt_lang is passed. The latter gets priority over decoder_input_ids. + if tgt_lang is not None: + batch_size = len(input_ids) if input_ids is not None else len(kwargs.get("inputs_embeds")) + + if hasattr(self.generation_config, "text_decoder_lang_to_code_id"): + # also accept __xxx__ + tgt_lang = tgt_lang.replace("__", "") + if tgt_lang not in self.generation_config.text_decoder_lang_to_code_id: + raise ValueError( + f"""`tgt_lang={tgt_lang}` is not supported by this model. Please specify a `tgt_lang` in + {", ".join(self.generation_config.text_decoder_lang_to_code_id.keys())}""" + ) + # tgt_lang gets priority over decoder input ids + text_tgt_lang_id = self.generation_config.text_decoder_lang_to_code_id.get(tgt_lang) + text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size, device=self.device) + else: + raise ValueError( + """This model generation config doesn't have a `text_decoder_lang_to_code_id` key which maps + the target language to the right token id. Make sure to load the right generation config.""" + ) + else: + # only a warning, otherwise errors appear in the tests + logger.warning( + """You must either specify a `tgt_lang` or pass a correct `text_decoder_input_ids` to get + a correct generation, otherwise the generation will probably make no sense.""" + ) + + return super().generate( + input_ids, + generation_config, + logits_processor, + stopping_criteria, + prefix_allowed_tokens_fn, + synced_gpus, + decoder_input_ids=text_decoder_input_ids, + **kwargs, + ) + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + # cached cross_attention states don't have to be reordered -> they are always the same + reordered_past += ( + tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], + ) + return reordered_past + + +@add_start_docstrings( + "The speech-to-text SeamlessM4Tv2 Model transformer which can be used for S2TT.", + SEAMLESS_M4T_V2_START_DOCSTRING, +) +class SeamlessM4Tv2ForSpeechToText(SeamlessM4Tv2PreTrainedModel, GenerationMixin): + _keys_to_ignore_on_load_missing = ["text_decoder", "t2u_model", "vocoder"] + main_input_name = "input_features" + + _tied_weights_keys = [ + "lm_head.weight", + "text_decoder.embed_tokens.weight", + ] + + # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TForSpeechToText.__init__ with SeamlessM4T->SeamlessM4Tv2 + def __init__(self, config: SeamlessM4Tv2Config): + super().__init__(config) + + self.shared = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) + self.speech_encoder = SeamlessM4Tv2SpeechEncoder(config) + self.text_decoder = SeamlessM4Tv2Decoder(config, self.shared) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TForSpeechToText.get_encoder + def get_encoder(self): + return self.speech_encoder + + # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TForSpeechToText.get_decoder + def get_decoder(self): + return self.text_decoder + + # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TForSpeechToText.get_output_embeddings + def get_output_embeddings(self): + return self.lm_head + + # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TForSpeechToText.set_output_embeddings + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TForSpeechToText.get_input_embeddings + def get_input_embeddings(self): + return self.text_decoder.embed_tokens + + # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TForSpeechToText.set_input_embeddings + def set_input_embeddings(self, value): + self.text_decoder.embed_tokens = value + + # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TForSpeechToText._tie_weights + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.text_decoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.lm_head, self.shared) + + @add_start_docstrings_to_model_forward(M4T_SPEECH_INPUTS_DOCSTRING) + # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TForSpeechToText.forward + def forward( + self, + input_features: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Seq2SeqLMOutput, Tuple[torch.FloatTensor]]: + if labels is not None: + if use_cache: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if encoder_outputs is None: + encoder_outputs = self.speech_encoder( + input_features=input_features, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + encoder_attention_mask = attention_mask + if attention_mask is not None: + sub_sampled_lengths = self._compute_sub_sample_lengths_from_attention_mask(attention_mask).to( + encoder_outputs[0].device + ) + encoder_attention_mask = _compute_new_attention_mask( + hidden_states=encoder_outputs[0], seq_lens=sub_sampled_lengths + ) + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.text_decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + lm_logits = self.lm_head(decoder_outputs[0]) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + labels = labels.to(lm_logits.device) + masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + outputs = decoder_outputs + encoder_outputs + output = (lm_logits,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return Seq2SeqLMOutput( + loss=masked_lm_loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TForSpeechToText.generate + def generate( + self, + input_features=None, + tgt_lang=None, + generation_config=None, + logits_processor=None, + stopping_criteria=None, + prefix_allowed_tokens_fn=None, + synced_gpus=False, + **kwargs, + ): + """ + Generates sequences of token ids. + + + + Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the + model's default generation configuration. You can override any `generation_config` by passing the corresponding + parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`. + + For an overview of generation strategies and code examples, check out the [following + guide](./generation_strategies). + + + + Parameters: + input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_banks)`): + Input audio features. This should be returnes by the [`SeamlessM4TFeatureExtractor`] class or the + [`SeamlessM4TProcessor`] class. See [`SeamlessM4TFeatureExtractor.__call__`] for details. + + tgt_lang (`str`, *optional*): + The language to use as target language for translation. + generation_config (`~generation.GenerationConfig`, *optional*): + The generation configuration to be used as base parametrization for the generation call. `**kwargs` + passed to generate matching the attributes of `generation_config` will override them. If + `generation_config` is not provided, the default will be used, which had the following loading + priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model + configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s + default values, whose documentation should be checked to parameterize generation. + logits_processor (`LogitsProcessorList`, *optional*): + Custom logits processors that complement the default logits processors built from arguments and + generation config. If a logit processor is passed that is already created with the arguments or a + generation config an error is thrown. This feature is intended for advanced users. + stopping_criteria (`StoppingCriteriaList`, *optional*): + Custom stopping criteria that complement the default stopping criteria built from arguments and a + generation config. If a stopping criteria is passed that is already created with the arguments or a + generation config an error is thrown. This feature is intended for advanced users. + prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*): + If provided, this function constraints the beam search to allowed tokens only at each step. If not + provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and + `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned + on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful + for constrained generation conditioned on the prefix, as described in [Autoregressive Entity + Retrieval](https://arxiv.org/abs/2010.00904). + synced_gpus (`bool`, *optional*, defaults to `False`): + Whether to continue running the while loop until max_length (needed to avoid deadlocking with + `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3). + kwargs (`Dict[str, Any]`, *optional*): + Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be + forwarded to the `forward` function of the model. + + Return: + [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` + or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. The possible + [`~utils.ModelOutput`] types are: + - [`~generation.GenerateEncoderDecoderOutput`], + - [`~generation.GenerateBeamEncoderDecoderOutput`] + """ + text_decoder_input_ids = kwargs.pop("decoder_input_ids", None) + # overwrite text_decoder_input_ids if tgt_lang is passed. The latter gets priority over decoder_input_ids. + input_features = input_features if input_features is not None else kwargs.pop("inputs") + if tgt_lang is not None: + inputs = kwargs.get("input_embeds") if input_features is None else input_features + inputs = ( + inputs + if inputs is not None + else kwargs.get("encoder_outputs", {"last_hidden_state": None})["last_hidden_state"] + ) + batch_size = len(inputs) + + if hasattr(self.generation_config, "text_decoder_lang_to_code_id"): + # also accept __xxx__ + tgt_lang = tgt_lang.replace("__", "") + if tgt_lang not in self.generation_config.text_decoder_lang_to_code_id: + raise ValueError( + f"""`tgt_lang={tgt_lang}` is not supported by this model. Please specify a `tgt_lang` in + {", ".join(self.generation_config.text_decoder_lang_to_code_id.keys())}""" + ) + # tgt_lang gets priority over decoder input ids + text_tgt_lang_id = self.generation_config.text_decoder_lang_to_code_id.get(tgt_lang) + text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size, device=self.device) + else: + raise ValueError( + """This model generation config doesn't have a `text_decoder_lang_to_code_id` key which maps + the target language to the right token id. Make sure to load the right generation config.""" + ) + else: + # only a warning, otherwise errors appear in the tests + logger.warning( + """You must either specify a `tgt_lang` or pass a correct `text_decoder_input_ids` to get + a correct generation, otherwise the generation will probably make no sense.""" + ) + return super().generate( + input_features, + generation_config, + logits_processor, + stopping_criteria, + prefix_allowed_tokens_fn, + synced_gpus, + decoder_input_ids=text_decoder_input_ids, + **kwargs, + ) + + @staticmethod + # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TForSpeechToText._reorder_cache + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + # cached cross_attention states don't have to be reordered -> they are always the same + reordered_past += ( + tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], + ) + return reordered_past + + +@add_start_docstrings( + "The text-to-speech SeamlessM4Tv2 Model transformer which can be used for T2ST.", + SEAMLESS_M4T_V2_START_DOCSTRING, +) +class SeamlessM4Tv2ForTextToSpeech(SeamlessM4Tv2PreTrainedModel, GenerationMixin): + _keys_to_ignore_on_load_missing = ["speech_encoder"] + main_input_name = "input_ids" + + _tied_weights_keys = [ + "lm_head.weight", + "text_encoder.embed_tokens.weight", + "text_decoder.embed_tokens.weight", + ] + + # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TForTextToSpeech.__init__ with SeamlessM4T->SeamlessM4Tv2 + def __init__(self, config: SeamlessM4Tv2Config): + super().__init__(config) + + self.shared = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) + + self.text_encoder = SeamlessM4Tv2Encoder(config, self.shared) + self.text_decoder = SeamlessM4Tv2Decoder(config, self.shared) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + self.t2u_model = SeamlessM4Tv2TextToUnitForConditionalGeneration(config) + self.vocoder = SeamlessM4Tv2CodeHifiGan(config) + + # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TForTextToSpeech.get_encoder + def get_encoder(self): + return self.text_encoder + + # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TForTextToSpeech.get_decoder + def get_decoder(self): + return self.text_decoder + + # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TForTextToSpeech.get_output_embeddings + def get_output_embeddings(self): + return self.lm_head + + # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TForTextToSpeech.set_output_embeddings + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TForTextToSpeech.get_input_embeddings + def get_input_embeddings(self): + return self.text_decoder.embed_tokens + + # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TForTextToSpeech.set_input_embeddings + def set_input_embeddings(self, value): + self.text_encoder.embed_tokens = value + self.text_decoder.embed_tokens = value + self.shared = value + + # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TForTextToSpeech._tie_weights + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.text_encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.text_decoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.lm_head, self.shared) + + @add_start_docstrings_to_model_forward(M4T_TEXT_INPUTS_DOCSTRING) + # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TForTextToSpeech.forward with SeamlessM4T->SeamlessM4Tv2 + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Seq2SeqLMOutput, Tuple[torch.FloatTensor]]: + if labels is not None: + if use_cache: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if encoder_outputs is None: + # if encoder_outputs is not None, it's probably used within a .generate method so no need to warn + logger.warning( + "This is the same forward method as `SeamlessM4Tv2ForTextToText`." + "It doesn't use the text-to-unit model `SeamlessM4Tv2TextToUnitForConditionalGeneration`." + "If you want to generate speech, use the `.generate` method." + ) + encoder_outputs = self.text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + encoder_attention_mask = attention_mask + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.text_decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + lm_logits = self.lm_head(decoder_outputs[0]) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + labels = labels.to(lm_logits.device) + masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + outputs = decoder_outputs + encoder_outputs + output = (lm_logits,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return Seq2SeqLMOutput( + loss=masked_lm_loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + @torch.no_grad() + def generate( + self, + input_ids: Optional[torch.Tensor] = None, + return_intermediate_token_ids: Optional[bool] = None, + tgt_lang: Optional[str] = None, + speaker_id: Optional[int] = 0, + **kwargs, + ) -> Union[torch.Tensor, SeamlessM4Tv2GenerationOutput]: + """ + Generates translated audio waveforms. + + + + This method successively calls the `.generate` function of two different sub-models. You can specify keyword + arguments at two different levels: general arguments that will be passed to both models, or prefixed arguments + that will be passed to one of them. + + For example, calling `.generate(input_ids, num_beams=4, speech_do_sample=True)` will successively perform + beam-search decoding on the text model, and multinomial beam-search sampling on the speech model. + + For an overview of generation strategies and code examples, check out the [following + guide](./generation_strategies). + + + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`SeamlessM4TTokenizer`] or [`SeamlessM4TProcessor`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + return_intermediate_token_ids (`bool`, *optional*): + If `True`, also returns the intermediate generated text and unit tokens. Set to `True` if you also want + to get translated text alongside the audio. + tgt_lang (`str`, *optional*): + The language to use as target language for translation. + speaker_id (`int`, *optional*, defaults to 0): + The id of the speaker used for speech synthesis. Must be lower than `config.vocoder_num_spkrs`. + kwargs (*optional*): + Remaining dictionary of keyword arguments that will be passed to [`GenerationMixin.generate`]. Keyword + arguments are of two types: + + - Without a prefix, they will be entered as `**kwargs` for the `generate` method of each sub-model, + except for `decoder_input_ids` which will only be passed through the text components. + - With a *text_* or *speech_* prefix, they will be input for the `generate` method of the + text model and speech model respectively. It has the priority over the keywords without a prefix. + + This means you can, for example, specify a generation strategy for one generation but not for the + other. + + + Returns: + `Union[SeamlessM4Tv2GenerationOutput, Tuple[Tensor]]`: + - If `return_intermediate_token_ids`, returns [`SeamlessM4Tv2GenerationOutput`]. + - If not `return_intermediate_token_ids`, returns a tuple composed of waveforms of shape `(batch_size, + sequence_length)`and and `waveform_lengths` which gives the length of each sample. + """ + batch_size = len(input_ids) if input_ids is not None else len(kwargs.get("inputs_embeds")) + + if tgt_lang is None: + raise ValueError("You must specify a `tgt_lang` to generate translated speech.") + else: + # also accept __xxx__ + tgt_lang = tgt_lang.replace("__", "") + for key in ["text_decoder_lang_to_code_id", "t2u_lang_code_to_id", "vocoder_lang_code_to_id"]: + lang_code_to_id = getattr(self.generation_config, key, None) + if lang_code_to_id is None: + raise ValueError( + f"""This model generation config doesn't have a `{key}` key which maps the target language + to the right token id. Make sure to load the right generation config.""" + ) + elif tgt_lang not in lang_code_to_id: + raise ValueError( + f"""`tgt_lang={tgt_lang}` is not supported by this model. + Please specify a `tgt_lang` in {",".join(lang_code_to_id.keys())}. Note that SeamlessM4Tv2 supports + more languages for text translation than for speech synthesis.""" + ) + + kwargs_text, kwargs_speech = format_speech_generation_kwargs(kwargs) + kwargs_text["output_hidden_states"] = True + kwargs_text["return_dict_in_generate"] = True + kwargs_text["output_scores"] = True + + text_decoder_input_ids = kwargs_text.get("decoder_input_ids") + + # overwrite text_decoder_input_ids if tgt_lang is passed. The latter gets priority over decoder_input_ids. + text_tgt_lang_id = self.generation_config.text_decoder_lang_to_code_id.get(tgt_lang) + text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size, device=self.device) + + kwargs_text["decoder_input_ids"] = text_decoder_input_ids + + # first generation + text_generation_output = super().generate(input_ids, **kwargs_text) + sequences = text_generation_output.sequences + + # prepare second generation + num_return_sequences = len(sequences) // batch_size + attention_mask = kwargs_speech.get("attention_mask", kwargs_text.get("attention_mask", None)) + + if attention_mask is not None: + # repeat attention mask alongside batch dimension + attention_mask = torch.repeat_interleave(attention_mask, num_return_sequences, dim=0) + encoder_hidden_states = text_generation_output.encoder_hidden_states[-1] + + # repeat attention mask alongside batch dimension + encoder_hidden_states = torch.repeat_interleave(encoder_hidden_states, num_return_sequences, dim=0) + + # get decoder last hidden state - must do a pass through the text decoder + t2u_input_embeds = self.text_decoder( + input_ids=sequences[:, :-1], # Manually trim the final EOS token + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=attention_mask, + ).last_hidden_state + + pad_token_id = self.generation_config.pad_token_id + + # Compute new attention mask + seq_lens = (sequences[:, :-1] != pad_token_id).int().sum(1) + t2u_model_attention_mask = _compute_new_attention_mask(t2u_input_embeds, seq_lens) + kwargs_speech["attention_mask"] = t2u_model_attention_mask + + # REMOVE EOS and lang_id + t2u_input_ids = sequences[:, 2:-1] + # replace every other EOS + t2u_input_ids = torch.masked_fill( + t2u_input_ids, t2u_input_ids == self.generation_config.eos_token_id, pad_token_id + ) + + # compute t2u_char_input_ids + t2u_subwords = self._indices_to_subwords(t2u_input_ids) + t2u_char_count_per_id = self._count_character_length_in_subword( + t2u_input_ids, t2u_subwords, pad_token_id=pad_token_id + ) + + # Add pads for lang, EOS tokens as per NLLB "source" tokenizer mode. + pad_zero = t2u_char_count_per_id.new_zeros((t2u_char_count_per_id.shape[0], 1)) + t2u_char_count_per_id = torch.cat([pad_zero, t2u_char_count_per_id, pad_zero], dim=1) + t2u_char_input_ids = self._get_char_input_ids( + t2u_input_ids, t2u_subwords, t2u_char_count_per_id, pad_token_id=pad_token_id + ) + + # second pass + t2u_output = self.t2u_model( + inputs_embeds=t2u_input_embeds, + char_input_ids=t2u_char_input_ids, + char_count_per_id=t2u_char_count_per_id, + **kwargs_speech, + ) + + t2u_logits = t2u_output[0] + padding_mask = t2u_output[1].bool() + + # The text-to-unit model is non auto-regressive. We keep the ability to use sampling with temperature + temperature = kwargs_speech.get("temperature", None) + if (temperature is None or temperature == 1.0) or not kwargs_speech.get("do_sample", False): + unit_ids = t2u_logits.argmax(dim=-1) + else: + t2u_logits = t2u_logits / temperature + # apply softmax + probs = nn.functional.softmax(t2u_logits, dim=-1) + # reshape to 2D: (batch_size, seq_len, t2u_vocab_size) -> (batch_size*seq_len, t2u_vocab_size) + probs = probs.reshape((-1, probs.shape[2])) + # multinomial then reshape : (batch_size*seq_len)-> (batch_size,seq_len) + unit_ids = torch.multinomial(probs, num_samples=1).view(t2u_logits.shape[0], -1) + + output_unit_ids = unit_ids.detach().clone() + + replace_mask = (unit_ids == self.config.t2u_eos_token_id) | (~padding_mask) + # replace eos per pad + unit_ids = unit_ids.masked_fill(replace_mask, self.config.t2u_pad_token_id) + + # offset of control symbols + unit_ids = torch.where( + unit_ids == self.config.t2u_pad_token_id, unit_ids, unit_ids - self.config.vocoder_offset + ) + + vocoder_tgt_lang_id = self.generation_config.vocoder_lang_code_to_id.get(tgt_lang) + vocoder_tgt_lang_id = torch.tensor([[vocoder_tgt_lang_id]] * len(unit_ids), device=self.device) + + speaker_id = torch.tensor([[speaker_id]] * len(unit_ids), device=self.device) + + waveform, waveform_lengths = self.vocoder( + input_ids=unit_ids, speaker_id=speaker_id, lang_id=vocoder_tgt_lang_id + ) + + if return_intermediate_token_ids: + return SeamlessM4Tv2GenerationOutput( + waveform=waveform, + waveform_lengths=waveform_lengths, + sequences=sequences, + unit_sequences=output_unit_ids, + ) + + return waveform, waveform_lengths + + @staticmethod + # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TForTextToSpeech._reorder_cache + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + # cached cross_attention states don't have to be reordered -> they are always the same + reordered_past += ( + tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], + ) + return reordered_past + + +@add_start_docstrings( + "The speech-to-speech SeamlessM4Tv2 Model transformer which can be used for S2ST.", + SEAMLESS_M4T_V2_START_DOCSTRING, +) +class SeamlessM4Tv2ForSpeechToSpeech(SeamlessM4Tv2PreTrainedModel, GenerationMixin): + _keys_to_ignore_on_load_missing = ["text_encoder"] + main_input_name = "input_features" + + _tied_weights_keys = [ + "lm_head.weight", + "text_decoder.embed_tokens.weight", + ] + + # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TForSpeechToSpeech.__init__ with SeamlessM4T->SeamlessM4Tv2 + def __init__(self, config): + super().__init__(config) + + self.shared = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) + self.speech_encoder = SeamlessM4Tv2SpeechEncoder(config) + self.text_decoder = SeamlessM4Tv2Decoder(config, self.shared) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + self.t2u_model = SeamlessM4Tv2TextToUnitForConditionalGeneration(config) + self.vocoder = SeamlessM4Tv2CodeHifiGan(config) + + # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TForSpeechToSpeech.get_encoder + def get_encoder(self): + return self.speech_encoder + + # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TForSpeechToSpeech.get_decoder + def get_decoder(self): + return self.text_decoder + + # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TForSpeechToSpeech.get_output_embeddings + def get_output_embeddings(self): + return self.lm_head + + # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TForSpeechToSpeech.set_output_embeddings + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TForSpeechToSpeech.get_input_embeddings + def get_input_embeddings(self): + return self.text_decoder.embed_tokens + + # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TForSpeechToSpeech.set_input_embeddings + def set_input_embeddings(self, value): + self.text_decoder.embed_tokens = value + + # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TForSpeechToSpeech._tie_weights + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.text_decoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.lm_head, self.shared) + + @add_start_docstrings_to_model_forward(M4T_SPEECH_INPUTS_DOCSTRING) + # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TForSpeechToSpeech.forward with SeamlessM4T->SeamlessM4Tv2 + def forward( + self, + input_features: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Seq2SeqLMOutput, Tuple[torch.FloatTensor]]: + if labels is not None: + if use_cache: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if encoder_outputs is None: + # if encoder_outputs is not None, it's probably used within a .generate method so no need to warn + logger.warning( + "This is the same forward method as `SeamlessM4Tv2ForSpeechToText`. It doesn't use `self.t2u_model`." + "If you want to generate speech, use the `generate` method." + ) + + encoder_outputs = self.speech_encoder( + input_features=input_features, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + encoder_attention_mask = attention_mask + if attention_mask is not None: + sub_sampled_lengths = self._compute_sub_sample_lengths_from_attention_mask(attention_mask).to( + encoder_outputs[0].device + ) + encoder_attention_mask = _compute_new_attention_mask( + hidden_states=encoder_outputs[0], seq_lens=sub_sampled_lengths + ) + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.text_decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + lm_logits = self.lm_head(decoder_outputs[0]) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + labels = labels.to(lm_logits.device) + masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + outputs = decoder_outputs + encoder_outputs + output = (lm_logits,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return Seq2SeqLMOutput( + loss=masked_lm_loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + @torch.no_grad() + def generate( + self, + input_features: Optional[torch.Tensor] = None, + return_intermediate_token_ids: Optional[bool] = None, + tgt_lang: Optional[str] = None, + speaker_id: Optional[int] = 0, + **kwargs, + ) -> Union[torch.Tensor, SeamlessM4Tv2GenerationOutput]: + """ + Generates translated audio waveforms. + + + + This method successively calls the `.generate` function of two different sub-models. You can specify keyword + arguments at two different levels: general arguments that will be passed to both models, or prefixed arguments + that will be passed to one of them. + + For example, calling `.generate(input_features, num_beams=4, speech_do_sample=True)` will successively perform + beam-search decoding on the text model, and multinomial beam-search sampling on the speech model. + + For an overview of generation strategies and code examples, check out the [following + guide](./generation_strategies). + + + + Args: + input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_banks)`): + Input audio features. This should be returnes by the [`SeamlessM4TFeatureExtractor`] class or the + [`SeamlessM4TProcessor`] class. See [`SeamlessM4TFeatureExtractor.__call__`] for details. + return_intermediate_token_ids (`bool`, *optional*): + If `True`, also returns the intermediate generated text and unit tokens. Set to `True` if you also want + to get translated text alongside the audio. + tgt_lang (`str`, *optional*): + The language to use as target language for translation. + speaker_id (`int`, *optional*, defaults to 0): + The id of the speaker used for speech synthesis. Must be lower than `config.vocoder_num_spkrs`. + + kwargs (*optional*): + Remaining dictionary of keyword arguments that will be passed to [`GenerationMixin.generate`]. Keyword + arguments are of two types: + + - Without a prefix, they will be entered as `**kwargs` for the `generate` method of each sub-model, + except for `decoder_input_ids` which will only be passed through the text components. + - With a *text_* or *speech_* prefix, they will be input for the `generate` method of the + text model and speech model respectively. It has the priority over the keywords without a prefix. + + This means you can, for example, specify a generation strategy for one generation but not for the + other. + + + Returns: + `Union[SeamlessM4Tv2GenerationOutput, Tuple[Tensor]]`: + - If `return_intermediate_token_ids`, returns [`SeamlessM4Tv2GenerationOutput`]. + - If not `return_intermediate_token_ids`, returns a tuple composed of waveforms of shape `(batch_size, + sequence_length)`and and `waveform_lengths` which gives the length of each sample. + """ + batch_size = len(input_features) if input_features is not None else len(kwargs.get("inputs_embeds")) + + if tgt_lang is None: + raise ValueError("You must specify a `tgt_lang` to generate translated speech.") + else: + # also accept __xxx__ + tgt_lang = tgt_lang.replace("__", "") + for key in ["text_decoder_lang_to_code_id", "t2u_lang_code_to_id", "vocoder_lang_code_to_id"]: + lang_code_to_id = getattr(self.generation_config, key, None) + if lang_code_to_id is None: + raise ValueError( + f"""This model generation config doesn't have a `{key}` key which maps the target language + to the right token id. Make sure to load the right generation config.""" + ) + elif tgt_lang not in lang_code_to_id: + raise ValueError( + f"""`tgt_lang={tgt_lang}` is not supported by this model. + Please specify a `tgt_lang` in {",".join(lang_code_to_id.keys())}. Note that SeamlessM4Tv2 supports + more languages for text translation than for speech synthesis.""" + ) + + kwargs_text, kwargs_speech = format_speech_generation_kwargs(kwargs) + kwargs_text["output_hidden_states"] = True + kwargs_text["return_dict_in_generate"] = True + kwargs_text["output_scores"] = True + + text_decoder_input_ids = kwargs_text.get("decoder_input_ids") + # overwrite text_decoder_input_ids if tgt_lang is passed. The latter gets priority over decoder_input_ids. + text_tgt_lang_id = self.generation_config.text_decoder_lang_to_code_id.get(tgt_lang) + text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size, device=self.device) + + kwargs_text["decoder_input_ids"] = text_decoder_input_ids + + # first generation + text_generation_output = super().generate(input_features, **kwargs_text) + sequences = text_generation_output.sequences + + # prepare second generation + num_return_sequences = len(sequences) // batch_size + attention_mask = kwargs_speech.get("attention_mask", kwargs_text.get("attention_mask", None)) + + # get last_hidden_state from encoder + encoder_hidden_states = self.speech_encoder(input_features=input_features, attention_mask=attention_mask)[0] + + # input modality = speech so new attention mask for the decoder + if attention_mask is not None: + sub_sampled_lengths = self._compute_sub_sample_lengths_from_attention_mask(attention_mask).to( + encoder_hidden_states.device + ) + attention_mask = _compute_new_attention_mask( + hidden_states=encoder_hidden_states, seq_lens=sub_sampled_lengths + ) + + # repeat attention mask alongside batch dimension + attention_mask = torch.repeat_interleave(attention_mask, num_return_sequences, dim=0) + + # repeat attention mask alongside batch dimension + encoder_hidden_states = torch.repeat_interleave(encoder_hidden_states, num_return_sequences, dim=0) + + # get decoder last hidden state - must do a pass through the text decoder + t2u_input_embeds = self.text_decoder( + input_ids=sequences[:, :-1], # Manually trim the final EOS token + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=attention_mask, + ).last_hidden_state + + pad_token_id = self.generation_config.pad_token_id + + # Compute new attention mask + seq_lens = (sequences[:, :-1] != pad_token_id).int().sum(1) + t2u_model_attention_mask = _compute_new_attention_mask(t2u_input_embeds, seq_lens) + kwargs_speech["attention_mask"] = t2u_model_attention_mask + + # REMOVE EOS and lang_id + t2u_input_ids = sequences[:, 2:-1] + # replace every other EOS + t2u_input_ids = torch.masked_fill( + t2u_input_ids, t2u_input_ids == self.generation_config.eos_token_id, pad_token_id + ) + + # compute t2u_char_input_ids + t2u_subwords = self._indices_to_subwords(t2u_input_ids) + t2u_char_count_per_id = self._count_character_length_in_subword( + t2u_input_ids, t2u_subwords, pad_token_id=pad_token_id + ) + + # Add pads for lang, EOS tokens as per NLLB "source" tokenizer mode. + pad_zero = t2u_char_count_per_id.new_zeros((t2u_char_count_per_id.shape[0], 1)) + t2u_char_count_per_id = torch.cat([pad_zero, t2u_char_count_per_id, pad_zero], dim=1) + t2u_char_input_ids = self._get_char_input_ids( + t2u_input_ids, t2u_subwords, t2u_char_count_per_id, pad_token_id=pad_token_id + ) + + # second pass + t2u_output = self.t2u_model( + inputs_embeds=t2u_input_embeds, + char_input_ids=t2u_char_input_ids, + char_count_per_id=t2u_char_count_per_id, + **kwargs_speech, + ) + + t2u_logits = t2u_output[0] + padding_mask = t2u_output[1].bool() + + # The text-to-unit model is non auto-regressive. We keep the ability to use sampling with temperature + temperature = kwargs_speech.get("temperature", None) + if (temperature is None or temperature == 1.0) or not kwargs_speech.get("do_sample", False): + unit_ids = t2u_logits.argmax(dim=-1) + else: + t2u_logits = t2u_logits / temperature + # apply softmax + probs = nn.functional.softmax(t2u_logits, dim=-1) + # reshape to 2D: (batch_size, seq_len, t2u_vocab_size) -> (batch_size*seq_len, t2u_vocab_size) + probs = probs.reshape((-1, probs.shape[2])) + # multinomial then reshape : (batch_size*seq_len)-> (batch_size,seq_len) + unit_ids = torch.multinomial(probs, num_samples=1).view(t2u_logits.shape[0], -1) + + output_unit_ids = unit_ids.detach().clone() + + replace_mask = (unit_ids == self.config.t2u_eos_token_id) | (~padding_mask) + # replace eos per pad + unit_ids = unit_ids.masked_fill(replace_mask, self.config.t2u_pad_token_id) + + # offset of control symbols + unit_ids = torch.where( + unit_ids == self.config.t2u_pad_token_id, unit_ids, unit_ids - self.config.vocoder_offset + ) + + vocoder_tgt_lang_id = self.generation_config.vocoder_lang_code_to_id.get(tgt_lang) + vocoder_tgt_lang_id = torch.tensor([[vocoder_tgt_lang_id]] * len(unit_ids), device=self.device) + + speaker_id = torch.tensor([[speaker_id]] * len(unit_ids), device=self.device) + + waveform, waveform_lengths = self.vocoder( + input_ids=unit_ids, speaker_id=speaker_id, lang_id=vocoder_tgt_lang_id + ) + + if return_intermediate_token_ids: + return SeamlessM4Tv2GenerationOutput( + waveform=waveform, + waveform_lengths=waveform_lengths, + sequences=sequences, + unit_sequences=output_unit_ids, + ) + + return waveform, waveform_lengths + + @staticmethod + # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TForSpeechToSpeech._reorder_cache + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + # cached cross_attention states don't have to be reordered -> they are always the same + reordered_past += ( + tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], + ) + return reordered_past + + +@add_start_docstrings( + "The original SeamlessM4Tv2 Model transformer which can be used for every tasks available (S2ST, S2TT, T2TT, T2ST).", + SEAMLESS_M4T_V2_START_DOCSTRING, + """ + current_modality (`str`, *optional*, defaults to `"text"`): + Default modality. Used only to initialize the model. It can be set to `"text"` or `"speech"`. + This will be updated automatically according to the modality passed to the forward and generate passes (`input_ids` for text and `input_features` for audio). + """, +) +class SeamlessM4Tv2Model(SeamlessM4Tv2PreTrainedModel, GenerationMixin): + _tied_weights_keys = [ + "lm_head.weight", + "text_encoder.embed_tokens.weight", + "text_decoder.embed_tokens.weight", + ] + + # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TModel.__init__ with SeamlessM4T->SeamlessM4Tv2 + def __init__(self, config, current_modality="text"): + super().__init__(config) + + self.shared = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) + + self.text_encoder = SeamlessM4Tv2Encoder(config, self.shared) + self.speech_encoder = SeamlessM4Tv2SpeechEncoder(config) + self.text_decoder = SeamlessM4Tv2Decoder(config, self.shared) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + self.current_modality = current_modality + if current_modality == "speech": + self.main_input_name = "input_features" + + # these models already call post_init in their initialization + self.t2u_model = SeamlessM4Tv2TextToUnitForConditionalGeneration(config) + self.vocoder = SeamlessM4Tv2CodeHifiGan(config) + + # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TModel.set_modality + def set_modality(self, modality="text"): + if modality == "text": + self.main_input_name = "input_ids" + self.current_modality = "text" + elif modality == "speech": + self.main_input_name = "input_features" + self.current_modality = "speech" + else: + raise ValueError(f"`modality={modality}` is not a valid modality. It must be `text` or `speech`.") + + # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TModel.get_encoder + def get_encoder(self): + if self.current_modality == "text": + return self.text_encoder + else: + return self.speech_encoder + + # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TModel.get_output_embeddings + def get_output_embeddings(self): + return self.lm_head + + # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TModel.set_output_embeddings + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TModel.get_input_embeddings + def get_input_embeddings(self): + return self.text_decoder.embed_tokens + + # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TModel.set_input_embeddings + def set_input_embeddings(self, value): + self.text_encoder.embed_tokens = value + self.text_decoder.embed_tokens = value + self.shared = value + + # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TModel._tie_weights + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.text_encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.text_decoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.lm_head, self.shared) + + @add_start_docstrings_to_model_forward(M4T_MODEL_INPUTS_DOCSTRING) + # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TModel.forward with SeamlessM4T->SeamlessM4Tv2 + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + input_features: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Seq2SeqLMOutput, Tuple[torch.FloatTensor]]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None: + if use_cache: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + if input_ids is None and input_features is None and inputs_embeds is None and encoder_outputs is None: + raise ValueError( + "`input_ids`,`input_features`, `inputs_embeds` and `encoder_outputs` are all empty. Make sure at least one of them is not." + ) + elif input_features is not None: + if input_ids is not None: + logger.warning( + "`input_ids` is not `None` but `input_features` has been given." + "`input_features` will be used in priority through the `speech_encoder`. " + "Make sure that `input_features` and `input_ids` are mutually exclusive." + ) + + if inputs_embeds is not None: + logger.warning( + "`inputs_embeds` is not `None` but `input_features` has been given." + "`input_features` will be used in priority through `speech_encoder`. " + "`inputs_embeds` will be ignored." + ) + + # if encoder_outputs is not None, it's probably used within a .generate method so no need to warn + logger.warning( + "This calls the same method `forward` as `SeamlessM4Tv2ForTextToText` and `SeamlessM4Tv2ForSpeechToText`" + "depending on the input modality. If you want to generate speech, use the `generate` method." + ) + + self.set_modality("speech") + + encoder_outputs = self.speech_encoder( + input_features=input_features, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + elif input_ids is not None or inputs_embeds is not None: + # if encoder_outputs is not None, it's probably used within a .generate method so no need to warn + logger.warning( + "This calls the same method `forward` as `SeamlessM4Tv2ForTextToText` and `SeamlessM4Tv2ForSpeechToText`" + "depending on the input modality. If you want to generate speech, use the `generate` method." + ) + self.set_modality("text") + encoder_outputs = self.text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + encoder_attention_mask = attention_mask + # input modality = speech so new attention mask + if self.current_modality == "speech" and attention_mask is not None: + sub_sampled_lengths = self._compute_sub_sample_lengths_from_attention_mask(attention_mask).to( + encoder_outputs[0].device + ) + encoder_attention_mask = _compute_new_attention_mask( + hidden_states=encoder_outputs[0], seq_lens=sub_sampled_lengths + ) + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.text_decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + lm_logits = self.lm_head(decoder_outputs[0]) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + labels = labels.to(lm_logits.device) + masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + outputs = decoder_outputs + encoder_outputs + output = (lm_logits,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return Seq2SeqLMOutput( + loss=masked_lm_loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + @torch.no_grad() + def generate( + self, + input_ids: Optional[torch.Tensor] = None, + input_features: Optional[torch.Tensor] = None, + return_intermediate_token_ids: Optional[bool] = None, + tgt_lang: Optional[str] = None, + speaker_id: Optional[int] = 0, + generate_speech: Optional[bool] = True, + **kwargs, + ) -> Union[torch.Tensor, SeamlessM4Tv2GenerationOutput]: + """ + Generates translated token ids and/or translated audio waveforms. + + + + This method successively calls the `.generate` function of two different sub-models. You can specify keyword + arguments at two different levels: general arguments that will be passed to both models, or prefixed arguments + that will be passed to one of them. + + For example, calling `.generate(input_ids=input_ids, num_beams=4, speech_do_sample=True)` will successively + perform beam-search decoding on the text model, and multinomial beam-search sampling on the speech model. + + For an overview of generation strategies and code examples, check out the [following + guide](./generation_strategies). + + + + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`SeamlessM4TTokenizer`] or [`SeamlessM4TProcessor`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_banks)`, *optional*): + Input audio features. This should be returnes by the [`SeamlessM4TFeatureExtractor`] class or the + [`SeamlessM4TProcessor`] class. See [`SeamlessM4TFeatureExtractor.__call__`] for details. + return_intermediate_token_ids (`bool`, *optional*): + If `True`, also returns the intermediate generated text and unit tokens. Set to `True` if you also want + to get translated text alongside the audio. Note that if `generate_speech=True`, this parameter will be + ignored. + tgt_lang (`str`, *optional*): + The language to use as target language for translation. + speaker_id (`int`, *optional*, defaults to 0): + The id of the speaker used for speech synthesis. Must be lower than `config.vocoder_num_spkrs`. + generate_speech (`bool`, *optional*, defaults to `True`): + If `False`, will only returns the text tokens and won't generate speech. + + kwargs (*optional*): + Remaining dictioy of keyword arguments that will be passed to [`GenerationMixin.generate`]. Keyword + arguments are of two types: + + - Without a prefix, they will be entered as `**kwargs` for the `generate` method of each sub-model, + except for `decoder_input_ids` which will only be passed through the text components. + - With a *text_* or *speech_* prefix, they will be input for the `generate` method of the + text model and speech model respectively. It has the priority over the keywords without a prefix. + + This means you can, for example, specify a generation strategy for one generation but not for the + other. + + Returns: + `Union[SeamlessM4Tv2GenerationOutput, Tuple[Tensor], ModelOutput]`: + - If `generate_speech` and `return_intermediate_token_ids`, returns [`SeamlessM4Tv2GenerationOutput`]. + - If `generate_speech` and not `return_intermediate_token_ids`, returns a tuple composed of waveforms of + shape `(batch_size, sequence_length)`and and `waveform_lengths` which gives the length of each sample. + - If `generate_speech=False`, it will returns `ModelOutput`. + """ + if input_ids is None and input_features is None and kwargs.get("inputs_embeds", None) is None: + raise ValueError( + "`input_ids`,`input_features` and `inputs_embeds` are all empty. Make sure at least one of them is not." + ) + + if generate_speech and tgt_lang is None: + raise ValueError("You must specify a `tgt_lang` to generate translated speech.") + + if tgt_lang is not None: + # also accept __xxx__ + tgt_lang = tgt_lang.replace("__", "") + if generate_speech: + keys_to_check = ["text_decoder_lang_to_code_id", "t2u_lang_code_to_id", "vocoder_lang_code_to_id"] + else: + keys_to_check = ["text_decoder_lang_to_code_id"] + for key in keys_to_check: + lang_code_to_id = getattr(self.generation_config, key, None) + if lang_code_to_id is None: + raise ValueError( + f"""This model generation config doesn't have a `{key}` key which maps the target language + to the right token id. Make sure to load the right generation config.""" + ) + elif tgt_lang not in lang_code_to_id: + raise ValueError( + f"""`tgt_lang={tgt_lang}` is not supported by this model. + Please specify a `tgt_lang` in {",".join(lang_code_to_id.keys())}. Note that SeamlessM4Tv2 supports + more languages for text translation than for speech synthesis.""" + ) + + batch_size = ( + len(input_features) + if input_features is not None + else (len(input_ids) if input_ids is not None else len(kwargs.get("inputs_embeds"))) + ) + + kwargs_text, kwargs_speech = format_speech_generation_kwargs(kwargs) + kwargs_text["output_hidden_states"] = True + kwargs_text["return_dict_in_generate"] = True + kwargs_text["output_scores"] = True + + text_decoder_input_ids = kwargs_text.get("decoder_input_ids") + # overwrite text_decoder_input_ids if tgt_lang is passed. The latter gets priority over decoder_input_ids. + if tgt_lang is not None: + # tgt_lang gets priority over decoder input ids + text_tgt_lang_id = self.generation_config.text_decoder_lang_to_code_id.get(tgt_lang) + text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size, device=self.device) + + kwargs_text["decoder_input_ids"] = text_decoder_input_ids + + # first generation + if input_features is not None: + self.set_modality("speech") + if input_ids is not None: + logger.warning( + "`input_features` and `input_ids` are both non empty. `input_features` will be used in priority " + "through the speech encoder. Make sure `input_features=None` if you want to use the text encoder." + ) + text_generation_output = super().generate(input_features=input_features, **kwargs_text) + else: + self.set_modality("text") + text_generation_output = super().generate(input_ids=input_ids, input_features=None, **kwargs_text) + sequences = text_generation_output.sequences + + if not generate_speech: + return text_generation_output + + # prepare second generation + num_return_sequences = len(sequences) // batch_size + attention_mask = kwargs_speech.get("attention_mask", kwargs_text.get("attention_mask", None)) + + # get encoder last hidden states + if self.current_modality == "speech": + # get last_hidden_state from encoder - must do a pass through the speech encoder + encoder_hidden_states = self.speech_encoder( + input_features=input_features, attention_mask=attention_mask + ).last_hidden_state + + # input modality = speech so new attention mask for the decoder + if attention_mask is not None: + sub_sampled_lengths = self._compute_sub_sample_lengths_from_attention_mask(attention_mask).to( + encoder_hidden_states.device + ) + attention_mask = _compute_new_attention_mask( + hidden_states=encoder_hidden_states, seq_lens=sub_sampled_lengths + ) + else: + encoder_hidden_states = text_generation_output.encoder_hidden_states[-1] + + if attention_mask is not None: + # repeat attention mask alongside batch dimension + attention_mask = torch.repeat_interleave(attention_mask, num_return_sequences, dim=0) + + # repeat attention mask alongside batch dimension + encoder_hidden_states = torch.repeat_interleave(encoder_hidden_states, num_return_sequences, dim=0) + + # get decoder last hidden state - must do a pass through the text decoder + t2u_input_embeds = self.text_decoder( + input_ids=sequences[:, :-1], # Manually trim the final EOS token + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=attention_mask, + ).last_hidden_state + + pad_token_id = self.generation_config.pad_token_id + + # Compute new attention mask + seq_lens = (sequences[:, :-1] != pad_token_id).int().sum(1) + t2u_model_attention_mask = _compute_new_attention_mask(t2u_input_embeds, seq_lens) + kwargs_speech["attention_mask"] = t2u_model_attention_mask + + # REMOVE EOS and lang_id + t2u_input_ids = sequences[:, 2:-1] + # replace every other EOS + t2u_input_ids = torch.masked_fill( + t2u_input_ids, t2u_input_ids == self.generation_config.eos_token_id, pad_token_id + ) + + # compute t2u_char_input_ids + t2u_subwords = self._indices_to_subwords(t2u_input_ids) + t2u_char_count_per_id = self._count_character_length_in_subword( + t2u_input_ids, t2u_subwords, pad_token_id=pad_token_id + ) + + # Add pads for lang, EOS tokens as per NLLB "source" tokenizer mode. + pad_zero = t2u_char_count_per_id.new_zeros((t2u_char_count_per_id.shape[0], 1)) + t2u_char_count_per_id = torch.cat([pad_zero, t2u_char_count_per_id, pad_zero], dim=1) + t2u_char_input_ids = self._get_char_input_ids( + t2u_input_ids, t2u_subwords, t2u_char_count_per_id, pad_token_id=pad_token_id + ) + + # second pass + t2u_output = self.t2u_model( + inputs_embeds=t2u_input_embeds, + char_input_ids=t2u_char_input_ids, + char_count_per_id=t2u_char_count_per_id, + **kwargs_speech, + ) + + t2u_logits = t2u_output[0] + padding_mask = t2u_output[1].bool() + + # The text-to-unit model is non auto-regressive. We keep the ability to use sampling with temperature + temperature = kwargs_speech.get("temperature", None) + if (temperature is None or temperature == 1.0) or not kwargs_speech.get("do_sample", False): + unit_ids = t2u_logits.argmax(dim=-1) + else: + t2u_logits = t2u_logits / temperature + # apply softmax + probs = nn.functional.softmax(t2u_logits, dim=-1) + # reshape to 2D: (batch_size, seq_len, t2u_vocab_size) -> (batch_size*seq_len, t2u_vocab_size) + probs = probs.reshape((-1, probs.shape[2])) + # multinomial then reshape : (batch_size*seq_len)-> (batch_size,seq_len) + unit_ids = torch.multinomial(probs, num_samples=1).view(t2u_logits.shape[0], -1) + + output_unit_ids = unit_ids.detach().clone() + + replace_mask = (unit_ids == self.config.t2u_eos_token_id) | (~padding_mask) + # replace eos per pad + unit_ids = unit_ids.masked_fill(replace_mask, self.config.t2u_pad_token_id) + + # offset of control symbols + unit_ids = torch.where( + unit_ids == self.config.t2u_pad_token_id, unit_ids, unit_ids - self.config.vocoder_offset + ) + + vocoder_tgt_lang_id = self.generation_config.vocoder_lang_code_to_id.get(tgt_lang) + vocoder_tgt_lang_id = torch.tensor([[vocoder_tgt_lang_id]] * len(unit_ids), device=self.device) + + speaker_id = torch.tensor([[speaker_id]] * len(unit_ids), device=self.device) + + waveform, waveform_lengths = self.vocoder( + input_ids=unit_ids, speaker_id=speaker_id, lang_id=vocoder_tgt_lang_id + ) + + if return_intermediate_token_ids: + return SeamlessM4Tv2GenerationOutput( + waveform=waveform, + waveform_lengths=waveform_lengths, + sequences=sequences, + unit_sequences=output_unit_ids, + ) + + return waveform, waveform_lengths + + @staticmethod + # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TModel._reorder_cache + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + # cached cross_attention states don't have to be reordered -> they are always the same + reordered_past += ( + tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], + ) + return reordered_past + + +__all__ = [ + "SeamlessM4Tv2ForTextToSpeech", + "SeamlessM4Tv2ForSpeechToSpeech", + "SeamlessM4Tv2ForTextToText", + "SeamlessM4Tv2ForSpeechToText", + "SeamlessM4Tv2Model", + "SeamlessM4Tv2PreTrainedModel", +] diff --git a/docs/transformers/src/transformers/models/segformer/__init__.py b/docs/transformers/src/transformers/models/segformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9fb469789773f13a19f62457cef3ece6311995fe --- /dev/null +++ b/docs/transformers/src/transformers/models/segformer/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_segformer import * + from .feature_extraction_segformer import * + from .image_processing_segformer import * + from .modeling_segformer import * + from .modeling_tf_segformer import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/segformer/configuration_segformer.py b/docs/transformers/src/transformers/models/segformer/configuration_segformer.py new file mode 100644 index 0000000000000000000000000000000000000000..58683a86c78a5a01cb0e67dae53a9cbd37e78ecd --- /dev/null +++ b/docs/transformers/src/transformers/models/segformer/configuration_segformer.py @@ -0,0 +1,171 @@ +# coding=utf-8 +# Copyright 2021 NVIDIA and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""SegFormer model configuration""" + +import warnings +from collections import OrderedDict +from typing import Mapping + +from packaging import version + +from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class SegformerConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`SegformerModel`]. It is used to instantiate an + SegFormer model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the SegFormer + [nvidia/segformer-b0-finetuned-ade-512-512](https://huggingface.co/nvidia/segformer-b0-finetuned-ade-512-512) + architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + num_encoder_blocks (`int`, *optional*, defaults to 4): + The number of encoder blocks (i.e. stages in the Mix Transformer encoder). + depths (`List[int]`, *optional*, defaults to `[2, 2, 2, 2]`): + The number of layers in each encoder block. + sr_ratios (`List[int]`, *optional*, defaults to `[8, 4, 2, 1]`): + Sequence reduction ratios in each encoder block. + hidden_sizes (`List[int]`, *optional*, defaults to `[32, 64, 160, 256]`): + Dimension of each of the encoder blocks. + patch_sizes (`List[int]`, *optional*, defaults to `[7, 3, 3, 3]`): + Patch size before each encoder block. + strides (`List[int]`, *optional*, defaults to `[4, 2, 2, 2]`): + Stride before each encoder block. + num_attention_heads (`List[int]`, *optional*, defaults to `[1, 2, 5, 8]`): + Number of attention heads for each attention layer in each block of the Transformer encoder. + mlp_ratios (`List[int]`, *optional*, defaults to `[4, 4, 4, 4]`): + Ratio of the size of the hidden layer compared to the size of the input layer of the Mix FFNs in the + encoder blocks. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + classifier_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability before the classification head. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + drop_path_rate (`float`, *optional*, defaults to 0.1): + The dropout probability for stochastic depth, used in the blocks of the Transformer encoder. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + decoder_hidden_size (`int`, *optional*, defaults to 256): + The dimension of the all-MLP decode head. + semantic_loss_ignore_index (`int`, *optional*, defaults to 255): + The index that is ignored by the loss function of the semantic segmentation model. + + Example: + + ```python + >>> from transformers import SegformerModel, SegformerConfig + + >>> # Initializing a SegFormer nvidia/segformer-b0-finetuned-ade-512-512 style configuration + >>> configuration = SegformerConfig() + + >>> # Initializing a model from the nvidia/segformer-b0-finetuned-ade-512-512 style configuration + >>> model = SegformerModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "segformer" + + def __init__( + self, + num_channels=3, + num_encoder_blocks=4, + depths=[2, 2, 2, 2], + sr_ratios=[8, 4, 2, 1], + hidden_sizes=[32, 64, 160, 256], + patch_sizes=[7, 3, 3, 3], + strides=[4, 2, 2, 2], + num_attention_heads=[1, 2, 5, 8], + mlp_ratios=[4, 4, 4, 4], + hidden_act="gelu", + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + classifier_dropout_prob=0.1, + initializer_range=0.02, + drop_path_rate=0.1, + layer_norm_eps=1e-6, + decoder_hidden_size=256, + semantic_loss_ignore_index=255, + **kwargs, + ): + super().__init__(**kwargs) + + if "reshape_last_stage" in kwargs and kwargs["reshape_last_stage"] is False: + warnings.warn( + "Reshape_last_stage is set to False in this config. This argument is deprecated and will soon be" + " removed, as the behaviour will default to that of reshape_last_stage = True.", + FutureWarning, + ) + + self.num_channels = num_channels + self.num_encoder_blocks = num_encoder_blocks + self.depths = depths + self.sr_ratios = sr_ratios + self.hidden_sizes = hidden_sizes + self.patch_sizes = patch_sizes + self.strides = strides + self.mlp_ratios = mlp_ratios + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.classifier_dropout_prob = classifier_dropout_prob + self.initializer_range = initializer_range + self.drop_path_rate = drop_path_rate + self.layer_norm_eps = layer_norm_eps + self.decoder_hidden_size = decoder_hidden_size + self.reshape_last_stage = kwargs.get("reshape_last_stage", True) + self.semantic_loss_ignore_index = semantic_loss_ignore_index + + +class SegformerOnnxConfig(OnnxConfig): + torch_onnx_minimum_version = version.parse("1.11") + + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict( + [ + ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}), + ] + ) + + @property + def atol_for_validation(self) -> float: + return 1e-4 + + @property + def default_onnx_opset(self) -> int: + return 12 + + +__all__ = ["SegformerConfig", "SegformerOnnxConfig"] diff --git a/docs/transformers/src/transformers/models/segformer/convert_segformer_original_to_pytorch.py b/docs/transformers/src/transformers/models/segformer/convert_segformer_original_to_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..c84e006ad6480a130e9e642f0924f052da9e0419 --- /dev/null +++ b/docs/transformers/src/transformers/models/segformer/convert_segformer_original_to_pytorch.py @@ -0,0 +1,387 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert SegFormer checkpoints.""" + +import argparse +import json +from collections import OrderedDict +from pathlib import Path + +import requests +import torch +from huggingface_hub import hf_hub_download +from PIL import Image + +from transformers import ( + SegformerConfig, + SegformerForImageClassification, + SegformerForSemanticSegmentation, + SegformerImageProcessor, +) +from transformers.utils import logging + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + + +def rename_keys(state_dict, encoder_only=False): + new_state_dict = OrderedDict() + for key, value in state_dict.items(): + if encoder_only and not key.startswith("head"): + key = "segformer.encoder." + key + if key.startswith("backbone"): + key = key.replace("backbone", "segformer.encoder") + if "patch_embed" in key: + # replace for example patch_embed1 by patch_embeddings.0 + idx = key[key.find("patch_embed") + len("patch_embed")] + key = key.replace(f"patch_embed{idx}", f"patch_embeddings.{int(idx) - 1}") + if "norm" in key: + key = key.replace("norm", "layer_norm") + if "segformer.encoder.layer_norm" in key: + # replace for example layer_norm1 by layer_norm.0 + idx = key[key.find("segformer.encoder.layer_norm") + len("segformer.encoder.layer_norm")] + key = key.replace(f"layer_norm{idx}", f"layer_norm.{int(idx) - 1}") + if "layer_norm1" in key: + key = key.replace("layer_norm1", "layer_norm_1") + if "layer_norm2" in key: + key = key.replace("layer_norm2", "layer_norm_2") + if "block" in key: + # replace for example block1 by block.0 + idx = key[key.find("block") + len("block")] + key = key.replace(f"block{idx}", f"block.{int(idx) - 1}") + if "attn.q" in key: + key = key.replace("attn.q", "attention.self.query") + if "attn.proj" in key: + key = key.replace("attn.proj", "attention.output.dense") + if "attn" in key: + key = key.replace("attn", "attention.self") + if "fc1" in key: + key = key.replace("fc1", "dense1") + if "fc2" in key: + key = key.replace("fc2", "dense2") + if "linear_pred" in key: + key = key.replace("linear_pred", "classifier") + if "linear_fuse" in key: + key = key.replace("linear_fuse.conv", "linear_fuse") + key = key.replace("linear_fuse.bn", "batch_norm") + if "linear_c" in key: + # replace for example linear_c4 by linear_c.3 + idx = key[key.find("linear_c") + len("linear_c")] + key = key.replace(f"linear_c{idx}", f"linear_c.{int(idx) - 1}") + if key.startswith("head"): + key = key.replace("head", "classifier") + new_state_dict[key] = value + + return new_state_dict + + +def read_in_k_v(state_dict, config): + # for each of the encoder blocks: + for i in range(config.num_encoder_blocks): + for j in range(config.depths[i]): + # read in weights + bias of keys and values (which is a single matrix in the original implementation) + kv_weight = state_dict.pop(f"segformer.encoder.block.{i}.{j}.attention.self.kv.weight") + kv_bias = state_dict.pop(f"segformer.encoder.block.{i}.{j}.attention.self.kv.bias") + # next, add keys and values (in that order) to the state dict + state_dict[f"segformer.encoder.block.{i}.{j}.attention.self.key.weight"] = kv_weight[ + : config.hidden_sizes[i], : + ] + state_dict[f"segformer.encoder.block.{i}.{j}.attention.self.key.bias"] = kv_bias[: config.hidden_sizes[i]] + state_dict[f"segformer.encoder.block.{i}.{j}.attention.self.value.weight"] = kv_weight[ + config.hidden_sizes[i] :, : + ] + state_dict[f"segformer.encoder.block.{i}.{j}.attention.self.value.bias"] = kv_bias[ + config.hidden_sizes[i] : + ] + + +# We will verify our results on a COCO image +def prepare_img(): + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + image = Image.open(requests.get(url, stream=True).raw) + + return image + + +@torch.no_grad() +def convert_segformer_checkpoint(model_name, checkpoint_path, pytorch_dump_folder_path): + """ + Copy/paste/tweak model's weights to our SegFormer structure. + """ + + # load default SegFormer configuration + config = SegformerConfig() + encoder_only = False + + # set attributes based on model_name + repo_id = "huggingface/label-files" + if "segformer" in model_name: + size = model_name[len("segformer.") : len("segformer.") + 2] + if "ade" in model_name: + config.num_labels = 150 + filename = "ade20k-id2label.json" + expected_shape = (1, 150, 128, 128) + elif "city" in model_name: + config.num_labels = 19 + filename = "cityscapes-id2label.json" + expected_shape = (1, 19, 128, 128) + else: + raise ValueError(f"Model {model_name} not supported") + elif "mit" in model_name: + encoder_only = True + size = model_name[4:6] + config.num_labels = 1000 + filename = "imagenet-1k-id2label.json" + expected_shape = (1, 1000) + else: + raise ValueError(f"Model {model_name} not supported") + + # set config attributes + id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r")) + id2label = {int(k): v for k, v in id2label.items()} + config.id2label = id2label + config.label2id = {v: k for k, v in id2label.items()} + if size == "b0": + pass + elif size == "b1": + config.hidden_sizes = [64, 128, 320, 512] + config.decoder_hidden_size = 256 + elif size == "b2": + config.hidden_sizes = [64, 128, 320, 512] + config.decoder_hidden_size = 768 + config.depths = [3, 4, 6, 3] + elif size == "b3": + config.hidden_sizes = [64, 128, 320, 512] + config.decoder_hidden_size = 768 + config.depths = [3, 4, 18, 3] + elif size == "b4": + config.hidden_sizes = [64, 128, 320, 512] + config.decoder_hidden_size = 768 + config.depths = [3, 8, 27, 3] + elif size == "b5": + config.hidden_sizes = [64, 128, 320, 512] + config.decoder_hidden_size = 768 + config.depths = [3, 6, 40, 3] + else: + raise ValueError(f"Size {size} not supported") + + # load image processor (only resize + normalize) + image_processor = SegformerImageProcessor( + image_scale=(512, 512), keep_ratio=False, align=False, do_random_crop=False + ) + + # prepare image + image = prepare_img() + pixel_values = image_processor(images=image, return_tensors="pt").pixel_values + + logger.info(f"Converting model {model_name}...") + + # load original state dict + if encoder_only: + state_dict = torch.load(checkpoint_path, map_location=torch.device("cpu"), weights_only=True) + else: + state_dict = torch.load(checkpoint_path, map_location=torch.device("cpu"), weights_only=True)["state_dict"] + + # rename keys + state_dict = rename_keys(state_dict, encoder_only=encoder_only) + if not encoder_only: + del state_dict["decode_head.conv_seg.weight"] + del state_dict["decode_head.conv_seg.bias"] + + # key and value matrices need special treatment + read_in_k_v(state_dict, config) + + # create HuggingFace model and load state dict + if encoder_only: + config.reshape_last_stage = False + model = SegformerForImageClassification(config) + else: + model = SegformerForSemanticSegmentation(config) + model.load_state_dict(state_dict) + model.eval() + + # forward pass + outputs = model(pixel_values) + logits = outputs.logits + + # set expected_slice based on model name + # ADE20k checkpoints + if model_name == "segformer.b0.512x512.ade.160k": + expected_slice = torch.tensor( + [ + [[-4.6310, -5.5232, -6.2356], [-5.1921, -6.1444, -6.5996], [-5.4424, -6.2790, -6.7574]], + [[-12.1391, -13.3122, -13.9554], [-12.8732, -13.9352, -14.3563], [-12.9438, -13.8226, -14.2513]], + [[-12.5134, -13.4686, -14.4915], [-12.8669, -14.4343, -14.7758], [-13.2523, -14.5819, -15.0694]], + ] + ) + elif model_name == "segformer.b1.512x512.ade.160k": + expected_slice = torch.tensor( + [ + [[-7.5820, -8.7231, -8.3215], [-8.0600, -10.3529, -10.0304], [-7.5208, -9.4103, -9.6239]], + [[-12.6918, -13.8994, -13.7137], [-13.3196, -15.7523, -15.4789], [-12.9343, -14.8757, -14.9689]], + [[-11.1911, -11.9421, -11.3243], [-11.3342, -13.6839, -13.3581], [-10.3909, -12.1832, -12.4858]], + ] + ) + elif model_name == "segformer.b2.512x512.ade.160k": + expected_slice = torch.tensor( + [ + [[-11.8173, -14.3850, -16.3128], [-14.5648, -16.5804, -18.6568], [-14.7223, -15.7387, -18.4218]], + [[-15.7290, -17.9171, -19.4423], [-18.3105, -19.9448, -21.4661], [-17.9296, -18.6497, -20.7910]], + [[-15.0783, -17.0336, -18.2789], [-16.8771, -18.6870, -20.1612], [-16.2454, -17.1426, -19.5055]], + ] + ) + elif model_name == "segformer.b3.512x512.ade.160k": + expected_slice = torch.tensor( + [ + [[-9.0878, -10.2081, -10.1891], [-9.3144, -10.7941, -10.9843], [-9.2294, -10.3855, -10.5704]], + [[-12.2316, -13.9068, -13.6102], [-12.9161, -14.3702, -14.3235], [-12.5233, -13.7174, -13.7932]], + [[-14.6275, -15.2490, -14.9727], [-14.3400, -15.9687, -16.2827], [-14.1484, -15.4033, -15.8937]], + ] + ) + elif model_name == "segformer.b4.512x512.ade.160k": + expected_slice = torch.tensor( + [ + [[-12.3144, -13.2447, -14.0802], [-13.3614, -14.5816, -15.6117], [-13.3340, -14.4433, -16.2219]], + [[-19.2781, -20.4128, -20.7506], [-20.6153, -21.6566, -22.0998], [-19.9800, -21.0430, -22.1494]], + [[-18.8739, -19.7804, -21.1834], [-20.1233, -21.6765, -23.2944], [-20.0315, -21.2641, -23.6944]], + ] + ) + elif model_name == "segformer.b5.640x640.ade.160k": + expected_slice = torch.tensor( + [ + [[-9.5524, -12.0835, -11.7348], [-10.5229, -13.6446, -14.5662], [-9.5842, -12.8851, -13.9414]], + [[-15.3432, -17.5323, -17.0818], [-16.3330, -18.9255, -19.2101], [-15.1340, -17.7848, -18.3971]], + [[-12.6072, -14.9486, -14.6631], [-13.7629, -17.0907, -17.7745], [-12.7899, -16.1695, -17.1671]], + ] + ) + # Cityscapes checkpoints + elif model_name == "segformer.b0.1024x1024.city.160k": + expected_slice = torch.tensor( + [ + [[-11.9295, -13.4057, -14.8106], [-13.3431, -14.8179, -15.3781], [-14.2836, -15.5942, -16.1588]], + [[-11.4906, -12.8067, -13.6564], [-13.1189, -14.0500, -14.1543], [-13.8748, -14.5136, -14.8789]], + [[0.5374, 0.1067, -0.4742], [0.1141, -0.2255, -0.7099], [-0.3000, -0.5924, -1.3105]], + ] + ) + elif model_name == "segformer.b0.512x1024.city.160k": + expected_slice = torch.tensor( + [ + [[-7.8217, -9.8767, -10.1717], [-9.4438, -10.9058, -11.4047], [-9.7939, -12.3495, -12.1079]], + [[-7.1514, -9.5336, -10.0860], [-9.7776, -11.6822, -11.8439], [-10.1411, -12.7655, -12.8972]], + [[0.3021, 0.0805, -0.2310], [-0.0328, -0.1605, -0.2714], [-0.1408, -0.5477, -0.6976]], + ] + ) + elif model_name == "segformer.b0.640x1280.city.160k": + expected_slice = torch.tensor( + [ + [ + [-1.1372e01, -1.2787e01, -1.3477e01], + [-1.2536e01, -1.4194e01, -1.4409e01], + [-1.3217e01, -1.4888e01, -1.5327e01], + ], + [ + [-1.4791e01, -1.7122e01, -1.8277e01], + [-1.7163e01, -1.9192e01, -1.9533e01], + [-1.7897e01, -1.9991e01, -2.0315e01], + ], + [ + [7.6723e-01, 4.1921e-01, -7.7878e-02], + [4.7772e-01, 9.5557e-03, -2.8082e-01], + [3.6032e-01, -2.4826e-01, -5.1168e-01], + ], + ] + ) + elif model_name == "segformer.b0.768x768.city.160k": + expected_slice = torch.tensor( + [ + [[-9.4959, -11.3087, -11.7479], [-11.0025, -12.6540, -12.3319], [-11.4064, -13.0487, -12.9905]], + [[-9.8905, -11.3084, -12.0854], [-11.1726, -12.7698, -12.9583], [-11.5985, -13.3278, -14.1774]], + [[0.2213, 0.0192, -0.2466], [-0.1731, -0.4213, -0.4874], [-0.3126, -0.6541, -1.1389]], + ] + ) + elif model_name == "segformer.b1.1024x1024.city.160k": + expected_slice = torch.tensor( + [ + [[-13.5748, -13.9111, -12.6500], [-14.3500, -15.3683, -14.2328], [-14.7532, -16.0424, -15.6087]], + [[-17.1651, -15.8725, -12.9653], [-17.2580, -17.3718, -14.8223], [-16.6058, -16.8783, -16.7452]], + [[-3.6456, -3.0209, -1.4203], [-3.0797, -3.1959, -2.0000], [-1.8757, -1.9217, -1.6997]], + ] + ) + elif model_name == "segformer.b2.1024x1024.city.160k": + expected_slice = torch.tensor( + [ + [[-16.0976, -16.4856, -17.3962], [-16.6234, -19.0342, -19.7685], [-16.0900, -18.0661, -19.1180]], + [[-18.4750, -18.8488, -19.5074], [-19.4030, -22.1570, -22.5977], [-19.1191, -20.8486, -22.3783]], + [[-4.5178, -5.5037, -6.5109], [-5.0884, -7.2174, -8.0334], [-4.4156, -5.8117, -7.2970]], + ] + ) + elif model_name == "segformer.b3.1024x1024.city.160k": + expected_slice = torch.tensor( + [ + [[-14.2081, -14.4732, -14.1977], [-14.5867, -16.4423, -16.6356], [-13.4441, -14.9685, -16.8696]], + [[-14.4576, -14.7073, -15.0451], [-15.0816, -17.6237, -17.9873], [-14.4213, -16.0199, -18.5992]], + [[-4.7349, -4.9588, -5.0966], [-4.3210, -6.9325, -7.2591], [-3.4312, -4.7484, -7.1917]], + ] + ) + elif model_name == "segformer.b4.1024x1024.city.160k": + expected_slice = torch.tensor( + [ + [[-11.7737, -11.9526, -11.3273], [-13.6692, -14.4574, -13.8878], [-13.8937, -14.6924, -15.9345]], + [[-14.6706, -14.5330, -14.1306], [-16.1502, -16.8180, -16.4269], [-16.8338, -17.8939, -20.1746]], + [[1.0491, 0.8289, 1.0310], [1.1044, 0.5219, 0.8055], [1.0899, 0.6926, 0.5590]], + ] + ) + elif model_name == "segformer.b5.1024x1024.city.160k": + expected_slice = torch.tensor( + [ + [[-12.5641, -13.4777, -13.0684], [-13.9587, -15.8983, -16.6557], [-13.3109, -15.7350, -16.3141]], + [[-14.7074, -15.4352, -14.5944], [-16.6353, -18.1663, -18.6120], [-15.1702, -18.0329, -18.1547]], + [[-1.7990, -2.0951, -1.7784], [-2.6397, -3.8245, -3.9686], [-1.5264, -2.8126, -2.9316]], + ] + ) + else: + predicted_class_idx = logits.argmax(-1).item() + print("Predicted class:", model.config.id2label[predicted_class_idx]) + + # verify logits + if not encoder_only: + assert logits.shape == expected_shape + assert torch.allclose(logits[0, :3, :3, :3], expected_slice, atol=1e-2) + + # finally, save model and image processor + logger.info(f"Saving PyTorch model and image processor to {pytorch_dump_folder_path}...") + Path(pytorch_dump_folder_path).mkdir(exist_ok=True) + model.save_pretrained(pytorch_dump_folder_path) + image_processor.save_pretrained(pytorch_dump_folder_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--model_name", + default="segformer.b0.512x512.ade.160k", + type=str, + help="Name of the model you'd like to convert.", + ) + parser.add_argument( + "--checkpoint_path", default=None, type=str, help="Path to the original PyTorch checkpoint (.pth file)." + ) + parser.add_argument( + "--pytorch_dump_folder_path", default=None, type=str, help="Path to the folder to output PyTorch model." + ) + args = parser.parse_args() + convert_segformer_checkpoint(args.model_name, args.checkpoint_path, args.pytorch_dump_folder_path) diff --git a/docs/transformers/src/transformers/models/segformer/feature_extraction_segformer.py b/docs/transformers/src/transformers/models/segformer/feature_extraction_segformer.py new file mode 100644 index 0000000000000000000000000000000000000000..b979acf71ae312976c62fbfac63fb5752e049d09 --- /dev/null +++ b/docs/transformers/src/transformers/models/segformer/feature_extraction_segformer.py @@ -0,0 +1,38 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Feature extractor class for SegFormer.""" + +import warnings + +from ...utils import logging +from ...utils.import_utils import requires +from .image_processing_segformer import SegformerImageProcessor + + +logger = logging.get_logger(__name__) + + +@requires(backends=("vision",)) +class SegformerFeatureExtractor(SegformerImageProcessor): + def __init__(self, *args, **kwargs) -> None: + warnings.warn( + "The class SegformerFeatureExtractor is deprecated and will be removed in version 5 of Transformers." + " Please use SegformerImageProcessor instead.", + FutureWarning, + ) + super().__init__(*args, **kwargs) + + +__all__ = ["SegformerFeatureExtractor"] diff --git a/docs/transformers/src/transformers/models/segformer/image_processing_segformer.py b/docs/transformers/src/transformers/models/segformer/image_processing_segformer.py new file mode 100644 index 0000000000000000000000000000000000000000..79cbe47482dca3b3d21a54d667a8a46ddca21722 --- /dev/null +++ b/docs/transformers/src/transformers/models/segformer/image_processing_segformer.py @@ -0,0 +1,484 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for Segformer.""" + +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np + +from ...image_processing_utils import INIT_SERVICE_KWARGS, BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import resize, to_channel_dimension_format +from ...image_utils import ( + IMAGENET_DEFAULT_MEAN, + IMAGENET_DEFAULT_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + infer_channel_dimension_format, + is_scaled_image, + make_list_of_images, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...utils import ( + TensorType, + filter_out_non_signature_kwargs, + is_torch_available, + is_torch_tensor, + is_vision_available, + logging, +) +from ...utils.deprecation import deprecate_kwarg +from ...utils.import_utils import requires + + +if is_vision_available(): + import PIL.Image + +if is_torch_available(): + import torch + + +logger = logging.get_logger(__name__) + + +@requires(backends=("vision",)) +class SegformerImageProcessor(BaseImageProcessor): + r""" + Constructs a Segformer image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to the specified `(size["height"], + size["width"])`. Can be overridden by the `do_resize` parameter in the `preprocess` method. + size (`Dict[str, int]` *optional*, defaults to `{"height": 512, "width": 512}`): + Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess` + method. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`): + Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the + `preprocess` method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale` + parameter in the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess` + method. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess` + method. + image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + do_reduce_labels (`bool`, *optional*, defaults to `False`): + Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 is + used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k). The + background label will be replaced by 255. Can be overridden by the `do_reduce_labels` parameter in the + `preprocess` method. + """ + + model_input_names = ["pixel_values"] + + @deprecate_kwarg("reduce_labels", new_name="do_reduce_labels", version="4.41.0") + @filter_out_non_signature_kwargs(extra=INIT_SERVICE_KWARGS) + def __init__( + self, + do_resize: bool = True, + size: Dict[str, int] = None, + resample: PILImageResampling = PILImageResampling.BILINEAR, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_reduce_labels: bool = False, + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"height": 512, "width": 512} + size = get_size_dict(size) + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN + self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD + self.do_reduce_labels = do_reduce_labels + + @classmethod + def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs): + """ + Overrides the `from_dict` method from the base class to save support of deprecated `reduce_labels` in old configs + """ + image_processor_dict = image_processor_dict.copy() + if "reduce_labels" in image_processor_dict: + image_processor_dict["do_reduce_labels"] = image_processor_dict.pop("reduce_labels") + return super().from_dict(image_processor_dict, **kwargs) + + # Copied from transformers.models.vit.image_processing_vit.ViTImageProcessor.resize + def resize( + self, + image: np.ndarray, + size: Dict[str, int], + resample: PILImageResampling = PILImageResampling.BILINEAR, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize an image to `(size["height"], size["width"])`. + + Args: + image (`np.ndarray`): + Image to resize. + size (`Dict[str, int]`): + Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`): + `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`. + data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the output image. If unset, the channel dimension format of the input + image is used. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + + Returns: + `np.ndarray`: The resized image. + """ + size = get_size_dict(size) + if "height" not in size or "width" not in size: + raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}") + output_size = (size["height"], size["width"]) + return resize( + image, + size=output_size, + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + # Copied from transformers.models.beit.image_processing_beit.BeitImageProcessor.reduce_label + def reduce_label(self, label: ImageInput) -> np.ndarray: + label = to_numpy_array(label) + # Avoid using underflow conversion + label[label == 0] = 255 + label = label - 1 + label[label == 254] = 255 + return label + + def _preprocess( + self, + image: ImageInput, + do_reduce_labels: bool, + do_resize: bool, + do_rescale: bool, + do_normalize: bool, + size: Optional[Dict[str, int]] = None, + resample: PILImageResampling = None, + rescale_factor: Optional[float] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + if do_reduce_labels: + image = self.reduce_label(image) + + if do_resize: + image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format) + + if do_rescale: + image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + + if do_normalize: + image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) + + return image + + def _preprocess_image( + self, + image: ImageInput, + do_resize: Optional[bool] = None, + size: Dict[str, int] = None, + resample: PILImageResampling = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> np.ndarray: + """Preprocesses a single image.""" + # All transformations expect numpy arrays. + image = to_numpy_array(image) + if do_rescale and is_scaled_image(image): + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + if input_data_format is None: + input_data_format = infer_channel_dimension_format(image) + image = self._preprocess( + image=image, + do_reduce_labels=False, + do_resize=do_resize, + size=size, + resample=resample, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + input_data_format=input_data_format, + ) + if data_format is not None: + image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + return image + + def _preprocess_mask( + self, + segmentation_map: ImageInput, + do_reduce_labels: Optional[bool] = None, + do_resize: Optional[bool] = None, + size: Dict[str, int] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> np.ndarray: + """Preprocesses a single mask.""" + segmentation_map = to_numpy_array(segmentation_map) + # Add channel dimension if missing - needed for certain transformations + if segmentation_map.ndim == 2: + added_channel_dim = True + segmentation_map = segmentation_map[None, ...] + input_data_format = ChannelDimension.FIRST + else: + added_channel_dim = False + if input_data_format is None: + input_data_format = infer_channel_dimension_format(segmentation_map, num_channels=1) + # reduce zero label if needed + segmentation_map = self._preprocess( + image=segmentation_map, + do_reduce_labels=do_reduce_labels, + do_resize=do_resize, + resample=PILImageResampling.NEAREST, + size=size, + do_rescale=False, + do_normalize=False, + input_data_format=input_data_format, + ) + # Remove extra channel dimension if added for processing + if added_channel_dim: + segmentation_map = segmentation_map.squeeze(0) + segmentation_map = segmentation_map.astype(np.int64) + return segmentation_map + + def __call__(self, images, segmentation_maps=None, **kwargs): + """ + Preprocesses a batch of images and optionally segmentation maps. + + Overrides the `__call__` method of the `Preprocessor` class so that both images and segmentation maps can be + passed in as positional arguments. + """ + return super().__call__(images, segmentation_maps=segmentation_maps, **kwargs) + + @deprecate_kwarg("reduce_labels", new_name="do_reduce_labels", version="4.41.0") + @filter_out_non_signature_kwargs() + def preprocess( + self, + images: ImageInput, + segmentation_maps: Optional[ImageInput] = None, + do_resize: Optional[bool] = None, + size: Optional[Dict[str, int]] = None, + resample: PILImageResampling = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_reduce_labels: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: ChannelDimension = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> PIL.Image.Image: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + segmentation_maps (`ImageInput`, *optional*): + Segmentation map to preprocess. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Size of the image after `resize` is applied. + resample (`int`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only + has an effect if `do_resize` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image values between [0 - 1]. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation. + do_reduce_labels (`bool`, *optional*, defaults to `self.do_reduce_labels`): + Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 + is used for background, and background itself is not included in all classes of a dataset (e.g. + ADE20k). The background label will be replaced by 255. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `ChannelDimension.LAST`: image in (height, width, num_channels) format. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + do_resize = do_resize if do_resize is not None else self.do_resize + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + do_reduce_labels = do_reduce_labels if do_reduce_labels is not None else self.do_reduce_labels + resample = resample if resample is not None else self.resample + size = size if size is not None else self.size + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + + images = make_list_of_images(images) + + if segmentation_maps is not None: + segmentation_maps = make_list_of_images(segmentation_maps, expected_ndims=2) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_resize=do_resize, + size=size, + resample=resample, + ) + + images = [ + self._preprocess_image( + image=img, + do_resize=do_resize, + resample=resample, + size=size, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + data_format=data_format, + input_data_format=input_data_format, + ) + for img in images + ] + + data = {"pixel_values": images} + + if segmentation_maps is not None: + segmentation_maps = [ + self._preprocess_mask( + segmentation_map=segmentation_map, + do_reduce_labels=do_reduce_labels, + do_resize=do_resize, + size=size, + input_data_format=input_data_format, + ) + for segmentation_map in segmentation_maps + ] + data["labels"] = segmentation_maps + + return BatchFeature(data=data, tensor_type=return_tensors) + + # Copied from transformers.models.beit.image_processing_beit.BeitImageProcessor.post_process_semantic_segmentation with Beit->Segformer + def post_process_semantic_segmentation(self, outputs, target_sizes: List[Tuple] = None): + """ + Converts the output of [`SegformerForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch. + + Args: + outputs ([`SegformerForSemanticSegmentation`]): + Raw outputs of the model. + target_sizes (`List[Tuple]` of length `batch_size`, *optional*): + List of tuples corresponding to the requested final size (height, width) of each prediction. If unset, + predictions will not be resized. + + Returns: + semantic_segmentation: `List[torch.Tensor]` of length `batch_size`, where each item is a semantic + segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is + specified). Each entry of each `torch.Tensor` correspond to a semantic class id. + """ + # TODO: add support for other frameworks + logits = outputs.logits + + # Resize logits and compute semantic segmentation maps + if target_sizes is not None: + if len(logits) != len(target_sizes): + raise ValueError( + "Make sure that you pass in as many target sizes as the batch dimension of the logits" + ) + + if is_torch_tensor(target_sizes): + target_sizes = target_sizes.numpy() + + semantic_segmentation = [] + + for idx in range(len(logits)): + resized_logits = torch.nn.functional.interpolate( + logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False + ) + semantic_map = resized_logits[0].argmax(dim=0) + semantic_segmentation.append(semantic_map) + else: + semantic_segmentation = logits.argmax(dim=1) + semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])] + + return semantic_segmentation + + +__all__ = ["SegformerImageProcessor"] diff --git a/docs/transformers/src/transformers/models/segformer/modeling_segformer.py b/docs/transformers/src/transformers/models/segformer/modeling_segformer.py new file mode 100644 index 0000000000000000000000000000000000000000..4f9da2cab7b341ece9337d3f6be32f472142bb6e --- /dev/null +++ b/docs/transformers/src/transformers/models/segformer/modeling_segformer.py @@ -0,0 +1,840 @@ +# coding=utf-8 +# Copyright 2021 NVIDIA The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch SegFormer model.""" + +import math +from typing import Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput, SemanticSegmenterOutput +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from .configuration_segformer import SegformerConfig + + +logger = logging.get_logger(__name__) + + +# General docstring +_CONFIG_FOR_DOC = "SegformerConfig" + +# Base docstring +_CHECKPOINT_FOR_DOC = "nvidia/mit-b0" +_EXPECTED_OUTPUT_SHAPE = [1, 256, 16, 16] + +# Image classification docstring +_IMAGE_CLASS_CHECKPOINT = "nvidia/mit-b0" +_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat" + + +class SegFormerImageClassifierOutput(ImageClassifierOutput): + """ + Base class for outputs of image classification models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each stage) of shape `(batch_size, num_channels, height, width)`. Hidden-states (also + called feature maps) of the model at the output of each stage. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +# Copied from transformers.models.beit.modeling_beit.drop_path +def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, + however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the + layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the + argument. + """ + if drop_prob == 0.0 or not training: + return input + keep_prob = 1 - drop_prob + shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) + random_tensor.floor_() # binarize + output = input.div(keep_prob) * random_tensor + return output + + +# Copied from transformers.models.convnext.modeling_convnext.ConvNextDropPath with ConvNext->Segformer +class SegformerDropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: Optional[float] = None) -> None: + super().__init__() + self.drop_prob = drop_prob + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return drop_path(hidden_states, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return "p={}".format(self.drop_prob) + + +class SegformerOverlapPatchEmbeddings(nn.Module): + """Construct the overlapping patch embeddings.""" + + def __init__(self, patch_size, stride, num_channels, hidden_size): + super().__init__() + self.proj = nn.Conv2d( + num_channels, + hidden_size, + kernel_size=patch_size, + stride=stride, + padding=patch_size // 2, + ) + + self.layer_norm = nn.LayerNorm(hidden_size) + + def forward(self, pixel_values): + embeddings = self.proj(pixel_values) + _, _, height, width = embeddings.shape + # (batch_size, num_channels, height, width) -> (batch_size, num_channels, height*width) -> (batch_size, height*width, num_channels) + # this can be fed to a Transformer layer + embeddings = embeddings.flatten(2).transpose(1, 2) + embeddings = self.layer_norm(embeddings) + return embeddings, height, width + + +class SegformerEfficientSelfAttention(nn.Module): + """SegFormer's efficient self-attention mechanism. Employs the sequence reduction process introduced in the [PvT + paper](https://arxiv.org/abs/2102.12122).""" + + def __init__(self, config, hidden_size, num_attention_heads, sequence_reduction_ratio): + super().__init__() + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + + if self.hidden_size % self.num_attention_heads != 0: + raise ValueError( + f"The hidden size ({self.hidden_size}) is not a multiple of the number of attention " + f"heads ({self.num_attention_heads})" + ) + + self.attention_head_size = int(self.hidden_size / self.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(self.hidden_size, self.all_head_size) + self.key = nn.Linear(self.hidden_size, self.all_head_size) + self.value = nn.Linear(self.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + self.sr_ratio = sequence_reduction_ratio + if sequence_reduction_ratio > 1: + self.sr = nn.Conv2d( + hidden_size, hidden_size, kernel_size=sequence_reduction_ratio, stride=sequence_reduction_ratio + ) + self.layer_norm = nn.LayerNorm(hidden_size) + + def transpose_for_scores(self, hidden_states): + new_shape = hidden_states.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + hidden_states = hidden_states.view(new_shape) + return hidden_states.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + height, + width, + output_attentions=False, + ): + query_layer = self.transpose_for_scores(self.query(hidden_states)) + + if self.sr_ratio > 1: + batch_size, seq_len, num_channels = hidden_states.shape + # Reshape to (batch_size, num_channels, height, width) + hidden_states = hidden_states.permute(0, 2, 1).reshape(batch_size, num_channels, height, width) + # Apply sequence reduction + hidden_states = self.sr(hidden_states) + # Reshape back to (batch_size, seq_len, num_channels) + hidden_states = hidden_states.reshape(batch_size, num_channels, -1).permute(0, 2, 1) + hidden_states = self.layer_norm(hidden_states) + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +class SegformerSelfOutput(nn.Module): + def __init__(self, config, hidden_size): + super().__init__() + self.dense = nn.Linear(hidden_size, hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +class SegformerAttention(nn.Module): + def __init__(self, config, hidden_size, num_attention_heads, sequence_reduction_ratio): + super().__init__() + self.self = SegformerEfficientSelfAttention( + config=config, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + sequence_reduction_ratio=sequence_reduction_ratio, + ) + self.output = SegformerSelfOutput(config, hidden_size=hidden_size) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward(self, hidden_states, height, width, output_attentions=False): + self_outputs = self.self(hidden_states, height, width, output_attentions) + + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class SegformerDWConv(nn.Module): + def __init__(self, dim=768): + super().__init__() + self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) + + def forward(self, hidden_states, height, width): + batch_size, seq_len, num_channels = hidden_states.shape + hidden_states = hidden_states.transpose(1, 2).view(batch_size, num_channels, height, width) + hidden_states = self.dwconv(hidden_states) + hidden_states = hidden_states.flatten(2).transpose(1, 2) + + return hidden_states + + +class SegformerMixFFN(nn.Module): + def __init__(self, config, in_features, hidden_features=None, out_features=None): + super().__init__() + out_features = out_features or in_features + self.dense1 = nn.Linear(in_features, hidden_features) + self.dwconv = SegformerDWConv(hidden_features) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + self.dense2 = nn.Linear(hidden_features, out_features) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, height, width): + hidden_states = self.dense1(hidden_states) + hidden_states = self.dwconv(hidden_states, height, width) + hidden_states = self.intermediate_act_fn(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.dense2(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +class SegformerLayer(nn.Module): + """This corresponds to the Block class in the original implementation.""" + + def __init__(self, config, hidden_size, num_attention_heads, drop_path, sequence_reduction_ratio, mlp_ratio): + super().__init__() + self.layer_norm_1 = nn.LayerNorm(hidden_size) + self.attention = SegformerAttention( + config, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + sequence_reduction_ratio=sequence_reduction_ratio, + ) + self.drop_path = SegformerDropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.layer_norm_2 = nn.LayerNorm(hidden_size) + mlp_hidden_size = int(hidden_size * mlp_ratio) + self.mlp = SegformerMixFFN(config, in_features=hidden_size, hidden_features=mlp_hidden_size) + + def forward(self, hidden_states, height, width, output_attentions=False): + self_attention_outputs = self.attention( + self.layer_norm_1(hidden_states), # in Segformer, layernorm is applied before self-attention + height, + width, + output_attentions=output_attentions, + ) + + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + # first residual connection (with stochastic depth) + attention_output = self.drop_path(attention_output) + hidden_states = attention_output + hidden_states + + mlp_output = self.mlp(self.layer_norm_2(hidden_states), height, width) + + # second residual connection (with stochastic depth) + mlp_output = self.drop_path(mlp_output) + layer_output = mlp_output + hidden_states + + outputs = (layer_output,) + outputs + + return outputs + + +class SegformerEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + + # stochastic depth decay rule + drop_path_decays = [ + x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu") + ] + + # patch embeddings + embeddings = [] + for i in range(config.num_encoder_blocks): + embeddings.append( + SegformerOverlapPatchEmbeddings( + patch_size=config.patch_sizes[i], + stride=config.strides[i], + num_channels=config.num_channels if i == 0 else config.hidden_sizes[i - 1], + hidden_size=config.hidden_sizes[i], + ) + ) + self.patch_embeddings = nn.ModuleList(embeddings) + + # Transformer blocks + blocks = [] + cur = 0 + for i in range(config.num_encoder_blocks): + # each block consists of layers + layers = [] + if i != 0: + cur += config.depths[i - 1] + for j in range(config.depths[i]): + layers.append( + SegformerLayer( + config, + hidden_size=config.hidden_sizes[i], + num_attention_heads=config.num_attention_heads[i], + drop_path=drop_path_decays[cur + j], + sequence_reduction_ratio=config.sr_ratios[i], + mlp_ratio=config.mlp_ratios[i], + ) + ) + blocks.append(nn.ModuleList(layers)) + + self.block = nn.ModuleList(blocks) + + # Layer norms + self.layer_norm = nn.ModuleList( + [nn.LayerNorm(config.hidden_sizes[i]) for i in range(config.num_encoder_blocks)] + ) + + def forward( + self, + pixel_values: torch.FloatTensor, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + ) -> Union[Tuple, BaseModelOutput]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + batch_size = pixel_values.shape[0] + + hidden_states = pixel_values + for idx, x in enumerate(zip(self.patch_embeddings, self.block, self.layer_norm)): + embedding_layer, block_layer, norm_layer = x + # first, obtain patch embeddings + hidden_states, height, width = embedding_layer(hidden_states) + # second, send embeddings through blocks + for i, blk in enumerate(block_layer): + layer_outputs = blk(hidden_states, height, width, output_attentions) + hidden_states = layer_outputs[0] + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + # third, apply layer norm + hidden_states = norm_layer(hidden_states) + # fourth, optionally reshape back to (batch_size, num_channels, height, width) + if idx != len(self.patch_embeddings) - 1 or ( + idx == len(self.patch_embeddings) - 1 and self.config.reshape_last_stage + ): + hidden_states = hidden_states.reshape(batch_size, height, width, -1).permute(0, 3, 1, 2).contiguous() + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class SegformerPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = SegformerConfig + base_model_prefix = "segformer" + main_input_name = "pixel_values" + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d)): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +SEGFORMER_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use + it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`SegformerConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +SEGFORMER_INPUTS_DOCSTRING = r""" + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`SegformerImageProcessor.__call__`] for details. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare SegFormer encoder (Mix-Transformer) outputting raw hidden-states without any specific head on top.", + SEGFORMER_START_DOCSTRING, +) +class SegformerModel(SegformerPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.config = config + + # hierarchical Transformer encoder + self.encoder = SegformerEncoder(config) + + # Initialize weights and apply final processing + self.post_init() + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(SEGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutput, + config_class=_CONFIG_FOR_DOC, + modality="vision", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward( + self, + pixel_values: torch.FloatTensor, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_outputs = self.encoder( + pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + + if not return_dict: + return (sequence_output,) + encoder_outputs[1:] + + return BaseModelOutput( + last_hidden_state=sequence_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + """ + SegFormer Model transformer with an image classification head on top (a linear layer on top of the final hidden + states) e.g. for ImageNet. + """, + SEGFORMER_START_DOCSTRING, +) +class SegformerForImageClassification(SegformerPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.num_labels = config.num_labels + self.segformer = SegformerModel(config) + + # Classifier head + self.classifier = nn.Linear(config.hidden_sizes[-1], config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(SEGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_IMAGE_CLASS_CHECKPOINT, + output_type=SegFormerImageClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, + ) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SegFormerImageClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.segformer( + pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + # convert last hidden states to (batch_size, height*width, hidden_size) + batch_size = sequence_output.shape[0] + if self.config.reshape_last_stage: + # (batch_size, num_channels, height, width) -> (batch_size, height, width, num_channels) + sequence_output = sequence_output.permute(0, 2, 3, 1) + sequence_output = sequence_output.reshape(batch_size, -1, self.config.hidden_sizes[-1]) + + # global average pooling + sequence_output = sequence_output.mean(dim=1) + + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SegFormerImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class SegformerMLP(nn.Module): + """ + Linear Embedding. + """ + + def __init__(self, config: SegformerConfig, input_dim): + super().__init__() + self.proj = nn.Linear(input_dim, config.decoder_hidden_size) + + def forward(self, hidden_states: torch.Tensor): + hidden_states = hidden_states.flatten(2).transpose(1, 2) + hidden_states = self.proj(hidden_states) + return hidden_states + + +class SegformerDecodeHead(SegformerPreTrainedModel): + def __init__(self, config): + super().__init__(config) + # linear layers which will unify the channel dimension of each of the encoder blocks to the same config.decoder_hidden_size + mlps = [] + for i in range(config.num_encoder_blocks): + mlp = SegformerMLP(config, input_dim=config.hidden_sizes[i]) + mlps.append(mlp) + self.linear_c = nn.ModuleList(mlps) + + # the following 3 layers implement the ConvModule of the original implementation + self.linear_fuse = nn.Conv2d( + in_channels=config.decoder_hidden_size * config.num_encoder_blocks, + out_channels=config.decoder_hidden_size, + kernel_size=1, + bias=False, + ) + self.batch_norm = nn.BatchNorm2d(config.decoder_hidden_size) + self.activation = nn.ReLU() + + self.dropout = nn.Dropout(config.classifier_dropout_prob) + self.classifier = nn.Conv2d(config.decoder_hidden_size, config.num_labels, kernel_size=1) + + self.config = config + + def forward(self, encoder_hidden_states: torch.FloatTensor) -> torch.Tensor: + batch_size = encoder_hidden_states[-1].shape[0] + + all_hidden_states = () + for encoder_hidden_state, mlp in zip(encoder_hidden_states, self.linear_c): + if self.config.reshape_last_stage is False and encoder_hidden_state.ndim == 3: + height = width = int(math.sqrt(encoder_hidden_state.shape[-1])) + encoder_hidden_state = ( + encoder_hidden_state.reshape(batch_size, height, width, -1).permute(0, 3, 1, 2).contiguous() + ) + + # unify channel dimension + height, width = encoder_hidden_state.shape[2], encoder_hidden_state.shape[3] + encoder_hidden_state = mlp(encoder_hidden_state) + encoder_hidden_state = encoder_hidden_state.permute(0, 2, 1) + encoder_hidden_state = encoder_hidden_state.reshape(batch_size, -1, height, width) + # upsample + encoder_hidden_state = nn.functional.interpolate( + encoder_hidden_state, size=encoder_hidden_states[0].size()[2:], mode="bilinear", align_corners=False + ) + all_hidden_states += (encoder_hidden_state,) + + hidden_states = self.linear_fuse(torch.cat(all_hidden_states[::-1], dim=1)) + hidden_states = self.batch_norm(hidden_states) + hidden_states = self.activation(hidden_states) + hidden_states = self.dropout(hidden_states) + + # logits are of shape (batch_size, num_labels, height/4, width/4) + logits = self.classifier(hidden_states) + + return logits + + +@add_start_docstrings( + """SegFormer Model transformer with an all-MLP decode head on top e.g. for ADE20k, CityScapes.""", + SEGFORMER_START_DOCSTRING, +) +class SegformerForSemanticSegmentation(SegformerPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.segformer = SegformerModel(config) + self.decode_head = SegformerDecodeHead(config) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(SEGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=SemanticSegmenterOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: torch.FloatTensor, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SemanticSegmenterOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*): + Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy). + + Returns: + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, SegformerForSemanticSegmentation + >>> from PIL import Image + >>> import requests + + >>> image_processor = AutoImageProcessor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512") + >>> model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = image_processor(images=image, return_tensors="pt") + >>> outputs = model(**inputs) + >>> logits = outputs.logits # shape (batch_size, num_labels, height/4, width/4) + >>> list(logits.shape) + [1, 150, 128, 128] + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + if labels is not None and self.config.num_labels < 1: + raise ValueError(f"Number of labels should be >=0: {self.config.num_labels}") + + outputs = self.segformer( + pixel_values, + output_attentions=output_attentions, + output_hidden_states=True, # we need the intermediate hidden states + return_dict=return_dict, + ) + + encoder_hidden_states = outputs.hidden_states if return_dict else outputs[1] + + logits = self.decode_head(encoder_hidden_states) + + loss = None + if labels is not None: + # upsample logits to the images' original size + upsampled_logits = nn.functional.interpolate( + logits, size=labels.shape[-2:], mode="bilinear", align_corners=False + ) + if self.config.num_labels > 1: + loss_fct = CrossEntropyLoss(ignore_index=self.config.semantic_loss_ignore_index) + loss = loss_fct(upsampled_logits, labels) + elif self.config.num_labels == 1: + valid_mask = ((labels >= 0) & (labels != self.config.semantic_loss_ignore_index)).float() + loss_fct = BCEWithLogitsLoss(reduction="none") + loss = loss_fct(upsampled_logits.squeeze(1), labels.float()) + loss = (loss * valid_mask).mean() + + if not return_dict: + if output_hidden_states: + output = (logits,) + outputs[1:] + else: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SemanticSegmenterOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states if output_hidden_states else None, + attentions=outputs.attentions, + ) + + +__all__ = [ + "SegformerDecodeHead", + "SegformerForImageClassification", + "SegformerForSemanticSegmentation", + "SegformerLayer", + "SegformerModel", + "SegformerPreTrainedModel", +] diff --git a/docs/transformers/src/transformers/models/segformer/modeling_tf_segformer.py b/docs/transformers/src/transformers/models/segformer/modeling_tf_segformer.py new file mode 100644 index 0000000000000000000000000000000000000000..bad72cdf2bbc1752279507febd14d86e94a3e20d --- /dev/null +++ b/docs/transformers/src/transformers/models/segformer/modeling_tf_segformer.py @@ -0,0 +1,1045 @@ +# coding=utf-8 +# Copyright 2022 NVIDIA The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TensorFlow SegFormer model.""" + +from __future__ import annotations + +import math +from typing import Optional, Tuple, Union + +import tensorflow as tf + +from ...activations_tf import get_tf_activation +from ...file_utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + replace_return_docstrings, +) +from ...modeling_tf_outputs import TFBaseModelOutput, TFSemanticSegmenterOutput, TFSequenceClassifierOutput +from ...modeling_tf_utils import ( + TFPreTrainedModel, + TFSequenceClassificationLoss, + keras, + keras_serializable, + unpack_inputs, +) +from ...tf_utils import shape_list, stable_softmax +from ...utils import logging +from .configuration_segformer import SegformerConfig + + +logger = logging.get_logger(__name__) + +# General docstring +_CONFIG_FOR_DOC = "SegformerConfig" + +# Base docstring +_CHECKPOINT_FOR_DOC = "nvidia/mit-b0" +_EXPECTED_OUTPUT_SHAPE = [1, 256, 16, 16] + +# Image classification docstring +_IMAGE_CLASS_CHECKPOINT = "nvidia/mit-b0" +_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat" + + +# Copied from transformers.models.convnext.modeling_tf_convnext.TFConvNextDropPath with ConvNext->Segformer +class TFSegformerDropPath(keras.layers.Layer): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + References: + (1) github.com:rwightman/pytorch-image-models + """ + + def __init__(self, drop_path: float, **kwargs): + super().__init__(**kwargs) + self.drop_path = drop_path + + def call(self, x: tf.Tensor, training=None): + if training: + keep_prob = 1 - self.drop_path + shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1) + random_tensor = keep_prob + tf.random.uniform(shape, 0, 1) + random_tensor = tf.floor(random_tensor) + return (x / keep_prob) * random_tensor + return x + + +class TFSegformerOverlapPatchEmbeddings(keras.layers.Layer): + """Construct the overlapping patch embeddings.""" + + def __init__(self, patch_size, stride, num_channels, hidden_size, **kwargs): + super().__init__(**kwargs) + self.padding = keras.layers.ZeroPadding2D(padding=patch_size // 2) + self.proj = keras.layers.Conv2D( + filters=hidden_size, kernel_size=patch_size, strides=stride, padding="VALID", name="proj" + ) + + self.layer_norm = keras.layers.LayerNormalization(epsilon=1e-05, name="layer_norm") + self.num_channels = num_channels + self.hidden_size = hidden_size + + def call(self, pixel_values: tf.Tensor) -> Tuple[tf.Tensor, int, int]: + embeddings = self.proj(self.padding(pixel_values)) + height = shape_list(embeddings)[1] + width = shape_list(embeddings)[2] + hidden_dim = shape_list(embeddings)[3] + # (batch_size, height, width, num_channels) -> (batch_size, height*width, num_channels) + # this can be fed to a Transformer layer + embeddings = tf.reshape(embeddings, (-1, height * width, hidden_dim)) + embeddings = self.layer_norm(embeddings) + return embeddings, height, width + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "proj", None) is not None: + with tf.name_scope(self.proj.name): + self.proj.build([None, None, None, self.num_channels]) + if getattr(self, "layer_norm", None) is not None: + with tf.name_scope(self.layer_norm.name): + self.layer_norm.build([None, None, self.hidden_size]) + + +class TFSegformerEfficientSelfAttention(keras.layers.Layer): + """SegFormer's efficient self-attention mechanism. Employs the sequence reduction process introduced in the [PvT + paper](https://arxiv.org/abs/2102.12122).""" + + def __init__( + self, + config: SegformerConfig, + hidden_size: int, + num_attention_heads: int, + sequence_reduction_ratio: int, + **kwargs, + ): + super().__init__(**kwargs) + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + + if self.hidden_size % self.num_attention_heads != 0: + raise ValueError( + f"The hidden size ({self.hidden_size}) is not a multiple of the number of attention " + f"heads ({self.num_attention_heads})" + ) + + self.attention_head_size = self.hidden_size // self.num_attention_heads + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.sqrt_att_head_size = math.sqrt(self.attention_head_size) + + self.query = keras.layers.Dense(self.all_head_size, name="query") + self.key = keras.layers.Dense(self.all_head_size, name="key") + self.value = keras.layers.Dense(self.all_head_size, name="value") + + self.dropout = keras.layers.Dropout(config.attention_probs_dropout_prob) + + self.sr_ratio = sequence_reduction_ratio + if sequence_reduction_ratio > 1: + self.sr = keras.layers.Conv2D( + filters=hidden_size, kernel_size=sequence_reduction_ratio, strides=sequence_reduction_ratio, name="sr" + ) + self.layer_norm = keras.layers.LayerNormalization(epsilon=1e-05, name="layer_norm") + + def transpose_for_scores(self, tensor: tf.Tensor) -> tf.Tensor: + # Reshape from [batch_size, seq_length, all_head_size] + # to [batch_size, seq_length, num_attention_heads, attention_head_size] + batch_size = shape_list(tensor)[0] + tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size)) + + # Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] + # to [batch_size, num_attention_heads, seq_length, attention_head_size] + return tf.transpose(tensor, perm=[0, 2, 1, 3]) + + def call( + self, + hidden_states: tf.Tensor, + height: int, + width: int, + output_attentions: bool = False, + training: bool = False, + ) -> Union[tf.Tensor, Tuple[tf.Tensor, tf.Tensor]]: + batch_size = shape_list(hidden_states)[0] + num_channels = shape_list(hidden_states)[2] + + query_layer = self.transpose_for_scores(self.query(hidden_states)) + + if self.sr_ratio > 1: + # Reshape to (batch_size, height, width, num_channels) + hidden_states = tf.reshape(hidden_states, (batch_size, height, width, num_channels)) + # Apply sequence reduction + hidden_states = self.sr(hidden_states) + # Reshape back to (batch_size, seq_len, num_channels) + hidden_states = tf.reshape(hidden_states, (batch_size, -1, num_channels)) + hidden_states = self.layer_norm(hidden_states) + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) + + scale = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype) + attention_scores = tf.divide(attention_scores, scale) + + # Normalize the attention scores to probabilities. + attention_probs = stable_softmax(logits=attention_scores, axis=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs, training=training) + + context_layer = tf.matmul(attention_probs, value_layer) + + context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3]) + # (batch_size, seq_len_q, all_head_size) + context_layer = tf.reshape(context_layer, (batch_size, -1, self.all_head_size)) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "query", None) is not None: + with tf.name_scope(self.query.name): + self.query.build([None, None, self.hidden_size]) + if getattr(self, "key", None) is not None: + with tf.name_scope(self.key.name): + self.key.build([None, None, self.hidden_size]) + if getattr(self, "value", None) is not None: + with tf.name_scope(self.value.name): + self.value.build([None, None, self.hidden_size]) + if getattr(self, "sr", None) is not None: + with tf.name_scope(self.sr.name): + self.sr.build([None, None, None, self.hidden_size]) + if getattr(self, "layer_norm", None) is not None: + with tf.name_scope(self.layer_norm.name): + self.layer_norm.build([None, None, self.hidden_size]) + + +class TFSegformerSelfOutput(keras.layers.Layer): + def __init__(self, config: SegformerConfig, hidden_size: int, **kwargs): + super().__init__(**kwargs) + self.dense = keras.layers.Dense(hidden_size, name="dense") + self.dropout = keras.layers.Dropout(config.hidden_dropout_prob) + self.hidden_size = hidden_size + + def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states, training=training) + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.hidden_size]) + + +class TFSegformerAttention(keras.layers.Layer): + def __init__( + self, + config: SegformerConfig, + hidden_size: int, + num_attention_heads: int, + sequence_reduction_ratio: int, + **kwargs, + ): + super().__init__(**kwargs) + self.self = TFSegformerEfficientSelfAttention( + config=config, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + sequence_reduction_ratio=sequence_reduction_ratio, + name="self", + ) + self.dense_output = TFSegformerSelfOutput(config, hidden_size=hidden_size, name="output") + + def call( + self, hidden_states: tf.Tensor, height: int, width: int, output_attentions: bool = False + ) -> Union[tf.Tensor, Tuple[tf.Tensor, tf.Tensor]]: + self_outputs = self.self(hidden_states, height, width, output_attentions) + + attention_output = self.dense_output(self_outputs[0]) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "self", None) is not None: + with tf.name_scope(self.self.name): + self.self.build(None) + if getattr(self, "dense_output", None) is not None: + with tf.name_scope(self.dense_output.name): + self.dense_output.build(None) + + +class TFSegformerDWConv(keras.layers.Layer): + def __init__(self, dim: int = 768, **kwargs): + super().__init__(**kwargs) + self.depthwise_convolution = keras.layers.Conv2D( + filters=dim, kernel_size=3, strides=1, padding="same", groups=dim, name="dwconv" + ) + self.dim = dim + + def call(self, hidden_states: tf.Tensor, height: int, width: int) -> tf.Tensor: + batch_size = shape_list(hidden_states)[0] + num_channels = shape_list(hidden_states)[-1] + hidden_states = tf.reshape(hidden_states, (batch_size, height, width, num_channels)) + hidden_states = self.depthwise_convolution(hidden_states) + + new_height = shape_list(hidden_states)[1] + new_width = shape_list(hidden_states)[2] + num_channels = shape_list(hidden_states)[3] + hidden_states = tf.reshape(hidden_states, (batch_size, new_height * new_width, num_channels)) + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "depthwise_convolution", None) is not None: + with tf.name_scope(self.depthwise_convolution.name): + self.depthwise_convolution.build([None, None, None, self.dim]) + + +class TFSegformerMixFFN(keras.layers.Layer): + def __init__( + self, + config: SegformerConfig, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + **kwargs, + ): + super().__init__(**kwargs) + out_features = out_features or in_features + self.dense1 = keras.layers.Dense(hidden_features, name="dense1") + self.depthwise_convolution = TFSegformerDWConv(hidden_features, name="dwconv") + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = get_tf_activation(config.hidden_act) + else: + self.intermediate_act_fn = config.hidden_act + self.dense2 = keras.layers.Dense(out_features, name="dense2") + self.dropout = keras.layers.Dropout(config.hidden_dropout_prob) + self.hidden_features = hidden_features + self.in_features = in_features + + def call(self, hidden_states: tf.Tensor, height: int, width: int, training: bool = False) -> tf.Tensor: + hidden_states = self.dense1(hidden_states) + hidden_states = self.depthwise_convolution(hidden_states, height, width) + hidden_states = self.intermediate_act_fn(hidden_states) + hidden_states = self.dropout(hidden_states, training=training) + hidden_states = self.dense2(hidden_states) + hidden_states = self.dropout(hidden_states, training=training) + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense1", None) is not None: + with tf.name_scope(self.dense1.name): + self.dense1.build([None, None, self.in_features]) + if getattr(self, "depthwise_convolution", None) is not None: + with tf.name_scope(self.depthwise_convolution.name): + self.depthwise_convolution.build(None) + if getattr(self, "dense2", None) is not None: + with tf.name_scope(self.dense2.name): + self.dense2.build([None, None, self.hidden_features]) + + +class TFSegformerLayer(keras.layers.Layer): + """This corresponds to the Block class in the original implementation.""" + + def __init__( + self, + config, + hidden_size: int, + num_attention_heads: int, + drop_path: float, + sequence_reduction_ratio: int, + mlp_ratio: int, + **kwargs, + ): + super().__init__(**kwargs) + self.layer_norm_1 = keras.layers.LayerNormalization(epsilon=1e-05, name="layer_norm_1") + self.attention = TFSegformerAttention( + config, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + sequence_reduction_ratio=sequence_reduction_ratio, + name="attention", + ) + self.drop_path = TFSegformerDropPath(drop_path) if drop_path > 0.0 else keras.layers.Activation("linear") + self.layer_norm_2 = keras.layers.LayerNormalization(epsilon=1e-05, name="layer_norm_2") + mlp_hidden_size = int(hidden_size * mlp_ratio) + self.mlp = TFSegformerMixFFN(config, in_features=hidden_size, hidden_features=mlp_hidden_size, name="mlp") + self.hidden_size = hidden_size + + def call( + self, + hidden_states: tf.Tensor, + height: int, + width: int, + output_attentions: bool = False, + training: bool = False, + ) -> Tuple: + self_attention_outputs = self.attention( + self.layer_norm_1(hidden_states), # in Segformer, layernorm is applied before self-attention + height, + width, + output_attentions=output_attentions, + training=training, + ) + + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + # first residual connection (with stochastic depth) + attention_output = self.drop_path(attention_output, training=training) + hidden_states = attention_output + hidden_states + mlp_output = self.mlp(self.layer_norm_2(hidden_states), height, width) + + # second residual connection (with stochastic depth) + mlp_output = self.drop_path(mlp_output, training=training) + layer_output = mlp_output + hidden_states + + outputs = (layer_output,) + outputs + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "layer_norm_1", None) is not None: + with tf.name_scope(self.layer_norm_1.name): + self.layer_norm_1.build([None, None, self.hidden_size]) + if getattr(self, "attention", None) is not None: + with tf.name_scope(self.attention.name): + self.attention.build(None) + if getattr(self, "layer_norm_2", None) is not None: + with tf.name_scope(self.layer_norm_2.name): + self.layer_norm_2.build([None, None, self.hidden_size]) + if getattr(self, "mlp", None) is not None: + with tf.name_scope(self.mlp.name): + self.mlp.build(None) + + +class TFSegformerEncoder(keras.layers.Layer): + def __init__(self, config: SegformerConfig, **kwargs): + super().__init__(**kwargs) + self.config = config + + # stochastic depth decay rule + drop_path_decays = [x.numpy() for x in tf.linspace(0.0, config.drop_path_rate, sum(config.depths))] + + # patch embeddings + embeddings = [] + for i in range(config.num_encoder_blocks): + embeddings.append( + TFSegformerOverlapPatchEmbeddings( + patch_size=config.patch_sizes[i], + stride=config.strides[i], + num_channels=config.num_channels if i == 0 else config.hidden_sizes[i - 1], + hidden_size=config.hidden_sizes[i], + name=f"patch_embeddings.{i}", + ) + ) + self.embeddings = embeddings + + # Transformer blocks + blocks = [] + cur = 0 + for i in range(config.num_encoder_blocks): + # each block consists of layers + layers = [] + if i != 0: + cur += config.depths[i - 1] + for j in range(config.depths[i]): + layers.append( + TFSegformerLayer( + config, + hidden_size=config.hidden_sizes[i], + num_attention_heads=config.num_attention_heads[i], + drop_path=drop_path_decays[cur + j], + sequence_reduction_ratio=config.sr_ratios[i], + mlp_ratio=config.mlp_ratios[i], + name=f"block.{i}.{j}", + ) + ) + blocks.append(layers) + + self.block = blocks + + # Layer norms + self.layer_norms = [ + keras.layers.LayerNormalization(epsilon=1e-05, name=f"layer_norm.{i}") + for i in range(config.num_encoder_blocks) + ] + + def call( + self, + pixel_values: tf.Tensor, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + training: bool = False, + ) -> Union[Tuple, TFBaseModelOutput]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + batch_size = shape_list(pixel_values)[0] + + hidden_states = pixel_values + for idx, x in enumerate(zip(self.embeddings, self.block, self.layer_norms)): + embedding_layer, block_layer, norm_layer = x + # first, obtain patch embeddings + hidden_states, height, width = embedding_layer(hidden_states) + + # second, send embeddings through blocks + # (each block consists of multiple layers i.e., list of layers) + for i, blk in enumerate(block_layer): + layer_outputs = blk( + hidden_states, + height, + width, + output_attentions, + training=training, + ) + hidden_states = layer_outputs[0] + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + # third, apply layer norm + hidden_states = norm_layer(hidden_states) + + # fourth, optionally reshape back to (batch_size, height, width, num_channels) + if idx != len(self.embeddings) - 1 or (idx == len(self.embeddings) - 1 and self.config.reshape_last_stage): + num_channels = shape_list(hidden_states)[-1] + hidden_states = tf.reshape(hidden_states, (batch_size, height, width, num_channels)) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return TFBaseModelOutput( + last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "layer_norms", None) is not None: + for layer, shape in zip(self.layer_norms, self.config.hidden_sizes): + with tf.name_scope(layer.name): + layer.build([None, None, shape]) + if getattr(self, "block", None) is not None: + for block in self.block: + for layer in block: + with tf.name_scope(layer.name): + layer.build(None) + if getattr(self, "embeddings", None) is not None: + for layer in self.embeddings: + with tf.name_scope(layer.name): + layer.build(None) + + +@keras_serializable +class TFSegformerMainLayer(keras.layers.Layer): + config_class = SegformerConfig + + def __init__(self, config: SegformerConfig, **kwargs): + super().__init__(**kwargs) + + self.config = config + # hierarchical Transformer encoder + self.encoder = TFSegformerEncoder(config, name="encoder") + + @unpack_inputs + def call( + self, + pixel_values: tf.Tensor, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: bool = False, + ) -> Union[Tuple, TFBaseModelOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # When running on CPU, `keras.layers.Conv2D` doesn't support `NCHW` format. + # So change the input format from `NCHW` to `NHWC`. + # shape = (batch_size, in_height, in_width, in_channels=num_channels) + pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1)) + + encoder_outputs = self.encoder( + pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + sequence_output = encoder_outputs[0] + # Change to NCHW output format to have uniformity in the modules + sequence_output = tf.transpose(sequence_output, perm=[0, 3, 1, 2]) + + # Change the other hidden state outputs to NCHW as well + if output_hidden_states: + hidden_states = tuple([tf.transpose(h, perm=(0, 3, 1, 2)) for h in encoder_outputs[1]]) + + if not return_dict: + if tf.greater(len(encoder_outputs[1:]), 0): + transposed_encoder_outputs = tuple(tf.transpose(v, perm=[0, 3, 1, 2]) for v in encoder_outputs[1:][0]) + return (sequence_output,) + (transposed_encoder_outputs,) + else: + return (sequence_output,) + encoder_outputs[1:] + + return TFBaseModelOutput( + last_hidden_state=sequence_output, + hidden_states=hidden_states if output_hidden_states else encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "encoder", None) is not None: + with tf.name_scope(self.encoder.name): + self.encoder.build(None) + + +class TFSegformerPreTrainedModel(TFPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = SegformerConfig + base_model_prefix = "segformer" + main_input_name = "pixel_values" + + @property + def input_signature(self): + return {"pixel_values": tf.TensorSpec(shape=(None, self.config.num_channels, 512, 512), dtype=tf.float32)} + + +SEGFORMER_START_DOCSTRING = r""" + This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it + as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`SegformerConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. +""" + +SEGFORMER_INPUTS_DOCSTRING = r""" + + Args: + pixel_values (`np.ndarray`, `tf.Tensor`, `List[tf.Tensor]` ``Dict[str, tf.Tensor]` or `Dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See + [`SegformerImageProcessor.__call__`] for details. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the + config will be used instead. + + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. This argument can be used only in eager mode, in graph mode the value in the config will be + used instead. + + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in + eager mode, in graph mode the value will always be set to True. + + training (`bool`, *optional*, defaults to `False``): + Whether or not to use the model in training mode (some modules like dropout modules have different + behaviors between training and evaluation). +""" + + +@add_start_docstrings( + "The bare SegFormer encoder (Mix-Transformer) outputting raw hidden-states without any specific head on top.", + SEGFORMER_START_DOCSTRING, +) +class TFSegformerModel(TFSegformerPreTrainedModel): + def __init__(self, config: SegformerConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.config = config + + # hierarchical Transformer encoder + self.segformer = TFSegformerMainLayer(config, name="segformer") + + @unpack_inputs + @add_start_docstrings_to_model_forward(SEGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFBaseModelOutput, + config_class=_CONFIG_FOR_DOC, + modality="vision", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def call( + self, + pixel_values: tf.Tensor, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: bool = False, + ) -> Union[Tuple, TFBaseModelOutput]: + outputs = self.segformer( + pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "segformer", None) is not None: + with tf.name_scope(self.segformer.name): + self.segformer.build(None) + + +@add_start_docstrings( + """ + SegFormer Model transformer with an image classification head on top (a linear layer on top of the final hidden + states) e.g. for ImageNet. + """, + SEGFORMER_START_DOCSTRING, +) +class TFSegformerForImageClassification(TFSegformerPreTrainedModel, TFSequenceClassificationLoss): + def __init__(self, config: SegformerConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.num_labels = config.num_labels + self.segformer = TFSegformerMainLayer(config, name="segformer") + + # Classifier head + self.classifier = keras.layers.Dense(config.num_labels, name="classifier") + self.config = config + + @unpack_inputs + @add_start_docstrings_to_model_forward(SEGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_IMAGE_CLASS_CHECKPOINT, + output_type=TFSequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, + ) + def call( + self, + pixel_values: tf.Tensor | None = None, + labels: tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TFSequenceClassifierOutput]: + outputs = self.segformer( + pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + # convert last hidden states to (batch_size, height*width, hidden_size) + batch_size = shape_list(sequence_output)[0] + sequence_output = tf.transpose(sequence_output, perm=[0, 2, 3, 1]) + sequence_output = tf.reshape(sequence_output, (batch_size, -1, self.config.hidden_sizes[-1])) + + # global average pooling + sequence_output = tf.reduce_mean(sequence_output, axis=1) + + logits = self.classifier(sequence_output) + + loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return TFSequenceClassifierOutput( + loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "segformer", None) is not None: + with tf.name_scope(self.segformer.name): + self.segformer.build(None) + if getattr(self, "classifier", None) is not None: + with tf.name_scope(self.classifier.name): + self.classifier.build([None, None, self.config.hidden_sizes[-1]]) + + +class TFSegformerMLP(keras.layers.Layer): + """ + Linear Embedding. + """ + + def __init__(self, input_dim: int, config: SegformerConfig, **kwargs): + super().__init__(**kwargs) + self.proj = keras.layers.Dense(config.decoder_hidden_size, name="proj") + self.input_dim = input_dim + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + height = shape_list(hidden_states)[1] + width = shape_list(hidden_states)[2] + hidden_dim = shape_list(hidden_states)[-1] + hidden_states = tf.reshape(hidden_states, (-1, height * width, hidden_dim)) + hidden_states = self.proj(hidden_states) + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "proj", None) is not None: + with tf.name_scope(self.proj.name): + self.proj.build([None, None, self.input_dim]) + + +class TFSegformerDecodeHead(TFSegformerPreTrainedModel): + def __init__(self, config: SegformerConfig, **kwargs): + super().__init__(config, **kwargs) + # linear layers which will unify the channel dimension of each of the encoder blocks to the same config.decoder_hidden_size + mlps = [] + for i in range(config.num_encoder_blocks): + mlp = TFSegformerMLP(config=config, input_dim=config.hidden_sizes[i], name=f"linear_c.{i}") + mlps.append(mlp) + self.mlps = mlps + + # the following 3 layers implement the ConvModule of the original implementation + self.linear_fuse = keras.layers.Conv2D( + filters=config.decoder_hidden_size, kernel_size=1, use_bias=False, name="linear_fuse" + ) + self.batch_norm = keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.9, name="batch_norm") + self.activation = keras.layers.Activation("relu") + + self.dropout = keras.layers.Dropout(config.classifier_dropout_prob) + self.classifier = keras.layers.Conv2D(filters=config.num_labels, kernel_size=1, name="classifier") + + self.config = config + + def call(self, encoder_hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor: + all_hidden_states = () + for encoder_hidden_state, mlp in zip(encoder_hidden_states, self.mlps): + if self.config.reshape_last_stage is False and len(shape_list(encoder_hidden_state)) == 3: + height = tf.math.sqrt(tf.cast(shape_list(encoder_hidden_state)[1], tf.float32)) + height = width = tf.cast(height, tf.int32) + channel_dim = shape_list(encoder_hidden_state)[-1] + encoder_hidden_state = tf.reshape(encoder_hidden_state, (-1, height, width, channel_dim)) + + # unify channel dimension + encoder_hidden_state = tf.transpose(encoder_hidden_state, perm=[0, 2, 3, 1]) + height, width = shape_list(encoder_hidden_state)[1:3] + encoder_hidden_state = mlp(encoder_hidden_state) + channel_dim = shape_list(encoder_hidden_state)[-1] + encoder_hidden_state = tf.reshape(encoder_hidden_state, (-1, height, width, channel_dim)) + + # upsample + temp_state = tf.transpose(encoder_hidden_states[0], perm=[0, 2, 3, 1]) + upsample_resolution = shape_list(temp_state)[1:-1] + encoder_hidden_state = tf.image.resize(encoder_hidden_state, size=upsample_resolution, method="bilinear") + all_hidden_states += (encoder_hidden_state,) + + hidden_states = self.linear_fuse(tf.concat(all_hidden_states[::-1], axis=-1)) + hidden_states = self.batch_norm(hidden_states, training=training) + hidden_states = self.activation(hidden_states) + hidden_states = self.dropout(hidden_states, training=training) + + # logits of shape (batch_size, height/4, width/4, num_labels) + logits = self.classifier(hidden_states) + + return logits + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "linear_fuse", None) is not None: + with tf.name_scope(self.linear_fuse.name): + self.linear_fuse.build( + [None, None, None, self.config.decoder_hidden_size * self.config.num_encoder_blocks] + ) + if getattr(self, "batch_norm", None) is not None: + with tf.name_scope(self.batch_norm.name): + self.batch_norm.build([None, None, None, self.config.decoder_hidden_size]) + if getattr(self, "classifier", None) is not None: + with tf.name_scope(self.classifier.name): + self.classifier.build([None, None, None, self.config.decoder_hidden_size]) + if getattr(self, "mlps", None) is not None: + for layer in self.mlps: + with tf.name_scope(layer.name): + layer.build(None) + + +@add_start_docstrings( + """SegFormer Model transformer with an all-MLP decode head on top e.g. for ADE20k, CityScapes.""", + SEGFORMER_START_DOCSTRING, +) +class TFSegformerForSemanticSegmentation(TFSegformerPreTrainedModel): + def __init__(self, config: SegformerConfig, **kwargs): + super().__init__(config, **kwargs) + self.segformer = TFSegformerMainLayer(config, name="segformer") + self.decode_head = TFSegformerDecodeHead(config, name="decode_head") + + def hf_compute_loss(self, logits, labels): + # upsample logits to the images' original size + # `labels` is of shape (batch_size, height, width) + label_interp_shape = shape_list(labels)[1:] + + upsampled_logits = tf.image.resize(logits, size=label_interp_shape, method="bilinear") + # compute weighted loss + loss_fct = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction="none") + + def masked_loss(real, pred): + unmasked_loss = loss_fct(real, pred) + mask = tf.cast(real != self.config.semantic_loss_ignore_index, dtype=unmasked_loss.dtype) + masked_loss = unmasked_loss * mask + # Reduction strategy in the similar spirit with + # https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_tf_utils.py#L210 + reduced_masked_loss = tf.reduce_sum(masked_loss) / tf.reduce_sum(mask) + return tf.reshape(reduced_masked_loss, (1,)) + + return masked_loss(labels, upsampled_logits) + + @unpack_inputs + @add_start_docstrings_to_model_forward(SEGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=TFSemanticSegmenterOutput, config_class=_CONFIG_FOR_DOC) + def call( + self, + pixel_values: tf.Tensor, + labels: tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TFSemanticSegmenterOutput]: + r""" + labels (`tf.Tensor` of shape `(batch_size, height, width)`, *optional*): + Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels > 1`, a (per-pixel) classification loss is computed + (Cross-Entropy). + + Returns: + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, TFSegformerForSemanticSegmentation + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512") + >>> model = TFSegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512") + + >>> inputs = image_processor(images=image, return_tensors="tf") + >>> outputs = model(**inputs, training=False) + >>> # logits are of shape (batch_size, num_labels, height/4, width/4) + >>> logits = outputs.logits + >>> list(logits.shape) + [1, 150, 128, 128] + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + if labels is not None and not self.config.num_labels > 1: + raise ValueError("The number of labels should be greater than one") + + outputs = self.segformer( + pixel_values, + output_attentions=output_attentions, + output_hidden_states=True, # we need the intermediate hidden states + return_dict=return_dict, + ) + + encoder_hidden_states = outputs.hidden_states if return_dict else outputs[1] + + logits = self.decode_head(encoder_hidden_states) + + loss = None + if labels is not None: + loss = self.hf_compute_loss(logits=logits, labels=labels) + + # make logits of shape (batch_size, num_labels, height, width) to + # keep them consistent across APIs + logits = tf.transpose(logits, perm=[0, 3, 1, 2]) + + if not return_dict: + if output_hidden_states: + output = (logits,) + outputs[1:] + else: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFSemanticSegmenterOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states if output_hidden_states else None, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "segformer", None) is not None: + with tf.name_scope(self.segformer.name): + self.segformer.build(None) + if getattr(self, "decode_head", None) is not None: + with tf.name_scope(self.decode_head.name): + self.decode_head.build(None) + + +__all__ = [ + "TFSegformerDecodeHead", + "TFSegformerForImageClassification", + "TFSegformerForSemanticSegmentation", + "TFSegformerModel", + "TFSegformerPreTrainedModel", +] diff --git a/docs/transformers/src/transformers/models/seggpt/__init__.py b/docs/transformers/src/transformers/models/seggpt/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ec646f2f81a21d50d50b67177638d275427bc42b --- /dev/null +++ b/docs/transformers/src/transformers/models/seggpt/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_seggpt import * + from .image_processing_seggpt import * + from .modeling_seggpt import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/seggpt/configuration_seggpt.py b/docs/transformers/src/transformers/models/seggpt/configuration_seggpt.py new file mode 100644 index 0000000000000000000000000000000000000000..a735281facccc5b74f9924dd48b2363303286029 --- /dev/null +++ b/docs/transformers/src/transformers/models/seggpt/configuration_seggpt.py @@ -0,0 +1,143 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""SegGpt model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class SegGptConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`SegGptModel`]. It is used to instantiate a SegGPT + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the SegGPT + [BAAI/seggpt-vit-large](https://huggingface.co/BAAI/seggpt-vit-large) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 1024): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 24): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + image_size (`List[int]`, *optional*, defaults to `[896, 448]`): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch. + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + qkv_bias (`bool`, *optional*, defaults to `True`): + Whether to add a bias to the queries, keys and values. + mlp_dim (`int`, *optional*): + The dimensionality of the MLP layer in the Transformer encoder. If unset, defaults to + `hidden_size` * 4. + drop_path_rate (`float`, *optional*, defaults to 0.1): + The drop path rate for the dropout layers. + pretrain_image_size (`int`, *optional*, defaults to 224): + The pretrained size of the absolute position embeddings. + decoder_hidden_size (`int`, *optional*, defaults to 64): + Hidden size for decoder. + use_relative_position_embeddings (`bool`, *optional*, defaults to `True`): + Whether to use relative position embeddings in the attention layers. + merge_index (`int`, *optional*, defaults to 2): + The index of the encoder layer to merge the embeddings. + intermediate_hidden_state_indices (`List[int]`, *optional*, defaults to `[5, 11, 17, 23]`): + The indices of the encoder layers which we store as features for the decoder. + beta (`float`, *optional*, defaults to 0.01): + Regularization factor for SegGptLoss (smooth-l1 loss). + + Example: + + ```python + >>> from transformers import SegGptConfig, SegGptModel + + >>> # Initializing a SegGPT seggpt-vit-large style configuration + >>> configuration = SegGptConfig() + + >>> # Initializing a model (with random weights) from the seggpt-vit-large style configuration + >>> model = SegGptModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "seggpt" + + def __init__( + self, + hidden_size=1024, + num_hidden_layers=24, + num_attention_heads=16, + hidden_act="gelu", + hidden_dropout_prob=0.0, + initializer_range=0.02, + layer_norm_eps=1e-6, + image_size=[896, 448], + patch_size=16, + num_channels=3, + qkv_bias=True, + mlp_dim=None, + drop_path_rate=0.1, + pretrain_image_size=224, + decoder_hidden_size=64, + use_relative_position_embeddings=True, + merge_index=2, + intermediate_hidden_state_indices=[5, 11, 17, 23], + beta=0.01, + **kwargs, + ): + super().__init__(**kwargs) + + if merge_index > min(intermediate_hidden_state_indices): + raise ValueError( + f"Merge index must be less than the minimum encoder output index, but got {merge_index=} and {intermediate_hidden_state_indices=}" + ) + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.qkv_bias = qkv_bias + self.drop_path_rate = drop_path_rate + self.pretrain_image_size = pretrain_image_size + self.decoder_hidden_size = decoder_hidden_size + self.use_relative_position_embeddings = use_relative_position_embeddings + self.merge_index = merge_index + self.intermediate_hidden_state_indices = intermediate_hidden_state_indices + self.beta = beta + self.mlp_dim = int(hidden_size * 4) if mlp_dim is None else mlp_dim + + +__all__ = ["SegGptConfig"] diff --git a/docs/transformers/src/transformers/models/seggpt/convert_seggpt_to_hf.py b/docs/transformers/src/transformers/models/seggpt/convert_seggpt_to_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..d67daeab93d899f6792ff09ec5bad5a776b6b16d --- /dev/null +++ b/docs/transformers/src/transformers/models/seggpt/convert_seggpt_to_hf.py @@ -0,0 +1,221 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert SegGPT checkpoints from the original repository. + +URL: https://github.com/baaivision/Painter/tree/main/SegGPT +""" + +import argparse + +import requests +import torch +from PIL import Image + +from transformers import SegGptConfig, SegGptForImageSegmentation, SegGptImageProcessor +from transformers.utils import logging + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + + +# here we list all keys to be renamed (original name on the left, our name on the right) +def create_rename_keys(config): + rename_keys = [] + + # fmt: off + + # rename embedding and its parameters + rename_keys.append(("patch_embed.proj.weight", "model.embeddings.patch_embeddings.projection.weight")) + rename_keys.append(("patch_embed.proj.bias", "model.embeddings.patch_embeddings.projection.bias")) + rename_keys.append(("mask_token", "model.embeddings.mask_token")) + rename_keys.append(("segment_token_x", "model.embeddings.segment_token_input")) + rename_keys.append(("segment_token_y", "model.embeddings.segment_token_prompt")) + rename_keys.append(("type_token_cls", "model.embeddings.type_token_semantic")) + rename_keys.append(("type_token_ins", "model.embeddings.type_token_instance")) + rename_keys.append(("pos_embed", "model.embeddings.position_embeddings")) + + # rename decoder and other + rename_keys.append(("norm.weight", "model.encoder.layernorm.weight")) + rename_keys.append(("norm.bias", "model.encoder.layernorm.bias")) + rename_keys.append(("decoder_embed.weight", "decoder.decoder_embed.weight")) + rename_keys.append(("decoder_embed.bias", "decoder.decoder_embed.bias")) + rename_keys.append(("decoder_pred.0.weight", "decoder.decoder_pred.conv.weight")) + rename_keys.append(("decoder_pred.0.bias", "decoder.decoder_pred.conv.bias")) + rename_keys.append(("decoder_pred.1.weight", "decoder.decoder_pred.layernorm.weight")) + rename_keys.append(("decoder_pred.1.bias", "decoder.decoder_pred.layernorm.bias")) + rename_keys.append(("decoder_pred.3.weight", "decoder.decoder_pred.head.weight")) + rename_keys.append(("decoder_pred.3.bias", "decoder.decoder_pred.head.bias")) + + # rename blocks + for i in range(config.num_hidden_layers): + rename_keys.append((f"blocks.{i}.attn.qkv.weight", f"model.encoder.layers.{i}.attention.qkv.weight")) + rename_keys.append((f"blocks.{i}.attn.qkv.bias", f"model.encoder.layers.{i}.attention.qkv.bias")) + rename_keys.append((f"blocks.{i}.attn.proj.weight", f"model.encoder.layers.{i}.attention.proj.weight")) + rename_keys.append((f"blocks.{i}.attn.proj.bias", f"model.encoder.layers.{i}.attention.proj.bias")) + rename_keys.append((f"blocks.{i}.attn.rel_pos_h", f"model.encoder.layers.{i}.attention.rel_pos_h")) + rename_keys.append((f"blocks.{i}.attn.rel_pos_w", f"model.encoder.layers.{i}.attention.rel_pos_w")) + + rename_keys.append((f"blocks.{i}.mlp.fc1.weight", f"model.encoder.layers.{i}.mlp.lin1.weight")) + rename_keys.append((f"blocks.{i}.mlp.fc1.bias", f"model.encoder.layers.{i}.mlp.lin1.bias")) + rename_keys.append((f"blocks.{i}.mlp.fc2.weight", f"model.encoder.layers.{i}.mlp.lin2.weight")) + rename_keys.append((f"blocks.{i}.mlp.fc2.bias", f"model.encoder.layers.{i}.mlp.lin2.bias")) + + rename_keys.append((f"blocks.{i}.norm1.weight", f"model.encoder.layers.{i}.layernorm_before.weight")) + rename_keys.append((f"blocks.{i}.norm1.bias", f"model.encoder.layers.{i}.layernorm_before.bias")) + rename_keys.append((f"blocks.{i}.norm2.weight", f"model.encoder.layers.{i}.layernorm_after.weight")) + rename_keys.append((f"blocks.{i}.norm2.bias", f"model.encoder.layers.{i}.layernorm_after.bias")) + + # fmt: on + + return rename_keys + + +def rename_key(dct, old, new): + val = dct.pop(old) + dct[new] = val + + +# We will verify our results on spongebob images +def prepare_input(): + image_input_url = ( + "https://raw.githubusercontent.com/baaivision/Painter/main/SegGPT/SegGPT_inference/examples/hmbb_2.jpg" + ) + image_prompt_url = ( + "https://raw.githubusercontent.com/baaivision/Painter/main/SegGPT/SegGPT_inference/examples/hmbb_1.jpg" + ) + mask_prompt_url = ( + "https://raw.githubusercontent.com/baaivision/Painter/main/SegGPT/SegGPT_inference/examples/hmbb_1_target.png" + ) + + image_input = Image.open(requests.get(image_input_url, stream=True).raw) + image_prompt = Image.open(requests.get(image_prompt_url, stream=True).raw) + mask_prompt = Image.open(requests.get(mask_prompt_url, stream=True).raw) + + return image_input, image_prompt, mask_prompt + + +@torch.no_grad() +def convert_seggpt_checkpoint(args): + model_name = args.model_name + pytorch_dump_folder_path = args.pytorch_dump_folder_path + verify_logits = args.verify_logits + push_to_hub = args.push_to_hub + + # Define default GroundingDINO configuation + config = SegGptConfig() + + # Load original checkpoint + checkpoint_url = "https://huggingface.co/BAAI/SegGpt/blob/main/seggpt_vit_large.pth" + original_state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu")["model"] + + # # Rename keys + new_state_dict = original_state_dict.copy() + rename_keys = create_rename_keys(config) + + for src, dest in rename_keys: + rename_key(new_state_dict, src, dest) + + # Load HF model + model = SegGptForImageSegmentation(config) + model.eval() + missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False) + print("Missing keys:", missing_keys) + print("Unexpected keys:", unexpected_keys) + + input_img, prompt_img, prompt_mask = prepare_input() + image_processor = SegGptImageProcessor() + inputs = image_processor(images=input_img, prompt_images=prompt_img, prompt_masks=prompt_mask, return_tensors="pt") + + expected_prompt_pixel_values = torch.tensor( + [ + [[-0.6965, -0.6965, -0.6965], [-0.6965, -0.6965, -0.6965], [-0.6965, -0.6965, -0.6965]], + [[1.6583, 1.6583, 1.6583], [1.6583, 1.6583, 1.6583], [1.6583, 1.6583, 1.6583]], + [[2.3088, 2.3088, 2.3088], [2.3088, 2.3088, 2.3088], [2.3088, 2.3088, 2.3088]], + ] + ) + + expected_pixel_values = torch.tensor( + [ + [[1.6324, 1.6153, 1.5810], [1.6153, 1.5982, 1.5810], [1.5810, 1.5639, 1.5639]], + [[1.2731, 1.2556, 1.2206], [1.2556, 1.2381, 1.2031], [1.2206, 1.2031, 1.1681]], + [[1.6465, 1.6465, 1.6465], [1.6465, 1.6465, 1.6465], [1.6291, 1.6291, 1.6291]], + ] + ) + + expected_prompt_masks = torch.tensor( + [ + [[-2.1179, -2.1179, -2.1179], [-2.1179, -2.1179, -2.1179], [-2.1179, -2.1179, -2.1179]], + [[-2.0357, -2.0357, -2.0357], [-2.0357, -2.0357, -2.0357], [-2.0357, -2.0357, -2.0357]], + [[-1.8044, -1.8044, -1.8044], [-1.8044, -1.8044, -1.8044], [-1.8044, -1.8044, -1.8044]], + ] + ) + + assert torch.allclose(inputs.pixel_values[0, :, :3, :3], expected_pixel_values, atol=1e-4) + assert torch.allclose(inputs.prompt_pixel_values[0, :, :3, :3], expected_prompt_pixel_values, atol=1e-4) + assert torch.allclose(inputs.prompt_masks[0, :, :3, :3], expected_prompt_masks, atol=1e-4) + + torch.manual_seed(2) + outputs = model(**inputs) + print(outputs) + + if verify_logits: + expected_output = torch.tensor( + [ + [[-2.1208, -2.1190, -2.1198], [-2.1237, -2.1228, -2.1227], [-2.1232, -2.1226, -2.1228]], + [[-2.0405, -2.0396, -2.0403], [-2.0434, -2.0434, -2.0433], [-2.0428, -2.0432, -2.0434]], + [[-1.8102, -1.8088, -1.8099], [-1.8131, -1.8126, -1.8129], [-1.8130, -1.8128, -1.8131]], + ] + ) + assert torch.allclose(outputs.pred_masks[0, :, :3, :3], expected_output, atol=1e-4) + print("Looks good!") + else: + print("Converted without verifying logits") + + if pytorch_dump_folder_path is not None: + print(f"Saving model and processor for {model_name} to {pytorch_dump_folder_path}") + model.save_pretrained(pytorch_dump_folder_path) + image_processor.save_pretrained(pytorch_dump_folder_path) + + if push_to_hub: + print(f"Pushing model and processor for {model_name} to hub") + model.push_to_hub(f"EduardoPacheco/{model_name}") + image_processor.push_to_hub(f"EduardoPacheco/{model_name}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--model_name", + default="seggpt-vit-large", + type=str, + choices=["seggpt-vit-large"], + help="Name of the SegGpt model you'd like to convert.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory." + ) + parser.add_argument( + "--verify_logits", + action="store_false", + help="Whether or not to verify the logits against the original implementation.", + ) + parser.add_argument( + "--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub." + ) + + args = parser.parse_args() + convert_seggpt_checkpoint(args) diff --git a/docs/transformers/src/transformers/models/seggpt/image_processing_seggpt.py b/docs/transformers/src/transformers/models/seggpt/image_processing_seggpt.py new file mode 100644 index 0000000000000000000000000000000000000000..26c7c1f47acd77f3874c8ccd7fb88a9a6064f877 --- /dev/null +++ b/docs/transformers/src/transformers/models/seggpt/image_processing_seggpt.py @@ -0,0 +1,618 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for SegGPT.""" + +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import resize, to_channel_dimension_format +from ...image_utils import ( + IMAGENET_DEFAULT_MEAN, + IMAGENET_DEFAULT_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + infer_channel_dimension_format, + is_scaled_image, + make_list_of_images, + to_numpy_array, + valid_images, +) +from ...utils import TensorType, is_torch_available, is_vision_available, logging, requires_backends + + +if is_torch_available(): + import torch + +if is_vision_available(): + pass + + +logger = logging.get_logger(__name__) + + +# See https://arxiv.org/pdf/2212.02499.pdf at 3.1 Redefining Output Spaces as "Images" - Semantic Segmentation from PAINTER paper +# Taken from https://github.com/Abdullah-Meda/Painter/blob/main/Painter/data/coco_semseg/gen_color_coco_panoptic_segm.py#L31 +def build_palette(num_labels: int) -> List[Tuple[int, int]]: + base = int(num_labels ** (1 / 3)) + 1 + margin = 256 // base + + # we assume that class_idx 0 is the background which is mapped to black + color_list = [(0, 0, 0)] + for location in range(num_labels): + num_seq_r = location // base**2 + num_seq_g = (location % base**2) // base + num_seq_b = location % base + + R = 255 - num_seq_r * margin + G = 255 - num_seq_g * margin + B = 255 - num_seq_b * margin + + color_list.append((R, G, B)) + + return color_list + + +def mask_to_rgb( + mask: np.ndarray, palette: Optional[List[Tuple[int, int]]] = None, data_format: Optional[ChannelDimension] = None +) -> np.ndarray: + data_format = data_format if data_format is not None else ChannelDimension.FIRST + + if palette is not None: + height, width = mask.shape + + rgb_mask = np.zeros((3, height, width), dtype=np.uint8) + + classes_in_mask = np.unique(mask) + + for class_idx in classes_in_mask: + rgb_value = palette[class_idx] + class_mask = (mask == class_idx).astype(np.uint8) + class_mask = np.expand_dims(class_mask, axis=-1) + class_rgb_mask = class_mask * np.array(rgb_value) + class_rgb_mask = np.moveaxis(class_rgb_mask, -1, 0) + rgb_mask += class_rgb_mask.astype(np.uint8) + + rgb_mask = np.clip(rgb_mask, 0, 255).astype(np.uint8) + + else: + rgb_mask = np.repeat(mask[None, ...], 3, axis=0) + + return to_channel_dimension_format(rgb_mask, data_format) + + +class SegGptImageProcessor(BaseImageProcessor): + r""" + Constructs a SegGpt image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to the specified `(size["height"], + size["width"])`. Can be overridden by the `do_resize` parameter in the `preprocess` method. + size (`dict`, *optional*, defaults to `{"height": 448, "width": 448}`): + Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess` + method. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`): + Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the + `preprocess` method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale` + parameter in the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the + `preprocess` method. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess` + method. + image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_DEFAULT_MEAN`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_DEFAULT_STD`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the prompt mask to RGB format. Can be overridden by the `do_convert_rgb` parameter in the + `preprocess` method. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize: bool = True, + size: Optional[Dict[str, int]] = None, + resample: PILImageResampling = PILImageResampling.BICUBIC, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_convert_rgb: bool = True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"height": 448, "width": 448} + size = get_size_dict(size) + self.do_resize = do_resize + self.do_rescale = do_rescale + self.do_normalize = do_normalize + self.size = size + self.resample = resample + self.rescale_factor = rescale_factor + self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN + self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD + self.do_convert_rgb = do_convert_rgb + + def get_palette(self, num_labels: int) -> List[Tuple[int, int]]: + """Build a palette to map the prompt mask from a single channel to a 3 channel RGB. + + Args: + num_labels (`int`): + Number of classes in the segmentation task (excluding the background). + + Returns: + `List[Tuple[int, int]]`: Palette to map the prompt mask from a single channel to a 3 channel RGB. + """ + return build_palette(num_labels) + + def mask_to_rgb( + self, + image: np.ndarray, + palette: Optional[List[Tuple[int, int]]] = None, + data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> np.ndarray: + """Converts a segmentation map to RGB format. + + Args: + image (`np.ndarray`): + Segmentation map with dimensions (height, width) where pixel values represent the class index. + palette (`List[Tuple[int, int]]`, *optional*, defaults to `None`): + Palette to use to convert the mask to RGB format. If unset, the mask is duplicated across the channel + dimension. + data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the output image. If unset, the channel dimension format of the input + image is used. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + + Returns: + `np.ndarray`: The mask in RGB format. + """ + return mask_to_rgb(image, palette=palette, data_format=data_format) + + # Copied from transformers.models.vit.image_processing_vit.ViTImageProcessor.resize with PILImageResampling.BILINEAR->PILImageResampling.BICUBIC + def resize( + self, + image: np.ndarray, + size: Dict[str, int], + resample: PILImageResampling = PILImageResampling.BICUBIC, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize an image to `(size["height"], size["width"])`. + + Args: + image (`np.ndarray`): + Image to resize. + size (`Dict[str, int]`): + Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BICUBIC`. + data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the output image. If unset, the channel dimension format of the input + image is used. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + + Returns: + `np.ndarray`: The resized image. + """ + size = get_size_dict(size) + if "height" not in size or "width" not in size: + raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}") + output_size = (size["height"], size["width"]) + return resize( + image, + size=output_size, + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + def _preprocess_step( + self, + images: ImageInput, + do_resize: Optional[bool] = None, + size: Dict[str, int] = None, + resample: PILImageResampling = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + do_convert_rgb: Optional[bool] = None, + num_labels: Optional[int] = None, + **kwargs, + ): + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to _preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Dictionary in the format `{"height": h, "width": w}` specifying the size of the output image after + resizing. + resample (`PILImageResampling` filter, *optional*, defaults to `self.resample`): + `PILImageResampling` filter to use if resizing the image e.g. `PILImageResampling.BICUBIC`. Only has + an effect if `do_resize` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image values between [0 - 1]. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use if `do_normalize` is set to `True`. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use if `do_normalize` is set to `True`. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the prompt mask to RGB format. If `num_labels` is specified, a palette will be built + to map the prompt mask from a single channel to a 3 channel RGB. If unset, the prompt mask is duplicated + across the channel dimension. Must be set to `False` if the prompt mask is already in RGB format. + num_labels: (`int`, *optional*): + Number of classes in the segmentation task (excluding the background). If specified, a palette will be + built, assuming that class_idx 0 is the background, to map the prompt mask from a single class_idx + channel to a 3 channel RGB. Not specifying this will result in the prompt mask either being passed + through as is if it is already in RGB format or being duplicated across the channel dimension. + """ + do_resize = do_resize if do_resize is not None else self.do_resize + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + resample = resample if resample is not None else self.resample + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + + size = size if size is not None else self.size + size_dict = get_size_dict(size) + + # If segmentation map is passed we expect 2D images + images = make_list_of_images(images, expected_ndims=2 if do_convert_rgb else 3) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + if do_resize and size is None: + raise ValueError("Size must be specified if do_resize is True.") + + if do_rescale and rescale_factor is None: + raise ValueError("Rescale factor must be specified if do_rescale is True.") + + if do_normalize and (image_mean is None or image_std is None): + raise ValueError("Image mean and std must be specified if do_normalize is True.") + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if do_rescale and is_scaled_image(images[0]): + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None and not do_convert_rgb: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + if do_convert_rgb: + palette = self.get_palette(num_labels) if num_labels is not None else None + # Since this is the input for the next transformations its format should be the same as the input_data_format + images = [ + self.mask_to_rgb(image=image, palette=palette, data_format=ChannelDimension.FIRST) for image in images + ] + input_data_format = ChannelDimension.FIRST + + if do_resize: + images = [ + self.resize(image=image, size=size_dict, resample=resample, input_data_format=input_data_format) + for image in images + ] + + if do_rescale: + images = [ + self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + for image in images + ] + + if do_normalize: + images = [ + self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) + for image in images + ] + + images = [ + to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images + ] + + return images + + def preprocess( + self, + images: Optional[ImageInput] = None, + prompt_images: Optional[ImageInput] = None, + prompt_masks: Optional[ImageInput] = None, + do_resize: Optional[bool] = None, + size: Dict[str, int] = None, + resample: PILImageResampling = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_convert_rgb: Optional[bool] = None, + num_labels: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ): + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to _preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + prompt_images (`ImageInput`): + Prompt image to _preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + prompt_masks (`ImageInput`): + Prompt mask from prompt image to _preprocess that specify prompt_masks value in the preprocessed output. + Can either be in the format of segmentation maps (no channels) or RGB images. If in the format of + RGB images, `do_convert_rgb` should be set to `False`. If in the format of segmentation maps, `num_labels` + specifying `num_labels` is recommended to build a palette to map the prompt mask from a single channel to + a 3 channel RGB. If `num_labels` is not specified, the prompt mask will be duplicated across the channel + dimension. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Dictionary in the format `{"height": h, "width": w}` specifying the size of the output image after + resizing. + resample (`PILImageResampling` filter, *optional*, defaults to `self.resample`): + `PILImageResampling` filter to use if resizing the image e.g. `PILImageResampling.BICUBIC`. Only has + an effect if `do_resize` is set to `True`. Doesn't apply to prompt mask as it is resized using nearest. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image values between [0 - 1]. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use if `do_normalize` is set to `True`. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use if `do_normalize` is set to `True`. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the prompt mask to RGB format. If `num_labels` is specified, a palette will be built + to map the prompt mask from a single channel to a 3 channel RGB. If unset, the prompt mask is duplicated + across the channel dimension. Must be set to `False` if the prompt mask is already in RGB format. + num_labels: (`int`, *optional*): + Number of classes in the segmentation task (excluding the background). If specified, a palette will be + built, assuming that class_idx 0 is the background, to map the prompt mask from a plain segmentation map + with no channels to a 3 channel RGB. Not specifying this will result in the prompt mask either being passed + through as is if it is already in RGB format (if `do_convert_rgb` is false) or being duplicated + across the channel dimension. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + if all(v is None for v in [images, prompt_images, prompt_masks]): + raise ValueError("At least one of images, prompt_images, prompt_masks must be specified.") + + data = {} + + if images is not None: + images = self._preprocess_step( + images, + is_mask=False, + do_resize=do_resize, + size=size, + resample=resample, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_convert_rgb=False, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + data["pixel_values"] = images + + if prompt_images is not None: + prompt_images = self._preprocess_step( + prompt_images, + is_mask=False, + do_resize=do_resize, + size=size, + resample=resample, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_convert_rgb=False, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + data["prompt_pixel_values"] = prompt_images + + if prompt_masks is not None: + prompt_masks = self._preprocess_step( + prompt_masks, + do_resize=do_resize, + size=size, + resample=PILImageResampling.NEAREST, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_convert_rgb=do_convert_rgb, + num_labels=num_labels, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + data["prompt_masks"] = prompt_masks + + return BatchFeature(data=data, tensor_type=return_tensors) + + def post_process_semantic_segmentation( + self, outputs, target_sizes: Optional[List[Tuple[int, int]]] = None, num_labels: Optional[int] = None + ): + """ + Converts the output of [`SegGptImageSegmentationOutput`] into segmentation maps. Only supports + PyTorch. + + Args: + outputs ([`SegGptImageSegmentationOutput`]): + Raw outputs of the model. + target_sizes (`List[Tuple[int, int]]`, *optional*): + List of length (batch_size), where each list item (`Tuple[int, int]`) corresponds to the requested + final size (height, width) of each prediction. If left to None, predictions will not be resized. + num_labels (`int`, *optional*): + Number of classes in the segmentation task (excluding the background). If specified, a palette will be + built, assuming that class_idx 0 is the background, to map prediction masks from RGB values to class + indices. This value should be the same used when preprocessing inputs. + Returns: + semantic_segmentation: `List[torch.Tensor]` of length `batch_size`, where each item is a semantic + segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is + specified). Each entry of each `torch.Tensor` correspond to a semantic class id. + """ + requires_backends(self, ["torch"]) + # batch_size x num_channels x 2*height x width + masks = outputs.pred_masks + + # Predicted mask and prompt are concatenated in the height dimension + # batch_size x num_channels x height x width + masks = masks[:, :, masks.shape[2] // 2 :, :] + + # To unnormalize we need to permute to channel last + # batch_size x height x width x num_channels + std = torch.tensor(self.image_std).to(masks.device) + mean = torch.tensor(self.image_mean).to(masks.device) + + masks = masks.permute(0, 2, 3, 1) * std + mean + + # batch_size x num_channels x height x width + masks = masks.permute(0, 3, 1, 2) + + # Clip to match with palette if specified + masks = torch.clip(masks * 255, 0, 255) + + semantic_segmentation = [] + palette_tensor = None + palette = self.get_palette(num_labels) if num_labels is not None else None + if palette is not None: + palette_tensor = torch.tensor(palette).to(device=masks.device, dtype=torch.float) + _, num_channels, _, _ = masks.shape + palette_tensor = palette_tensor.view(1, 1, num_labels + 1, num_channels) + + for idx, mask in enumerate(masks): + if target_sizes is not None: + mask = torch.nn.functional.interpolate( + mask.unsqueeze(0), + size=target_sizes[idx], + mode="nearest", + )[0] + + if num_labels is not None: + channels, height, width = mask.shape + dist = mask.permute(1, 2, 0).view(height, width, 1, channels) + dist = dist - palette_tensor + dist = torch.pow(dist, 2) + dist = torch.sum(dist, dim=-1) + pred = dist.argmin(dim=-1) + + else: + # If no palette is specified SegGpt will try to paint using the mask class idx as RGB + pred = mask.mean(dim=0).int() + + semantic_segmentation.append(pred) + + return semantic_segmentation + + +__all__ = ["SegGptImageProcessor"] diff --git a/docs/transformers/src/transformers/models/seggpt/modeling_seggpt.py b/docs/transformers/src/transformers/models/seggpt/modeling_seggpt.py new file mode 100644 index 0000000000000000000000000000000000000000..e4058b3346724bb1ec9aec18496ebddeaba3fc2a --- /dev/null +++ b/docs/transformers/src/transformers/models/seggpt/modeling_seggpt.py @@ -0,0 +1,1031 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch SegGpt model.""" + +import collections.abc +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import functional as F + +from ...activations import ACT2FN +from ...modeling_utils import PreTrainedModel +from ...utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, + torch_int, +) +from .configuration_seggpt import SegGptConfig + + +logger = logging.get_logger(__name__) + +# General docstring +_CONFIG_FOR_DOC = "SegGptConfig" + +# Base docstring +_CHECKPOINT_FOR_DOC = "BAAI/seggpt-vit-large" +_EXPECTED_OUTPUT_SHAPE = [3, 896, 448] + + +@dataclass +class SegGptEncoderOutput(ModelOutput): + """ + Output type of [`SegGptEncoderOutput`]. + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, patch_height, patch_width, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`Tuple[torch.FloatTensor]`, `optional`, returned when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape `(batch_size, patch_height, patch_width, hidden_size)`. + attentions (`Tuple[torch.FloatTensor]`, `optional`, returned when `config.output_attentions=True`): + Tuple of *torch.FloatTensor* (one for each layer) of shape + `(batch_size, num_heads, seq_len, seq_len)`. + intermediate_hidden_states (`Tuple[torch.FloatTensor]`, *optional*, returned when `config.intermediate_hidden_state_indices` is set): + Tuple of `torch.FloatTensor` of shape `(batch_size, patch_height, patch_width, hidden_size)`. + Each element in the Tuple corresponds to the output of the layer specified in `config.intermediate_hidden_state_indices`. + Additionaly, each feature passes through a LayerNorm. + """ + + last_hidden_state: torch.FloatTensor + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + intermediate_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class SegGptImageSegmentationOutput(ModelOutput): + """ + Output type of [`SegGptImageSegmentationOutput`]. + + Args: + loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided): + The loss value. + pred_masks (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + The predicted masks. + hidden_states (`Tuple[torch.FloatTensor]`, `optional`, returned when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape `(batch_size, patch_height, patch_width, hidden_size)`. + attentions (`Tuple[torch.FloatTensor]`, `optional`, returned when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape + `(batch_size, num_heads, seq_len, seq_len)`. + """ + + loss: Optional[torch.FloatTensor] = None + pred_masks: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +# Copied from transformers.models.sam.modeling_sam.SamPatchEmbeddings with Sam->SegGpt +class SegGptPatchEmbeddings(nn.Module): + """ + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__(self, config): + super().__init__() + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.hidden_size + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + + self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) + + def forward(self, pixel_values): + batch_size, num_channels, height, width = pixel_values.shape + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + if height != self.image_size[0] or width != self.image_size[1]: + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})." + ) + embeddings = self.projection(pixel_values).permute(0, 2, 3, 1) + return embeddings + + +class SegGptEmbeddings(nn.Module): + """ + Construct the embeddings from patch, position embeddings for input and prompt. + """ + + def __init__(self, config: SegGptConfig) -> None: + super().__init__() + + self.mask_token = nn.Parameter(torch.zeros(1, 1, 1, config.hidden_size)) + self.segment_token_input = nn.Parameter(torch.zeros(1, 1, 1, config.hidden_size)) + self.segment_token_prompt = nn.Parameter(torch.zeros(1, 1, 1, config.hidden_size)) + # token for seg types + self.type_token_semantic = nn.Parameter(torch.zeros(1, 1, 1, config.hidden_size)) + self.type_token_instance = nn.Parameter(torch.zeros(1, 1, 1, config.hidden_size)) + + self.patch_embeddings = SegGptPatchEmbeddings(config) + + num_positions = (config.pretrain_image_size // config.patch_size) ** 2 + 1 + self.position_embeddings = nn.Parameter(torch.randn(1, num_positions, config.hidden_size)) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def interpolate_pos_encoding(self, height: int, width: int) -> torch.Tensor: + patch_pos_embed = self.position_embeddings[:, 1:] + num_patches = patch_pos_embed.shape[1] + pretrain_patch_size = torch_int(num_patches**0.5) + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if torch.jit.is_tracing() or pretrain_patch_size != height or pretrain_patch_size != width: + patch_pos_embed = F.interpolate( + patch_pos_embed.reshape(1, pretrain_patch_size, pretrain_patch_size, -1).permute(0, 3, 1, 2), + size=(height, width), + mode="bicubic", + align_corners=False, + ) + + return patch_pos_embed.permute(0, 2, 3, 1) + else: + return patch_pos_embed.reshape(1, height, width, -1) + + def forward( + self, + pixel_values: torch.Tensor, + prompt_pixel_values: torch.Tensor, + bool_masked_pos: Optional[torch.BoolTensor] = None, + embedding_type: Optional[str] = None, + ) -> torch.Tensor: + input_embeddings = self.patch_embeddings(pixel_values) + prompt_embeddings = self.patch_embeddings(prompt_pixel_values) + + batch_size, patch_height, patch_width, _ = input_embeddings.shape + + mask_token = self.mask_token.expand(batch_size, patch_height, patch_width, -1) + # replace the masked visual tokens by mask_token + w = bool_masked_pos.unsqueeze(-1).type_as(mask_token).reshape(-1, patch_height, patch_width, 1) + prompt_embeddings = prompt_embeddings * (1 - w) + mask_token * w + + embedding_type = embedding_type if embedding_type is not None else "instance" + + # add positional encoding to each token + pos_embed = self.interpolate_pos_encoding(patch_height, patch_width) + + # add segment token + input_embeddings = input_embeddings + self.segment_token_input + prompt_embeddings = prompt_embeddings + self.segment_token_prompt + + # add position embedding skipping CLS + input_embeddings = input_embeddings + pos_embed + prompt_embeddings = prompt_embeddings + pos_embed + + # add type embedding to each token + if embedding_type == "semantic": + type_embedding = self.type_token_semantic + elif embedding_type == "instance": + type_embedding = self.type_token_instance + else: + raise ValueError(f"Embedding type should be either 'semantic' or 'instance', but got {embedding_type}") + + input_embeddings = input_embeddings + type_embedding + prompt_embeddings = prompt_embeddings + type_embedding + + embeddings = torch.cat((input_embeddings, prompt_embeddings), dim=0) + + return embeddings + + +class SegGptAttention(nn.Module): + """Multi-head Attention block with relative position embeddings.""" + + def __init__(self, config): + super().__init__() + image_size, patch_size = config.image_size, config.patch_size + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + + input_size = (image_size[0] // config.patch_size, image_size[1] // config.patch_size) + head_dim = config.hidden_size // config.num_attention_heads + + self.num_attention_heads = config.num_attention_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.qkv_bias) + self.proj = nn.Linear(config.hidden_size, config.hidden_size) + + self.use_relative_position_embeddings = config.use_relative_position_embeddings + if self.use_relative_position_embeddings: + if input_size is None: + raise ValueError("Input size must be provided if using relative positional encoding.") + + # initialize relative positional embeddings + self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) + + def get_rel_pos(self, q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + + Args: + q_size (int): + size of the query. + k_size (int): + size of key k. + rel_pos (`torch.Tensor`): + relative position embeddings (L, channel). + + Returns: + Extracted positional embeddings according to relative positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=max_rel_dist, + mode="linear", + ) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + + def add_decomposed_rel_pos( + self, + attn: torch.Tensor, + query: torch.Tensor, + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, + q_size: Tuple[int, int], + k_size: Tuple[int, int], + ) -> torch.Tensor: + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py + + Args: + attn (`torch.Tensor`): + attention map. + query (`torch.Tensor`): + query q in the attention layer with shape (batch_size, query_height * query_width, channel). + rel_pos_h (`torch.Tensor`): + relative position embeddings (Lh, channel) for height axis. + rel_pos_w (`torch.Tensor`): + relative position embeddings (Lw, channel) for width axis. + q_size (tuple): + spatial sequence size of query q with (query_height, query_width). + k_size (tuple): + spatial sequence size of key k with (key_height, key_width). + + Returns: + attn (`torch.Tensor`): + attention map with added relative positional embeddings. + """ + query_height, query_width = q_size + key_height, key_width = k_size + relative_position_height = self.get_rel_pos(query_height, key_height, rel_pos_h) + relative_position_width = self.get_rel_pos(query_width, key_width, rel_pos_w) + + batch_size, _, dim = query.shape + reshaped_query = query.reshape(batch_size, query_height, query_width, dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height) + rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width) + attn = attn.reshape(batch_size, query_height, query_width, key_height, key_width) + attn = attn + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] + attn = attn.reshape(batch_size, query_height * query_width, key_height * key_width) + return attn + + def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor: + batch_size, height, width, _ = hidden_states.shape + # qkv with shape (3, batch_size, nHead, height * width, channel) + qkv = ( + self.qkv(hidden_states) + .reshape(batch_size, height * width, 3, self.num_attention_heads, -1) + .permute(2, 0, 3, 1, 4) + ) + # q, k, v with shape (batch_size * nHead, height * width, channel) + query, key, value = qkv.reshape(3, batch_size * self.num_attention_heads, height * width, -1).unbind(0) + + attn_weights = (query * self.scale) @ key.transpose(-2, -1) + + if self.use_relative_position_embeddings: + attn_weights = self.add_decomposed_rel_pos( + attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width) + ) + + attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(batch_size, self.num_attention_heads, height * width, -1) + attn_weights = attn_weights_reshaped.view(batch_size * self.num_attention_heads, height * width, -1) + else: + attn_weights_reshaped = None + + attn_output = (attn_weights @ value).reshape(batch_size, self.num_attention_heads, height, width, -1) + attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1) + + attn_output = self.proj(attn_output) + + return (attn_output, attn_weights_reshaped) + + +# Copied from transformers.models.sam.modeling_sam.SamMLPBlock with SamMLPBlock->SegGptMlp +class SegGptMlp(nn.Module): + def __init__(self, config): + super().__init__() + self.lin1 = nn.Linear(config.hidden_size, config.mlp_dim) + self.lin2 = nn.Linear(config.mlp_dim, config.hidden_size) + self.act = ACT2FN[config.hidden_act] + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.lin1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.lin2(hidden_states) + return hidden_states + + +# Copied from transformers.models.beit.modeling_beit.drop_path +def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, + however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the + layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the + argument. + """ + if drop_prob == 0.0 or not training: + return input + keep_prob = 1 - drop_prob + shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) + random_tensor.floor_() # binarize + output = input.div(keep_prob) * random_tensor + return output + + +# Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->SegGpt +class SegGptDropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: Optional[float] = None) -> None: + super().__init__() + self.drop_prob = drop_prob + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return drop_path(hidden_states, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return "p={}".format(self.drop_prob) + + +class SegGptLayer(nn.Module): + def __init__(self, config: SegGptConfig, drop_path_rate: float) -> None: + super().__init__() + self.attention = SegGptAttention(config) + self.mlp = SegGptMlp(config) + self.drop_path = SegGptDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() + self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + ensemble_cond: int, + feature_ensemble: bool = False, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + self_attention_outputs = self.attention( + self.layernorm_before(hidden_states), # in SegGpt, layernorm is applied before self-attention + output_attentions=output_attentions, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + if feature_ensemble and attention_output.shape[0] // 2 >= ensemble_cond: + prompt, inputs = attention_output.split(attention_output.shape[1] // 2, dim=1) + if ensemble_cond == 2: + num_prompts = attention_output.shape[0] // 2 + inputs = inputs.reshape(2, num_prompts, -1) + inputs = inputs.mean(dim=1, keepdim=True).expand_as(inputs) + inputs = inputs.reshape(*prompt.shape) + else: + inputs = inputs.mean(dim=0, keepdim=True).expand_as(inputs) + attention_output = torch.cat([prompt, inputs], dim=1) + + # first residual connection + hidden_states = self.drop_path(attention_output) + hidden_states + residual = hidden_states + + hidden_states = self.layernorm_after(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + self.drop_path(hidden_states) + + outputs = (hidden_states,) + outputs + + return outputs + + +class SegGptEncoder(nn.Module): + def __init__(self, config: SegGptConfig) -> None: + super().__init__() + self.config = config + dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers, device="cpu")] + self.layers = nn.ModuleList([SegGptLayer(config, dpr[i]) for i in range(config.num_hidden_layers)]) + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + feature_ensemble: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ) -> Union[tuple, SegGptEncoderOutput]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + intermediate_hidden_states = [] + + for i, layer_module in enumerate(self.layers): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # Condition to check if we have the appropriate number of prompts to ensemble + ensemble_cond = 2 if self.config.merge_index > i else 1 + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.__call__, + hidden_states, + ensemble_cond, + feature_ensemble, + output_attentions, + ) + else: + layer_outputs = layer_module(hidden_states, ensemble_cond, feature_ensemble, output_attentions) + + hidden_states = layer_outputs[0] + + if i == self.config.merge_index: + hidden_states = ( + hidden_states[: hidden_states.shape[0] // 2] + hidden_states[hidden_states.shape[0] // 2 :] + ) * 0.5 + + if i in self.config.intermediate_hidden_state_indices: + intermediate_hidden_states.append(self.layernorm(hidden_states)) + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [hidden_states, all_hidden_states, all_self_attentions, intermediate_hidden_states] + if v is not None + ) + return SegGptEncoderOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + intermediate_hidden_states=intermediate_hidden_states, + ) + + +# Copied from transformers.models.convnext.modeling_convnext.ConvNextLayerNorm with ConvNext->SegGpt +class SegGptLayerNorm(nn.Module): + r"""LayerNorm that supports two data formats: channels_last (default) or channels_first. + The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height, + width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width). + """ + + def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): + super().__init__() + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.eps = eps + self.data_format = data_format + if self.data_format not in ["channels_last", "channels_first"]: + raise NotImplementedError(f"Unsupported data format: {self.data_format}") + self.normalized_shape = (normalized_shape,) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.data_format == "channels_last": + x = torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + elif self.data_format == "channels_first": + input_dtype = x.dtype + x = x.float() + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = x.to(dtype=input_dtype) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + + +class SegGptDecoderHead(nn.Module): + def __init__(self, config): + super().__init__() + self.conv = nn.Conv2d( + config.decoder_hidden_size, + config.decoder_hidden_size, + kernel_size=3, + padding=1, + ) + self.layernorm = SegGptLayerNorm( + normalized_shape=config.decoder_hidden_size, eps=config.layer_norm_eps, data_format="channels_first" + ) + self.act_fct = ACT2FN[config.hidden_act] + self.head = nn.Conv2d(config.decoder_hidden_size, 3, kernel_size=1, bias=True) # decoder to patch + + def forward(self, hidden_states: torch.FloatTensor): + hidden_states = self.conv(hidden_states) + hidden_states = self.layernorm(hidden_states) + hidden_states = self.act_fct(hidden_states) + hidden_states = self.head(hidden_states) + + return hidden_states + + +class SegGptDecoder(nn.Module): + def __init__(self, config): + super().__init__() + self.decoder_embed = nn.Linear( + config.hidden_size * len(config.intermediate_hidden_state_indices), + config.patch_size**2 * config.decoder_hidden_size, + bias=True, + ) + self.decoder_pred = SegGptDecoderHead(config) + self.patch_size = config.patch_size + self.decoder_hidden_size = config.decoder_hidden_size + self.config = config + + def _reshape_hidden_states(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: + batch_size, patch_height, patch_width, _ = hidden_states.shape + hidden_states = hidden_states.reshape( + batch_size, patch_height, patch_width, self.patch_size, self.patch_size, self.decoder_hidden_size + ) + hidden_states = hidden_states.permute(0, 5, 1, 3, 2, 4) + hidden_states = hidden_states.reshape( + shape=(batch_size, -1, patch_height * self.patch_size, patch_width * self.patch_size) + ) + + return hidden_states + + def forward(self, hidden_states: torch.FloatTensor): + hidden_states = self.decoder_embed(hidden_states) + hidden_states = self._reshape_hidden_states(hidden_states) + hidden_states = self.decoder_pred(hidden_states) + + return hidden_states + + +class SegGptPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = SegGptConfig + base_model_prefix = "model" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + _no_split_modules = ["SegGptEmbeddings", "SegGptLayer"] + + def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: + """Initialize the weights""" + std = self.config.initializer_range + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid + # `trunc_normal_cpu` not implemented in `half` issues + module.weight.data = nn.init.trunc_normal_(module.weight.data.to(torch.float32), mean=0.0, std=std).to( + module.weight.dtype + ) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, SegGptAttention): + module.rel_pos_h.data = nn.init.trunc_normal_( + module.rel_pos_h.data.to(torch.float32), + mean=0.0, + std=std, + ).to(module.rel_pos_h.dtype) + + module.rel_pos_w.data = nn.init.trunc_normal_( + module.rel_pos_w.data.to(torch.float32), + mean=0.0, + std=std, + ).to(module.rel_pos_w.dtype) + + elif isinstance(module, SegGptEmbeddings): + module.position_embeddings.data = nn.init.trunc_normal_( + module.position_embeddings.data.to(torch.float32), + mean=0.0, + std=std, + ).to(module.position_embeddings.dtype) + + torch.nn.init.normal_(module.mask_token, std=std) + torch.nn.init.normal_(module.segment_token_input, std=std) + torch.nn.init.normal_(module.segment_token_prompt, std=std) + torch.nn.init.normal_(module.type_token_semantic, std=std) + torch.nn.init.normal_(module.type_token_instance, std=std) + + +SEGGPT_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it + as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`SegGptConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +SEGGPT_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`SegGptImageProcessor.__call__`] + for details. + + prompt_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Prompt pixel values. Prompt pixel values can be obtained using [`AutoImageProcessor`]. See + [`SegGptImageProcessor.__call__`] for details. + + prompt_masks (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Prompt mask. Prompt mask can be obtained using [`AutoImageProcessor`]. See [`SegGptImageProcessor.__call__`] for + details. + + bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*): + Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). + + feature_ensemble (`bool`, *optional*): + Boolean indicating whether to use feature ensemble or not. If `True`, the model will use feature ensemble + if we have at least two prompts. If `False`, the model will not use feature ensemble. This argument should + be considered when doing few-shot inference on an input image i.e. more than one prompt for the same image. + + embedding_type (`str`, *optional*): + Embedding type. Indicates whether the prompt is a semantic or instance embedding. Can be either + instance or semantic. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare SegGpt Model transformer outputting raw hidden-states without any specific head on top.", + SEGGPT_START_DOCSTRING, +) +class SegGptModel(SegGptPreTrainedModel): + def __init__(self, config: SegGptConfig): + super().__init__(config) + self.config = config + + self.embeddings = SegGptEmbeddings(config) + self.encoder = SegGptEncoder(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> SegGptPatchEmbeddings: + return self.embeddings.patch_embeddings + + def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None: + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(SEGGPT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=SegGptEncoderOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: torch.Tensor, + prompt_pixel_values: torch.Tensor, + prompt_masks: torch.Tensor, + bool_masked_pos: Optional[torch.BoolTensor] = None, + feature_ensemble: Optional[bool] = None, + embedding_type: Optional[str] = None, + labels: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SegGptEncoderOutput]: + r""" + labels (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, `optional`): + Ground truth mask for input images. + + Returns: + + Examples: + + ```python + >>> from transformers import SegGptImageProcessor, SegGptModel + >>> from PIL import Image + >>> import requests + + >>> image_input_url = "https://raw.githubusercontent.com/baaivision/Painter/main/SegGPT/SegGPT_inference/examples/hmbb_2.jpg" + >>> image_prompt_url = "https://raw.githubusercontent.com/baaivision/Painter/main/SegGPT/SegGPT_inference/examples/hmbb_1.jpg" + >>> mask_prompt_url = "https://raw.githubusercontent.com/baaivision/Painter/main/SegGPT/SegGPT_inference/examples/hmbb_1_target.png" + + >>> image_input = Image.open(requests.get(image_input_url, stream=True).raw) + >>> image_prompt = Image.open(requests.get(image_prompt_url, stream=True).raw) + >>> mask_prompt = Image.open(requests.get(mask_prompt_url, stream=True).raw).convert("L") + + >>> checkpoint = "BAAI/seggpt-vit-large" + >>> model = SegGptModel.from_pretrained(checkpoint) + >>> image_processor = SegGptImageProcessor.from_pretrained(checkpoint) + + >>> inputs = image_processor(images=image_input, prompt_images=image_prompt, prompt_masks=mask_prompt, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> list(outputs.last_hidden_state.shape) + [1, 56, 28, 1024] + ``` + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + feature_ensemble = feature_ensemble if feature_ensemble is not None else False + + expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype + pixel_values = pixel_values.to(expected_dtype) + prompt_pixel_values = prompt_pixel_values.to(expected_dtype) + + # Prepare inputs + pixel_values = torch.cat((prompt_pixel_values, pixel_values), dim=2) + prompt_pixel_values = ( + torch.cat((prompt_masks, prompt_masks), dim=2) + if labels is None + else torch.cat((prompt_masks, labels), dim=2) + ) + + if bool_masked_pos is None and labels is not None: + logger.warning_once( + "Labels were provided, but bool_masked_pos were not. It will be set to default value. If you're training the model, make sure to provide a bool_masked_pos." + ) + + # We concat on height axis so SegGPT can handle as a single image, hence we need to mask the portion + # of the mask prompt pixels that will be destinated to the prediction as they don't add any information. + # This is only the case for inference. In training, the model concat of prompt mask and label is masked + # and reconstructed together (In-Context Painting). + if bool_masked_pos is None: + num_patches = self.embeddings.patch_embeddings.num_patches + bool_masked_pos_zeros = torch.zeros(num_patches // 2, dtype=torch.bool, device=pixel_values.device) + bool_masked_pos_ones = torch.ones( + num_patches - num_patches // 2, dtype=torch.bool, device=pixel_values.device + ) + bool_masked_pos = torch.cat([bool_masked_pos_zeros, bool_masked_pos_ones]) + bool_masked_pos = bool_masked_pos.unsqueeze(0) + + embedding_output = self.embeddings( + pixel_values, prompt_pixel_values, embedding_type=embedding_type, bool_masked_pos=bool_masked_pos + ) + + encoder_outputs = self.encoder( + embedding_output, + feature_ensemble=feature_ensemble, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + return encoder_outputs + + +def patchify(tensor: torch.Tensor, patch_size: int) -> torch.Tensor: + batch_size, num_channels, height, width = tensor.shape + patch_height = height // patch_size + patch_width = width // patch_size + + tensor = tensor.reshape(shape=(batch_size, num_channels, patch_height, patch_size, patch_width, patch_size)) + tensor = tensor.permute(0, 2, 4, 3, 5, 1) + tensor = tensor.reshape(shape=(batch_size, patch_height * patch_width, patch_size**2 * 3)) + + return tensor + + +def unpatchify(tensor: torch.Tensor, patch_height: int, patch_width: int) -> torch.Tensor: + batch_size = tensor.shape[0] + patch_size = int((tensor.shape[-1] / 3) ** 0.5) + if patch_height * patch_width != tensor.shape[1]: + raise ValueError( + f"Number of patches {tensor.shape[1]} does not match patch height ({patch_height}) and width ({patch_width})." + ) + + tensor = tensor.reshape(shape=(batch_size, patch_height, patch_width, patch_size, patch_size, 3)) + tensor = tensor.permute(0, 5, 1, 3, 2, 4) + tensor = tensor.reshape(shape=(batch_size, 3, patch_height * patch_size, patch_width * patch_size)) + + return tensor + + +class SegGptLoss(nn.Module): + def __init__(self, config): + super().__init__() + self.beta = config.beta + self.patch_size = config.patch_size + + def forward( + self, + prompt_masks: torch.FloatTensor, + pred_masks: torch.FloatTensor, + labels: torch.FloatTensor, + bool_masked_pos: torch.BoolTensor, + ): + """Computes the L1 loss between the predicted masks and the ground truth masks. + + Args: + prompt_masks (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values from mask prompt. + + pred_masks (`torch.FloatTensor` of shape `(batch_size, num_channels, 2*height, width)`): + Predicted masks. + + labels (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Ground truth mask for input images. + + bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`): + Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). + + Returns: + `torch.FloatTensor`: The mean L1 loss between the predicted masks and the ground truth masks. + """ + ground_truth = torch.cat((prompt_masks, labels), dim=2) + + mask = bool_masked_pos[:, :, None].repeat(1, 1, self.patch_size**2 * 3) + mask = unpatchify(mask, ground_truth.shape[2] // self.patch_size, ground_truth.shape[3] // self.patch_size) + + loss = F.smooth_l1_loss(pred_masks, ground_truth, reduction="none", beta=self.beta) + loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches + + return loss + + +@add_start_docstrings( + "SegGpt model with a decoder on top for one-shot image segmentation.", + SEGGPT_START_DOCSTRING, +) +class SegGptForImageSegmentation(SegGptPreTrainedModel): + def __init__(self, config: SegGptConfig): + super().__init__(config) + self.config = config + + self.model = SegGptModel(config) + self.decoder = SegGptDecoder(config) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(SEGGPT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=SegGptImageSegmentationOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: torch.Tensor, + prompt_pixel_values: torch.Tensor, + prompt_masks: torch.Tensor, + bool_masked_pos: Optional[torch.BoolTensor] = None, + feature_ensemble: Optional[bool] = None, + embedding_type: Optional[str] = None, + labels: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SegGptImageSegmentationOutput]: + r""" + labels (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, `optional`): + Ground truth mask for input images. + + Returns: + + Examples: + + ```python + >>> from transformers import SegGptImageProcessor, SegGptForImageSegmentation + >>> from PIL import Image + >>> import requests + + >>> image_input_url = "https://raw.githubusercontent.com/baaivision/Painter/main/SegGPT/SegGPT_inference/examples/hmbb_2.jpg" + >>> image_prompt_url = "https://raw.githubusercontent.com/baaivision/Painter/main/SegGPT/SegGPT_inference/examples/hmbb_1.jpg" + >>> mask_prompt_url = "https://raw.githubusercontent.com/baaivision/Painter/main/SegGPT/SegGPT_inference/examples/hmbb_1_target.png" + + >>> image_input = Image.open(requests.get(image_input_url, stream=True).raw) + >>> image_prompt = Image.open(requests.get(image_prompt_url, stream=True).raw) + >>> mask_prompt = Image.open(requests.get(mask_prompt_url, stream=True).raw).convert("L") + + >>> checkpoint = "BAAI/seggpt-vit-large" + >>> model = SegGptForImageSegmentation.from_pretrained(checkpoint) + >>> image_processor = SegGptImageProcessor.from_pretrained(checkpoint) + + >>> inputs = image_processor(images=image_input, prompt_images=image_prompt, prompt_masks=mask_prompt, return_tensors="pt") + >>> outputs = model(**inputs) + >>> result = image_processor.post_process_semantic_segmentation(outputs, target_sizes=[(image_input.height, image_input.width)])[0] + >>> print(list(result.shape)) + [170, 297] + ``` + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if bool_masked_pos is None: + num_patches = self.model.embeddings.patch_embeddings.num_patches + bool_masked_pos_zeros = torch.zeros(num_patches // 2, dtype=torch.bool, device=pixel_values.device) + bool_masked_pos_ones = torch.ones( + num_patches - num_patches // 2, dtype=torch.bool, device=pixel_values.device + ) + bool_masked_pos = torch.cat([bool_masked_pos_zeros, bool_masked_pos_ones]) + bool_masked_pos = bool_masked_pos.unsqueeze(0) + + outputs = self.model( + pixel_values=pixel_values, + prompt_pixel_values=prompt_pixel_values, + prompt_masks=prompt_masks, + bool_masked_pos=bool_masked_pos, + feature_ensemble=feature_ensemble, + embedding_type=embedding_type, + labels=labels, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + intermediate_hidden_states = outputs.intermediate_hidden_states if return_dict else outputs[-1] + intermediate_hidden_states = torch.cat(intermediate_hidden_states, dim=-1) + pred_masks = self.decoder(intermediate_hidden_states) + + loss = None + if labels is not None: + loss_fn = SegGptLoss(self.config) + loss = loss_fn(prompt_masks, pred_masks, labels, bool_masked_pos) + + if not return_dict: + output = (pred_masks,) + if output_hidden_states: + output = output + (outputs[1],) + + if output_attentions: + idx = 2 if output_hidden_states else 1 + output = output + (outputs[idx],) + + if loss is not None: + output = (loss,) + output + return output + + return SegGptImageSegmentationOutput( + loss=loss, + pred_masks=pred_masks, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = ["SegGptModel", "SegGptPreTrainedModel", "SegGptForImageSegmentation"] diff --git a/docs/transformers/src/transformers/models/sew/__init__.py b/docs/transformers/src/transformers/models/sew/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0bd614f74a68062d842548bdcba4a65ce19fd62c --- /dev/null +++ b/docs/transformers/src/transformers/models/sew/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_sew import * + from .modeling_sew import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/sew/configuration_sew.py b/docs/transformers/src/transformers/models/sew/configuration_sew.py new file mode 100644 index 0000000000000000000000000000000000000000..4365e9c33ded70b8ffc28df98b45ea5b8147d94d --- /dev/null +++ b/docs/transformers/src/transformers/models/sew/configuration_sew.py @@ -0,0 +1,256 @@ +# coding=utf-8 +# Copyright 2021 ASAPP Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""SEW model configuration""" + +import functools +import operator + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class SEWConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`SEWModel`]. It is used to instantiate a SEW model + according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the SEW + [asapp/sew-tiny-100k](https://huggingface.co/asapp/sew-tiny-100k) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32): + Vocabulary size of the SEW model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`SEW`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + squeeze_factor (`int`, *optional*, defaults to 2): + Sequence length downsampling factor after the encoder and upsampling factor after the transformer. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + hidden_dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + activation_dropout (`float`, *optional*, defaults to 0.1): + The dropout ratio for activations inside the fully connected layer. + attention_dropout (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + final_dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for the final projection layer of [`SEWForCTC`]. + layerdrop (`float`, *optional*, defaults to 0.1): + The LayerDrop probability. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) for more + details. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + feat_extract_norm (`str`, *optional*, defaults to `"group"`): + The norm to be applied to 1D convolutional layers in feature encoder. One of `"group"` for group + normalization of only the first 1D convolutional layer or `"layer"` for layer normalization of all 1D + convolutional layers. + feat_proj_dropout (`float`, *optional*, defaults to 0.0): + The dropout probability for output of the feature encoder. + feat_extract_activation (`str, `optional`, defaults to `"gelu"`): + The non-linear activation function (function or string) in the 1D convolutional layers of the feature + extractor. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` are supported. + conv_dim (`Tuple[int]` or `List[int]`, *optional*, defaults to `(64, 128, 128, 128, 128, 256, 256, 256, 256, 512, 512, 512, 512)`): + A tuple of integers defining the number of input and output channels of each 1D convolutional layer in the + feature encoder. The length of *conv_dim* defines the number of 1D convolutional layers. + conv_stride (`Tuple[int]` or `List[int]`, *optional*, defaults to `(5, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1)`): + A tuple of integers defining the stride of each 1D convolutional layer in the feature encoder. The length + of *conv_stride* defines the number of convolutional layers and has to match the length of *conv_dim*. + conv_kernel (`Tuple[int]` or `List[int]`, *optional*, defaults to `(10, 3, 1, 3, 1, 3, 1, 3, 1, 2, 1, 2, 1)`): + A tuple of integers defining the kernel size of each 1D convolutional layer in the feature encoder. The + length of *conv_kernel* defines the number of convolutional layers and has to match the length of + *conv_dim*. + conv_bias (`bool`, *optional*, defaults to `False`): + Whether the 1D convolutional layers have a bias. + num_conv_pos_embeddings (`int`, *optional*, defaults to 128): + Number of convolutional positional embeddings. Defines the kernel size of 1D convolutional positional + embeddings layer. + num_conv_pos_embedding_groups (`int`, *optional*, defaults to 16): + Number of groups of 1D convolutional positional embeddings layer. + apply_spec_augment (`bool`, *optional*, defaults to `True`): + Whether to apply *SpecAugment* data augmentation to the outputs of the feature encoder. For reference see + [SpecAugment: A Simple Data Augmentation Method for Automatic Speech + Recognition](https://arxiv.org/abs/1904.08779). + mask_time_prob (`float`, *optional*, defaults to 0.05): + Percentage (between 0 and 1) of all feature vectors along the time axis which will be masked. The masking + procecure generates ''mask_time_prob*len(time_axis)/mask_time_length'' independent masks over the axis. If + reasoning from the propability of each feature vector to be chosen as the start of the vector span to be + masked, *mask_time_prob* should be `prob_vector_start*mask_time_length`. Note that overlap may decrease the + actual percentage of masked vectors. This is only relevant if `apply_spec_augment is True`. + mask_time_length (`int`, *optional*, defaults to 10): + Length of vector span along the time axis. + mask_time_min_masks (`int`, *optional*, defaults to 2),: + The minimum number of masks of length `mask_feature_length` generated along the time axis, each time step, + irrespectively of `mask_feature_prob`. Only relevant if ''mask_time_prob*len(time_axis)/mask_time_length < + mask_time_min_masks'' + mask_feature_prob (`float`, *optional*, defaults to 0.0): + Percentage (between 0 and 1) of all feature vectors along the feature axis which will be masked. The + masking procecure generates ''mask_feature_prob*len(feature_axis)/mask_time_length'' independent masks over + the axis. If reasoning from the propability of each feature vector to be chosen as the start of the vector + span to be masked, *mask_feature_prob* should be `prob_vector_start*mask_feature_length`. Note that overlap + may decrease the actual percentage of masked vectors. This is only relevant if `apply_spec_augment is + True`. + mask_feature_length (`int`, *optional*, defaults to 10): + Length of vector span along the feature axis. + mask_feature_min_masks (`int`, *optional*, defaults to 0),: + The minimum number of masks of length `mask_feature_length` generated along the feature axis, each time + step, irrespectively of `mask_feature_prob`. Only relevant if + ''mask_feature_prob*len(feature_axis)/mask_feature_length < mask_feature_min_masks'' + ctc_loss_reduction (`str`, *optional*, defaults to `"sum"`): + Specifies the reduction to apply to the output of `torch.nn.CTCLoss`. Only relevant when training an + instance of [`SEWForCTC`]. + ctc_zero_infinity (`bool`, *optional*, defaults to `False`): + Whether to zero infinite losses and the associated gradients of `torch.nn.CTCLoss`. Infinite losses mainly + occur when the inputs are too short to be aligned to the targets. Only relevant when training an instance + of [`SEWForCTC`]. + use_weighted_layer_sum (`bool`, *optional*, defaults to `False`): + Whether to use a weighted average of layer outputs with learned weights. Only relevant when using an + instance of [`Wav2Vec2ForSequenceClassification`]. + classifier_proj_size (`int`, *optional*, defaults to 256): + Dimensionality of the projection before token mean-pooling for classification. + + Example: + + ```python + >>> from transformers import SEWConfig, SEWModel + + >>> # Initializing a SEW asapp/sew-tiny-100k style configuration + >>> configuration = SEWConfig() + + >>> # Initializing a model (with random weights) from the asapp/sew-tiny-100k style configuration + >>> model = SEWModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "sew" + + def __init__( + self, + vocab_size=32, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + squeeze_factor=2, + hidden_act="gelu", + hidden_dropout=0.1, + activation_dropout=0.1, + attention_dropout=0.1, + feat_proj_dropout=0.0, + final_dropout=0.1, + layerdrop=0.1, + initializer_range=0.02, + layer_norm_eps=1e-5, + feat_extract_norm="group", + feat_extract_activation="gelu", + conv_dim=(64, 128, 128, 128, 128, 256, 256, 256, 256, 512, 512, 512, 512), + conv_stride=(5, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1), + conv_kernel=(10, 3, 1, 3, 1, 3, 1, 3, 1, 2, 1, 2, 1), + conv_bias=False, + num_conv_pos_embeddings=128, + num_conv_pos_embedding_groups=16, + apply_spec_augment=True, + mask_time_prob=0.05, + mask_time_length=10, + mask_time_min_masks=2, + mask_feature_prob=0.0, + mask_feature_length=10, + mask_feature_min_masks=0, + ctc_loss_reduction="mean", + ctc_zero_infinity=False, + use_weighted_layer_sum=False, + classifier_proj_size=256, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + **kwargs, + ): + super().__init__(**kwargs, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id) + self.hidden_size = hidden_size + self.feat_extract_norm = feat_extract_norm + self.feat_extract_activation = feat_extract_activation + self.conv_dim = list(conv_dim) + self.conv_stride = list(conv_stride) + self.conv_kernel = list(conv_kernel) + self.conv_bias = conv_bias + self.num_conv_pos_embeddings = num_conv_pos_embeddings + self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups + self.num_feat_extract_layers = len(self.conv_dim) + self.num_hidden_layers = num_hidden_layers + self.intermediate_size = intermediate_size + self.squeeze_factor = squeeze_factor + self.hidden_act = hidden_act + self.num_attention_heads = num_attention_heads + self.hidden_dropout = hidden_dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.feat_proj_dropout = feat_proj_dropout + self.final_dropout = final_dropout + self.layerdrop = layerdrop + self.layer_norm_eps = layer_norm_eps + self.initializer_range = initializer_range + self.vocab_size = vocab_size + + if ( + (len(self.conv_stride) != self.num_feat_extract_layers) + or (len(self.conv_kernel) != self.num_feat_extract_layers) + or (len(self.conv_dim) != self.num_feat_extract_layers) + ): + raise ValueError( + "Configuration for convolutional layers is incorrect. " + "It is required that `len(config.conv_dim)` == `len(config.conv_stride)` == `len(config.conv_kernel)`, " + f"but is `len(config.conv_dim) = {len(self.conv_dim)}`, `len(config.conv_stride) " + f"= {len(self.conv_stride)}`, `len(config.conv_kernel) = {len(self.conv_kernel)}`." + ) + + # fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779 + self.apply_spec_augment = apply_spec_augment + self.mask_time_prob = mask_time_prob + self.mask_time_length = mask_time_length + self.mask_time_min_masks = mask_time_min_masks + self.mask_feature_prob = mask_feature_prob + self.mask_feature_length = mask_feature_length + self.mask_feature_min_masks = mask_feature_min_masks + + # ctc loss + self.ctc_loss_reduction = ctc_loss_reduction + self.ctc_zero_infinity = ctc_zero_infinity + + # sequence classification + self.use_weighted_layer_sum = use_weighted_layer_sum + self.classifier_proj_size = classifier_proj_size + + @property + def inputs_to_logits_ratio(self): + return functools.reduce(operator.mul, self.conv_stride, 1) + + +__all__ = ["SEWConfig"] diff --git a/docs/transformers/src/transformers/models/sew/convert_sew_original_pytorch_checkpoint_to_pytorch.py b/docs/transformers/src/transformers/models/sew/convert_sew_original_pytorch_checkpoint_to_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..df0cae2a3b298923294800295e6a86e500685d21 --- /dev/null +++ b/docs/transformers/src/transformers/models/sew/convert_sew_original_pytorch_checkpoint_to_pytorch.py @@ -0,0 +1,305 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert SEW checkpoint.""" + +import argparse +import json +import os + +import fairseq +import torch +from fairseq.data import Dictionary + +# Register SEW's fairseq modules +from sew_asapp import tasks # noqa: F401 + +from transformers import ( + SEWConfig, + SEWForCTC, + SEWModel, + Wav2Vec2CTCTokenizer, + Wav2Vec2FeatureExtractor, + Wav2Vec2Processor, + logging, +) + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + +MAPPING = { + "post_extract_proj": "feature_projection", + "encoder.pos_conv.0": "encoder.pos_conv_embed.conv", + "self_attn.k_proj": "encoder.layers.*.attention.k_proj", + "self_attn.v_proj": "encoder.layers.*.attention.v_proj", + "self_attn.q_proj": "encoder.layers.*.attention.q_proj", + "self_attn.out_proj": "encoder.layers.*.attention.out_proj", + "self_attn_layer_norm": "encoder.layers.*.layer_norm", + "fc1": "encoder.layers.*.feed_forward.intermediate_dense", + "fc2": "encoder.layers.*.feed_forward.output_dense", + "final_layer_norm": "encoder.layers.*.final_layer_norm", + "encoder.upsample.0": "encoder.upsample.projection", + "encoder.layer_norm": "encoder.layer_norm", + "w2v_model.layer_norm": "layer_norm", + "w2v_encoder.proj": "lm_head", + "mask_emb": "masked_spec_embed", +} + + +def set_recursively(hf_pointer, key, value, full_name, weight_type): + for attribute in key.split("."): + hf_pointer = getattr(hf_pointer, attribute) + + if weight_type is not None: + hf_shape = getattr(hf_pointer, weight_type).shape + else: + hf_shape = hf_pointer.shape + + assert hf_shape == value.shape, ( + f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be" + f" {value.shape} for {full_name}" + ) + + if weight_type == "weight": + hf_pointer.weight.data = value + elif weight_type == "weight_g": + hf_pointer.weight_g.data = value + elif weight_type == "weight_v": + hf_pointer.weight_v.data = value + elif weight_type == "bias": + hf_pointer.bias.data = value + else: + hf_pointer.data = value + + logger.info(f"{key + '.' + weight_type if weight_type is not None else ''} was initialized from {full_name}.") + + +def recursively_load_weights(fairseq_model, hf_model, is_finetuned): + unused_weights = [] + fairseq_dict = fairseq_model.state_dict() + + feature_extractor = hf_model.sew.feature_extractor if is_finetuned else hf_model.feature_extractor + + for name, value in fairseq_dict.items(): + is_used = False + if "conv_layers" in name: + load_conv_layer( + name, + value, + feature_extractor, + unused_weights, + hf_model.config.feat_extract_norm == "group", + ) + is_used = True + else: + for key, mapped_key in MAPPING.items(): + mapped_key = "sew." + mapped_key if (is_finetuned and mapped_key != "lm_head") else mapped_key + + if key in name or key.split("w2v_model.")[-1] == name.split(".")[0]: + is_used = True + if "*" in mapped_key: + layer_index = name.split(key)[0].split(".")[-2] + mapped_key = mapped_key.replace("*", layer_index) + if "weight_g" in name: + weight_type = "weight_g" + elif "weight_v" in name: + weight_type = "weight_v" + elif "weight" in name: + weight_type = "weight" + elif "bias" in name: + weight_type = "bias" + else: + weight_type = None + set_recursively(hf_model, mapped_key, value, name, weight_type) + continue + if not is_used: + unused_weights.append(name) + + logger.warning(f"Unused weights: {unused_weights}") + + +def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_group_norm): + name = full_name.split("conv_layers.")[-1] + items = name.split(".") + layer_id = int(items[0]) + type_id = int(items[1]) + + if type_id == 0: + if "bias" in name: + assert value.shape == feature_extractor.conv_layers[layer_id].conv.bias.data.shape, ( + f"{full_name} has size {value.shape}, but" + f" {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found." + ) + feature_extractor.conv_layers[layer_id].conv.bias.data = value + logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.") + elif "weight" in name: + assert value.shape == feature_extractor.conv_layers[layer_id].conv.weight.data.shape, ( + f"{full_name} has size {value.shape}, but" + f" {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found." + ) + feature_extractor.conv_layers[layer_id].conv.weight.data = value + logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.") + elif (type_id == 2 and not use_group_norm) or (type_id == 2 and layer_id == 0 and use_group_norm): + if "bias" in name: + assert value.shape == feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape, ( + f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.bias.data.shape} was" + " found." + ) + feature_extractor.conv_layers[layer_id].layer_norm.bias.data = value + logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.") + elif "weight" in name: + assert value.shape == feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape, ( + f"{full_name} has size {value.shape}, but" + f" {feature_extractor[layer_id].layer_norm.weight.data.shape} was found." + ) + feature_extractor.conv_layers[layer_id].layer_norm.weight.data = value + logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.") + else: + unused_weights.append(full_name) + + +def convert_config(model, is_finetuned): + config = SEWConfig() + if is_finetuned: + fs_config = model.w2v_encoder.w2v_model.cfg + else: + fs_config = model.cfg + + config.conv_bias = fs_config.conv_bias + conv_layers = eval(fs_config.conv_feature_layers) + config.conv_dim = [x[0] for x in conv_layers] + config.conv_kernel = [x[1] for x in conv_layers] + config.conv_stride = [x[2] for x in conv_layers] + config.feat_extract_activation = "gelu" + config.feat_extract_norm = "layer" if fs_config.extractor_mode == "layer_norm" else "group" + config.final_dropout = 0.0 + config.hidden_act = fs_config.activation_fn.name + config.hidden_size = fs_config.encoder_embed_dim + config.initializer_range = 0.02 + config.intermediate_size = fs_config.encoder_ffn_embed_dim + config.layer_norm_eps = 1e-5 + config.layerdrop = fs_config.encoder_layerdrop + config.num_attention_heads = fs_config.encoder_attention_heads + config.num_conv_pos_embedding_groups = fs_config.conv_pos_groups + config.num_conv_pos_embeddings = fs_config.conv_pos + config.num_feat_extract_layers = len(conv_layers) + config.num_hidden_layers = fs_config.encoder_layers + config.squeeze_factor = fs_config.squeeze_factor + + # take care of any params that are overridden by the Wav2VecCtc model + if is_finetuned: + fs_config = model.cfg + config.final_dropout = fs_config.final_dropout + config.layerdrop = fs_config.layerdrop + config.activation_dropout = fs_config.activation_dropout + config.apply_spec_augment = fs_config.mask_prob > 0 or fs_config.mask_channel_prob > 0 + config.attention_dropout = fs_config.attention_dropout + config.feat_proj_dropout = fs_config.dropout_input + config.hidden_dropout = fs_config.dropout + config.mask_feature_length = fs_config.mask_channel_length + config.mask_feature_prob = fs_config.mask_channel_prob + config.mask_time_length = fs_config.mask_length + config.mask_time_prob = fs_config.mask_prob + + config.feature_extractor_type = "Wav2Vec2FeatureExtractor" + config.tokenizer_class = "Wav2Vec2CTCTokenizer" + + return config + + +@torch.no_grad() +def convert_sew_checkpoint( + checkpoint_path, pytorch_dump_folder_path, config_path=None, dict_path=None, is_finetuned=True +): + """ + Copy/paste/tweak model's weights to transformers design. + """ + + if is_finetuned: + model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task( + [checkpoint_path], arg_overrides={"data": "/".join(dict_path.split("/")[:-1])} + ) + else: + model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([checkpoint_path]) + + if config_path is not None: + config = SEWConfig.from_pretrained(config_path) + else: + config = convert_config(model[0], is_finetuned) + model = model[0].eval() + + return_attention_mask = True if config.feat_extract_norm == "layer" else False + feature_extractor = Wav2Vec2FeatureExtractor( + feature_size=1, + sampling_rate=16000, + padding_value=0, + do_normalize=True, + return_attention_mask=return_attention_mask, + ) + + if is_finetuned: + if dict_path: + target_dict = Dictionary.load(dict_path) + + # important change bos & pad token id since CTC symbol is and + # not as in fairseq + target_dict.indices[target_dict.bos_word] = target_dict.pad_index + target_dict.indices[target_dict.pad_word] = target_dict.bos_index + config.bos_token_id = target_dict.pad_index + config.pad_token_id = target_dict.bos_index + config.eos_token_id = target_dict.eos_index + config.vocab_size = len(target_dict.symbols) + vocab_path = os.path.join(pytorch_dump_folder_path, "vocab.json") + if not os.path.isdir(pytorch_dump_folder_path): + logger.error("--pytorch_dump_folder_path ({}) should be a directory".format(pytorch_dump_folder_path)) + return + os.makedirs(pytorch_dump_folder_path, exist_ok=True) + with open(vocab_path, "w", encoding="utf-8") as vocab_handle: + json.dump(target_dict.indices, vocab_handle) + tokenizer = Wav2Vec2CTCTokenizer( + vocab_path, + unk_token=target_dict.unk_word, + pad_token=target_dict.pad_word, + bos_token=target_dict.bos_word, + eos_token=target_dict.eos_word, + word_delimiter_token="|", + do_lower_case=False, + ) + processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer) + processor.save_pretrained(pytorch_dump_folder_path) + + hf_model = SEWForCTC(config) + else: + hf_model = SEWModel(config) + feature_extractor.save_pretrained(pytorch_dump_folder_path) + + recursively_load_weights(model, hf_model, is_finetuned) + + hf_model.save_pretrained(pytorch_dump_folder_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.") + parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to fairseq checkpoint") + parser.add_argument("--dict_path", default=None, type=str, help="Path to dict of fine-tuned model") + parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert") + parser.add_argument( + "--is_finetuned", action="store_true", help="Whether the model to convert is a fine-tuned model or not" + ) + args = parser.parse_args() + convert_sew_checkpoint( + args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path, args.dict_path, args.is_finetuned + ) diff --git a/docs/transformers/src/transformers/models/sew/modeling_sew.py b/docs/transformers/src/transformers/models/sew/modeling_sew.py new file mode 100644 index 0000000000000000000000000000000000000000..572c07e3c9d9d573ab61280ced5fe7fdd2e8a87a --- /dev/null +++ b/docs/transformers/src/transformers/models/sew/modeling_sew.py @@ -0,0 +1,1498 @@ +# coding=utf-8 +# Copyright 2021 ASAPP Inc. and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch SEW model.""" + +import math +import warnings +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss + +from ...activations import ACT2FN +from ...integrations.deepspeed import is_deepspeed_zero3_enabled +from ...integrations.fsdp import is_fsdp_managed_module +from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput +from ...modeling_utils import PreTrainedModel +from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, +) +from .configuration_sew import SEWConfig + + +if is_flash_attn_available(): + from ...modeling_flash_attention_utils import _flash_attention_forward + + +logger = logging.get_logger(__name__) + + +_HIDDEN_STATES_START_POSITION = 1 + +# General docstring +_CONFIG_FOR_DOC = "SEWConfig" + +# Base docstring +_CHECKPOINT_FOR_DOC = "asapp/sew-tiny-100k-ft-ls100h" +_EXPECTED_OUTPUT_SHAPE = [1, 292, 512] + +# CTC docstring +_CTC_EXPECTED_OUTPUT = ( + "'MISTER QUILTER IS THE APPOSTILE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPOLLE'" +) +_CTC_EXPECTED_LOSS = 0.42 + +# Audio class docstring +_SEQ_CLASS_CHECKPOINT = "anton-l/sew-mid-100k-ft-keyword-spotting" +_SEQ_CLASS_EXPECTED_OUTPUT = "'_unknown_'" +_SEQ_CLASS_EXPECTED_LOSS = 9.52 + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices +def _compute_mask_indices( + shape: Tuple[int, int], + mask_prob: float, + mask_length: int, + attention_mask: Optional[torch.LongTensor] = None, + min_masks: int = 0, +) -> np.ndarray: + """ + Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for + ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on + CPU as part of the preprocessing during training. + + Args: + shape: The shape for which to compute masks. This should be of a tuple of size 2 where + the first element is the batch size and the second element is the length of the axis to span. + mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of + independently generated mask spans of length `mask_length` is computed by + `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the + actual percentage will be smaller. + mask_length: size of the mask + min_masks: minimum number of masked spans + attention_mask: A (right-padded) attention mask which independently shortens the feature axis of + each batch dimension. + """ + batch_size, sequence_length = shape + + if mask_length < 1: + raise ValueError("`mask_length` has to be bigger than 0.") + + if mask_length > sequence_length: + raise ValueError( + f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}" + f" and `sequence_length`: {sequence_length}`" + ) + + # epsilon is used for probabilistic rounding + epsilon = np.random.rand(1).item() + + def compute_num_masked_span(input_length): + """Given input length, compute how many spans should be masked""" + num_masked_span = int(mask_prob * input_length / mask_length + epsilon) + num_masked_span = max(num_masked_span, min_masks) + + # make sure num masked span <= sequence_length + if num_masked_span * mask_length > sequence_length: + num_masked_span = sequence_length // mask_length + + # make sure num_masked span is also <= input_length - (mask_length - 1) + if input_length - (mask_length - 1) < num_masked_span: + num_masked_span = max(input_length - (mask_length - 1), 0) + + return num_masked_span + + # compute number of masked spans in batch + input_lengths = ( + attention_mask.detach().sum(-1).tolist() + if attention_mask is not None + else [sequence_length for _ in range(batch_size)] + ) + + # SpecAugment mask to fill + spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool) + spec_aug_mask_idxs = [] + + max_num_masked_span = compute_num_masked_span(sequence_length) + + if max_num_masked_span == 0: + return spec_aug_mask + + for input_length in input_lengths: + # compute num of masked spans for this input + num_masked_span = compute_num_masked_span(input_length) + + # get random indices to mask + spec_aug_mask_idx = np.random.choice( + np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False + ) + + # pick first sampled index that will serve as a dummy index to pad vector + # to ensure same dimension for all batches due to probabilistic rounding + # Picking first sample just pads those vectors twice. + if len(spec_aug_mask_idx) == 0: + # this case can only happen if `input_length` is strictly smaller then + # `sequence_length` in which case the last token has to be a padding + # token which we can use as a dummy mask id + dummy_mask_idx = sequence_length - 1 + else: + dummy_mask_idx = spec_aug_mask_idx[0] + + spec_aug_mask_idx = np.concatenate( + [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx] + ) + spec_aug_mask_idxs.append(spec_aug_mask_idx) + + spec_aug_mask_idxs = np.array(spec_aug_mask_idxs) + + # expand masked indices to masked spans + spec_aug_mask_idxs = np.broadcast_to( + spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length) + ) + spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length) + + # add offset to the starting indexes so that indexes now create a span + offsets = np.arange(mask_length)[None, None, :] + offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape( + batch_size, max_num_masked_span * mask_length + ) + spec_aug_mask_idxs = spec_aug_mask_idxs + offsets + + # ensure that we cannot have indices larger than sequence_length + if spec_aug_mask_idxs.max() > sequence_length - 1: + spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1 + + # scatter indices to mask + np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1) + + return spec_aug_mask + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2NoLayerNormConvLayer with Wav2Vec2->SEW +class SEWNoLayerNormConvLayer(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 + self.out_conv_dim = config.conv_dim[layer_id] + + self.conv = nn.Conv1d( + self.in_conv_dim, + self.out_conv_dim, + kernel_size=config.conv_kernel[layer_id], + stride=config.conv_stride[layer_id], + bias=config.conv_bias, + ) + self.activation = ACT2FN[config.feat_extract_activation] + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + hidden_states = self.activation(hidden_states) + return hidden_states + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2LayerNormConvLayer with Wav2Vec2->SEW +class SEWLayerNormConvLayer(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 + self.out_conv_dim = config.conv_dim[layer_id] + + self.conv = nn.Conv1d( + self.in_conv_dim, + self.out_conv_dim, + kernel_size=config.conv_kernel[layer_id], + stride=config.conv_stride[layer_id], + bias=config.conv_bias, + ) + self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True) + self.activation = ACT2FN[config.feat_extract_activation] + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + + hidden_states = hidden_states.transpose(-2, -1) + hidden_states = self.layer_norm(hidden_states) + hidden_states = hidden_states.transpose(-2, -1) + + hidden_states = self.activation(hidden_states) + return hidden_states + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GroupNormConvLayer with Wav2Vec2->SEW +class SEWGroupNormConvLayer(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 + self.out_conv_dim = config.conv_dim[layer_id] + + self.conv = nn.Conv1d( + self.in_conv_dim, + self.out_conv_dim, + kernel_size=config.conv_kernel[layer_id], + stride=config.conv_stride[layer_id], + bias=config.conv_bias, + ) + self.activation = ACT2FN[config.feat_extract_activation] + + self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True) + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + hidden_states = self.layer_norm(hidden_states) + hidden_states = self.activation(hidden_states) + return hidden_states + + +class SEWPositionalConvEmbedding(nn.Module): + def __init__(self, config): + super().__init__() + self.conv = nn.Conv1d( + config.hidden_size, + config.hidden_size, + kernel_size=config.num_conv_pos_embeddings, + padding=config.num_conv_pos_embeddings // 2, + groups=config.num_conv_pos_embedding_groups, + stride=config.squeeze_factor, + ) + + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + + if is_deepspeed_zero3_enabled(): + import deepspeed + + with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0): + self.conv = weight_norm(self.conv, name="weight", dim=2) + if hasattr(self.conv, "parametrizations"): + weight_g = self.conv.parametrizations.weight.original0 + weight_v = self.conv.parametrizations.weight.original1 + else: + weight_g = self.conv.weight_g + weight_v = self.conv.weight_v + deepspeed.zero.register_external_parameter(self, weight_v) + deepspeed.zero.register_external_parameter(self, weight_g) + else: + self.conv = weight_norm(self.conv, name="weight", dim=2) + + self.padding = SEWSamePadLayer(config.num_conv_pos_embeddings) + self.activation = ACT2FN[config.feat_extract_activation] + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + hidden_states = self.padding(hidden_states) + hidden_states = self.activation(hidden_states) + + return hidden_states + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2SamePadLayer with Wav2Vec2->SEW +class SEWSamePadLayer(nn.Module): + def __init__(self, num_conv_pos_embeddings): + super().__init__() + self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0 + + def forward(self, hidden_states): + if self.num_pad_remove > 0: + hidden_states = hidden_states[:, :, : -self.num_pad_remove] + return hidden_states + + +class SEWUpsampling(nn.Module): + def __init__(self, config): + super().__init__() + self.projection = nn.Linear(config.hidden_size, config.hidden_size * config.squeeze_factor) + self.activation = ACT2FN[config.feat_extract_activation] + self.squeeze_factor = config.squeeze_factor + + def forward(self, hidden_states): + hidden_states = self.projection(hidden_states) + hidden_states = self.activation(hidden_states) + + if self.squeeze_factor > 1: + # transform embedding channels to sequence length + bsz, src_len, src_embed_dim = hidden_states.size() + tgt_len = src_len * self.squeeze_factor + tgt_embed_dim = src_embed_dim // self.squeeze_factor + hidden_states = hidden_states.reshape(bsz, src_len, self.squeeze_factor, tgt_embed_dim) + hidden_states = hidden_states.reshape(bsz, tgt_len, tgt_embed_dim) + + return hidden_states + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureEncoder with Wav2Vec2->SEW +class SEWFeatureEncoder(nn.Module): + """Construct the features from raw audio waveform""" + + def __init__(self, config): + super().__init__() + + if config.feat_extract_norm == "group": + conv_layers = [SEWGroupNormConvLayer(config, layer_id=0)] + [ + SEWNoLayerNormConvLayer(config, layer_id=i + 1) for i in range(config.num_feat_extract_layers - 1) + ] + elif config.feat_extract_norm == "layer": + conv_layers = [SEWLayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)] + else: + raise ValueError( + f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']" + ) + self.conv_layers = nn.ModuleList(conv_layers) + self.gradient_checkpointing = False + self._requires_grad = True + + def _freeze_parameters(self): + for param in self.parameters(): + param.requires_grad = False + self._requires_grad = False + + def forward(self, input_values): + hidden_states = input_values[:, None] + + # make sure hidden_states require grad for gradient_checkpointing + if self._requires_grad and self.training: + hidden_states.requires_grad = True + + for conv_layer in self.conv_layers: + if self._requires_grad and self.gradient_checkpointing and self.training: + hidden_states = self._gradient_checkpointing_func( + conv_layer.__call__, + hidden_states, + ) + else: + hidden_states = conv_layer(hidden_states) + + return hidden_states + + +class SEWFeatureExtractor(SEWFeatureEncoder): + def __init__(self, config): + super().__init__(config) + warnings.warn( + f"The class `{self.__class__.__name__}` has been depreciated " + "and will be removed in Transformers v5. " + f"Use `{self.__class__.__bases__[0].__name__}` instead.", + FutureWarning, + ) + + +# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->SEW +class SEWAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + is_causal: bool = False, + config: Optional[SEWConfig] = None, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + self.config = config + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + self.is_causal = is_causal + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.reshape(*proj_shape) + value_states = value_states.reshape(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value + + +# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->SEW +class SEWFlashAttention2(SEWAttention): + """ + SEW flash attention module. This module inherits from `SEWAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() + + def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # SEWFlashAttention2 attention does not support output_attentions + if output_attentions: + raise ValueError("SEWFlashAttention2 attention does not support output_attentions") + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, q_len, _ = hidden_states.size() + + # get query proj + query_states = self._reshape(self.q_proj(hidden_states), -1, bsz) + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0].transpose(1, 2) + value_states = past_key_value[1].transpose(1, 2) + elif is_cross_attention: + # cross_attentions + key_states = self._reshape(self.k_proj(key_value_states), -1, bsz) + value_states = self._reshape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) + value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1) + value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1) + else: + # self_attention + key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) + value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2)) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=self.dropout if self.training else 0.0, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class SEWSdpaAttention(SEWAttention): + # Copied from transformers.models.bart.modeling_bart.BartSdpaAttention.forward with Bart->SEW + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + if output_attentions or layer_head_mask is not None: + # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "SEWModel is using SEWSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention" + ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states, + key_value_states=key_value_states, + past_key_value=past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + query_states = self._shape(query_states, tgt_len, bsz) + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. + is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False + + # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, + # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=self.dropout if self.training else 0.0, + is_causal=is_causal, + ) + + if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, None, past_key_value + + +SEW_ATTENTION_CLASSES = { + "eager": SEWAttention, + "sdpa": SEWSdpaAttention, + "flash_attention_2": SEWFlashAttention2, +} + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeedForward with Wav2Vec2->SEW +class SEWFeedForward(nn.Module): + def __init__(self, config): + super().__init__() + self.intermediate_dropout = nn.Dropout(config.activation_dropout) + + self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.output_dropout = nn.Dropout(config.hidden_dropout) + + def forward(self, hidden_states): + hidden_states = self.intermediate_dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + hidden_states = self.intermediate_dropout(hidden_states) + + hidden_states = self.output_dense(hidden_states) + hidden_states = self.output_dropout(hidden_states) + return hidden_states + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2EncoderLayer with Wav2Vec2->SEW, WAV2VEC2->SEW +class SEWEncoderLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.attention = SEW_ATTENTION_CLASSES[config._attn_implementation]( + embed_dim=config.hidden_size, + num_heads=config.num_attention_heads, + dropout=config.attention_dropout, + is_decoder=False, + ) + + self.dropout = nn.Dropout(config.hidden_dropout) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.feed_forward = SEWFeedForward(config) + self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states, attention_mask=None, output_attentions=False): + attn_residual = hidden_states + hidden_states, attn_weights, _ = self.attention( + hidden_states, attention_mask=attention_mask, output_attentions=output_attentions + ) + hidden_states = self.dropout(hidden_states) + hidden_states = attn_residual + hidden_states + + hidden_states = self.layer_norm(hidden_states) + hidden_states = hidden_states + self.feed_forward(hidden_states) + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class SEWEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.pos_conv_embed = SEWPositionalConvEmbedding(config) + self.pool = nn.AvgPool1d(config.squeeze_factor, config.squeeze_factor) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout) + self.layers = nn.ModuleList([SEWEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.upsample = SEWUpsampling(config) + self.gradient_checkpointing = False + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + + def forward( + self, + hidden_states, + attention_mask=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + if attention_mask is not None: + expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) + if self._use_flash_attention_2: + # make sure padded tokens output 0 + hidden_states[~expand_attention_mask] = 0.0 + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + else: + # make sure padded tokens output 0 + hidden_states[~expand_attention_mask] = 0.0 + input_lengths = (attention_mask.long()).sum(-1) + # apply pooling formula to get real output_lengths + output_lengths = input_lengths // self.config.squeeze_factor + max_encoder_length = hidden_states.shape[1] // self.config.squeeze_factor + attention_ids = ( + torch.arange(0, max_encoder_length, device=output_lengths.device) + .view(1, -1) + .expand(output_lengths.shape[0], -1) + ) + attention_mask = (attention_ids < output_lengths.view(-1, 1)).long() + + # extend attention_mask + attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) + attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min + attention_mask = attention_mask.expand( + attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] + ) + + n_input_timesteps = hidden_states.shape[1] + + hidden_states = hidden_states.transpose(1, 2) + position_embeddings = self.pos_conv_embed(hidden_states) + pooled_hidden_states = self.pool(hidden_states) + min_length = min(position_embeddings.size(-1), pooled_hidden_states.size(-1)) + hidden_states = pooled_hidden_states[..., :min_length] + position_embeddings[..., :min_length] + hidden_states = hidden_states.transpose(1, 2) + + hidden_states = self.layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states) + + synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) + + for layer in self.layers: + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = torch.rand([]) + + skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False + if not skip_the_layer or synced_gpus: + # under fsdp or deepspeed zero3 all gpus must run in sync + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + output_attentions, + ) + else: + layer_outputs = layer( + hidden_states, attention_mask=attention_mask, output_attentions=output_attentions + ) + hidden_states = layer_outputs[0] + + if skip_the_layer: + layer_outputs = (None, None) + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + hidden_states = self.upsample(hidden_states) + if hidden_states.shape[1] < n_input_timesteps: + hidden_states = nn.functional.pad(hidden_states, (0, 0, 0, n_input_timesteps - hidden_states.shape[1])) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class SEWPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = SEWConfig + base_model_prefix = "sew" + main_input_name = "input_values" + supports_gradient_checkpointing = True + _supports_flash_attn_2 = True + _supports_sdpa = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, SEWPositionalConvEmbedding): + nn.init.normal_( + module.conv.weight, + mean=0, + std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)), + ) + nn.init.constant_(module.conv.bias, 0) + elif isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, nn.Conv1d): + if is_deepspeed_zero3_enabled(): + import deepspeed + + if hasattr(module, "weight_v") and hasattr(module, "weight_g"): + with deepspeed.zero.GatheredParameters([module.weight_v, module.weight_g], modifier_rank=0): + nn.init.kaiming_normal_(module.weight.data) + else: + with deepspeed.zero.GatheredParameters(module.weight, modifier_rank=0): + nn.init.kaiming_normal_(module.weight.data) + else: + nn.init.kaiming_normal_(module.weight.data) + + if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None: + module.bias.data.zero_() + + def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): + """ + Computes the output length of the convolutional layers + """ + + def _conv_out_length(input_length, kernel_size, stride): + # 1D convolutional layer output length formula taken + # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html + return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1 + + for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): + input_lengths = _conv_out_length(input_lengths, kernel_size, stride) + + return input_lengths + + def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor): + output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long) + batch_size = attention_mask.shape[0] + + attention_mask = torch.zeros( + (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device + ) + # these two operations makes sure that all values before the output lengths idxs are attended to + attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1 + attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() + return attention_mask + + +SEW_START_DOCSTRING = r""" + SEW was proposed in [Performance-Efficiency Trade-offs in Unsupervised Pre-training for Speech + Recognition](https://arxiv.org/abs/2109.06870) by Felix Wu, Kwangyoun Kim, Jing Pan, Kyu Han, Kilian Q. Weinberger, + Yoav Artzi. + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving etc.). + + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use + it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`SEWConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +SEW_INPUTS_DOCSTRING = r""" + Args: + input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file + into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install + soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and + conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2Processor.__call__`] for details. + attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0, + 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare SEW Model transformer outputting raw hidden-states without any specific head on top.", + SEW_START_DOCSTRING, +) +class SEWModel(SEWPreTrainedModel): + def __init__(self, config: SEWConfig): + super().__init__(config) + self.config = config + self.feature_extractor = SEWFeatureEncoder(config) + self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps) + + self.project_features = config.conv_dim[-1] != config.hidden_size + if self.project_features: + self.feature_projection = nn.Linear(config.conv_dim[-1], config.hidden_size) + self.feature_dropout = nn.Dropout(config.feat_proj_dropout) + + if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0: + self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_()) + + self.encoder = SEWEncoder(config) + + # Initialize weights and apply final processing + self.post_init() + + # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states + def _mask_hidden_states( + self, + hidden_states: torch.FloatTensor, + mask_time_indices: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + ): + """ + Masks extracted features along time axis and/or along feature axis according to + [SpecAugment](https://arxiv.org/abs/1904.08779). + """ + + # `config.apply_spec_augment` can set masking to False + if not getattr(self.config, "apply_spec_augment", True): + return hidden_states + + # generate indices & apply SpecAugment along time axis + batch_size, sequence_length, hidden_size = hidden_states.size() + + if mask_time_indices is not None: + # apply SpecAugment along time axis with given mask_time_indices + hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) + elif self.config.mask_time_prob > 0 and self.training: + mask_time_indices = _compute_mask_indices( + (batch_size, sequence_length), + mask_prob=self.config.mask_time_prob, + mask_length=self.config.mask_time_length, + attention_mask=attention_mask, + min_masks=self.config.mask_time_min_masks, + ) + mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool) + hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) + + if self.config.mask_feature_prob > 0 and self.training: + # generate indices & apply SpecAugment along feature axis + mask_feature_indices = _compute_mask_indices( + (batch_size, hidden_size), + mask_prob=self.config.mask_feature_prob, + mask_length=self.config.mask_feature_length, + min_masks=self.config.mask_feature_min_masks, + ) + mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool) + mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1) + hidden_states[mask_feature_indices] = 0 + + return hidden_states + + @add_start_docstrings_to_model_forward(SEW_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutput, + config_class=_CONFIG_FOR_DOC, + modality="audio", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward( + self, + input_values: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + mask_time_indices: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + extract_features = self.feature_extractor(input_values) + extract_features = extract_features.transpose(1, 2) + extract_features = self.layer_norm(extract_features) + + if self.project_features: + extract_features = self.feature_projection(extract_features) + hidden_states = self.feature_dropout(extract_features) + + if attention_mask is not None: + # compute reduced attention_mask corresponding to feature vectors + attention_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask) + + hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices) + + encoder_outputs = self.encoder( + hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = encoder_outputs[0] + + if not return_dict: + return (hidden_states,) + encoder_outputs[1:] + + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + """SEW Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""", + SEW_START_DOCSTRING, +) +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->SEW, wav2vec2->sew, WAV2VEC2->SEW +class SEWForCTC(SEWPreTrainedModel): + def __init__(self, config, target_lang: Optional[str] = None): + super().__init__(config) + + self.sew = SEWModel(config) + self.dropout = nn.Dropout(config.final_dropout) + + self.target_lang = target_lang + + if config.vocab_size is None: + raise ValueError( + f"You are trying to instantiate {self.__class__} with a configuration that " + "does not define the vocabulary size of the language model head. Please " + "instantiate the model as follows: `SEWForCTC.from_pretrained(..., vocab_size=vocab_size)`. " + "or define `vocab_size` of your model's configuration." + ) + output_hidden_size = ( + config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size + ) + self.lm_head = nn.Linear(output_hidden_size, config.vocab_size) + + # Initialize weights and apply final processing + self.post_init() + + def tie_weights(self): + """ + This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when + passing `target_lang=...` to `from_pretrained(...)`. + + This method is **not** supposed to be called by the user and is prone to be changed in the future. + """ + + # Note that `tie_weights` is usually used to tie input and output embedding weights. The method is re-purposed to + # correctly load adapter layers for SEW so that we do not have to introduce a new API to + # [`PreTrainedModel`]. While slightly hacky, SEW never has to tie input and output embeddings, so that it is + # ok to repurpose this function here. + target_lang = self.target_lang + + if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None: + raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.") + elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None: + logger.info("By default `target_lang` is set to 'eng'.") + elif target_lang is not None: + self.load_adapter(target_lang, force_load=True) + + def freeze_feature_extractor(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameter will + not be updated during training. + """ + warnings.warn( + "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. " + "Please use the equivalent `freeze_feature_encoder` method instead.", + FutureWarning, + ) + self.freeze_feature_encoder() + + def freeze_feature_encoder(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameter will + not be updated during training. + """ + self.sew.feature_extractor._freeze_parameters() + + def freeze_base_model(self): + """ + Calling this function will disable the gradient computation for the base model so that its parameters will not + be updated during training. Only the classification head will be updated. + """ + for param in self.sew.parameters(): + param.requires_grad = False + + @add_start_docstrings_to_model_forward(SEW_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=CausalLMOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_CTC_EXPECTED_OUTPUT, + expected_loss=_CTC_EXPECTED_LOSS, + ) + def forward( + self, + input_values: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: Optional[torch.Tensor] = None, + ) -> Union[Tuple, CausalLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*): + Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to + the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. + All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., + config.vocab_size - 1]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None and labels.max() >= self.config.vocab_size: + raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}") + + outputs = self.sew( + input_values, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + hidden_states = self.dropout(hidden_states) + + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # retrieve loss input_lengths from attention_mask + attention_mask = ( + attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long) + ) + input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long) + + # assuming that padded tokens are filled with -100 + # when not being attended to + labels_mask = labels >= 0 + target_lengths = labels_mask.sum(-1) + flattened_targets = labels.masked_select(labels_mask) + + # ctc_loss doesn't support fp16 + log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1) + + with torch.backends.cudnn.flags(enabled=False): + loss = nn.functional.ctc_loss( + log_probs, + flattened_targets, + input_lengths, + target_lengths, + blank=self.config.pad_token_id, + reduction=self.config.ctc_loss_reduction, + zero_infinity=self.config.ctc_zero_infinity, + ) + + if not return_dict: + output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutput( + loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions + ) + + +@add_start_docstrings( + """ + SEW Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like SUPERB + Keyword Spotting. + """, + SEW_START_DOCSTRING, +) +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification with Wav2Vec2->SEW, wav2vec2->sew, WAV2VEC2->SEW +class SEWForSequenceClassification(SEWPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + if hasattr(config, "add_adapter") and config.add_adapter: + raise ValueError( + "Sequence classification does not support the use of SEW adapters (config.add_adapter=True)" + ) + self.sew = SEWModel(config) + num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings + if config.use_weighted_layer_sum: + self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers) + self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size) + self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def freeze_feature_extractor(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameters will + not be updated during training. + """ + warnings.warn( + "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. " + "Please use the equivalent `freeze_feature_encoder` method instead.", + FutureWarning, + ) + self.freeze_feature_encoder() + + def freeze_feature_encoder(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameter will + not be updated during training. + """ + self.sew.feature_extractor._freeze_parameters() + + def freeze_base_model(self): + """ + Calling this function will disable the gradient computation for the base model so that its parameters will not + be updated during training. Only the classification head will be updated. + """ + for param in self.sew.parameters(): + param.requires_grad = False + + @add_start_docstrings_to_model_forward(SEW_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_SEQ_CLASS_CHECKPOINT, + output_type=SequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + modality="audio", + expected_output=_SEQ_CLASS_EXPECTED_OUTPUT, + expected_loss=_SEQ_CLASS_EXPECTED_LOSS, + ) + def forward( + self, + input_values: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: Optional[torch.Tensor] = None, + ) -> Union[Tuple, SequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states + + outputs = self.sew( + input_values, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if self.config.use_weighted_layer_sum: + hidden_states = outputs[_HIDDEN_STATES_START_POSITION] + hidden_states = torch.stack(hidden_states, dim=1) + norm_weights = nn.functional.softmax(self.layer_weights, dim=-1) + hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) + else: + hidden_states = outputs[0] + + hidden_states = self.projector(hidden_states) + if attention_mask is None: + pooled_output = hidden_states.mean(dim=1) + else: + padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask) + expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) + hidden_states[~expand_padding_mask] = 0.0 + pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1) + + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = ["SEWForCTC", "SEWForSequenceClassification", "SEWModel", "SEWPreTrainedModel"] diff --git a/docs/transformers/src/transformers/models/sew_d/__init__.py b/docs/transformers/src/transformers/models/sew_d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7902aa464a9cfd5b465c2448014e218f4540c9d3 --- /dev/null +++ b/docs/transformers/src/transformers/models/sew_d/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_sew_d import * + from .modeling_sew_d import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/sew_d/configuration_sew_d.py b/docs/transformers/src/transformers/models/sew_d/configuration_sew_d.py new file mode 100644 index 0000000000000000000000000000000000000000..f03a8f43c1a276d4ee3ee4e77c1c36ccfa7e075c --- /dev/null +++ b/docs/transformers/src/transformers/models/sew_d/configuration_sew_d.py @@ -0,0 +1,291 @@ +# coding=utf-8 +# Copyright 2021 ASAPP Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""SEW-D model configuration""" + +import functools +import operator + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class SEWDConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`SEWDModel`]. It is used to instantiate a SEW-D + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the SEW-D + [asapp/sew-d-tiny-100k](https://huggingface.co/asapp/sew-d-tiny-100k) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32): + Vocabulary size of the SEW-D model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`SEWD`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + squeeze_factor (`int`, *optional*, defaults to 2): + Sequence length downsampling factor after the encoder and upsampling factor after the transformer. + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + position_buckets (`int`, *optional*, defaults to 256): + The maximum size of relative position embeddings. + share_att_key (`bool`, *optional*, defaults to `True`): + Whether to share attention key with c2p and p2c. + relative_attention (`bool`, *optional*, defaults to `True`): + Whether to use relative position encoding. + pos_att_type (`Tuple[str]`, *optional*, defaults to `("p2c", "c2p")`): + The type of relative position attention, it can be a combination of `("p2c", "c2p")`, e.g. `("p2c")`, + `("p2c", "c2p")`, `("p2c", "c2p")`. + norm_rel_ebd (`str`, *optional*, defaults to `"layer_norm"`): + Whether to use layer norm in relative embedding (`"layer_norm"` if yes) + hidden_act (`str` or `function`, *optional*, defaults to `"gelu_python"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"`, `"gelu_python"` and `"gelu_new"` are supported. + hidden_dropout (`float`, *optional*, defaults to 0.1): + Deprecated. Not used by the model and will be removed in a future version. + activation_dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_dropout (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + final_dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for the final projection layer of [`SEWDForCTC`]. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-7): + The epsilon used by the layer normalization layers in the transformer encoder. + feature_layer_norm_eps (`float`, *optional*, defaults to 1e-5): + The epsilon used by the layer normalization after the feature encoder. + feat_extract_norm (`str`, *optional*, defaults to `"group"`): + The norm to be applied to 1D convolutional layers in feature encoder. One of `"group"` for group + normalization of only the first 1D convolutional layer or `"layer"` for layer normalization of all 1D + convolutional layers. + feat_proj_dropout (`float`, *optional*, defaults to 0.0): + The dropout probability for output of the feature encoder. + feat_extract_activation (`str, `optional`, defaults to `"gelu"`): + The non-linear activation function (function or string) in the 1D convolutional layers of the feature + extractor. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` are supported. + conv_dim (`Tuple[int]` or `List[int]`, *optional*, defaults to `(64, 128, 128, 128, 128, 256, 256, 256, 256, 512, 512, 512, 512)`): + A tuple of integers defining the number of input and output channels of each 1D convolutional layer in the + feature encoder. The length of *conv_dim* defines the number of 1D convolutional layers. + conv_stride (`Tuple[int]` or `List[int]`, *optional*, defaults to `(5, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1)`): + A tuple of integers defining the stride of each 1D convolutional layer in the feature encoder. The length + of *conv_stride* defines the number of convolutional layers and has to match the length of *conv_dim*. + conv_kernel (`Tuple[int]` or `List[int]`, *optional*, defaults to `(10, 3, 1, 3, 1, 3, 1, 3, 1, 2, 1, 2, 1)`): + A tuple of integers defining the kernel size of each 1D convolutional layer in the feature encoder. The + length of *conv_kernel* defines the number of convolutional layers and has to match the length of + *conv_dim*. + conv_bias (`bool`, *optional*, defaults to `False`): + Whether the 1D convolutional layers have a bias. + num_conv_pos_embeddings (`int`, *optional*, defaults to 128): + Number of convolutional positional embeddings. Defines the kernel size of 1D convolutional positional + embeddings layer. + num_conv_pos_embedding_groups (`int`, *optional*, defaults to 16): + Number of groups of 1D convolutional positional embeddings layer. + apply_spec_augment (`bool`, *optional*, defaults to `True`): + Whether to apply *SpecAugment* data augmentation to the outputs of the feature encoder. For reference see + [SpecAugment: A Simple Data Augmentation Method for Automatic Speech + Recognition](https://arxiv.org/abs/1904.08779). + mask_time_prob (`float`, *optional*, defaults to 0.05): + Percentage (between 0 and 1) of all feature vectors along the time axis which will be masked. The masking + procecure generates ''mask_time_prob*len(time_axis)/mask_time_length'' independent masks over the axis. If + reasoning from the propability of each feature vector to be chosen as the start of the vector span to be + masked, *mask_time_prob* should be `prob_vector_start*mask_time_length`. Note that overlap may decrease the + actual percentage of masked vectors. This is only relevant if `apply_spec_augment is True`. + mask_time_length (`int`, *optional*, defaults to 10): + Length of vector span along the time axis. + mask_time_min_masks (`int`, *optional*, defaults to 2),: + The minimum number of masks of length `mask_feature_length` generated along the time axis, each time step, + irrespectively of `mask_feature_prob`. Only relevant if ''mask_time_prob*len(time_axis)/mask_time_length < + mask_time_min_masks'' + mask_feature_prob (`float`, *optional*, defaults to 0.0): + Percentage (between 0 and 1) of all feature vectors along the feature axis which will be masked. The + masking procecure generates ''mask_feature_prob*len(feature_axis)/mask_time_length'' independent masks over + the axis. If reasoning from the propability of each feature vector to be chosen as the start of the vector + span to be masked, *mask_feature_prob* should be `prob_vector_start*mask_feature_length`. Note that overlap + may decrease the actual percentage of masked vectors. This is only relevant if `apply_spec_augment is + True`. + mask_feature_length (`int`, *optional*, defaults to 10): + Length of vector span along the feature axis. + mask_feature_min_masks (`int`, *optional*, defaults to 0),: + The minimum number of masks of length `mask_feature_length` generated along the feature axis, each time + step, irrespectively of `mask_feature_prob`. Only relevant if + ''mask_feature_prob*len(feature_axis)/mask_feature_length < mask_feature_min_masks'' + diversity_loss_weight (`int`, *optional*, defaults to 0.1): + The weight of the codebook diversity loss component. + ctc_loss_reduction (`str`, *optional*, defaults to `"sum"`): + Specifies the reduction to apply to the output of `torch.nn.CTCLoss`. Only relevant when training an + instance of [`SEWDForCTC`]. + ctc_zero_infinity (`bool`, *optional*, defaults to `False`): + Whether to zero infinite losses and the associated gradients of `torch.nn.CTCLoss`. Infinite losses mainly + occur when the inputs are too short to be aligned to the targets. Only relevant when training an instance + of [`SEWDForCTC`]. + use_weighted_layer_sum (`bool`, *optional*, defaults to `False`): + Whether to use a weighted average of layer outputs with learned weights. Only relevant when using an + instance of [`Wav2Vec2ForSequenceClassification`]. + classifier_proj_size (`int`, *optional*, defaults to 256): + Dimensionality of the projection before token mean-pooling for classification. + + Example: + + ```python + >>> from transformers import SEWDConfig, SEWDModel + + >>> # Initializing a SEW-D asapp/sew-d-tiny-100k style configuration + >>> configuration = SEWDConfig() + + >>> # Initializing a model (with random weights) from the asapp/sew-d-tiny-100k style configuration + >>> model = SEWDModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "sew-d" + + def __init__( + self, + vocab_size=32, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + squeeze_factor=2, + max_position_embeddings=512, + position_buckets=256, + share_att_key=True, + relative_attention=True, + pos_att_type=("p2c", "c2p"), + norm_rel_ebd="layer_norm", + hidden_act="gelu_python", + hidden_dropout=0.1, + activation_dropout=0.1, + attention_dropout=0.1, + feat_proj_dropout=0.0, + final_dropout=0.1, + initializer_range=0.02, + layer_norm_eps=1e-7, + feature_layer_norm_eps=1e-5, + feat_extract_norm="group", + feat_extract_activation="gelu", + conv_dim=(64, 128, 128, 128, 128, 256, 256, 256, 256, 512, 512, 512, 512), + conv_stride=(5, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1), + conv_kernel=(10, 3, 1, 3, 1, 3, 1, 3, 1, 2, 1, 2, 1), + conv_bias=False, + num_conv_pos_embeddings=128, + num_conv_pos_embedding_groups=16, + apply_spec_augment=True, + mask_time_prob=0.05, + mask_time_length=10, + mask_time_min_masks=2, + mask_feature_prob=0.0, + mask_feature_length=10, + mask_feature_min_masks=0, + ctc_loss_reduction="mean", + ctc_zero_infinity=False, + use_weighted_layer_sum=False, + classifier_proj_size=256, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + **kwargs, + ): + super().__init__(**kwargs, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id) + self.hidden_size = hidden_size + self.feat_extract_norm = feat_extract_norm + self.feat_extract_activation = feat_extract_activation + self.conv_dim = list(conv_dim) + self.conv_stride = list(conv_stride) + self.conv_kernel = list(conv_kernel) + self.conv_bias = conv_bias + self.num_conv_pos_embeddings = num_conv_pos_embeddings + self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups + self.num_feat_extract_layers = len(self.conv_dim) + self.num_hidden_layers = num_hidden_layers + self.intermediate_size = intermediate_size + self.squeeze_factor = squeeze_factor + self.max_position_embeddings = max_position_embeddings + self.position_buckets = position_buckets + self.share_att_key = share_att_key + self.relative_attention = relative_attention + self.norm_rel_ebd = norm_rel_ebd + self.pos_att_type = list(pos_att_type) + self.hidden_act = hidden_act + self.num_attention_heads = num_attention_heads + self._hidden_dropout = hidden_dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.feat_proj_dropout = feat_proj_dropout + self.final_dropout = final_dropout + self.layer_norm_eps = layer_norm_eps + self.feature_layer_norm_eps = feature_layer_norm_eps + self.initializer_range = initializer_range + self.vocab_size = vocab_size + + if ( + (len(self.conv_stride) != self.num_feat_extract_layers) + or (len(self.conv_kernel) != self.num_feat_extract_layers) + or (len(self.conv_dim) != self.num_feat_extract_layers) + ): + raise ValueError( + "Configuration for convolutional layers is incorrect. " + "It is required that `len(config.conv_dim)` == `len(config.conv_stride)` == `len(config.conv_kernel)`, " + f"but is `len(config.conv_dim) = {len(self.conv_dim)}`, `len(config.conv_stride) " + f"= {len(self.conv_stride)}`, `len(config.conv_kernel) = {len(self.conv_kernel)}`." + ) + + # fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779 + self.apply_spec_augment = apply_spec_augment + self.mask_time_prob = mask_time_prob + self.mask_time_length = mask_time_length + self.mask_time_min_masks = mask_time_min_masks + self.mask_feature_prob = mask_feature_prob + self.mask_feature_length = mask_feature_length + self.mask_feature_min_masks = mask_feature_min_masks + + # ctc loss + self.ctc_loss_reduction = ctc_loss_reduction + self.ctc_zero_infinity = ctc_zero_infinity + + # sequence classification + self.use_weighted_layer_sum = use_weighted_layer_sum + self.classifier_proj_size = classifier_proj_size + + @property + def inputs_to_logits_ratio(self): + return functools.reduce(operator.mul, self.conv_stride, 1) + + def to_dict(self): + """ + Serializes this instance to a Python dictionary. + """ + output = super().to_dict() + output["hidden_dropout"] = output.pop("_hidden_dropout") + return output + + +__all__ = ["SEWDConfig"] diff --git a/docs/transformers/src/transformers/models/sew_d/convert_sew_d_original_pytorch_checkpoint_to_pytorch.py b/docs/transformers/src/transformers/models/sew_d/convert_sew_d_original_pytorch_checkpoint_to_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..1540efa4be171a4741bf304435c66f495acd92e4 --- /dev/null +++ b/docs/transformers/src/transformers/models/sew_d/convert_sew_d_original_pytorch_checkpoint_to_pytorch.py @@ -0,0 +1,317 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert SEW checkpoint.""" + +import argparse +import json +import os + +import fairseq +import torch +from fairseq.data import Dictionary + +# Register SEW's fairseq modules +from sew_asapp import tasks # noqa: F401 + +from transformers import ( + SEWDConfig, + SEWDForCTC, + SEWDModel, + Wav2Vec2CTCTokenizer, + Wav2Vec2FeatureExtractor, + Wav2Vec2Processor, + logging, +) + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + +MAPPING = { + "post_extract_proj": "feature_projection", + "encoder.pos_conv.0": "encoder.pos_conv_embed.conv", + "attention.self.query_proj": "encoder.encoder.layer.*.attention.self.query_proj", + "attention.self.key_proj": "encoder.encoder.layer.*.attention.self.key_proj", + "attention.self.value_proj": "encoder.encoder.layer.*.attention.self.value_proj", + "attention.output.dense": "encoder.encoder.layer.*.attention.output.dense", + "attention.output.LayerNorm": "encoder.encoder.layer.*.attention.output.LayerNorm", + "intermediate.dense": "encoder.encoder.layer.*.intermediate.dense", + "output.dense": "encoder.encoder.layer.*.output.dense", + "output.LayerNorm": "encoder.encoder.layer.*.output.LayerNorm", + "encoder.encoder.rel_embeddings": "encoder.encoder.rel_embeddings", + "encoder.encoder.LayerNorm": "encoder.encoder.LayerNorm", + "encoder.upsample.0": "encoder.upsample.projection", + "encoder.layer_norm": "encoder.layer_norm", + "w2v_model.layer_norm": "layer_norm", + "w2v_encoder.proj": "lm_head", + "mask_emb": "masked_spec_embed", +} + + +def set_recursively(hf_pointer, key, value, full_name, weight_type): + for attribute in key.split("."): + hf_pointer = getattr(hf_pointer, attribute) + + if weight_type is not None: + hf_shape = getattr(hf_pointer, weight_type).shape + else: + hf_shape = hf_pointer.shape + + assert hf_shape == value.shape, ( + f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be" + f" {value.shape} for {full_name}" + ) + + if weight_type == "weight": + hf_pointer.weight.data = value + elif weight_type == "weight_g": + hf_pointer.weight_g.data = value + elif weight_type == "weight_v": + hf_pointer.weight_v.data = value + elif weight_type == "bias": + hf_pointer.bias.data = value + else: + hf_pointer.data = value + + logger.info(f"{key + '.' + weight_type if weight_type is not None else ''} was initialized from {full_name}.") + + +def recursively_load_weights(fairseq_model, hf_model, is_finetuned): + unused_weights = [] + fairseq_dict = fairseq_model.state_dict() + + feature_extractor = hf_model.sew_d.feature_extractor if is_finetuned else hf_model.feature_extractor + + for name, value in fairseq_dict.items(): + is_used = False + if "conv_layers" in name: + load_conv_layer( + name, + value, + feature_extractor, + unused_weights, + hf_model.config.feat_extract_norm == "group", + ) + is_used = True + else: + for key, mapped_key in MAPPING.items(): + mapped_key = "sew_d." + mapped_key if (is_finetuned and mapped_key != "lm_head") else mapped_key + + if key in name or key.split("w2v_model.")[-1] == name.split(".")[0]: + is_used = True + if "*" in mapped_key: + layer_index = name.split(key)[0].split(".")[-2] + if not layer_index.isnumeric(): + continue + mapped_key = mapped_key.replace("*", layer_index) + if "weight_g" in name: + weight_type = "weight_g" + elif "weight_v" in name: + weight_type = "weight_v" + elif "weight" in name: + weight_type = "weight" + elif "bias" in name: + weight_type = "bias" + else: + weight_type = None + set_recursively(hf_model, mapped_key, value, name, weight_type) + continue + if not is_used: + unused_weights.append(name) + + logger.warning(f"Unused weights: {unused_weights}") + + +def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_group_norm): + name = full_name.split("conv_layers.")[-1] + items = name.split(".") + layer_id = int(items[0]) + type_id = int(items[1]) + + if type_id == 0: + if "bias" in name: + assert value.shape == feature_extractor.conv_layers[layer_id].conv.bias.data.shape, ( + f"{full_name} has size {value.shape}, but" + f" {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found." + ) + feature_extractor.conv_layers[layer_id].conv.bias.data = value + logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.") + elif "weight" in name: + assert value.shape == feature_extractor.conv_layers[layer_id].conv.weight.data.shape, ( + f"{full_name} has size {value.shape}, but" + f" {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found." + ) + feature_extractor.conv_layers[layer_id].conv.weight.data = value + logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.") + elif (type_id == 2 and not use_group_norm) or (type_id == 2 and layer_id == 0 and use_group_norm): + if "bias" in name: + assert value.shape == feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape, ( + f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.bias.data.shape} was" + " found." + ) + feature_extractor.conv_layers[layer_id].layer_norm.bias.data = value + logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.") + elif "weight" in name: + assert value.shape == feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape, ( + f"{full_name} has size {value.shape}, but" + f" {feature_extractor[layer_id].layer_norm.weight.data.shape} was found." + ) + feature_extractor.conv_layers[layer_id].layer_norm.weight.data = value + logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.") + else: + unused_weights.append(full_name) + + +def convert_config(model, is_finetuned): + config = SEWDConfig() + if is_finetuned: + fs_config = model.w2v_encoder.w2v_model.cfg + else: + fs_config = model.cfg + + config.conv_bias = fs_config.conv_bias + conv_layers = eval(fs_config.conv_feature_layers) + config.conv_dim = [x[0] for x in conv_layers] + config.conv_kernel = [x[1] for x in conv_layers] + config.conv_stride = [x[2] for x in conv_layers] + config.feat_extract_activation = "gelu" + config.feat_extract_norm = "layer" if fs_config.extractor_mode == "layer_norm" else "group" + config.final_dropout = 0.0 + config.hidden_act = fs_config.activation_fn.name + config.hidden_size = fs_config.encoder_embed_dim + config.initializer_range = 0.02 + config.intermediate_size = fs_config.encoder_ffn_embed_dim + config.layer_norm_eps = 1e-5 + config.layerdrop = fs_config.encoder_layerdrop + config.num_attention_heads = fs_config.encoder_attention_heads + config.num_conv_pos_embedding_groups = fs_config.conv_pos_groups + config.num_conv_pos_embeddings = fs_config.conv_pos + config.num_feat_extract_layers = len(conv_layers) + config.num_hidden_layers = fs_config.encoder_layers + config.squeeze_factor = fs_config.squeeze_factor + # DeBERTa-specific parameters: + config.max_position_embeddings = fs_config.max_position_embeddings + config.position_buckets = fs_config.position_buckets + config.share_att_key = fs_config.share_att_key + config.relative_attention = fs_config.relative_attention + config.position_biased_input = fs_config.position_biased_input + config.pos_att_type = tuple(fs_config.pos_att_type.split("|")) + config.norm_rel_ebd = fs_config.norm_rel_ebd + + # take care of any params that are overridden by the Wav2VecCtc model + if is_finetuned: + fs_config = model.cfg + config.final_dropout = fs_config.final_dropout + config.layerdrop = fs_config.layerdrop + config.activation_dropout = fs_config.activation_dropout + config.apply_spec_augment = fs_config.mask_prob > 0 or fs_config.mask_channel_prob > 0 + config.attention_dropout = fs_config.attention_dropout + config.feat_proj_dropout = fs_config.dropout_input + config.hidden_dropout = fs_config.dropout + config.mask_feature_length = fs_config.mask_channel_length + config.mask_feature_prob = fs_config.mask_channel_prob + config.mask_time_length = fs_config.mask_length + config.mask_time_prob = fs_config.mask_prob + + config.feature_extractor_type = "Wav2Vec2FeatureExtractor" + config.tokenizer_class = "Wav2Vec2CTCTokenizer" + + return config + + +@torch.no_grad() +def convert_sew_checkpoint( + checkpoint_path, pytorch_dump_folder_path, config_path=None, dict_path=None, is_finetuned=True +): + """ + Copy/paste/tweak model's weights to transformers design. + """ + + if is_finetuned: + model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task( + [checkpoint_path], arg_overrides={"data": "/".join(dict_path.split("/")[:-1])} + ) + else: + model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([checkpoint_path]) + + if config_path is not None: + config = SEWDConfig.from_pretrained(config_path) + else: + config = convert_config(model[0], is_finetuned) + model = model[0].eval() + + return_attention_mask = True if config.feat_extract_norm == "layer" else False + feature_extractor = Wav2Vec2FeatureExtractor( + feature_size=1, + sampling_rate=16000, + padding_value=0, + do_normalize=True, + return_attention_mask=return_attention_mask, + ) + + if is_finetuned: + if dict_path: + target_dict = Dictionary.load(dict_path) + + # important change bos & pad token id since CTC symbol is and + # not as in fairseq + target_dict.indices[target_dict.bos_word] = target_dict.pad_index + target_dict.indices[target_dict.pad_word] = target_dict.bos_index + config.bos_token_id = target_dict.pad_index + config.pad_token_id = target_dict.bos_index + config.eos_token_id = target_dict.eos_index + config.vocab_size = len(target_dict.symbols) + vocab_path = os.path.join(pytorch_dump_folder_path, "vocab.json") + if not os.path.isdir(pytorch_dump_folder_path): + logger.error("--pytorch_dump_folder_path ({}) should be a directory".format(pytorch_dump_folder_path)) + return + os.makedirs(pytorch_dump_folder_path, exist_ok=True) + with open(vocab_path, "w", encoding="utf-8") as vocab_handle: + json.dump(target_dict.indices, vocab_handle) + tokenizer = Wav2Vec2CTCTokenizer( + vocab_path, + unk_token=target_dict.unk_word, + pad_token=target_dict.pad_word, + bos_token=target_dict.bos_word, + eos_token=target_dict.eos_word, + word_delimiter_token="|", + do_lower_case=False, + ) + processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer) + processor.save_pretrained(pytorch_dump_folder_path) + + hf_model = SEWDForCTC(config) + else: + hf_model = SEWDModel(config) + feature_extractor.save_pretrained(pytorch_dump_folder_path) + + recursively_load_weights(model, hf_model, is_finetuned) + + hf_model.save_pretrained(pytorch_dump_folder_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.") + parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to fairseq checkpoint") + parser.add_argument("--dict_path", default=None, type=str, help="Path to dict of fine-tuned model") + parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert") + parser.add_argument( + "--is_finetuned", action="store_true", help="Whether the model to convert is a fine-tuned model or not" + ) + args = parser.parse_args() + convert_sew_checkpoint( + args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path, args.dict_path, args.is_finetuned + ) diff --git a/docs/transformers/src/transformers/models/sew_d/modeling_sew_d.py b/docs/transformers/src/transformers/models/sew_d/modeling_sew_d.py new file mode 100644 index 0000000000000000000000000000000000000000..96c587fbb2eb4254c238cc7ee88a3b1126381c5a --- /dev/null +++ b/docs/transformers/src/transformers/models/sew_d/modeling_sew_d.py @@ -0,0 +1,1748 @@ +# coding=utf-8 +# Copyright 2021 ASAPP Inc. and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch SEW model.""" + +import math +import warnings +from collections.abc import Sequence +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss, LayerNorm + +from ...activations import ACT2FN +from ...integrations.deepspeed import is_deepspeed_zero3_enabled +from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import softmax_backward_data +from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging +from .configuration_sew_d import SEWDConfig + + +logger = logging.get_logger(__name__) + +_HIDDEN_STATES_START_POSITION = 1 + + +# General docstring +_CONFIG_FOR_DOC = "SEWDConfig" + +# Base docstring +_CHECKPOINT_FOR_DOC = "asapp/sew-d-tiny-100k-ft-ls100h" +_EXPECTED_OUTPUT_SHAPE = [1, 292, 384] + +# CTC docstring +_CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTIL OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'" +_CTC_EXPECTED_LOSS = 0.21 + +# Audio class docstring +_SEQ_CLASS_CHECKPOINT = "anton-l/sew-d-mid-400k-ft-keyword-spotting" +_SEQ_CLASS_EXPECTED_OUTPUT = "'_unknown_'" +_SEQ_CLASS_EXPECTED_LOSS = 3.16 + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices +def _compute_mask_indices( + shape: Tuple[int, int], + mask_prob: float, + mask_length: int, + attention_mask: Optional[torch.LongTensor] = None, + min_masks: int = 0, +) -> np.ndarray: + """ + Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for + ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on + CPU as part of the preprocessing during training. + + Args: + shape: The shape for which to compute masks. This should be of a tuple of size 2 where + the first element is the batch size and the second element is the length of the axis to span. + mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of + independently generated mask spans of length `mask_length` is computed by + `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the + actual percentage will be smaller. + mask_length: size of the mask + min_masks: minimum number of masked spans + attention_mask: A (right-padded) attention mask which independently shortens the feature axis of + each batch dimension. + """ + batch_size, sequence_length = shape + + if mask_length < 1: + raise ValueError("`mask_length` has to be bigger than 0.") + + if mask_length > sequence_length: + raise ValueError( + f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}" + f" and `sequence_length`: {sequence_length}`" + ) + + # epsilon is used for probabilistic rounding + epsilon = np.random.rand(1).item() + + def compute_num_masked_span(input_length): + """Given input length, compute how many spans should be masked""" + num_masked_span = int(mask_prob * input_length / mask_length + epsilon) + num_masked_span = max(num_masked_span, min_masks) + + # make sure num masked span <= sequence_length + if num_masked_span * mask_length > sequence_length: + num_masked_span = sequence_length // mask_length + + # make sure num_masked span is also <= input_length - (mask_length - 1) + if input_length - (mask_length - 1) < num_masked_span: + num_masked_span = max(input_length - (mask_length - 1), 0) + + return num_masked_span + + # compute number of masked spans in batch + input_lengths = ( + attention_mask.detach().sum(-1).tolist() + if attention_mask is not None + else [sequence_length for _ in range(batch_size)] + ) + + # SpecAugment mask to fill + spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool) + spec_aug_mask_idxs = [] + + max_num_masked_span = compute_num_masked_span(sequence_length) + + if max_num_masked_span == 0: + return spec_aug_mask + + for input_length in input_lengths: + # compute num of masked spans for this input + num_masked_span = compute_num_masked_span(input_length) + + # get random indices to mask + spec_aug_mask_idx = np.random.choice( + np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False + ) + + # pick first sampled index that will serve as a dummy index to pad vector + # to ensure same dimension for all batches due to probabilistic rounding + # Picking first sample just pads those vectors twice. + if len(spec_aug_mask_idx) == 0: + # this case can only happen if `input_length` is strictly smaller then + # `sequence_length` in which case the last token has to be a padding + # token which we can use as a dummy mask id + dummy_mask_idx = sequence_length - 1 + else: + dummy_mask_idx = spec_aug_mask_idx[0] + + spec_aug_mask_idx = np.concatenate( + [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx] + ) + spec_aug_mask_idxs.append(spec_aug_mask_idx) + + spec_aug_mask_idxs = np.array(spec_aug_mask_idxs) + + # expand masked indices to masked spans + spec_aug_mask_idxs = np.broadcast_to( + spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length) + ) + spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length) + + # add offset to the starting indexes so that indexes now create a span + offsets = np.arange(mask_length)[None, None, :] + offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape( + batch_size, max_num_masked_span * mask_length + ) + spec_aug_mask_idxs = spec_aug_mask_idxs + offsets + + # ensure that we cannot have indices larger than sequence_length + if spec_aug_mask_idxs.max() > sequence_length - 1: + spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1 + + # scatter indices to mask + np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1) + + return spec_aug_mask + + +def make_log_bucket_position(relative_pos, bucket_size, max_position): + sign = torch.sign(relative_pos) + mid = bucket_size // 2 + abs_pos = torch.where( + (relative_pos < mid) & (relative_pos > -mid), + torch.tensor(mid - 1).type_as(relative_pos), + torch.abs(relative_pos), + ) + log_pos = ( + torch.ceil(torch.log(abs_pos / mid) / torch.log(torch.tensor((max_position - 1) / mid)) * (mid - 1)) + mid + ) + bucket_pos = torch.where(abs_pos <= mid, relative_pos.type_as(log_pos), log_pos * sign) + return bucket_pos + + +def build_relative_position(query_size, key_size, bucket_size=-1, max_position=-1, device=None): + """ + Build relative position according to the query and key + + We assume the absolute position of query \\(P_q\\) is range from (0, query_size) and the absolute position of key + \\(P_k\\) is range from (0, key_size), The relative positions from query to key is \\(R_{q \\rightarrow k} = P_q - + P_k\\) + + Args: + query_size (int): the length of query + key_size (int): the length of key + bucket_size (int): the size of position bucket + max_position (int): the maximum allowed absolute position + device (`torch.device`): the device on which tensors will be created. + + Return: + `torch.LongTensor`: A tensor with shape [1, query_size, key_size] + """ + + q_ids = torch.arange(0, query_size, device=device) + k_ids = torch.arange(0, key_size, device=device) + rel_pos_ids = q_ids[:, None] - k_ids[None, :] + if bucket_size > 0 and max_position > 0: + rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position) + rel_pos_ids = rel_pos_ids.to(torch.long) + rel_pos_ids = rel_pos_ids[:query_size, :] + rel_pos_ids = rel_pos_ids.unsqueeze(0) + return rel_pos_ids + + +@torch.jit.script +# Copied from transformers.models.deberta.modeling_deberta.c2p_dynamic_expand +def c2p_dynamic_expand(c2p_pos, query_layer, relative_pos): + return c2p_pos.expand([query_layer.size(0), query_layer.size(1), query_layer.size(2), relative_pos.size(-1)]) + + +@torch.jit.script +# Copied from transformers.models.deberta.modeling_deberta.p2c_dynamic_expand +def p2c_dynamic_expand(c2p_pos, query_layer, key_layer): + return c2p_pos.expand([query_layer.size(0), query_layer.size(1), key_layer.size(-2), key_layer.size(-2)]) + + +@torch.jit.script +# Copied from transformers.models.deberta.modeling_deberta.pos_dynamic_expand +def pos_dynamic_expand(pos_index, p2c_att, key_layer): + return pos_index.expand(p2c_att.size()[:2] + (pos_index.size(-2), key_layer.size(-2))) + + +def get_mask(input, local_context): + if not isinstance(local_context, DropoutContext): + dropout = local_context + mask = None + else: + dropout = local_context.dropout + dropout *= local_context.scale + mask = local_context.mask if local_context.reuse_mask else None + + if dropout > 0 and mask is None: + mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).to(torch.bool) + + if isinstance(local_context, DropoutContext): + if local_context.mask is None: + local_context.mask = mask + + return mask, dropout + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2NoLayerNormConvLayer with Wav2Vec2->SEWD +class SEWDNoLayerNormConvLayer(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 + self.out_conv_dim = config.conv_dim[layer_id] + + self.conv = nn.Conv1d( + self.in_conv_dim, + self.out_conv_dim, + kernel_size=config.conv_kernel[layer_id], + stride=config.conv_stride[layer_id], + bias=config.conv_bias, + ) + self.activation = ACT2FN[config.feat_extract_activation] + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + hidden_states = self.activation(hidden_states) + return hidden_states + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2LayerNormConvLayer with Wav2Vec2->SEWD +class SEWDLayerNormConvLayer(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 + self.out_conv_dim = config.conv_dim[layer_id] + + self.conv = nn.Conv1d( + self.in_conv_dim, + self.out_conv_dim, + kernel_size=config.conv_kernel[layer_id], + stride=config.conv_stride[layer_id], + bias=config.conv_bias, + ) + self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True) + self.activation = ACT2FN[config.feat_extract_activation] + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + + hidden_states = hidden_states.transpose(-2, -1) + hidden_states = self.layer_norm(hidden_states) + hidden_states = hidden_states.transpose(-2, -1) + + hidden_states = self.activation(hidden_states) + return hidden_states + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GroupNormConvLayer with Wav2Vec2->SEWD +class SEWDGroupNormConvLayer(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 + self.out_conv_dim = config.conv_dim[layer_id] + + self.conv = nn.Conv1d( + self.in_conv_dim, + self.out_conv_dim, + kernel_size=config.conv_kernel[layer_id], + stride=config.conv_stride[layer_id], + bias=config.conv_bias, + ) + self.activation = ACT2FN[config.feat_extract_activation] + + self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True) + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + hidden_states = self.layer_norm(hidden_states) + hidden_states = self.activation(hidden_states) + return hidden_states + + +# Copied from transformers.models.sew.modeling_sew.SEWPositionalConvEmbedding with SEW->SEWD +class SEWDPositionalConvEmbedding(nn.Module): + def __init__(self, config): + super().__init__() + self.conv = nn.Conv1d( + config.hidden_size, + config.hidden_size, + kernel_size=config.num_conv_pos_embeddings, + padding=config.num_conv_pos_embeddings // 2, + groups=config.num_conv_pos_embedding_groups, + stride=config.squeeze_factor, + ) + + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + + if is_deepspeed_zero3_enabled(): + import deepspeed + + with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0): + self.conv = weight_norm(self.conv, name="weight", dim=2) + if hasattr(self.conv, "parametrizations"): + weight_g = self.conv.parametrizations.weight.original0 + weight_v = self.conv.parametrizations.weight.original1 + else: + weight_g = self.conv.weight_g + weight_v = self.conv.weight_v + deepspeed.zero.register_external_parameter(self, weight_v) + deepspeed.zero.register_external_parameter(self, weight_g) + else: + self.conv = weight_norm(self.conv, name="weight", dim=2) + + self.padding = SEWDSamePadLayer(config.num_conv_pos_embeddings) + self.activation = ACT2FN[config.feat_extract_activation] + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + hidden_states = self.padding(hidden_states) + hidden_states = self.activation(hidden_states) + + return hidden_states + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2SamePadLayer with Wav2Vec2->SEW +class SEWDSamePadLayer(nn.Module): + def __init__(self, num_conv_pos_embeddings): + super().__init__() + self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0 + + def forward(self, hidden_states): + if self.num_pad_remove > 0: + hidden_states = hidden_states[:, :, : -self.num_pad_remove] + return hidden_states + + +# Copied from transformers.models.sew.modeling_sew.SEWUpsampling with SEW->SEWD +class SEWDUpsampling(nn.Module): + def __init__(self, config): + super().__init__() + self.projection = nn.Linear(config.hidden_size, config.hidden_size * config.squeeze_factor) + self.activation = ACT2FN[config.feat_extract_activation] + self.squeeze_factor = config.squeeze_factor + + def forward(self, hidden_states): + hidden_states = self.projection(hidden_states) + hidden_states = self.activation(hidden_states) + + if self.squeeze_factor > 1: + # transform embedding channels to sequence length + bsz, src_len, src_embed_dim = hidden_states.size() + tgt_len = src_len * self.squeeze_factor + tgt_embed_dim = src_embed_dim // self.squeeze_factor + hidden_states = hidden_states.reshape(bsz, src_len, self.squeeze_factor, tgt_embed_dim) + hidden_states = hidden_states.reshape(bsz, tgt_len, tgt_embed_dim) + + return hidden_states + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureEncoder with Wav2Vec2->SEWD +class SEWDFeatureEncoder(nn.Module): + """Construct the features from raw audio waveform""" + + def __init__(self, config): + super().__init__() + + if config.feat_extract_norm == "group": + conv_layers = [SEWDGroupNormConvLayer(config, layer_id=0)] + [ + SEWDNoLayerNormConvLayer(config, layer_id=i + 1) for i in range(config.num_feat_extract_layers - 1) + ] + elif config.feat_extract_norm == "layer": + conv_layers = [SEWDLayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)] + else: + raise ValueError( + f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']" + ) + self.conv_layers = nn.ModuleList(conv_layers) + self.gradient_checkpointing = False + self._requires_grad = True + + def _freeze_parameters(self): + for param in self.parameters(): + param.requires_grad = False + self._requires_grad = False + + def forward(self, input_values): + hidden_states = input_values[:, None] + + # make sure hidden_states require grad for gradient_checkpointing + if self._requires_grad and self.training: + hidden_states.requires_grad = True + + for conv_layer in self.conv_layers: + if self._requires_grad and self.gradient_checkpointing and self.training: + hidden_states = self._gradient_checkpointing_func( + conv_layer.__call__, + hidden_states, + ) + else: + hidden_states = conv_layer(hidden_states) + + return hidden_states + + +class SEWDFeatureExtractor(SEWDFeatureEncoder): + def __init__(self, config): + super().__init__(config) + warnings.warn( + f"The class `{self.__class__.__name__}` has been depreciated " + "and will be removed in Transformers v5. " + f"Use `{self.__class__.__bases__[0].__name__}` instead.", + FutureWarning, + ) + + +class ContextPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.pooler_hidden_size, config.pooler_hidden_size) + self.dropout = StableDropout(config.pooler_dropout) + self.config = config + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + + context_token = hidden_states[:, 0] + context_token = self.dropout(context_token) + pooled_output = self.dense(context_token) + pooled_output = ACT2FN[self.config.pooler_hidden_act](pooled_output) + return pooled_output + + @property + def output_dim(self): + return self.config.hidden_size + + +class XSoftmax(torch.autograd.Function): + """ + Masked Softmax which is optimized for saving memory + + Args: + input (`torch.tensor`): The input tensor that will apply softmax. + mask (`torch.IntTensor`): + The mask matrix where 0 indicate that element will be ignored in the softmax calculation. + dim (int): The dimension that will apply softmax + + Example: + + ```python + >>> import torch + >>> from transformers.models.deberta_v2.modeling_deberta_v2 import XSoftmax + + >>> # Make a tensor + >>> x = torch.randn([4, 20, 100]) + + >>> # Create a mask + >>> mask = (x > 0).int() + + >>> # Specify the dimension to apply softmax + >>> dim = -1 + + >>> y = XSoftmax.apply(x, mask, dim) + ```""" + + @staticmethod + def forward(ctx, input, mask, dim): + ctx.dim = dim + rmask = ~(mask.to(torch.bool)) + + output = input.masked_fill(rmask, torch.tensor(torch.finfo(input.dtype).min)) + output = torch.softmax(output, ctx.dim) + output.masked_fill_(rmask, 0) + ctx.save_for_backward(output) + return output + + @staticmethod + def backward(ctx, grad_output): + (output,) = ctx.saved_tensors + inputGrad = softmax_backward_data(ctx, grad_output, output, ctx.dim, output) + return inputGrad, None, None + + @staticmethod + def symbolic(g, self, mask, dim): + import torch.onnx.symbolic_helper as sym_help + from torch.onnx.symbolic_opset9 import masked_fill, softmax + + mask_cast_value = g.op("Cast", mask, to_i=sym_help.cast_pytorch_to_onnx["Long"]) + r_mask = g.op( + "Cast", + g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value), + to_i=sym_help.cast_pytorch_to_onnx["Bool"], + ) + output = masked_fill( + g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.type().dtype()).min)) + ) + output = softmax(g, output, dim) + return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.bool))) + + +class DropoutContext: + def __init__(self): + self.dropout = 0 + self.mask = None + self.scale = 1 + self.reuse_mask = True + + +class XDropout(torch.autograd.Function): + """Optimized dropout function to save computation and memory by using mask operation instead of multiplication.""" + + @staticmethod + def forward(ctx, input, local_ctx): + mask, dropout = get_mask(input, local_ctx) + ctx.scale = 1.0 / (1 - dropout) + if dropout > 0: + ctx.save_for_backward(mask) + return input.masked_fill(mask, 0) * ctx.scale + else: + return input + + @staticmethod + def backward(ctx, grad_output): + if ctx.scale > 1: + (mask,) = ctx.saved_tensors + return grad_output.masked_fill(mask, 0) * ctx.scale, None + else: + return grad_output, None + + @staticmethod + def symbolic(g: torch._C.Graph, input: torch._C.Value, local_ctx: Union[float, DropoutContext]) -> torch._C.Value: + from torch.onnx import symbolic_opset12 + + dropout_p = local_ctx + if isinstance(local_ctx, DropoutContext): + dropout_p = local_ctx.dropout + # StableDropout only calls this function when training. + train = True + # TODO: We should check if the opset_version being used to export + # is > 12 here, but there's no good way to do that. As-is, if the + # opset_version < 12, export will fail with a CheckerError. + # Once https://github.com/pytorch/pytorch/issues/78391 is fixed, do something like: + # if opset_version < 12: + # return torch.onnx.symbolic_opset9.dropout(g, input, dropout_p, train) + return symbolic_opset12.dropout(g, input, dropout_p, train) + + +class StableDropout(nn.Module): + """ + Optimized dropout module for stabilizing the training + + Args: + drop_prob (float): the dropout probabilities + """ + + def __init__(self, drop_prob): + super().__init__() + self.drop_prob = drop_prob + self.count = 0 + self.context_stack = None + + def forward(self, x): + """ + Call the module + + Args: + x (`torch.tensor`): The input tensor to apply dropout + """ + if self.training and self.drop_prob > 0: + return XDropout.apply(x, self.get_context()) + return x + + def clear_context(self): + self.count = 0 + self.context_stack = None + + def init_context(self, reuse_mask=True, scale=1): + if self.context_stack is None: + self.context_stack = [] + self.count = 0 + for c in self.context_stack: + c.reuse_mask = reuse_mask + c.scale = scale + + def get_context(self): + if self.context_stack is not None: + if self.count >= len(self.context_stack): + self.context_stack.append(DropoutContext()) + ctx = self.context_stack[self.count] + ctx.dropout = self.drop_prob + self.count += 1 + return ctx + else: + return self.drop_prob + + +class SEWDSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps) + self.dropout = nn.Dropout(config.activation_dropout) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class DisentangledSelfAttention(nn.Module): + """ + Disentangled self-attention module + + Parameters: + config (`DebertaV2Config`): + A model config class instance with the configuration to build a new model. The schema is similar to + *BertConfig*, for more details, please refer [`DebertaV2Config`] + + """ + + def __init__(self, config): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + self.num_attention_heads = config.num_attention_heads + _attention_head_size = config.hidden_size // config.num_attention_heads + self.attention_head_size = getattr(config, "attention_head_size", _attention_head_size) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.query_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True) + self.key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True) + self.value_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True) + + self.share_att_key = getattr(config, "share_att_key", False) + self.pos_att_type = config.pos_att_type if config.pos_att_type is not None else [] + self.relative_attention = getattr(config, "relative_attention", False) + + if self.relative_attention: + self.position_buckets = getattr(config, "position_buckets", -1) + self.max_relative_positions = getattr(config, "max_relative_positions", -1) + if self.max_relative_positions < 1: + self.max_relative_positions = config.max_position_embeddings + self.pos_ebd_size = self.max_relative_positions + if self.position_buckets > 0: + self.pos_ebd_size = self.position_buckets + + self.pos_dropout = StableDropout(config.activation_dropout) + + if not self.share_att_key: + if "c2p" in self.pos_att_type: + self.pos_key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True) + if "p2c" in self.pos_att_type: + self.pos_query_proj = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = StableDropout(config.attention_dropout) + + def transpose_for_scores(self, x, attention_heads): + new_x_shape = x.size()[:-1] + (attention_heads, -1) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3).contiguous().view(-1, x.size(1), x.size(-1)) + + def forward( + self, + hidden_states, + attention_mask, + output_attentions=False, + query_states=None, + relative_pos=None, + rel_embeddings=None, + ): + """ + Call the module + + Args: + hidden_states (`torch.FloatTensor`): + Input states to the module usually the output from previous layer, it will be the Q,K and V in + *Attention(Q,K,V)* + + attention_mask (`torch.BoolTensor`): + An attention mask matrix of shape [*B*, *N*, *N*] where *B* is the batch size, *N* is the maximum + sequence length in which element [i,j] = *1* means the *i* th token in the input can attend to the *j* + th token. + + output_attentions (`bool`, *optional*): + Whether return the attention matrix. + + query_states (`torch.FloatTensor`, *optional*): + The *Q* state in *Attention(Q,K,V)*. + + relative_pos (`torch.LongTensor`): + The relative position encoding between the tokens in the sequence. It's of shape [*B*, *N*, *N*] with + values ranging in [*-max_relative_positions*, *max_relative_positions*]. + + rel_embeddings (`torch.FloatTensor`): + The embedding of relative distances. It's a tensor of shape [\\(2 \\times + \\text{max_relative_positions}\\), *hidden_size*]. + + + """ + if query_states is None: + query_states = hidden_states + query_layer = self.transpose_for_scores(self.query_proj(query_states), self.num_attention_heads) + key_layer = self.transpose_for_scores(self.key_proj(hidden_states), self.num_attention_heads) + value_layer = self.transpose_for_scores(self.value_proj(hidden_states), self.num_attention_heads) + + rel_att = None + # Take the dot product between "query" and "key" to get the raw attention scores. + scale_factor = 1 + if "c2p" in self.pos_att_type: + scale_factor += 1 + if "p2c" in self.pos_att_type: + scale_factor += 1 + scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor) + attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2) / scale.to(dtype=query_layer.dtype)) + if self.relative_attention: + rel_embeddings = self.pos_dropout(rel_embeddings) + rel_att = self.disentangled_attention_bias( + query_layer, key_layer, relative_pos, rel_embeddings, scale_factor + ) + + if rel_att is not None: + attention_scores = attention_scores + rel_att + attention_scores = attention_scores + attention_scores = attention_scores.view( + -1, self.num_attention_heads, attention_scores.size(-2), attention_scores.size(-1) + ) + + # bsz x height x length x dimension + attention_probs = XSoftmax.apply(attention_scores, attention_mask, -1) + attention_probs = self.dropout(attention_probs) + context_layer = torch.bmm( + attention_probs.view(-1, attention_probs.size(-2), attention_probs.size(-1)), value_layer + ) + context_layer = ( + context_layer.view(-1, self.num_attention_heads, context_layer.size(-2), context_layer.size(-1)) + .permute(0, 2, 1, 3) + .contiguous() + ) + new_context_layer_shape = context_layer.size()[:-2] + (-1,) + context_layer = context_layer.view(new_context_layer_shape) + if output_attentions: + return (context_layer, attention_probs) + else: + return context_layer + + def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_embeddings, scale_factor): + if relative_pos is None: + q = query_layer.size(-2) + relative_pos = build_relative_position( + q, + key_layer.size(-2), + bucket_size=self.position_buckets, + max_position=self.max_relative_positions, + device=query_layer.device, + ) + if relative_pos.dim() == 2: + relative_pos = relative_pos.unsqueeze(0).unsqueeze(0) + elif relative_pos.dim() == 3: + relative_pos = relative_pos.unsqueeze(1) + # bsz x height x query x key + elif relative_pos.dim() != 4: + raise ValueError(f"Relative position ids must be of dim 2 or 3 or 4. {relative_pos.dim()}") + + att_span = self.pos_ebd_size + relative_pos = relative_pos.to(device=query_layer.device, dtype=torch.long) + + rel_embeddings = rel_embeddings[0 : att_span * 2, :].unsqueeze(0) + if self.share_att_key: + pos_query_layer = self.transpose_for_scores( + self.query_proj(rel_embeddings), self.num_attention_heads + ).repeat(query_layer.size(0) // self.num_attention_heads, 1, 1) + pos_key_layer = self.transpose_for_scores(self.key_proj(rel_embeddings), self.num_attention_heads).repeat( + query_layer.size(0) // self.num_attention_heads, 1, 1 + ) + else: + if "c2p" in self.pos_att_type: + pos_key_layer = self.transpose_for_scores( + self.pos_key_proj(rel_embeddings), self.num_attention_heads + ).repeat(query_layer.size(0) // self.num_attention_heads, 1, 1) # .split(self.all_head_size, dim=-1) + if "p2c" in self.pos_att_type: + pos_query_layer = self.transpose_for_scores( + self.pos_query_proj(rel_embeddings), self.num_attention_heads + ).repeat(query_layer.size(0) // self.num_attention_heads, 1, 1) # .split(self.all_head_size, dim=-1) + + score = 0 + # content->position + if "c2p" in self.pos_att_type: + scale = torch.sqrt(torch.tensor(pos_key_layer.size(-1), dtype=torch.float) * scale_factor) + c2p_att = torch.bmm(query_layer, pos_key_layer.transpose(-1, -2)) + c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1) + c2p_att = torch.gather( + c2p_att, + dim=-1, + index=c2p_pos.squeeze(0).expand([query_layer.size(0), query_layer.size(1), relative_pos.size(-1)]), + ) + score += c2p_att / scale.to(dtype=c2p_att.dtype) + + # position->content + if "p2c" in self.pos_att_type: + scale = torch.sqrt(torch.tensor(pos_query_layer.size(-1), dtype=torch.float) * scale_factor) + if key_layer.size(-2) != query_layer.size(-2): + r_pos = build_relative_position( + key_layer.size(-2), + key_layer.size(-2), + bucket_size=self.position_buckets, + max_position=self.max_relative_positions, + device=query_layer.device, + ) + r_pos = r_pos.unsqueeze(0) + else: + r_pos = relative_pos + + p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1) + p2c_att = torch.bmm(key_layer, pos_query_layer.transpose(-1, -2)) + p2c_att = torch.gather( + p2c_att, + dim=-1, + index=p2c_pos.squeeze(0).expand([query_layer.size(0), key_layer.size(-2), key_layer.size(-2)]), + ).transpose(-1, -2) + score += p2c_att / scale.to(dtype=p2c_att.dtype) + + return score + + +class SEWDAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.self = DisentangledSelfAttention(config) + self.output = SEWDSelfOutput(config) + self.config = config + + def forward( + self, + hidden_states, + attention_mask, + output_attentions=False, + query_states=None, + relative_pos=None, + rel_embeddings=None, + ): + self_output = self.self( + hidden_states, + attention_mask, + output_attentions, + query_states=query_states, + relative_pos=relative_pos, + rel_embeddings=rel_embeddings, + ) + if output_attentions: + self_output, att_matrix = self_output + if query_states is None: + query_states = hidden_states + attention_output = self.output(self_output, query_states) + + if output_attentions: + return (attention_output, att_matrix) + else: + return attention_output + + +# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->SEWD +class SEWDIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class SEWDOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps) + self.dropout = nn.Dropout(config.activation_dropout) + self.config = config + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class SEWDLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.attention = SEWDAttention(config) + self.intermediate = SEWDIntermediate(config) + self.output = SEWDOutput(config) + + def forward( + self, + hidden_states, + attention_mask, + query_states=None, + relative_pos=None, + rel_embeddings=None, + output_attentions=False, + ): + attention_output = self.attention( + hidden_states, + attention_mask, + output_attentions=output_attentions, + query_states=query_states, + relative_pos=relative_pos, + rel_embeddings=rel_embeddings, + ) + if output_attentions: + attention_output, att_matrix = attention_output + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + if output_attentions: + return (layer_output, att_matrix) + else: + return layer_output + + +class ConvLayer(nn.Module): + def __init__(self, config): + super().__init__() + kernel_size = getattr(config, "conv_kernel_size", 3) + groups = getattr(config, "conv_groups", 1) + self.conv_act = getattr(config, "conv_act", "tanh") + self.conv = nn.Conv1d( + config.hidden_size, config.hidden_size, kernel_size, padding=(kernel_size - 1) // 2, groups=groups + ) + self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps) + self.dropout = StableDropout(config.hidden_dropout_prob) + self.config = config + + def forward(self, hidden_states, residual_states, input_mask): + out = self.conv(hidden_states.permute(0, 2, 1).contiguous()).permute(0, 2, 1).contiguous() + rmask = (1 - input_mask).bool() + out.masked_fill_(rmask.unsqueeze(-1).expand(out.size()), 0) + out = ACT2FN[self.conv_act](self.dropout(out)) + + layer_norm_input = residual_states + out + output = self.LayerNorm(layer_norm_input).to(layer_norm_input) + + if input_mask is None: + output_states = output + else: + if input_mask.dim() != layer_norm_input.dim(): + if input_mask.dim() == 4: + input_mask = input_mask.squeeze(1).squeeze(1) + input_mask = input_mask.unsqueeze(2) + + input_mask = input_mask.to(output.dtype) + output_states = output * input_mask + + return output_states + + +class SEWDTransformerEncoder(nn.Module): + """Modified BertEncoder with relative position bias support""" + + def __init__(self, config): + super().__init__() + + self.layer = nn.ModuleList([SEWDLayer(config) for _ in range(config.num_hidden_layers)]) + self.relative_attention = getattr(config, "relative_attention", False) + + if self.relative_attention: + self.max_relative_positions = getattr(config, "max_relative_positions", -1) + if self.max_relative_positions < 1: + self.max_relative_positions = config.max_position_embeddings + + self.position_buckets = getattr(config, "position_buckets", -1) + pos_ebd_size = self.max_relative_positions * 2 + + if self.position_buckets > 0: + pos_ebd_size = self.position_buckets * 2 + + self.rel_embeddings = nn.Embedding(pos_ebd_size, config.hidden_size) + + self.norm_rel_ebd = [x.strip() for x in getattr(config, "norm_rel_ebd", "none").lower().split("|")] + + if "layer_norm" in self.norm_rel_ebd: + self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine=True) + + self.conv = ConvLayer(config) if getattr(config, "conv_kernel_size", 0) > 0 else None + self.gradient_checkpointing = False + + def get_rel_embedding(self): + rel_embeddings = self.rel_embeddings.weight if self.relative_attention else None + if rel_embeddings is not None and ("layer_norm" in self.norm_rel_ebd): + rel_embeddings = self.LayerNorm(rel_embeddings) + return rel_embeddings + + def get_attention_mask(self, attention_mask): + if attention_mask.dim() <= 2: + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + attention_mask = extended_attention_mask * extended_attention_mask.squeeze(-2).unsqueeze(-1) + elif attention_mask.dim() == 3: + attention_mask = attention_mask.unsqueeze(1) + + return attention_mask + + def get_rel_pos(self, hidden_states, query_states=None, relative_pos=None): + if self.relative_attention and relative_pos is None: + q = query_states.size(-2) if query_states is not None else hidden_states.size(-2) + relative_pos = build_relative_position( + q, + hidden_states.size(-2), + bucket_size=self.position_buckets, + max_position=self.max_relative_positions, + device=hidden_states.device, + ) + return relative_pos + + def forward( + self, + hidden_states, + attention_mask, + output_hidden_states=True, + output_attentions=False, + query_states=None, + relative_pos=None, + return_dict=True, + ): + if attention_mask.dim() <= 2: + input_mask = attention_mask + else: + input_mask = attention_mask.sum(-2) > 0 + attention_mask = self.get_attention_mask(attention_mask) + relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos) + + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + if isinstance(hidden_states, Sequence): + next_kv = hidden_states[0] + else: + next_kv = hidden_states + rel_embeddings = self.get_rel_embedding() + output_states = next_kv + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (output_states,) + + if self.gradient_checkpointing and self.training: + output_states = self._gradient_checkpointing_func( + layer_module.__call__, + next_kv, + attention_mask, + query_states, + relative_pos, + rel_embeddings, + output_attentions, + ) + else: + output_states = layer_module( + next_kv, + attention_mask, + query_states=query_states, + relative_pos=relative_pos, + rel_embeddings=rel_embeddings, + output_attentions=output_attentions, + ) + + if output_attentions: + output_states, att_m = output_states + + if i == 0 and self.conv is not None: + output_states = self.conv(hidden_states, output_states, input_mask) + + if query_states is not None: + query_states = output_states + if isinstance(hidden_states, Sequence): + next_kv = hidden_states[i + 1] if i + 1 < len(self.layer) else None + else: + next_kv = output_states + + if output_attentions: + all_attentions = all_attentions + (att_m,) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (output_states,) + + if not return_dict: + return tuple(v for v in [output_states, all_hidden_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=output_states, hidden_states=all_hidden_states, attentions=all_attentions + ) + + +class SEWDEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.pos_conv_embed = SEWDPositionalConvEmbedding(config) + self.pool = nn.AvgPool1d(config.squeeze_factor, config.squeeze_factor) + self.encoder = SEWDTransformerEncoder(config) + self.upsample = SEWDUpsampling(config) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + max_encoder_length = hidden_states.shape[1] // self.config.squeeze_factor + if attention_mask is None: + attention_mask = torch.ones( + (hidden_states.shape[0], max_encoder_length), dtype=torch.long, device=hidden_states.device + ) + else: + # make sure padded tokens output 0 + expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) + hidden_states[~expand_attention_mask.bool()] = 0.0 + + input_lengths = (attention_mask.long()).sum(-1) + # apply pooling formula to get real output_lengths + output_lengths = input_lengths // self.config.squeeze_factor + attention_ids = ( + torch.arange(0, max_encoder_length, device=output_lengths.device) + .view(1, -1) + .expand(output_lengths.shape[0], -1) + ) + attention_mask = (attention_ids < output_lengths.view(-1, 1)).long() + + n_input_timesteps = hidden_states.shape[1] + + hidden_states = hidden_states.transpose(1, 2) + position_embeddings = self.pos_conv_embed(hidden_states) + pooled_hidden_states = self.pool(hidden_states) + min_length = min(position_embeddings.size(-1), pooled_hidden_states.size(-1)) + hidden_states = pooled_hidden_states[..., :min_length] + position_embeddings[..., :min_length] + hidden_states = hidden_states.transpose(1, 2) + + encoder_outputs = self.encoder(hidden_states, attention_mask, output_hidden_states, output_attentions) + + hidden_states = self.upsample(encoder_outputs.last_hidden_state) + if hidden_states.shape[1] < n_input_timesteps: + hidden_states = nn.functional.pad(hidden_states, (0, 0, 0, n_input_timesteps - hidden_states.shape[1])) + + if not return_dict: + return tuple( + v for v in [hidden_states, encoder_outputs.hidden_states, encoder_outputs.attentions] if v is not None + ) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class SEWDPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = SEWDConfig + base_model_prefix = "sew-d" + main_input_name = "input_values" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, SEWDPositionalConvEmbedding): + nn.init.normal_( + module.conv.weight, + mean=0, + std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)), + ) + nn.init.constant_(module.conv.bias, 0) + elif isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, nn.Conv1d): + if is_deepspeed_zero3_enabled(): + import deepspeed + + if hasattr(module, "weight_v") and hasattr(module, "weight_g"): + with deepspeed.zero.GatheredParameters([module.weight_v, module.weight_g], modifier_rank=0): + nn.init.kaiming_normal_(module.weight.data) + else: + with deepspeed.zero.GatheredParameters(module.weight, modifier_rank=0): + nn.init.kaiming_normal_(module.weight.data) + else: + nn.init.kaiming_normal_(module.weight.data) + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None: + module.bias.data.zero_() + + def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): + """ + Computes the output length of the convolutional layers + """ + + def _conv_out_length(input_length, kernel_size, stride): + # 1D convolutional layer output length formula taken + # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html + return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1 + + for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): + input_lengths = _conv_out_length(input_lengths, kernel_size, stride) + + return input_lengths + + def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor): + output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long) + batch_size = attention_mask.shape[0] + + attention_mask = torch.zeros( + (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device + ) + # these two operations makes sure that all values before the output lengths idxs are attended to + attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1 + attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() + return attention_mask + + +SEWD_START_DOCSTRING = r""" + SEW-D was proposed in [Performance-Efficiency Trade-offs in Unsupervised Pre-training for Speech + Recognition](https://arxiv.org/abs/2109.06870) by Felix Wu, Kwangyoun Kim, Jing Pan, Kyu Han, Kilian Q. Weinberger, + Yoav Artzi. + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving etc.). + + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use + it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`SEWDConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +SEWD_INPUTS_DOCSTRING = r""" + Args: + input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file + into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install + soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and + conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2Processor.__call__`] for details. + attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0, + 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare SEW-D Model transformer outputting raw hidden-states without any specific head on top.", + SEWD_START_DOCSTRING, +) +# Copied from transformers.models.sew.modeling_sew.SEWModel with SEW->SEWD, layer_norm_eps->feature_layer_norm_eps +class SEWDModel(SEWDPreTrainedModel): + def __init__(self, config: SEWDConfig): + super().__init__(config) + self.config = config + self.feature_extractor = SEWDFeatureEncoder(config) + self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.feature_layer_norm_eps) + + self.project_features = config.conv_dim[-1] != config.hidden_size + if self.project_features: + self.feature_projection = nn.Linear(config.conv_dim[-1], config.hidden_size) + self.feature_dropout = nn.Dropout(config.feat_proj_dropout) + + if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0: + self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_()) + + self.encoder = SEWDEncoder(config) + + # Initialize weights and apply final processing + self.post_init() + + # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states + def _mask_hidden_states( + self, + hidden_states: torch.FloatTensor, + mask_time_indices: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + ): + """ + Masks extracted features along time axis and/or along feature axis according to + [SpecAugment](https://arxiv.org/abs/1904.08779). + """ + + # `config.apply_spec_augment` can set masking to False + if not getattr(self.config, "apply_spec_augment", True): + return hidden_states + + # generate indices & apply SpecAugment along time axis + batch_size, sequence_length, hidden_size = hidden_states.size() + + if mask_time_indices is not None: + # apply SpecAugment along time axis with given mask_time_indices + hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) + elif self.config.mask_time_prob > 0 and self.training: + mask_time_indices = _compute_mask_indices( + (batch_size, sequence_length), + mask_prob=self.config.mask_time_prob, + mask_length=self.config.mask_time_length, + attention_mask=attention_mask, + min_masks=self.config.mask_time_min_masks, + ) + mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool) + hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) + + if self.config.mask_feature_prob > 0 and self.training: + # generate indices & apply SpecAugment along feature axis + mask_feature_indices = _compute_mask_indices( + (batch_size, hidden_size), + mask_prob=self.config.mask_feature_prob, + mask_length=self.config.mask_feature_length, + min_masks=self.config.mask_feature_min_masks, + ) + mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool) + mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1) + hidden_states[mask_feature_indices] = 0 + + return hidden_states + + @add_start_docstrings_to_model_forward(SEWD_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutput, + config_class=_CONFIG_FOR_DOC, + modality="audio", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward( + self, + input_values: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + mask_time_indices: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + extract_features = self.feature_extractor(input_values) + extract_features = extract_features.transpose(1, 2) + extract_features = self.layer_norm(extract_features) + + if self.project_features: + extract_features = self.feature_projection(extract_features) + hidden_states = self.feature_dropout(extract_features) + + if attention_mask is not None: + # compute reduced attention_mask corresponding to feature vectors + attention_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask) + + hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices) + + encoder_outputs = self.encoder( + hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = encoder_outputs[0] + + if not return_dict: + return (hidden_states,) + encoder_outputs[1:] + + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + """SEW-D Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""", + SEWD_START_DOCSTRING, +) +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->SEWD, wav2vec2->sew_d, WAV2VEC2->SEWD +class SEWDForCTC(SEWDPreTrainedModel): + def __init__(self, config, target_lang: Optional[str] = None): + super().__init__(config) + + self.sew_d = SEWDModel(config) + self.dropout = nn.Dropout(config.final_dropout) + + self.target_lang = target_lang + + if config.vocab_size is None: + raise ValueError( + f"You are trying to instantiate {self.__class__} with a configuration that " + "does not define the vocabulary size of the language model head. Please " + "instantiate the model as follows: `SEWDForCTC.from_pretrained(..., vocab_size=vocab_size)`. " + "or define `vocab_size` of your model's configuration." + ) + output_hidden_size = ( + config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size + ) + self.lm_head = nn.Linear(output_hidden_size, config.vocab_size) + + # Initialize weights and apply final processing + self.post_init() + + def tie_weights(self): + """ + This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when + passing `target_lang=...` to `from_pretrained(...)`. + + This method is **not** supposed to be called by the user and is prone to be changed in the future. + """ + + # Note that `tie_weights` is usually used to tie input and output embedding weights. The method is re-purposed to + # correctly load adapter layers for SEWD so that we do not have to introduce a new API to + # [`PreTrainedModel`]. While slightly hacky, SEWD never has to tie input and output embeddings, so that it is + # ok to repurpose this function here. + target_lang = self.target_lang + + if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None: + raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.") + elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None: + logger.info("By default `target_lang` is set to 'eng'.") + elif target_lang is not None: + self.load_adapter(target_lang, force_load=True) + + def freeze_feature_extractor(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameter will + not be updated during training. + """ + warnings.warn( + "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. " + "Please use the equivalent `freeze_feature_encoder` method instead.", + FutureWarning, + ) + self.freeze_feature_encoder() + + def freeze_feature_encoder(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameter will + not be updated during training. + """ + self.sew_d.feature_extractor._freeze_parameters() + + def freeze_base_model(self): + """ + Calling this function will disable the gradient computation for the base model so that its parameters will not + be updated during training. Only the classification head will be updated. + """ + for param in self.sew_d.parameters(): + param.requires_grad = False + + @add_start_docstrings_to_model_forward(SEWD_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=CausalLMOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_CTC_EXPECTED_OUTPUT, + expected_loss=_CTC_EXPECTED_LOSS, + ) + def forward( + self, + input_values: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: Optional[torch.Tensor] = None, + ) -> Union[Tuple, CausalLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*): + Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to + the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. + All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., + config.vocab_size - 1]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None and labels.max() >= self.config.vocab_size: + raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}") + + outputs = self.sew_d( + input_values, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + hidden_states = self.dropout(hidden_states) + + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # retrieve loss input_lengths from attention_mask + attention_mask = ( + attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long) + ) + input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long) + + # assuming that padded tokens are filled with -100 + # when not being attended to + labels_mask = labels >= 0 + target_lengths = labels_mask.sum(-1) + flattened_targets = labels.masked_select(labels_mask) + + # ctc_loss doesn't support fp16 + log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1) + + with torch.backends.cudnn.flags(enabled=False): + loss = nn.functional.ctc_loss( + log_probs, + flattened_targets, + input_lengths, + target_lengths, + blank=self.config.pad_token_id, + reduction=self.config.ctc_loss_reduction, + zero_infinity=self.config.ctc_zero_infinity, + ) + + if not return_dict: + output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutput( + loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions + ) + + +@add_start_docstrings( + """ + SEWD Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like SUPERB + Keyword Spotting. + """, + SEWD_START_DOCSTRING, +) +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification with Wav2Vec2->SEWD, wav2vec2->sew_d, WAV2VEC2->SEWD +class SEWDForSequenceClassification(SEWDPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + if hasattr(config, "add_adapter") and config.add_adapter: + raise ValueError( + "Sequence classification does not support the use of SEWD adapters (config.add_adapter=True)" + ) + self.sew_d = SEWDModel(config) + num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings + if config.use_weighted_layer_sum: + self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers) + self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size) + self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def freeze_feature_extractor(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameters will + not be updated during training. + """ + warnings.warn( + "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. " + "Please use the equivalent `freeze_feature_encoder` method instead.", + FutureWarning, + ) + self.freeze_feature_encoder() + + def freeze_feature_encoder(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameter will + not be updated during training. + """ + self.sew_d.feature_extractor._freeze_parameters() + + def freeze_base_model(self): + """ + Calling this function will disable the gradient computation for the base model so that its parameters will not + be updated during training. Only the classification head will be updated. + """ + for param in self.sew_d.parameters(): + param.requires_grad = False + + @add_start_docstrings_to_model_forward(SEWD_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_SEQ_CLASS_CHECKPOINT, + output_type=SequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + modality="audio", + expected_output=_SEQ_CLASS_EXPECTED_OUTPUT, + expected_loss=_SEQ_CLASS_EXPECTED_LOSS, + ) + def forward( + self, + input_values: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: Optional[torch.Tensor] = None, + ) -> Union[Tuple, SequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states + + outputs = self.sew_d( + input_values, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if self.config.use_weighted_layer_sum: + hidden_states = outputs[_HIDDEN_STATES_START_POSITION] + hidden_states = torch.stack(hidden_states, dim=1) + norm_weights = nn.functional.softmax(self.layer_weights, dim=-1) + hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) + else: + hidden_states = outputs[0] + + hidden_states = self.projector(hidden_states) + if attention_mask is None: + pooled_output = hidden_states.mean(dim=1) + else: + padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask) + expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) + hidden_states[~expand_padding_mask] = 0.0 + pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1) + + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = ["SEWDForCTC", "SEWDForSequenceClassification", "SEWDModel", "SEWDPreTrainedModel"] diff --git a/docs/transformers/src/transformers/models/shieldgemma2/__init__.py b/docs/transformers/src/transformers/models/shieldgemma2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3eaa894027352f1dd928e853bec4eef131acd8e4 --- /dev/null +++ b/docs/transformers/src/transformers/models/shieldgemma2/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_shieldgemma2 import * + from .modeling_shieldgemma2 import * + from .processing_shieldgemma2 import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/shieldgemma2/configuration_shieldgemma2.py b/docs/transformers/src/transformers/models/shieldgemma2/configuration_shieldgemma2.py new file mode 100644 index 0000000000000000000000000000000000000000..acb53924b2e5e2a2cc88c120eb6d8ccd416a9b67 --- /dev/null +++ b/docs/transformers/src/transformers/models/shieldgemma2/configuration_shieldgemma2.py @@ -0,0 +1,120 @@ +# coding=utf-8 +# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ..auto import CONFIG_MAPPING, AutoConfig + + +logger = logging.get_logger(__name__) + + +class ShieldGemma2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`ShieldGemma2ForImageClassification`]. It is used to instantiate an + ShieldGemma2ForImageClassification according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the shieldgemma-2-4b-it. + + e.g. [google/gemma-3-4b](https://huggingface.co/google/gemma-3-4b) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + text_config (`Union[ShieldGemma2TextConfig, dict]`, *optional*): + The config object of the text backbone. + vision_config (`Union[AutoConfig, dict]`, *optional*): + Custom vision config or dict. + mm_tokens_per_image (`int`, *optional*, defaults to 256): + The number of tokens per image embedding. + boi_token_index (`int`, *optional*, defaults to 255999): + The begin-of-image token index to wrap the image prompt. + eoi_token_index (`int`, *optional*, defaults to 256000): + The end-of-image token index to wrap the image prompt. + image_token_index (`int`, *optional*, defaults to 262144): + The image token index to encode the image prompt. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + + + Example: + + ```python + >>> from transformers import ShieldGemma2ForConditionalGeneration, ShieldGemma2Config, SiglipVisionConfig, ShieldGemma2TextConfig + + >>> # Initializing a Siglip-like vision config + >>> vision_config = SiglipVisionConfig() + + >>> # Initializing a ShieldGemma2 Text config + >>> text_config = ShieldGemma2TextConfig() + + >>> # Initializing a ShieldGemma2 gemma-3-4b style configuration + >>> configuration = ShieldGemma2Config(vision_config, text_config) + + >>> # Initializing a model from the gemma-3-4b style configuration + >>> model = ShieldGemma2TextConfig(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "shieldgemma2" + attribute_map = { + "image_token_id": "image_token_index", + "boi_token_id": "boi_token_index", + "eoi_token_id": "eoi_token_index", + } + sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig} + + def __init__( + self, + text_config=None, + vision_config=None, + mm_tokens_per_image: int = 256, + boi_token_index: int = 255_999, + eoi_token_index: int = 256_000, + image_token_index: int = 262_144, + initializer_range: float = 0.02, + **kwargs, + ): + if isinstance(vision_config, dict): + vision_config["model_type"] = ( + vision_config["model_type"] if "model_type" in vision_config else "siglip_vision_model" + ) + vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config) + elif vision_config is None: + vision_config = CONFIG_MAPPING["siglip_vision_model"]() + + self.vision_config = vision_config + + if isinstance(text_config, dict): + text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "gemma3_text" + text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) + elif text_config is None: + text_config = CONFIG_MAPPING["gemma3_text"]() + + self.text_config = text_config + self.vision_config = vision_config + self.mm_tokens_per_image = mm_tokens_per_image + self.boi_token_index = boi_token_index + self.eoi_token_index = eoi_token_index + self.image_token_index = image_token_index + self.initializer_range = initializer_range + + super().__init__(**kwargs) + + +__all__ = ["ShieldGemma2Config"] diff --git a/docs/transformers/src/transformers/models/shieldgemma2/convert_shieldgemma2_weights_orbax_to_hf.py b/docs/transformers/src/transformers/models/shieldgemma2/convert_shieldgemma2_weights_orbax_to_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..0ca46eb29e2de38ee06e339fe91c0ef22caf2f1d --- /dev/null +++ b/docs/transformers/src/transformers/models/shieldgemma2/convert_shieldgemma2_weights_orbax_to_hf.py @@ -0,0 +1,470 @@ +r"""Utility to convert Gemma models from Orbax to HF Transformers checkpoint. + +python -m transformers.models.shieldgemma2.convert_shieldgemma2_weights_orbax_to_hf \ + --tokenizer_path="$HOME/gemma3/tokenizer/gemma3_cleaned_262144_v2.spiece.model" \ + --checkpoint_path_gemma="$HOME/gemma3/gemma3_4b_pt_orbax/" \ + --checkpoint_path_shieldgemma="$HOME/shieldgemma2/shieldgemma-2_4b_orbax/" \ + --output_path="$HOME/shieldgemma2/shieldgemma2_4b_pt_safetensors/" \ + --precision='bfloat16' +""" + +import dataclasses +from collections.abc import Iterator, Mapping, Sequence +from typing import Any + +import accelerate +import numpy as np +import torch +import tree +from absl import app, flags, logging +from orbax import checkpoint as obc + +from ...image_utils import PILImageResampling +from ..gemma import GemmaTokenizerFast +from ..gemma3 import Gemma3ImageProcessor, Gemma3TextConfig +from ..siglip import SiglipVisionConfig +from . import ShieldGemma2Config, ShieldGemma2ForImageClassification, ShieldGemma2Processor + + +# ==== Internal Constants and Classes ==== + +_CHAT_TEMPLATE = """{{ bos_token }} +{%- for message in messages -%} + {%- if message['content'] is iterable -%} + {%- for item in message['content'] -%} + {{ 'user\n' }} + {%- if item['type'] == 'image' -%} + {{ '' }} + {%- elif item['type'] == 'text' -%} + {{ 'You are a policy expert trying to help determine whether a generated or real image is in' }} + {{ ' violation of the safety policy below. Your decision must very clearly follow the specific'}} + {{ ' language used in the safety policy and should not include any interpretation outside of the' }} + {{ ' safety policy.\n\n' }} + {{ '\n' }} + {{ item['text'] + '\n' }} + {{ '\n\n' }} + {{ 'Does the image violate the above policy? Your Answer must start with \\'Yes\\' or \\'No\\'.' }} + {{ '\n' }} + {%- endif -%} + {%- endfor -%} + {{'model\n'}} + {%- else -%} + {{ raise_exception("Conversation messages must contain iterable content containing images and policy definitions in text.") }} + {%- endif -%} +{%- endfor -%} +""" + +_DTYPES = { + "float32": torch.float32, + "bfloat16": torch.bfloat16, + "float16": torch.float16, +} + +_SHIELDGEMMA2_POLICIES: Mapping[str, str] = { + "dangerous": ( + "No Dangerous Content: The image shall not contain content that" + " facilitates or encourages activities that could cause real-world harm" + " (e.g., building firearms and explosive devices, promotion of" + " terrorism, instructions for suicide)." + ), + "sexual": ( + "No Sexually Explicit content: The image shall not contain content that" + " depicts explicit or graphic sexual acts (e.g., pornography, erotic" + " nudity, depictions of rape or sexual assault)." + ), + "violence": ( + "No Violence/Gore content: The image shall not contain content that" + " depicts shocking, sensational, or gratuitous violence (e.g.," + " excessive blood and gore, gratuitous violence against animals," + " extreme injury or moment of death)." + ), +} + +_SIGLIP_BASE = "SigLiPFromPatches_0/siglip_encoder" +_SIGLIP_EMBEDDING = "SigLiPFromPatches_0/siglip_encoder/embedding" +_SIGLIP_TRANSFORMER_ENCODER_BLOCK = "SigLiPFromPatches_0/siglip_encoder/Transformer/encoderblock_" +_SIGLIP_TRANSFORMER_ENCODER_BLOCK_LEN = len(_SIGLIP_TRANSFORMER_ENCODER_BLOCK) +_SIGLIP_TRANSFORMER_ENCODER_NORM = "SigLiPFromPatches_0/siglip_encoder/Transformer/encoder_norm" + +_TRANSFORMER_DECODER_BLOCK = "transformer/layer_" +_TRANSFORMER_DECODER_BLOCK_LEN = len(_TRANSFORMER_DECODER_BLOCK) +_TRANSFORMER_EMBEDDER = "transformer/embedder" +_TRANSFORMER_FINAL_NORM = "transformer/final_norm" +_TRANSFORMER_POST_TRAINING_PREFIX = "rlx_networks/policy_network/" +_TRANSFORMER_POST_TRAINING_PREFIX_LEN = len(_TRANSFORMER_POST_TRAINING_PREFIX) + +# ==== Flags ==== + +_GEMMA_CHECKPOINT_PATH = flags.DEFINE_string( + name="checkpoint_path_gemma", + default=None, + help="Path to the Orbax checkpoint containing the vision weights.", + required=True, +) + +_SHIELDGEMMA_CHECKPOINT_PATH = flags.DEFINE_string( + name="checkpoint_path_shieldgemma", + default=None, + help="Path to the Orbax checkpoint containing the language model weights.", + required=True, +) + +OUTPUT_PATH = flags.DEFINE_string( + name="output_path", + default=None, + help="Path to store the HF checkpoint.", + required=True, +) + +PRECISION = flags.DEFINE_enum( + name="precision", + default=None, + help="The floating point precision (aka dtype) of the model.", + enum_values=set(_DTYPES.keys()), + required=True, +) + +TOKENIZER_PATH = flags.DEFINE_string( + name="tokenizer_path", + default=None, + help="Path to the SentencePiece model file.", + required=True, +) + + +def convert_siglip_weight( + config: SiglipVisionConfig, + paths: Sequence[str], + weights: np.ndarray, +) -> tuple[str, np.ndarray]: + path, prop = paths + normalized_path: str = "" + updated_weights: np.ndarray = None + + if path == _SIGLIP_BASE: + normalized_path = "vision_tower.vision_model.embeddings.position_embedding.weight" + updated_weights = weights.reshape(-1, config.hidden_size) + elif path == _SIGLIP_EMBEDDING: + if prop == "kernel": + normalized_path = "vision_tower.vision_model.embeddings.patch_embedding.weight" + updated_weights = weights.transpose(3, 2, 0, 1) + elif prop == "bias": + normalized_path = "vision_tower.vision_model.embeddings.patch_embedding.bias" + updated_weights = weights + else: + raise ValueError(f"Unexpected member, `{prop}`, for path `{path}`. Should be `bias` or `kernel`.") + elif path.startswith(_SIGLIP_TRANSFORMER_ENCODER_BLOCK): + encoder_block_path = path[_SIGLIP_TRANSFORMER_ENCODER_BLOCK_LEN:] + next_path_seperator_idx = encoder_block_path.find("/") + layer_idx = encoder_block_path[:next_path_seperator_idx] + encoder_block_path = encoder_block_path[next_path_seperator_idx:] + normalized_path = f"vision_tower.vision_model.encoder.layers.{layer_idx}" + + if encoder_block_path.startswith("/LayerNorm"): + normalized_path += ".layer_norm1" if path.endswith("_0") else ".layer_norm2" + + if prop == "scale": + normalized_path += ".weight" + updated_weights = weights.transpose() + elif prop == "bias": + normalized_path += ".bias" + updated_weights = weights + else: + raise ValueError(f"Unexpected member, `{prop}`, for path `{path}`. Should be `bias` or `scale`.") + elif encoder_block_path.startswith("/MlpBlock_0"): + normalized_path += ".mlp.fc1" if "/Dense_0" in encoder_block_path else ".mlp.fc2" + + if prop == "kernel": + normalized_path += ".weight" + updated_weights = weights.transpose() + elif prop == "bias": + normalized_path += ".bias" + updated_weights = weights + else: + raise ValueError(f"Unexpected member, `{prop}`, for path `{path}`. Should be `bias` or `kernel`.") + elif encoder_block_path.startswith("/MultiHeadDotProductAttention_0"): + if encoder_block_path.endswith("/key"): + normalized_path += ".self_attn.k_proj" + elif encoder_block_path.endswith("/out"): + normalized_path += ".self_attn.out_proj" + elif encoder_block_path.endswith("/query"): + normalized_path += ".self_attn.q_proj" + elif encoder_block_path.endswith("/value"): + normalized_path += ".self_attn.v_proj" + else: + raise ValueError(f"Unexpected path `{path}` in SigLIP Transformer MultiHeadDotProductAttention_0.") + + if prop == "bias": + normalized_path += ".bias" + updated_weights = weights.reshape(-1, config.hidden_size).reshape(-1) + elif prop == "kernel": + normalized_path += ".weight" + updated_weights = weights.reshape(-1, config.hidden_size).transpose() + else: + raise ValueError(f"Unexpected member, `{prop}`, for path `{path}`. Should be `bias` or `kernel`.") + else: + raise ValueError(f"Unexpected path `{path}` in SigLIP Transformer Encoder Block.") + elif path == _SIGLIP_TRANSFORMER_ENCODER_NORM: + if prop == "scale": + normalized_path = "vision_tower.vision_model.post_layernorm.weight" + updated_weights = weights.transpose() + elif prop == "bias": + normalized_path = "vision_tower.vision_model.post_layernorm.bias" + updated_weights = weights + else: + raise ValueError(f"Unexpected member, `{prop}`, for path `{path}`. Should be `bias` or `scale`.") + else: + raise ValueError(f"Unexpected path `{path}`.") + + if "vision" in normalized_path: + print(normalized_path) + return normalized_path, updated_weights + + +def convert_transformer_weights( + config: Gemma3TextConfig, + paths: Sequence[str], + weights: np.ndarray, +) -> Iterator[tuple[str, np.ndarray]]: + path, prop = paths + + if path.startswith(_TRANSFORMER_POST_TRAINING_PREFIX): + path = path[_TRANSFORMER_POST_TRAINING_PREFIX_LEN:] + + converted_paths: list[str] = [] + converted_weights: list[Any] = [] + + attn_head_dim = config.num_attention_heads * config.head_dim + kv_head_dim = config.num_key_value_heads * config.head_dim + + if path == _TRANSFORMER_EMBEDDER: + if prop == "input_embedding": + # Tied to language_model.lm_head.weight, assigned at the end. + converted_paths = ["language_model.model.embed_tokens.weight"] + # Gemma3 model doesn't have image soft token in input and output embeddings, resize to avoid bugs we had with Mllama + pre_expansion_embeddings = weights + mu = np.mean(pre_expansion_embeddings, axis=0) + sigma = np.cov(pre_expansion_embeddings, rowvar=False, bias=True) + new_embeddings = np.random.multivariate_normal(mu, 1e-5 * sigma, size=64) + weights = np.vstack([pre_expansion_embeddings, new_embeddings]) + converted_weights = [weights] + else: + raise ValueError(f"Unexpected member, {prop}, in Embedder.") + elif path.startswith(f"{_TRANSFORMER_EMBEDDER}/mm"): + if path.endswith("/mm_input_projection"): + converted_paths = ["multi_modal_projector.mm_input_projection_weight"] + converted_weights = [weights] + elif path.endswith("/mm_soft_embedding_norm"): + converted_paths = ["multi_modal_projector.mm_soft_emb_norm.weight"] + converted_weights = [weights] + else: + raise ValueError(f"Unexpected subpath, `{path}`, in Embedder.") + elif path == _TRANSFORMER_FINAL_NORM: + converted_paths = ["language_model.model.norm.weight"] + converted_weights = [weights] + elif path.startswith(_TRANSFORMER_DECODER_BLOCK): + decoder_block_path = path[_TRANSFORMER_DECODER_BLOCK_LEN:] + next_path_seperator_idx = decoder_block_path.find("/") + layer_idx = decoder_block_path[:next_path_seperator_idx] + decoder_block_path = decoder_block_path[next_path_seperator_idx:] + + base_path = f"language_model.model.layers.{layer_idx}" + + if path.endswith("attn/attn_vec_einsum"): + converted_paths = [f"{base_path}.self_attn.o_proj.weight"] + converted_weights = [weights.transpose(2, 0, 1).reshape(config.hidden_size, attn_head_dim)] + elif path.endswith("attn/_key_norm"): + converted_paths = [f"{base_path}.self_attn.k_norm.weight"] + converted_weights = [weights] + elif path.endswith("attn/kv_einsum"): + converted_paths = [ + f"{base_path}.self_attn.k_proj.weight", + f"{base_path}.self_attn.v_proj.weight", + ] + k_proj_weights, v_proj_weights = weights + converted_weights = [ + k_proj_weights.transpose(0, 2, 1).reshape(kv_head_dim, config.hidden_size), + v_proj_weights.transpose(0, 2, 1).reshape(kv_head_dim, config.hidden_size), + ] + elif path.endswith("attn/q_einsum"): + converted_paths = [f"{base_path}.self_attn.q_proj.weight"] + converted_weights = [weights.transpose(0, 2, 1).reshape(attn_head_dim, config.hidden_size)] + elif path.endswith("attn/_query_norm"): + converted_paths = [f"{base_path}.self_attn.q_norm.weight"] + converted_weights = [weights] + elif path.endswith("mlp/gating_einsum"): + converted_paths = [ + f"{base_path}.mlp.gate_proj.weight", + f"{base_path}.mlp.up_proj.weight", + ] + gate_proj_weight, up_proj_weight = weights + converted_weights = [gate_proj_weight, up_proj_weight] + elif path.endswith("mlp/linear"): + converted_paths = [f"{base_path}.mlp.down_proj.weight"] + converted_weights = [weights.transpose()] + elif path.endswith("post_attention_norm"): + converted_paths = [f"{base_path}.post_attention_layernorm.weight"] + converted_weights = [weights] + elif path.endswith("post_ffw_norm"): + converted_paths = [f"{base_path}.post_feedforward_layernorm.weight"] + converted_weights = [weights] + elif path.endswith("pre_attention_norm"): + converted_paths = [f"{base_path}.input_layernorm.weight"] + converted_weights = [weights] + elif path.endswith("pre_ffw_norm"): + converted_paths = [f"{base_path}.pre_feedforward_layernorm.weight"] + converted_weights = [weights] + else: + raise ValueError(f"Unexpected path `{path}` in Decoder Block.") + else: + raise ValueError(f"Unexpected path `{path}`.") + + if (cpl := len(converted_paths)) != (cwl := len(converted_weights)): + raise ValueError( + "The `converted_paths` and `converted_weights` should be the same " + f"length. Got {cpl} and {cwl}, respectively, for {path}." + ) + + return zip(converted_paths, converted_weights) + + +def transpose_reshape(x: torch.Tensor) -> torch.Tensor: + x = x.transpose(1, 2) + return x.reshape(x.shape[0] * x.shape[1], x.shape[2]).contiguous() + + +@dataclasses.dataclass(frozen=True) +class ConversionResult: + state_tree: dict[str, torch.Tensor] + config: ShieldGemma2Config + + +def convert( + shieldgemma_checkpoint_path: str, + gemma_checkpoint_path: str, + config: ShieldGemma2Config, + target_dtype: torch.dtype, +) -> ConversionResult: + """Loads Orbax checkpoint from `input_path` and converts it to HF tree.""" + checkpointer = obc.PyTreeCheckpointer() + + sg2_ckpt = checkpointer.restore(shieldgemma_checkpoint_path) + g3_ckpt = checkpointer.restore(gemma_checkpoint_path) + + hf_tree: dict[str, torch.Tensor] = {} + + def update_tree(path: str, weights: np.ndarray) -> None: + torch_tensor = torch.from_numpy(weights.astype("float32")).type(target_dtype) + logging.info( + "%s converted shape=%s with dtype=%s", + path, + weights.shape, + torch_tensor.dtype, + ) + hf_tree[f"model.{path}"] = torch_tensor + + for paths, value in tree.flatten_with_path(g3_ckpt): + if paths[0].startswith("SigLiPFromPatches_"): + path, weights = convert_siglip_weight(config=config.vision_config, paths=paths, weights=value) + update_tree(path, weights) + + for paths, value in tree.flatten_with_path(sg2_ckpt): + for path, weights in convert_transformer_weights(config=config.text_config, paths=paths, weights=value): + update_tree(path, weights) + + hf_tree["model.language_model.lm_head.weight"] = hf_tree["model.language_model.model.embed_tokens.weight"] + + return ConversionResult(state_tree=hf_tree, config=config) + + +def main(*args): + del args + + dtype = getattr(torch, PRECISION.value) + output_path = OUTPUT_PATH.value + + tokenizer = GemmaTokenizerFast( + TOKENIZER_PATH.value, + extra_special_tokens={ + "image_token": "", # Should be ID=262_144 + "boi_token": "", # Should be ID=255_999 + "eoi_token": "", # Should be ID=256_000 + }, + ) + + yes_token_index, no_token_index = torch.tensor(tokenizer(["Yes", "No"])["input_ids"])[:, 1].numpy() + + config = ShieldGemma2Config( + yes_token_index=int(yes_token_index), + no_token_index=int(no_token_index), + text_config=Gemma3TextConfig( + vocab_size=262_208, + hidden_size=2560, + intermediate_size=2560 * 8 // 2, + num_attention_heads=8, + head_dim=256, + num_hidden_layers=34, + num_key_value_heads=4, + sliding_window=1024, + rope_scaling={"rope_type": "linear", "factor": 8.0}, # used for global RoPE only + rope_theta=1_000_000, + rope_local_base_freq=10_000, + attn_logit_softcapping=None, + query_pre_attn_scalar=256, + max_position_embeddings=8192, + ), + vision_config={ + "hidden_size": 1152, + "intermediate_size": 4304, + "num_hidden_layers": 27, + "num_attention_heads": 16, + "num_channels": 3, + "image_size": 896, + "patch_size": 14, + "hidden_act": "gelu_pytorch_tanh", + "layer_norm_eps": 1e-6, + "attention_dropout": 0.0, + "vision_use_head": False, + }, + ) + + config.save_pretrained(output_path) + + image_processor = Gemma3ImageProcessor( + image_seq_length=256, + image_mean=(0.5,) * 3, + image_std=(0.5,) * 3, + size={"height": 896, "width": 896}, + resample=PILImageResampling.BILINEAR, + ) + processor = ShieldGemma2Processor( + image_processor=image_processor, + tokenizer=tokenizer, + policy_definitions=_SHIELDGEMMA2_POLICIES, + ) + tokenizer.chat_template = _CHAT_TEMPLATE + processor.chat_template = _CHAT_TEMPLATE + + processor.save_pretrained(output_path) + logging.info("Saved Shieldgemma2Processor to %s", output_path) + del processor + del tokenizer + + logging.info("Converting Shieldgemma2 @ %s", dtype) + result = convert(_SHIELDGEMMA_CHECKPOINT_PATH.value, _GEMMA_CHECKPOINT_PATH.value, config, dtype) + logging.info("Converted Shieldgemma2 state tree from Orbax to Hugging Face.") + + with accelerate.init_empty_weights(): + model = ShieldGemma2ForImageClassification(config=config) + + model.load_state_dict(result.state_tree, assign=True, strict=True) + model.config.torch_dtype = dtype + logging.info("Loaded Shieldgemma2 in Hugging Face Transformers.") + model.save_pretrained(output_path, safe_serialization=True) + logging.info("Saved Shieldgemma2 to SafeTensors in %s", output_path) + del model + del result + + +if __name__ == "__main__": + app.run(main) diff --git a/docs/transformers/src/transformers/models/shieldgemma2/modeling_shieldgemma2.py b/docs/transformers/src/transformers/models/shieldgemma2/modeling_shieldgemma2.py new file mode 100644 index 0000000000000000000000000000000000000000..a0a65ea37c8dfa6e6e6cc50860fb937f78bf75e6 --- /dev/null +++ b/docs/transformers/src/transformers/models/shieldgemma2/modeling_shieldgemma2.py @@ -0,0 +1,220 @@ +# coding=utf-8 +# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import List, Optional, Union + +import torch +import torch.utils.checkpoint + +from ...cache_utils import Cache +from ...modeling_outputs import ImageClassifierOutputWithNoAttention +from ...modeling_utils import PreTrainedModel +from ...utils import ( + add_start_docstrings_to_model_forward, + logging, +) +from ..auto import AutoModelForImageTextToText +from .configuration_shieldgemma2 import ShieldGemma2Config + + +_CHECKPOINT_FOR_DOC = "google/shieldgemma-2-4b-it" +_CONFIG_FOR_DOC = "ShieldGemma2Config" + +logger = logging.get_logger(__name__) + +SHIELDGEMMA2_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance, see our + [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@dataclass +class ShieldGemma2ImageClassifierOutputWithNoAttention(ImageClassifierOutputWithNoAttention): + """ShieldGemma2 classifies imags as violative or not relative to a specific policy + Args: + """ + + probabilities: Optional[torch.Tensor] = None + + +class ShieldGemma2ForImageClassification(PreTrainedModel): + config_class = ShieldGemma2Config + + def __init__(self, config: ShieldGemma2Config): + super().__init__(config=config) + self.yes_token_index = getattr(config, "yes_token_index", 10_784) + self.no_token_index = getattr(config, "no_token_index", 3771) + self.model = AutoModelForImageTextToText.from_config(config=config) + + def get_input_embeddings(self): + return self.model.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.language_model.set_input_embeddings(value) + + def get_output_embeddings(self): + return self.model.language_model.get_output_embeddings() + + def set_output_embeddings(self, new_embeddings): + self.model.language_model.set_output_embeddings(new_embeddings) + + def set_decoder(self, decoder): + self.model.language_model.set_decoder(decoder) + + def get_decoder(self): + return self.model.language_model.get_decoder() + + def tie_weights(self): + return self.model.language_model.tie_weights() + + @add_start_docstrings_to_model_forward(SHIELDGEMMA2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None, + token_type_ids: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **lm_kwargs, + ) -> ShieldGemma2ImageClassifierOutputWithNoAttention: + """Predicts the binary probability that the image violates the specified policy. + + Returns: + A `ShieldGemma2ImageClassifierOutputWithNoAttention` instance containing the logits and probabilities + associated with the model predicting the `Yes` or `No` token as the response to that prompt, captured in the + following properties. + + * `logits` (`torch.Tensor` of shape `(batch_size, 2)`): + The first position along dim=1 is the logits for the `Yes` token and the second position along dim=1 is + the logits for the `No` token. + * `probabilities` (`torch.Tensor` of shape `(batch_size, 2)`): + The first position along dim=1 is the probability of predicting the `Yes` token and the second position + along dim=1 is the probability of predicting the `No` token. + + ShieldGemma prompts are constructed such that predicting the `Yes` token means the content *does violate* the + policy as described. If you are only interested in the violative condition, use + `violated = outputs.probabilities[:, 1]` to extract that slice from the output tensors. + + When used with the `ShieldGemma2Processor`, the `batch_size` will be equal to `len(images) * len(policies)`, + and the order within the batch will be img1_policy1, ... img1_policyN, ... imgM_policyN. + """ + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + token_type_ids=token_type_ids, + cache_position=cache_position, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + logits_to_keep=logits_to_keep, + **lm_kwargs, + ) + logits = outputs.logits + selected_logits = logits[:, -1, [self.yes_token_index, self.no_token_index]] + probabilities = torch.softmax(selected_logits, dim=-1) + return ShieldGemma2ImageClassifierOutputWithNoAttention( + logits=selected_logits, + probabilities=probabilities, + ) + + +__all__ = [ + "ShieldGemma2ForImageClassification", +] diff --git a/docs/transformers/src/transformers/models/shieldgemma2/processing_shieldgemma2.py b/docs/transformers/src/transformers/models/shieldgemma2/processing_shieldgemma2.py new file mode 100644 index 0000000000000000000000000000000000000000..097b83ef6ee6d5971535f7b2aad71efbc336c0a7 --- /dev/null +++ b/docs/transformers/src/transformers/models/shieldgemma2/processing_shieldgemma2.py @@ -0,0 +1,195 @@ +# coding=utf-8 +# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections.abc import Mapping, Sequence +from typing import Optional + +from ...feature_extraction_utils import BatchFeature +from ...image_utils import ImageInput +from ...processing_utils import Unpack +from ...utils import logging +from ..gemma3.processing_gemma3 import Gemma3Processor, Gemma3ProcessorKwargs + + +logger = logging.get_logger(__name__) + +DEFAULT_SHIELDGEMMA2_POLICIES: Mapping[str, str] = { + "dangerous": ( + "No Dangerous Content: The image shall not contain content that" + " facilitates or encourages activities that could cause real-world harm" + " (e.g., building firearms and explosive devices, promotion of" + " terrorism, instructions for suicide)." + ), + "sexual": ( + "No Sexually Explicit content: The image shall not contain content that" + " depicts explicit or graphic sexual acts (e.g., pornography, erotic" + " nudity, depictions of rape or sexual assault)." + ), + "violence": ( + "No Violence/Gore content: The image shall not contain content that" + " depicts shocking, sensational, or gratuitous violence (e.g.," + " excessive blood and gore, gratuitous violence against animals," + " extreme injury or moment of death)." + ), +} + + +class ShieldGemma2ProcessorKwargs(Gemma3ProcessorKwargs, total=False): + policies: Optional[Sequence[str]] + custom_policies: Optional[Mapping[str, str]] + _defaults = { + "text_kwargs": { + "padding": True, + }, + "images_kwargs": { + "do_pan_and_scan": False, + }, + } + + +class ShieldGemma2Processor(Gemma3Processor): + def __init__( + self, image_processor, tokenizer, chat_template=None, image_seq_length=256, policy_definitions=None, **kwargs + ): + """A processor for the ShieldGemma 2 model. + + Args: + image_processor: The image processor to use, typically a `Gemma3ImageProcessorFast` instance. + tokenizer: The tokenizer to use, typically a `GemmaTokenizerFast` instance. + chat_template: The chat template to use with this processor. Typically, this is unset as the processor + configuration on Hugging Face Hub includes this value already. + image_seq_length: The number of soft tokens per image. Typically, this is unset as the processor + configuration on Hugging Face Hub includes this value already. + policy_definitions: A mapping from policy name to its description in text used as the default policies to + classify images against. The policy descriptions are included in the text of the prompts generated by + this processor. Typically, this is unset as the processor configuration on Hugging Face Hub includes + the base policies ShieldGemma was trained on. + """ + super().__init__(image_processor, tokenizer, chat_template, image_seq_length, **kwargs) + if policy_definitions is None: + self.policy_definitions = DEFAULT_SHIELDGEMMA2_POLICIES + else: + self.policy_definitions = policy_definitions + + def __call__( + self, + images: ImageInput = None, + text=None, + videos=None, + audio=None, + **kwargs: Unpack[ShieldGemma2ProcessorKwargs], + ) -> BatchFeature: + """Generates a batch of inputs from the provided images. + + ShieldGemma was trained to classify image content for policy compliance using a specific prompt construction. + This processor generates a batch of such prompts from the provided images by: + + 1. Creating a list of conversations, one for each `` pair; + 2. Converting these conversations to text using `self.apply_chat_template()`; and + 3. Encoding the conversations and images using the same techniques as `Gemma3Processor`. + + Args: + images: A single image or a list of images to include in the batch. + text: Not supported. + videos: Not supported. + audio: Not supported. + kwargs: An optional dictionary of keyword arguments to configre the + processor. Possible values include: + + * `custom_policies`: Additional policy definitions that augment the `self.policy_definitions` passed + into the constructor. Note that `custom_policies` that share a key with `self.policy_definitions` + will override the policy description + * `policies`: (Optional) a list of keys in the joint `self.policy_definitions | custom_policies` + dictionary of specific interest for the provided images. If empty or None, prompts will be + generated for every key in the joint dictionary. + + Returns: + A `BatchFeature` continaing `input_ids`, `pixel_values`, etc. where each Tensor is of shape + `(len(images) * len(policies), )`, and the order within the batch will be + img1_policy1, ... img1_policyN, ... imgM_policyN. + """ + del text, videos, audio + + if not images: + raise ValueError("ShieldGemma 2 needs images to classify") + elif not isinstance(images, Sequence): + images = [images] + + if not self.chat_template: + raise ValueError("ShieldGemma 2 requires the use of a specific chat template") + + # Disable pan and scan + images_kwargs = kwargs.setdefault("images_kwargs", {}) + if images_kwargs.get("do_pan_and_scan") is True: + logger.warning_once("ShieldGemma2 does not support pan and scan.") + images_kwargs["do_pan_and_scan"] = False + + # Enable padding on the batch during tokenization + text_kwargs = kwargs.setdefault("text_kwargs", {}) + if "padding" not in text_kwargs: + text_kwargs["padding"] = kwargs.pop("padding", True) + text_kwargs["padding_side"] = kwargs.pop("padding_side", "left") + + policy_definitions: Mapping[str, str] = { + **self.policy_definitions, + **kwargs.get("custom_policies", {}), + } + + if (policies := kwargs.get("policies")) is None: + policies = list(policy_definitions.keys()) + + # TODO(ryanmullins): Support images from PIL or URLs. + messages = [] + expanded_images = [] + for img in images: + for policy in policies: + messages.append( + [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": policy_definitions[policy]}, + ], + } + ] + ) + expanded_images.append([img]) + + text = self.apply_chat_template(messages, tokenize=False) + return super().__call__(images=expanded_images, text=text, **kwargs) + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + ["token_type_ids"] + image_processor_input_names = self.image_processor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) + + +__all__ = ["ShieldGemma2Processor"] diff --git a/docs/transformers/src/transformers/models/siglip/__init__.py b/docs/transformers/src/transformers/models/siglip/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a5861f45f45b06bc01bbaa48f96fcc4f638fc590 --- /dev/null +++ b/docs/transformers/src/transformers/models/siglip/__init__.py @@ -0,0 +1,31 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_siglip import * + from .image_processing_siglip import * + from .image_processing_siglip_fast import * + from .modeling_siglip import * + from .processing_siglip import * + from .tokenization_siglip import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/siglip/configuration_siglip.py b/docs/transformers/src/transformers/models/siglip/configuration_siglip.py new file mode 100644 index 0000000000000000000000000000000000000000..f4a140cecc2c1082c0fa973059df4690a63b52b6 --- /dev/null +++ b/docs/transformers/src/transformers/models/siglip/configuration_siglip.py @@ -0,0 +1,269 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Siglip model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class SiglipTextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`SiglipTextModel`]. It is used to instantiate a + Siglip text encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the text encoder of the Siglip + [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the Siglip text model. Defines the number of different tokens that can be represented by + the `inputs_ids` passed when calling [`SiglipModel`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + max_position_embeddings (`int`, *optional*, defaults to 64): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + pad_token_id (`int`, *optional*, defaults to 1): + The id of the padding token in the vocabulary. + bos_token_id (`int`, *optional*, defaults to 49406): + The id of the beginning-of-sequence token in the vocabulary. + eos_token_id (`int`, *optional*, defaults to 49407): + The id of the end-of-sequence token in the vocabulary. + projection_size (`int`, *optional*, defaults to `hidden_size`): + The size of the projection head. + + Example: + + ```python + >>> from transformers import SiglipTextConfig, SiglipTextModel + + >>> # Initializing a SiglipTextConfig with google/siglip-base-patch16-224 style configuration + >>> configuration = SiglipTextConfig() + + >>> # Initializing a SiglipTextModel (with random weights) from the google/siglip-base-patch16-224 style configuration + >>> model = SiglipTextModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "siglip_text_model" + base_config_key = "text_config" + + def __init__( + self, + vocab_size=32000, + hidden_size=768, + intermediate_size=3072, + num_hidden_layers=12, + num_attention_heads=12, + max_position_embeddings=64, + hidden_act="gelu_pytorch_tanh", + layer_norm_eps=1e-6, + attention_dropout=0.0, + # This differs from `CLIPTokenizer`'s default and from openai/siglip + # See https://github.com/huggingface/transformers/pull/24773#issuecomment-1632287538 + pad_token_id=1, + bos_token_id=49406, + eos_token_id=49407, + projection_size=None, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.max_position_embeddings = max_position_embeddings + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + self.attention_dropout = attention_dropout + self.projection_size = projection_size if projection_size is not None else hidden_size + + +class SiglipVisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`SiglipVisionModel`]. It is used to instantiate a + Siglip vision encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip + [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + num_channels (`int`, *optional*, defaults to 3): + Number of channels in the input images. + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + + Example: + + ```python + >>> from transformers import SiglipVisionConfig, SiglipVisionModel + + >>> # Initializing a SiglipVisionConfig with google/siglip-base-patch16-224 style configuration + >>> configuration = SiglipVisionConfig() + + >>> # Initializing a SiglipVisionModel (with random weights) from the google/siglip-base-patch16-224 style configuration + >>> model = SiglipVisionModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "siglip_vision_model" + base_config_key = "vision_config" + + def __init__( + self, + hidden_size=768, + intermediate_size=3072, + num_hidden_layers=12, + num_attention_heads=12, + num_channels=3, + image_size=224, + patch_size=16, + hidden_act="gelu_pytorch_tanh", + layer_norm_eps=1e-6, + attention_dropout=0.0, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.image_size = image_size + self.attention_dropout = attention_dropout + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + + +class SiglipConfig(PretrainedConfig): + r""" + [`SiglipConfig`] is the configuration class to store the configuration of a [`SiglipModel`]. It is used to + instantiate a Siglip model according to the specified arguments, defining the text model and vision model configs. + Instantiating a configuration with the defaults will yield a similar configuration to that of the Siglip + [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + text_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`SiglipTextConfig`]. + vision_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`SiglipVisionConfig`]. + kwargs (*optional*): + Dictionary of keyword arguments. + + Example: + + ```python + >>> from transformers import SiglipConfig, SiglipModel + + >>> # Initializing a SiglipConfig with google/siglip-base-patch16-224 style configuration + >>> configuration = SiglipConfig() + + >>> # Initializing a SiglipModel (with random weights) from the google/siglip-base-patch16-224 style configuration + >>> model = SiglipModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + + >>> # We can also initialize a SiglipConfig from a SiglipTextConfig and a SiglipVisionConfig + >>> from transformers import SiglipTextConfig, SiglipVisionConfig + + >>> # Initializing a SiglipText and SiglipVision configuration + >>> config_text = SiglipTextConfig() + >>> config_vision = SiglipVisionConfig() + + >>> config = SiglipConfig.from_text_vision_configs(config_text, config_vision) + ```""" + + model_type = "siglip" + sub_configs = {"text_config": SiglipTextConfig, "vision_config": SiglipVisionConfig} + + def __init__(self, text_config=None, vision_config=None, **kwargs): + super().__init__(**kwargs) + + if text_config is None: + text_config = {} + logger.info("`text_config` is `None`. Initializing the `SiglipTextConfig` with default values.") + + if vision_config is None: + vision_config = {} + logger.info("`vision_config` is `None`. initializing the `SiglipVisionConfig` with default values.") + + self.text_config = SiglipTextConfig(**text_config) + self.vision_config = SiglipVisionConfig(**vision_config) + + self.initializer_factor = 1.0 + + @classmethod + def from_text_vision_configs(cls, text_config: SiglipTextConfig, vision_config: SiglipVisionConfig, **kwargs): + r""" + Instantiate a [`SiglipConfig`] (or a derived class) from siglip text model configuration and siglip vision + model configuration. + + Returns: + [`SiglipConfig`]: An instance of a configuration object + """ + + return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs) + + +__all__ = ["SiglipConfig", "SiglipTextConfig", "SiglipVisionConfig"] diff --git a/docs/transformers/src/transformers/models/siglip/convert_siglip_to_hf.py b/docs/transformers/src/transformers/models/siglip/convert_siglip_to_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..b61bd7ffb70edaf0674bb3312e047d0b50ff2524 --- /dev/null +++ b/docs/transformers/src/transformers/models/siglip/convert_siglip_to_hf.py @@ -0,0 +1,533 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert SigLIP checkpoints from the original repository. + +URL: https://github.com/google-research/big_vision/tree/main +""" + +import argparse +import collections +import os +from typing import Tuple + +import numpy as np +import requests +import torch +from huggingface_hub import hf_hub_download +from numpy import load +from PIL import Image + +from transformers import ( + GemmaTokenizerFast, + SiglipConfig, + SiglipImageProcessor, + SiglipModel, + SiglipProcessor, + SiglipTokenizer, +) +from transformers.utils import logging + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + + +MODEL_CONFIGS = { + "base": { + "hidden_size": 768, + "intermediate_size": 3072, + "num_hidden_layers": 12, + "num_attention_heads": 12, + }, + "large": { + "hidden_size": 1024, + "intermediate_size": 4096, + "num_hidden_layers": 24, + "num_attention_heads": 16, + }, + "giant-opt": { + "hidden_size": 1536, + "intermediate_size": 6144, + "num_hidden_layers": 40, + "num_attention_heads": 16, + }, + "so400m": { + "hidden_size": 1152, + "intermediate_size": 4304, + "num_hidden_layers": 27, + "num_attention_heads": 16, + }, +} + +model_name_to_checkpoint = { + # base checkpoints + "siglip-base-patch16-224": "/Users/nielsrogge/Documents/SigLIP/webli_en_b16_224_63724782.npz", + "siglip-base-patch16-256": "/Users/nielsrogge/Documents/SigLIP/webli_en_b16_256_60500360.npz", + "siglip-base-patch16-384": "/Users/nielsrogge/Documents/SigLIP/webli_en_b16_384_68578854.npz", + "siglip-base-patch16-512": "/Users/nielsrogge/Documents/SigLIP/webli_en_b16_512_68580893.npz", + # large checkpoints + "siglip-large-patch16-256": "/Users/nielsrogge/Documents/SigLIP/webli_en_l16_256_60552751.npz", + "siglip-large-patch16-384": "/Users/nielsrogge/Documents/SigLIP/webli_en_l16_384_63634585.npz", + # multilingual checkpoint + "siglip-base-patch16-256-i18n": "/Users/nielsrogge/Documents/SigLIP/webli_i18n_b16_256_66117334.npz", + # so400m checkpoints + "siglip-so400m-patch14-384": "/Users/nielsrogge/Documents/SigLIP/webli_en_so400m_384_58765454.npz", + # ----------------- v2 ----------------- + # base checkpoints + "siglip2-base-patch32-256": "gv-hf/siglip2/siglip2_b32_256.npz", + "siglip2-base-patch16-224": "gv-hf/siglip2/siglip2_b16_224.npz", + "siglip2-base-patch16-256": "gv-hf/siglip2/siglip2_b16_256.npz", + "siglip2-base-patch16-384": "gv-hf/siglip2/siglip2_b16_384.npz", + "siglip2-base-patch16-512": "gv-hf/siglip2/siglip2_b16_512.npz", + # large checkpoints + "siglip2-large-patch16-256": "gv-hf/siglip2/siglip2_l16_256.npz", + "siglip2-large-patch16-384": "gv-hf/siglip2/siglip2_l16_384.npz", + "siglip2-large-patch16-512": "gv-hf/siglip2/siglip2_l16_512.npz", + # giant opt checkpoints + "siglip2-giant-opt-patch16-256": "gv-hf/siglip2/siglip2_g-opt16_256.npz", + "siglip2-giant-opt-patch16-384": "gv-hf/siglip2/siglip2_g-opt16_384.npz", + # so400m checkpoints + "siglip2-so400m-patch14-224": "gv-hf/siglip2/siglip2_so400m14_224.npz", + "siglip2-so400m-patch14-384": "gv-hf/siglip2/siglip2_so400m14_384.npz", + "siglip2-so400m-patch16-256": "gv-hf/siglip2/siglip2_so400m16_256.npz", + "siglip2-so400m-patch16-384": "gv-hf/siglip2/siglip2_so400m16_384.npz", + "siglip2-so400m-patch16-512": "gv-hf/siglip2/siglip2_so400m16_512.npz", +} + +# ------------------------------------------------------------------------------------------------------ +# CONFIG +# ------------------------------------------------------------------------------------------------------ + + +def get_image_size_from_model_name(model_name: str) -> int: + if "-i18n" not in model_name: + size = model_name.split("-")[-1] + else: + size = model_name.split("-")[-2] + return int(size) + + +def get_patch_size_from_model_name(model_name: str) -> int: + patch_str = [x for x in model_name.split("-") if "patch" in x][0] + return int(patch_str[-2:]) + + +def get_vocab_size_from_model_name(model_name: str) -> int: + if "siglip2" in model_name: + vocab_size = 256000 + elif "-i18n" in model_name: + vocab_size = 250000 + else: + vocab_size = 32000 + return vocab_size + + +def get_vocab_file_from_model_name(model_name: str) -> str: + # get vocab file + if "i18n" in model_name: + vocab_file = "/Users/nielsrogge/Documents/SigLIP/multilingual_vocab/sentencepiece.model" + else: + vocab_file = "/Users/nielsrogge/Documents/SigLIP/english_vocab/sentencepiece.model" + return vocab_file + + +def get_text_and_vision_vit_variants(model_name: str) -> Tuple[str, str]: + variant = model_name.split("-")[1] if "giant-opt" not in model_name else "giant-opt" + return { + "base": ("base", "base"), + "large": ("large", "large"), + "so400m": ("so400m", "so400m"), + # g-opt siglip2 is not symmetric + "giant-opt": ("so400m", "giant-opt"), + }[variant] + + +def get_siglip_config(model_name): + text_variant, vision_variant = get_text_and_vision_vit_variants(model_name) + text_config = MODEL_CONFIGS[text_variant].copy() + vision_config = MODEL_CONFIGS[vision_variant].copy() + + text_config["vocab_size"] = get_vocab_size_from_model_name(model_name) + vision_config["image_size"] = get_image_size_from_model_name(model_name) + vision_config["patch_size"] = get_patch_size_from_model_name(model_name) + + if text_config["hidden_size"] != vision_config["hidden_size"]: + text_config["projection_size"] = vision_config["hidden_size"] + + return SiglipConfig(text_config=text_config, vision_config=vision_config) + + +# ------------------------------------------------------------------------------------------------------ +# PROCESSING +# ------------------------------------------------------------------------------------------------------ + + +def get_tokenizer(model_name: str) -> GemmaTokenizerFast: + if "siglip2" in model_name: + tokenizer = GemmaTokenizerFast.from_pretrained( + "google/gemma-2-9b-it", + add_bos_token=False, + add_eos_token=True, + padding_side="right", + do_lower_case=True, + # important: make tokenizer NOT return attention_mask since original one doesn't require it + model_input_names=["input_ids"], + ) + else: + # for siglip v1 + vocab_file = get_vocab_file_from_model_name(model_name) + # important: make tokenizer not return attention_mask since original one doesn't require it + tokenizer = SiglipTokenizer(vocab_file=vocab_file, model_input_names=["input_ids"]) + return tokenizer + + +def get_image_processor(model_name: str) -> SiglipImageProcessor: + image_size = get_image_size_from_model_name(model_name) + size = {"height": image_size, "width": image_size} + if "siglip2" in model_name: + image_processor = SiglipImageProcessor(size=size, resample=2) # bilinear resampling + else: + image_processor = SiglipImageProcessor(size=size) + return image_processor + + +# ------------------------------------------------------------------------------------------------------ +# CONVERT FUNCTIONS +# ------------------------------------------------------------------------------------------------------ + + +def split_encoderblock_layers(state_dict: dict) -> dict: + """ + Split the encoderblock weight into layers. In some cases they are concatenated in + the original checkpoints. + """ + # Make shallow copy + state_dict = state_dict.copy() + # Split encoderblock weight into layers + keys = list(state_dict.keys()) + for key in keys: + if "/encoderblock/" in key: + weight = state_dict.pop(key) + for i, weight_i in enumerate(weight): + new_name = key.replace("encoderblock", f"encoderblock_{i}") + state_dict[new_name] = weight_i + return state_dict + + +def create_rename_keys(config): + rename_keys = [] + # fmt: off + + # vision encoder + + rename_keys.append(("params/img/embedding/kernel", "vision_model.embeddings.patch_embedding.weight")) + rename_keys.append(("params/img/embedding/bias", "vision_model.embeddings.patch_embedding.bias")) + rename_keys.append(("params/img/pos_embedding", "vision_model.embeddings.position_embedding.weight")) + + for i in range(config.vision_config.num_hidden_layers): + rename_keys.append((f"params/img/Transformer/encoderblock_{i}/LayerNorm_0/scale", f"vision_model.encoder.layers.{i}.layer_norm1.weight")) + rename_keys.append((f"params/img/Transformer/encoderblock_{i}/LayerNorm_0/bias", f"vision_model.encoder.layers.{i}.layer_norm1.bias")) + rename_keys.append((f"params/img/Transformer/encoderblock_{i}/LayerNorm_1/scale", f"vision_model.encoder.layers.{i}.layer_norm2.weight")) + rename_keys.append((f"params/img/Transformer/encoderblock_{i}/LayerNorm_1/bias", f"vision_model.encoder.layers.{i}.layer_norm2.bias")) + rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MlpBlock_0/Dense_0/kernel", f"vision_model.encoder.layers.{i}.mlp.fc1.weight")) + rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MlpBlock_0/Dense_0/bias", f"vision_model.encoder.layers.{i}.mlp.fc1.bias")) + rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MlpBlock_0/Dense_1/kernel", f"vision_model.encoder.layers.{i}.mlp.fc2.weight")) + rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MlpBlock_0/Dense_1/bias", f"vision_model.encoder.layers.{i}.mlp.fc2.bias")) + rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/key/kernel", f"vision_model.encoder.layers.{i}.self_attn.k_proj.weight")) + rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/key/bias", f"vision_model.encoder.layers.{i}.self_attn.k_proj.bias")) + rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/value/kernel", f"vision_model.encoder.layers.{i}.self_attn.v_proj.weight")) + rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/value/bias", f"vision_model.encoder.layers.{i}.self_attn.v_proj.bias")) + rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/query/kernel", f"vision_model.encoder.layers.{i}.self_attn.q_proj.weight")) + rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/query/bias", f"vision_model.encoder.layers.{i}.self_attn.q_proj.bias")) + rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/out/kernel", f"vision_model.encoder.layers.{i}.self_attn.out_proj.weight")) + rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/out/bias", f"vision_model.encoder.layers.{i}.self_attn.out_proj.bias")) + + rename_keys.append(("params/img/Transformer/encoder_norm/scale", "vision_model.post_layernorm.weight")) + rename_keys.append(("params/img/Transformer/encoder_norm/bias", "vision_model.post_layernorm.bias")) + + rename_keys.append(("params/img/MAPHead_0/probe", "vision_model.head.probe")) + rename_keys.append(("params/img/MAPHead_0/LayerNorm_0/scale", "vision_model.head.layernorm.weight")) + rename_keys.append(("params/img/MAPHead_0/LayerNorm_0/bias", "vision_model.head.layernorm.bias")) + rename_keys.append(("params/img/MAPHead_0/MlpBlock_0/Dense_0/kernel", "vision_model.head.mlp.fc1.weight")) + rename_keys.append(("params/img/MAPHead_0/MlpBlock_0/Dense_0/bias", "vision_model.head.mlp.fc1.bias")) + rename_keys.append(("params/img/MAPHead_0/MlpBlock_0/Dense_1/kernel", "vision_model.head.mlp.fc2.weight")) + rename_keys.append(("params/img/MAPHead_0/MlpBlock_0/Dense_1/bias", "vision_model.head.mlp.fc2.bias")) + rename_keys.append(("params/img/MAPHead_0/MultiHeadDotProductAttention_0/out/kernel", "vision_model.head.attention.out_proj.weight")) + rename_keys.append(("params/img/MAPHead_0/MultiHeadDotProductAttention_0/out/bias", "vision_model.head.attention.out_proj.bias")) + + # text encoder + + rename_keys.append(("params/txt/Embed_0/embedding", "text_model.embeddings.token_embedding.weight")) + rename_keys.append(("params/txt/pos_embedding", "text_model.embeddings.position_embedding.weight")) + + for i in range(config.text_config.num_hidden_layers): + rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/LayerNorm_0/scale", f"text_model.encoder.layers.{i}.layer_norm1.weight")) + rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/LayerNorm_0/bias", f"text_model.encoder.layers.{i}.layer_norm1.bias")) + rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/LayerNorm_1/scale", f"text_model.encoder.layers.{i}.layer_norm2.weight")) + rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/LayerNorm_1/bias", f"text_model.encoder.layers.{i}.layer_norm2.bias")) + rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MlpBlock_0/Dense_0/kernel", f"text_model.encoder.layers.{i}.mlp.fc1.weight")) + rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MlpBlock_0/Dense_0/bias", f"text_model.encoder.layers.{i}.mlp.fc1.bias")) + rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MlpBlock_0/Dense_1/kernel", f"text_model.encoder.layers.{i}.mlp.fc2.weight")) + rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MlpBlock_0/Dense_1/bias", f"text_model.encoder.layers.{i}.mlp.fc2.bias")) + rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/key/kernel", f"text_model.encoder.layers.{i}.self_attn.k_proj.weight")) + rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/key/bias", f"text_model.encoder.layers.{i}.self_attn.k_proj.bias")) + rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/value/kernel", f"text_model.encoder.layers.{i}.self_attn.v_proj.weight")) + rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/value/bias", f"text_model.encoder.layers.{i}.self_attn.v_proj.bias")) + rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/query/kernel", f"text_model.encoder.layers.{i}.self_attn.q_proj.weight")) + rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/query/bias", f"text_model.encoder.layers.{i}.self_attn.q_proj.bias")) + rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/out/kernel", f"text_model.encoder.layers.{i}.self_attn.out_proj.weight")) + rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/out/bias", f"text_model.encoder.layers.{i}.self_attn.out_proj.bias")) + + rename_keys.append(("params/txt/Encoder_0/encoder_norm/scale", "text_model.final_layer_norm.weight")) + rename_keys.append(("params/txt/Encoder_0/encoder_norm/bias", "text_model.final_layer_norm.bias")) + rename_keys.append(("params/txt/head/kernel", "text_model.head.weight")) + rename_keys.append(("params/txt/head/bias", "text_model.head.bias")) + + # learned temperature and bias + rename_keys.append(("params/t", "logit_scale")) + rename_keys.append(("params/b", "logit_bias")) + + # fmt: on + return rename_keys + + +def rename_key(dct, old, new, config): + val = dct.pop(old) + + if ("out_proj" in new or "v_proj" in new or "k_proj" in new or "q_proj" in new) and "vision" in new: + val = val.reshape(-1, config.vision_config.hidden_size) + if ("out_proj" in new or "v_proj" in new or "k_proj" in new or "q_proj" in new) and "text" in new: + val = val.reshape(-1, config.text_config.hidden_size) + + if "patch_embedding.weight" in new: + val = val.transpose(3, 2, 0, 1) + elif new.endswith("weight") and "position_embedding" not in new and "token_embedding" not in new: + val = val.T + + if "position_embedding" in new and "vision" in new: + val = val.reshape(-1, config.vision_config.hidden_size) + if "position_embedding" in new and "text" in new: + val = val.reshape(-1, config.text_config.hidden_size) + + if new.endswith("bias"): + val = val.reshape(-1) + + dct[new] = torch.from_numpy(val) + + +def read_in_q_k_v_head(state_dict, config): + # read in individual input projection layers + key_proj_weight = ( + state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/key/kernel") + .reshape(-1, config.vision_config.hidden_size) + .T + ) + key_proj_bias = state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/key/bias").reshape(-1) + value_proj_weight = ( + state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/value/kernel") + .reshape(-1, config.vision_config.hidden_size) + .T + ) + value_proj_bias = state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/value/bias").reshape(-1) + query_proj_weight = ( + state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/query/kernel") + .reshape(-1, config.vision_config.hidden_size) + .T + ) + query_proj_bias = state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/query/bias").reshape(-1) + + # next, add them to the state dict as a single matrix + vector + state_dict["vision_model.head.attention.in_proj_weight"] = torch.from_numpy( + np.concatenate([query_proj_weight, key_proj_weight, value_proj_weight], axis=0) + ) + state_dict["vision_model.head.attention.in_proj_bias"] = torch.from_numpy( + np.concatenate([query_proj_bias, key_proj_bias, value_proj_bias], axis=0) + ) + + +# We will verify our results on an image of cute cats +def prepare_img(): + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + image = Image.open(requests.get(url, stream=True).raw) + return image + + +def flatten_nested_dict(params, parent_key="", sep="/"): + items = [] + + for k, v in params.items(): + new_key = parent_key + sep + k if parent_key else k + + if isinstance(v, collections.abc.MutableMapping): + items.extend(flatten_nested_dict(v, new_key, sep=sep).items()) + else: + items.append((new_key, v)) + return dict(items) + + +@torch.no_grad() +def convert_siglip_checkpoint(model_name, pytorch_dump_folder_path, verify_logits=True, push_to_hub=False): + """ + Copy/paste/tweak model's weights to our SigLIP structure. + """ + + # Define default SigLIP configuration + config = get_siglip_config(model_name) + + # Get checkpoint + checkpoint = model_name_to_checkpoint[model_name] + if not os.path.exists(checkpoint): + org, repo_id, *filepath = checkpoint.split("/") + checkpoint = hf_hub_download(repo_id=f"{org}/{repo_id}", filename="/".join(filepath)) + + # Load original state dict + data = load(checkpoint) + state_dict = flatten_nested_dict(data) + state_dict = split_encoderblock_layers(state_dict) + + # Remove and rename some keys + rename_keys = create_rename_keys(config) + for src, dest in rename_keys: + rename_key(state_dict, src, dest, config) + + # qkv matrices of attention pooling head need special treatment + read_in_q_k_v_head(state_dict, config) + + # Load HuggingFace model + model = SiglipModel(config).eval() + model.load_state_dict(state_dict) + + # Create processor + image_processor = get_image_processor(model_name) + tokenizer = get_tokenizer(model_name) + processor = SiglipProcessor(image_processor=image_processor, tokenizer=tokenizer) + + # Verify forward pass on dummy images and texts + url_1 = "https://cdn.openai.com/multimodal-neurons/assets/apple/apple-ipod.jpg" + image_1 = Image.open(requests.get(url_1, stream=True).raw).convert("RGB") + url_2 = "https://cdn.openai.com/multimodal-neurons/assets/apple/apple-blank.jpg" + image_2 = Image.open(requests.get(url_2, stream=True).raw).convert("RGB") + texts = ["an apple", "a picture of an apple"] + + inputs = processor(images=[image_1, image_2], text=texts, padding="max_length", max_length=64, return_tensors="pt") + with torch.no_grad(): + outputs = model(**inputs) + + if verify_logits: + image_size = config.vision_config.image_size + + # verify input_ids against original ones + if image_size == 224: + filename = "siglip_pixel_values.pt" + elif image_size == 256: + filename = "siglip_pixel_values_256.pt" + elif image_size == 384: + filename = "siglip_pixel_values_384.pt" + elif image_size == 512: + filename = "siglip_pixel_values_512.pt" + else: + raise ValueError("Image size not supported") + + filepath = hf_hub_download(repo_id="nielsr/test-image", filename=filename, repo_type="dataset") + original_pixel_values = torch.load(filepath, weights_only=True) + filepath = hf_hub_download(repo_id="nielsr/test-image", filename="siglip_input_ids.pt", repo_type="dataset") + original_input_ids = torch.load(filepath, weights_only=True) + + if "i18n" not in model_name: + assert inputs.input_ids.tolist() == original_input_ids.tolist() + + print("Mean of original pixel values:", original_pixel_values.mean()) + print("Mean of new pixel values:", inputs.pixel_values.mean()) + + # note: we're testing with original pixel values here since we don't have exact pixel values + with torch.no_grad(): + outputs = model(input_ids=original_input_ids, pixel_values=original_pixel_values) + print(outputs.logits_per_image[:3, :3]) + + probs = torch.sigmoid(outputs.logits_per_image) # these are the probabilities + print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'") + print(f"{probs[0][1]:.1%} that image 0 is '{texts[1]}'") + + if model_name == "siglip-base-patch16-224": + expected_slice = torch.tensor( + [[-2.9621, -2.1672], [-0.2713, 0.2910]], + ) + elif model_name == "siglip-base-patch16-256": + expected_slice = torch.tensor( + [[-3.1146, -1.9894], [-0.7312, 0.6387]], + ) + elif model_name == "siglip-base-patch16-384": + expected_slice = torch.tensor( + [[-2.8098, -2.1891], [-0.4242, 0.4102]], + ) + elif model_name == "siglip-base-patch16-512": + expected_slice = torch.tensor( + [[-2.7899, -2.2668], [-0.4295, -0.0735]], + ) + elif model_name == "siglip-large-patch16-256": + expected_slice = torch.tensor( + [[-1.5827, -0.5801], [-0.9153, 0.1363]], + ) + elif model_name == "siglip-large-patch16-384": + expected_slice = torch.tensor( + [[-2.1523, -0.2899], [-0.2959, 0.7884]], + ) + elif model_name == "siglip-so400m-patch14-384": + expected_slice = torch.tensor([[-1.2441, -0.6649], [-0.7060, 0.7374]]) + elif model_name == "siglip-base-patch16-256-i18n": + expected_slice = torch.tensor( + [[-0.9064, 0.1073], [-0.0299, 0.5304]], + ) + + assert torch.allclose(outputs.logits_per_image[:3, :3], expected_slice, atol=1e-4) + print("Looks ok!") + + if pytorch_dump_folder_path is not None: + pytorch_dump_folder_path = os.path.join(pytorch_dump_folder_path, model_name) + os.makedirs(pytorch_dump_folder_path, exist_ok=True) + print(f"Saving model {model_name} to {pytorch_dump_folder_path}") + model.save_pretrained(pytorch_dump_folder_path) + print(f"Saving processor to {pytorch_dump_folder_path}") + processor.save_pretrained(pytorch_dump_folder_path) + + if push_to_hub: + model.push_to_hub(f"s0225/{model_name}", private=True) + processor.push_to_hub(f"s0225/{model_name}", private=True) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--model_name", + default="siglip-base-patch16-224", + type=str, + choices=model_name_to_checkpoint.keys(), + help="Name of the model you'd like to convert.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory." + ) + parser.add_argument( + "--verify_logits", + action="store_true", + help="Whether to verify logits against the original implementation.", + ) + parser.add_argument( + "--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub." + ) + + args = parser.parse_args() + convert_siglip_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.verify_logits, args.push_to_hub) diff --git a/docs/transformers/src/transformers/models/siglip/image_processing_siglip.py b/docs/transformers/src/transformers/models/siglip/image_processing_siglip.py new file mode 100644 index 0000000000000000000000000000000000000000..7ec6c36d39ca83dda2516f13548a0853e8f8afb1 --- /dev/null +++ b/docs/transformers/src/transformers/models/siglip/image_processing_siglip.py @@ -0,0 +1,244 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for SigLIP.""" + +from typing import Dict, List, Optional, Union + +from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import ( + convert_to_rgb, + resize, + to_channel_dimension_format, +) +from ...image_utils import ( + IMAGENET_STANDARD_MEAN, + IMAGENET_STANDARD_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + infer_channel_dimension_format, + is_scaled_image, + make_flat_list_of_images, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging + + +logger = logging.get_logger(__name__) + + +if is_vision_available(): + import PIL + + +class SiglipImageProcessor(BaseImageProcessor): + r""" + Constructs a SigLIP image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by + `do_resize` in the `preprocess` method. + size (`Dict[str, int]` *optional*, defaults to `{"height": 224, "width": 224}`): + Size of the image after resizing. Can be overridden by `size` in the `preprocess` method. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`): + Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in + the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess` + method. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image by the specified mean and standard deviation. Can be overridden by + `do_normalize` in the `preprocess` method. + image_mean (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + Can be overridden by the `image_std` parameter in the `preprocess` method. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize: bool = True, + size: Dict[str, int] = None, + resample: PILImageResampling = PILImageResampling.BICUBIC, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_convert_rgb: Optional[bool] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"height": 224, "width": 224} + image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN + image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD + + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean + self.image_std = image_std + self.do_convert_rgb = do_convert_rgb + + @filter_out_non_signature_kwargs() + def preprocess( + self, + images: ImageInput, + do_resize: Optional[bool] = None, + size: Dict[str, int] = None, + resample: PILImageResampling = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + do_convert_rgb: Optional[bool] = None, + ) -> PIL.Image.Image: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Size of the image after resizing. + resample (`int`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only + has an effect if `do_resize` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to + `True`. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + """ + do_resize = do_resize if do_resize is not None else self.do_resize + size = size if size is not None else self.size + size = get_size_dict(size, param_name="size", default_to_square=False) + resample = resample if resample is not None else self.resample + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + + images = make_flat_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_resize=do_resize, + size=size, + resample=resample, + ) + if do_convert_rgb: + images = [convert_to_rgb(image) for image in images] + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if do_rescale and is_scaled_image(images[0]): + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + if do_resize: + height, width = size["height"], size["width"] + images = [ + resize(image=image, size=(height, width), resample=resample, input_data_format=input_data_format) + for image in images + ] + + if do_rescale: + images = [ + self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + for image in images + ] + + if do_normalize: + images = [ + self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) + for image in images + ] + + images = [ + to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images + ] + + data = {"pixel_values": images} + return BatchFeature(data=data, tensor_type=return_tensors) + + +__all__ = ["SiglipImageProcessor"] diff --git a/docs/transformers/src/transformers/models/siglip/image_processing_siglip_fast.py b/docs/transformers/src/transformers/models/siglip/image_processing_siglip_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..b28f89dbbf36c6ec5e51aa75d028a231b440eeff --- /dev/null +++ b/docs/transformers/src/transformers/models/siglip/image_processing_siglip_fast.py @@ -0,0 +1,41 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for SigLIP.""" + +from ...image_processing_utils_fast import BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, BaseImageProcessorFast +from ...image_utils import ( + IMAGENET_STANDARD_MEAN, + IMAGENET_STANDARD_STD, + PILImageResampling, +) +from ...utils import add_start_docstrings + + +@add_start_docstrings( + "Constructs a fast SigLIP image processor.", + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, +) +class SiglipImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.BICUBIC + image_mean = IMAGENET_STANDARD_MEAN + image_std = IMAGENET_STANDARD_STD + size = {"height": 224, "width": 224} + default_to_square = False + do_resize = True + do_rescale = True + do_normalize = True + + +__all__ = ["SiglipImageProcessorFast"] diff --git a/docs/transformers/src/transformers/models/siglip/modeling_siglip.py b/docs/transformers/src/transformers/models/siglip/modeling_siglip.py new file mode 100644 index 0000000000000000000000000000000000000000..f77248970c4826ef51b5f3625105fa51cbb04167 --- /dev/null +++ b/docs/transformers/src/transformers/models/siglip/modeling_siglip.py @@ -0,0 +1,1392 @@ +# coding=utf-8 +# Copyright 2024 Google AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Siglip model.""" + +import math +import warnings +from dataclasses import dataclass +from typing import Any, Callable, Optional, Tuple, Union + +import numpy as np +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from torch.nn.init import _calculate_fan_in_and_fan_out + +from ...activations import ACT2FN +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + can_return_tuple, + logging, + replace_return_docstrings, + torch_int, +) +from .configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig + + +logger = logging.get_logger(__name__) + +# General docstring +_CONFIG_FOR_DOC = "SiglipConfig" +_CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224" + + +def _trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2, + ) + + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + + +def trunc_normal_tf_( + tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0 +) -> torch.Tensor: + """Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \\leq \text{mean} \\leq b`. + + NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the + bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0 + and the result is subsequently scaled and shifted by the mean and std args. + + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + """ + with torch.no_grad(): + _trunc_normal_(tensor, 0, 1.0, a, b) + tensor.mul_(std).add_(mean) + + +def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + if mode == "fan_in": + denom = fan_in + elif mode == "fan_out": + denom = fan_out + elif mode == "fan_avg": + denom = (fan_in + fan_out) / 2 + + variance = scale / denom + + if distribution == "truncated_normal": + # constant is stddev of standard normal truncated to (-2, 2) + trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978) + elif distribution == "normal": + with torch.no_grad(): + tensor.normal_(std=math.sqrt(variance)) + elif distribution == "uniform": + bound = math.sqrt(3 * variance) + with torch.no_grad(): + tensor.uniform_(-bound, bound) + else: + raise ValueError(f"invalid distribution {distribution}") + + +def lecun_normal_(tensor): + variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal") + + +def default_flax_embed_init(tensor): + variance_scaling_(tensor, mode="fan_in", distribution="normal") + + +@dataclass +# Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Siglip +class SiglipVisionModelOutput(ModelOutput): + """ + Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states. + + Args: + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The image embeddings obtained by applying the projection layer to the pooler_output. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + image_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +# Copied from transformers.models.clip.modeling_clip.CLIPTextModelOutput with CLIP->Siglip +class SiglipTextModelOutput(ModelOutput): + """ + Base class for text model's outputs that also contains a pooling of the last hidden states. + + Args: + text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The text embeddings obtained by applying the projection layer to the pooler_output. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + text_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +# Copied from transformers.models.clip.modeling_clip.CLIPOutput with CLIP->Siglip +class SiglipOutput(ModelOutput): + """ + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`): + Contrastive loss for image-text similarity. + logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`): + The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text + similarity scores. + logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`): + The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image + similarity scores. + text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`): + The text embeddings obtained by applying the projection layer to the pooled output of [`SiglipTextModel`]. + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`): + The image embeddings obtained by applying the projection layer to the pooled output of [`SiglipVisionModel`]. + text_model_output (`BaseModelOutputWithPooling`): + The output of the [`SiglipTextModel`]. + vision_model_output (`BaseModelOutputWithPooling`): + The output of the [`SiglipVisionModel`]. + """ + + loss: Optional[torch.FloatTensor] = None + logits_per_image: Optional[torch.FloatTensor] = None + logits_per_text: Optional[torch.FloatTensor] = None + text_embeds: Optional[torch.FloatTensor] = None + image_embeds: Optional[torch.FloatTensor] = None + text_model_output: BaseModelOutputWithPooling = None + vision_model_output: BaseModelOutputWithPooling = None + + def to_tuple(self) -> Tuple[Any]: + return tuple( + self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() + for k in self.keys() + ) + + +class SiglipVisionEmbeddings(nn.Module): + def __init__(self, config: SiglipVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False) + + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing and no class embeddings. + + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 + """ + + num_patches = embeddings.shape[1] + num_positions = self.position_embedding.weight.shape[0] + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: + return self.position_embedding(self.position_ids) + + patch_pos_embed = self.position_embedding.weight.unsqueeze(0) + + dim = embeddings.shape[-1] + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(new_height, new_width), + mode="bicubic", + align_corners=False, + ) + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return patch_pos_embed + + def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor: + _, _, height, width = pixel_values.shape + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] + embeddings = patch_embeds.flatten(2).transpose(1, 2) + + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embedding(self.position_ids) + return embeddings + + +# Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with CLIP->Siglip +class SiglipTextEmbeddings(nn.Module): + def __init__(self, config: SiglipTextConfig): + super().__init__() + embed_dim = config.hidden_size + + self.token_embedding = nn.Embedding(config.vocab_size, embed_dim) + self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ) -> torch.Tensor: + seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] + max_position_embedding = self.position_embedding.weight.shape[0] + + if seq_length > max_position_embedding: + raise ValueError( + f"Sequence length must be less than max_position_embeddings (got `sequence length`: " + f"{seq_length} and max_position_embeddings: {max_position_embedding}" + ) + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + if inputs_embeds is None: + inputs_embeds = self.token_embedding(input_ids) + + position_embeddings = self.position_embedding(position_ids) + embeddings = inputs_embeds + position_embeddings + + return embeddings + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class SiglipAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Union[SiglipVisionConfig, SiglipTextConfig]): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + self.is_causal = False + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Input shape: Batch x Time x Channel""" + + batch_size, seq_length, embed_dim = hidden_states.shape + + queries = self.q_proj(hidden_states) + keys = self.k_proj(hidden_states) + values = self.v_proj(hidden_states) + + queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and output_attentions: + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + queries, + keys, + values, + attention_mask, + is_causal=self.is_causal, + scaling=self.scale, + dropout=0.0 if not self.training else self.dropout, + ) + + attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous() + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights + + +# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip +class SiglipMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class SiglipEncoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Union[SiglipVisionConfig, SiglipTextConfig]): + super().__init__() + self.embed_dim = config.hidden_size + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.self_attn = SiglipAttention(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = SiglipMLP(config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor]: + """ + Args: + hidden_states (`torch.FloatTensor`): + Input to the layer of shape `(batch, seq_len, embed_dim)`. + attention_mask (`torch.FloatTensor`): + Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class SiglipPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = SiglipConfig + base_model_prefix = "siglip" + supports_gradient_checkpointing = True + + _no_split_modules = [ + "SiglipTextEmbeddings", + "SiglipEncoderLayer", + "SiglipVisionEmbeddings", + "SiglipEncoderLayer", + "SiglipMultiheadAttentionPoolingHead", + ] + _supports_flash_attn_2 = True + _supports_sdpa = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, SiglipVisionEmbeddings): + width = ( + self.config.vision_config.hidden_size + if isinstance(self.config, SiglipConfig) + else self.config.hidden_size + ) + nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width)) + elif isinstance(module, nn.Embedding): + default_flax_embed_init(module.weight) + elif isinstance(module, SiglipAttention): + nn.init.xavier_uniform_(module.q_proj.weight) + nn.init.xavier_uniform_(module.k_proj.weight) + nn.init.xavier_uniform_(module.v_proj.weight) + nn.init.xavier_uniform_(module.out_proj.weight) + nn.init.zeros_(module.q_proj.bias) + nn.init.zeros_(module.k_proj.bias) + nn.init.zeros_(module.v_proj.bias) + nn.init.zeros_(module.out_proj.bias) + elif isinstance(module, SiglipMLP): + nn.init.xavier_uniform_(module.fc1.weight) + nn.init.xavier_uniform_(module.fc2.weight) + nn.init.normal_(module.fc1.bias, std=1e-6) + nn.init.normal_(module.fc2.bias, std=1e-6) + elif isinstance(module, SiglipMultiheadAttentionPoolingHead): + nn.init.xavier_uniform_(module.probe.data) + nn.init.xavier_uniform_(module.attention.in_proj_weight.data) + nn.init.zeros_(module.attention.in_proj_bias.data) + elif isinstance(module, SiglipModel): + logit_scale_init = torch.log(torch.tensor(1.0)) + module.logit_scale.data.fill_(logit_scale_init) + module.logit_bias.data.zero_() + elif isinstance(module, SiglipForImageClassification): + nn.init.normal_( + module.classifier.weight, + std=self.config.vision_config.hidden_size**-0.5 * self.config.initializer_factor, + ) + elif isinstance(module, (nn.Linear, nn.Conv2d)): + lecun_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +SIGLIP_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`SiglipConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +SIGLIP_TEXT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +SIGLIP_VISION_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): + Whether to interpolate the pre-trained position encodings. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +SIGLIP_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. + return_loss (`bool`, *optional*): + Whether or not to return the contrastive loss. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): + Whether to interpolate the pre-trained position encodings. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +# Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoder with AltCLIP->Siglip +class SiglipEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`SiglipEncoderLayer`]. + + Args: + config: SiglipConfig + """ + + def __init__(self, config: SiglipConfig): + super().__init__() + self.config = config + self.layers = nn.ModuleList([SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + # Ignore copy + @can_return_tuple + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> BaseModelOutput: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for encoder_layer in self.layers: + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=encoder_states, + attentions=all_attentions, + ) + + +class SiglipTextTransformer(nn.Module): + def __init__(self, config: SiglipTextConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + self.embeddings = SiglipTextEmbeddings(config) + self.encoder = SiglipEncoder(config) + self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + self.head = nn.Linear(embed_dim, config.projection_size) + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + + @can_return_tuple + @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> BaseModelOutputWithPooling: + r""" + Returns: + + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + if input_ids is None: + raise ValueError("You have to specify input_ids") + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + + hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids) + + # note: SigLIP's text model does not use a causal mask, unlike the original CLIP model. + # expand attention_mask + if attention_mask is not None and not self._use_flash_attention_2: + # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) + + encoder_outputs: BaseModelOutput = self.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + last_hidden_state = encoder_outputs.last_hidden_state + last_hidden_state = self.final_layer_norm(last_hidden_state) + + # Assuming "sticky" EOS tokenization, last token is always EOS. + pooled_output = last_hidden_state[:, -1, :] + pooled_output = self.head(pooled_output) + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + """The text model from SigLIP without any head or projection on top.""", + SIGLIP_START_DOCSTRING, +) +class SiglipTextModel(SiglipPreTrainedModel): + config_class = SiglipTextConfig + + def __init__(self, config: SiglipTextConfig): + super().__init__(config) + self.text_model = SiglipTextTransformer(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.text_model.embeddings.token_embedding + + def set_input_embeddings(self, value): + self.text_model.embeddings.token_embedding = value + + @can_return_tuple + @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> BaseModelOutputWithPooling: + r""" + Returns: + + Examples: + + ```python + >>> from transformers import AutoTokenizer, SiglipTextModel + + >>> model = SiglipTextModel.from_pretrained("google/siglip-base-patch16-224") + >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224") + + >>> # important: make sure to set padding="max_length" as that's how the model was trained + >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled (EOS token) states + ```""" + + return self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + +class SiglipVisionTransformer(nn.Module): + def __init__(self, config: SiglipVisionConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + + self.embeddings = SiglipVisionEmbeddings(config) + self.encoder = SiglipEncoder(config) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.use_head = True if not hasattr(config, "vision_use_head") else config.vision_use_head + if self.use_head: + self.head = SiglipMultiheadAttentionPoolingHead(config) + + @can_return_tuple + @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipVisionConfig) + def forward( + self, + pixel_values, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = False, + ) -> BaseModelOutputWithPooling: + r""" + Returns: + + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) + + encoder_outputs: BaseModelOutput = self.encoder( + inputs_embeds=hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + last_hidden_state = encoder_outputs.last_hidden_state + last_hidden_state = self.post_layernorm(last_hidden_state) + + pooler_output = self.head(last_hidden_state) if self.use_head else None + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooler_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class SiglipMultiheadAttentionPoolingHead(nn.Module): + """Multihead Attention Pooling.""" + + def __init__(self, config: SiglipVisionConfig): + super().__init__() + + self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size)) + self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True) + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.mlp = SiglipMLP(config) + + def forward(self, hidden_state): + batch_size = hidden_state.shape[0] + probe = self.probe.repeat(batch_size, 1, 1) + + hidden_state = self.attention(probe, hidden_state, hidden_state)[0] + + residual = hidden_state + hidden_state = self.layernorm(hidden_state) + hidden_state = residual + self.mlp(hidden_state) + + return hidden_state[:, 0] + + +@add_start_docstrings( + """The vision model from SigLIP without any head or projection on top.""", + SIGLIP_START_DOCSTRING, +) +class SiglipVisionModel(SiglipPreTrainedModel): + config_class = SiglipVisionConfig + main_input_name = "pixel_values" + + def __init__(self, config: SiglipVisionConfig): + super().__init__(config) + + self.vision_model = SiglipVisionTransformer(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.vision_model.embeddings.patch_embedding + + @can_return_tuple + @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipVisionConfig) + def forward( + self, + pixel_values, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + ) -> BaseModelOutputWithPooling: + r""" + Returns: + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, SiglipVisionModel + + >>> model = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-224") + >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled features + ```""" + + return self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + ) + + +@add_start_docstrings(SIGLIP_START_DOCSTRING) +class SiglipModel(SiglipPreTrainedModel): + config_class = SiglipConfig + + def __init__(self, config: SiglipConfig): + super().__init__(config) + + if not isinstance(config.text_config, SiglipTextConfig): + raise TypeError( + "config.text_config is expected to be of type SiglipTextConfig but is of type" + f" {type(config.text_config)}." + ) + + if not isinstance(config.vision_config, SiglipVisionConfig): + raise TypeError( + "config.vision_config is expected to be of type SiglipVisionConfig but is of type" + f" {type(config.vision_config)}." + ) + + text_config = config.text_config + vision_config = config.vision_config + + # First, initialize the text and vision models with proper attention implementation + text_model = SiglipTextModel._from_config(text_config) + vision_model = SiglipVisionModel._from_config(vision_config) + + # Second, get the text and vision submodules (for backward compatibility) + self.text_model = text_model.text_model + self.vision_model = vision_model.vision_model + + self.logit_scale = nn.Parameter(torch.randn(1)) + self.logit_bias = nn.Parameter(torch.randn(1)) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING) + def get_text_features( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> torch.FloatTensor: + r""" + Returns: + text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by + applying the projection layer to the pooled output of [`SiglipTextModel`]. + + Examples: + + ```python + >>> from transformers import AutoTokenizer, AutoModel + >>> import torch + + >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224") + >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224") + + >>> # important: make sure to set padding="max_length" as that's how the model was trained + >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt") + >>> with torch.no_grad(): + ... text_features = model.get_text_features(**inputs) + ```""" + # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + text_outputs: BaseModelOutputWithPooling = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + pooled_output = text_outputs.pooler_output + + return pooled_output + + @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING) + def get_image_features( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + ) -> torch.FloatTensor: + r""" + Returns: + image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by + applying the projection layer to the pooled output of [`SiglipVisionModel`]. + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, AutoModel + >>> import torch + + >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224") + >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> with torch.no_grad(): + ... image_features = model.get_image_features(**inputs) + ```""" + # Use SiglipModel's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + vision_outputs: BaseModelOutputWithPooling = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + ) + + pooled_output = vision_outputs.pooler_output + + return pooled_output + + @can_return_tuple + @add_start_docstrings_to_model_forward(SIGLIP_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=SiglipOutput, config_class=SiglipConfig) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + return_loss: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + ) -> SiglipOutput: + r""" + Returns: + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, AutoModel + >>> import torch + + >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224") + >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> texts = ["a photo of 2 cats", "a photo of 2 dogs"] + >>> # important: we pass `padding=max_length` since the model was trained with this + >>> inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt") + + >>> with torch.no_grad(): + ... outputs = model(**inputs) + + >>> logits_per_image = outputs.logits_per_image + >>> probs = torch.sigmoid(logits_per_image) # these are the probabilities + >>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'") + 31.9% that image 0 is 'a photo of 2 cats' + ```""" + # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + vision_outputs: BaseModelOutputWithPooling = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + ) + + text_outputs: BaseModelOutputWithPooling = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + image_embeds = vision_outputs.pooler_output + text_embeds = text_outputs.pooler_output + + # normalized features + image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) + text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) + + # cosine similarity as logits + logits_per_text = torch.matmul(text_embeds, image_embeds.t().to(text_embeds.device)) + + logit_scale, logit_bias = self.logit_scale.to(text_embeds.device), self.logit_bias.to(text_embeds.device) + logits_per_text = logits_per_text * logit_scale.exp() + logit_bias + + logits_per_image = logits_per_text.t() + + loss = None + if return_loss: + # Adapted from https://github.com/google-research/big_vision/blob/01edb81a4716f93a48be43b3a4af14e29cdb3a7f/big_vision/trainers/proj/image_text/siglip.py#L287 + eye = torch.eye(logits_per_text.size(0), device=logits_per_text.device) + m1_diag1 = -torch.ones_like(logits_per_text) + 2 * eye + loglik = torch.nn.functional.logsigmoid(m1_diag1 * logits_per_text) + nll = -torch.sum(loglik, dim=-1) + loss = nll.mean() + + return SiglipOutput( + loss=loss, + logits_per_image=logits_per_image, + logits_per_text=logits_per_text, + text_embeds=text_embeds, + image_embeds=image_embeds, + text_model_output=text_outputs, + vision_model_output=vision_outputs, + ) + + +@add_start_docstrings( + """ + SigLIP vision encoder with an image classification head on top (a linear layer on top of the pooled final hidden states of + the patch tokens) e.g. for ImageNet. + """, + SIGLIP_START_DOCSTRING, +) +class SiglipForImageClassification(SiglipPreTrainedModel): + main_input_name = "pixel_values" + + def __init__(self, config: SiglipConfig) -> None: + super().__init__(config) + + self.num_labels = config.num_labels + + # Create the vision model with proper attention + # and take only vision_model submodule (for backward compatibility) + vision_model = SiglipVisionModel._from_config(config.vision_config) + self.vision_model = vision_model.vision_model + + # Classifier head + self.classifier = ( + nn.Linear(config.vision_config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity() + ) + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @add_start_docstrings_to_model_forward(SIGLIP_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=ImageClassifierOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + ) -> ImageClassifierOutput: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + Returns: + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, SiglipForImageClassification + >>> import torch + >>> from PIL import Image + >>> import requests + + >>> torch.manual_seed(3) # doctest: +IGNORE_RESULT + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> # note: we are loading a `SiglipModel` from the hub here, + >>> # so the head will be randomly initialized, hence the predictions will be random if seed is not set above. + >>> image_processor = AutoImageProcessor.from_pretrained("google/siglip-base-patch16-224") + >>> model = SiglipForImageClassification.from_pretrained("google/siglip-base-patch16-224") + + >>> inputs = image_processor(images=image, return_tensors="pt") + >>> outputs = model(**inputs) + >>> logits = outputs.logits + >>> # model predicts one of the two classes + >>> predicted_class_idx = logits.argmax(-1).item() + >>> print("Predicted class:", model.config.id2label[predicted_class_idx]) + Predicted class: LABEL_1 + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + outputs: BaseModelOutputWithPooling = self.vision_model( + pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + ) + + sequence_output = outputs.last_hidden_state + + # average pool the patch tokens + sequence_output = torch.mean(sequence_output, dim=1) + # apply classifier + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "SiglipModel", + "SiglipPreTrainedModel", + "SiglipTextModel", + "SiglipVisionModel", + "SiglipForImageClassification", +] diff --git a/docs/transformers/src/transformers/models/siglip/processing_siglip.py b/docs/transformers/src/transformers/models/siglip/processing_siglip.py new file mode 100644 index 0000000000000000000000000000000000000000..0d66fb9d5f9907c8850e97d8a1da11e30a7be8d1 --- /dev/null +++ b/docs/transformers/src/transformers/models/siglip/processing_siglip.py @@ -0,0 +1,145 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Image/Text processor class for SigLIP. +""" + +from typing import List, Optional, Union + +from ...feature_extraction_utils import BatchFeature +from ...image_utils import ImageInput +from ...processing_utils import ProcessorMixin +from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy +from ...utils import TensorType + + +class SiglipProcessor(ProcessorMixin): + r""" + Constructs a Siglip processor which wraps a Siglip image processor and a Siglip tokenizer into a single processor. + + [`SiglipProcessor`] offers all the functionalities of [`SiglipImageProcessor`] and [`SiglipTokenizer`]. See the + [`~SiglipProcessor.__call__`] and [`~SiglipProcessor.decode`] for more information. + + Args: + image_processor ([`SiglipImageProcessor`]): + The image processor is a required input. + tokenizer ([`SiglipTokenizer`]): + The tokenizer is a required input. + """ + + attributes = ["image_processor", "tokenizer"] + image_processor_class = ("SiglipImageProcessor", "SiglipImageProcessorFast") + tokenizer_class = "AutoTokenizer" + + def __init__(self, image_processor, tokenizer): + super().__init__(image_processor, tokenizer) + + def __call__( + self, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + images: ImageInput = None, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` + and `kwargs` arguments to SiglipTokenizer's [`~SiglipTokenizer.__call__`] if `text` is not `None` to encode + the text. To prepare the image(s), this method forwards the `images` argument to + SiglipImageProcessor's [`~SiglipImageProcessor.__call__`] if `images` is not `None`. Please refer to the docstring + of the above two methods for more information. + + Args: + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding + index) among: + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + max_length (`int`, *optional*): + Maximum length of the returned list and optionally padding length (see above). + truncation (`bool`, *optional*): + Activates truncation to cut input sequences longer than `max_length` to `max_length`. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + """ + + if text is None and images is None: + raise ValueError("You have to specify either text or images. Both cannot be none.") + + if text is not None: + encoding = self.tokenizer( + text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length + ) + + if images is not None: + image_features = self.image_processor(images, return_tensors=return_tensors) + + if text is not None and images is not None: + encoding.update(image_features) + return encoding + elif text is not None: + return encoding + else: + return BatchFeature(data=dict(**image_features), tensor_type=return_tensors) + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to SiglipTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to SiglipTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + @property + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names with CLIP->Siglip, T5->Siglip + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) + + +__all__ = ["SiglipProcessor"] diff --git a/docs/transformers/src/transformers/models/siglip/tokenization_siglip.py b/docs/transformers/src/transformers/models/siglip/tokenization_siglip.py new file mode 100644 index 0000000000000000000000000000000000000000..e1e62dcd049df5887460d00c1b933115086f324b --- /dev/null +++ b/docs/transformers/src/transformers/models/siglip/tokenization_siglip.py @@ -0,0 +1,380 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization class for SigLIP model.""" + +import os +import re +import string +import warnings +from shutil import copyfile +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple + +import sentencepiece as spm + +from ...convert_slow_tokenizer import import_protobuf +from ...tokenization_utils import PreTrainedTokenizer +from ...tokenization_utils_base import AddedToken + + +if TYPE_CHECKING: + from ...tokenization_utils_base import TextInput +from ...utils import logging, requires_backends +from ...utils.import_utils import requires + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"} + + +SPIECE_UNDERLINE = "▁" + + +@requires(backends=("sentencepiece",)) +class SiglipTokenizer(PreTrainedTokenizer): + """ + Construct a Siglip tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece). + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that + contains the vocabulary necessary to instantiate a tokenizer. + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + additional_special_tokens (`List[str]`, *optional*): + Additional special tokens used by the tokenizer. + sp_model_kwargs (`dict`, *optional*): + Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for + SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things, + to set: + + - `enable_sampling`: Enable subword regularization. + - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout. + + - `nbest_size = {0,1}`: No sampling is performed. + - `nbest_size > 1`: samples from the nbest_size results. + - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice) + using forward-filtering-and-backward-sampling algorithm. + + - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for + BPE-dropout. + model_max_length (`int`, *optional*, defaults to 64): + The maximum length (in number of tokens) for model inputs. + do_lower_case (`bool`, *optional*, defaults to `True`): + Whether or not to lowercase the input when tokenizing. + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + eos_token="", + unk_token="", + pad_token="", + additional_special_tokens=None, + sp_model_kwargs: Optional[Dict[str, Any]] = None, + model_max_length=64, + do_lower_case=True, + **kwargs, + ) -> None: + requires_backends(self, "protobuf") + + pad_token = ( + AddedToken(pad_token, rstrip=True, lstrip=True, normalized=False, special=True) + if isinstance(pad_token, str) + else pad_token + ) + unk_token = ( + AddedToken(unk_token, rstrip=True, lstrip=True, normalized=False, special=True) + if isinstance(unk_token, str) + else unk_token + ) + eos_token = ( + AddedToken(eos_token, rstrip=True, lstrip=True, normalized=False, special=True) + if isinstance(eos_token, str) + else eos_token + ) + + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + + self.do_lower_case = do_lower_case + self.vocab_file = vocab_file + + self.sp_model = self.get_spm_processor() + self.vocab_file = vocab_file + + super().__init__( + eos_token=eos_token, + unk_token=unk_token, + pad_token=pad_token, + additional_special_tokens=additional_special_tokens, + sp_model_kwargs=self.sp_model_kwargs, + model_max_length=model_max_length, + do_lower_case=do_lower_case, + **kwargs, + ) + + def get_spm_processor(self): + tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs) + with open(self.vocab_file, "rb") as f: + sp_model = f.read() + model_pb2 = import_protobuf() + model = model_pb2.ModelProto.FromString(sp_model) + normalizer_spec = model_pb2.NormalizerSpec() + normalizer_spec.add_dummy_prefix = False + model.normalizer_spec.MergeFrom(normalizer_spec) + sp_model = model.SerializeToString() + tokenizer.LoadFromSerializedProto(sp_model) + return tokenizer + + @property + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.vocab_size + def vocab_size(self): + return self.sp_model.get_piece_size() + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_vocab + def get_vocab(self): + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_special_tokens_mask + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + # normal case: some special tokens + if token_ids_1 is None: + return ([0] * len(token_ids_0)) + [1] + return ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._add_eos_if_not_present + def _add_eos_if_not_present(self, token_ids: List[int]) -> List[int]: + """Do not add eos again if user already added it.""" + if len(token_ids) > 0 and token_ids[-1] == self.eos_token_id: + warnings.warn( + f"This sequence already has {self.eos_token}. In future versions this behavior may lead to duplicated" + " eos tokens being added." + ) + return token_ids + else: + return token_ids + [self.eos_token_id] + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.create_token_type_ids_from_sequences + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make + use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of zeros. + """ + eos = [self.eos_token_id] + + if token_ids_1 is None: + return len(token_ids_0 + eos) * [0] + return len(token_ids_0 + eos + token_ids_1 + eos) * [0] + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.build_inputs_with_special_tokens + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A sequence has the following format: + + - single sequence: `X ` + - pair of sequences: `A B ` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + token_ids_0 = self._add_eos_if_not_present(token_ids_0) + if token_ids_1 is None: + return token_ids_0 + else: + token_ids_1 = self._add_eos_if_not_present(token_ids_1) + return token_ids_0 + token_ids_1 + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.__getstate__ + def __getstate__(self): + state = self.__dict__.copy() + state["sp_model"] = None + return state + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.__setstate__ + def __setstate__(self, d): + self.__dict__ = d + + # for backward compatibility + if not hasattr(self, "sp_model_kwargs"): + self.sp_model_kwargs = {} + + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(self.vocab_file) + + def remove_punctuation(self, text: str) -> str: + return text.translate(str.maketrans("", "", string.punctuation)) + + # source: https://github.com/google-research/big_vision/blob/3b8e5ab6ad4f96e32b32826f9e1b8fd277914f9c/big_vision/evaluators/proj/image_text/prompt_engineering.py#L94 + def canonicalize_text(self, text, *, keep_punctuation_exact_string=None): + """Returns canonicalized `text` (puncuation removed). + + Args: + text (`str`): + String to be canonicalized. + keep_punctuation_exact_string (`str`, *optional*): + If provided, then this exact string is kept. For example providing '{}' will keep any occurrences of '{}' + (but will still remove '{' and '}' that appear separately). + """ + if keep_punctuation_exact_string: + text = keep_punctuation_exact_string.join( + self.remove_punctuation(part) for part in text.split(keep_punctuation_exact_string) + ) + else: + text = self.remove_punctuation(text) + text = re.sub(r"\s+", " ", text) + text = text.strip() + + return text + + def tokenize(self, text: "TextInput", add_special_tokens=False, **kwargs) -> List[str]: + """ + Converts a string to a list of tokens. + """ + tokens = super().tokenize(SPIECE_UNDERLINE + text.replace(SPIECE_UNDERLINE, " "), **kwargs) + + if len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens: + tokens = tokens[1:] + return tokens + + @property + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.unk_token_length + def unk_token_length(self): + return len(self.sp_model.encode(str(self.unk_token))) + + def _tokenize(self, text, **kwargs): + """ + Returns a tokenized string. + + We de-activated the `add_dummy_prefix` option, thus the sentencepiece internals will always strip any + SPIECE_UNDERLINE. + + For example: `self.sp_model.encode(f"{SPIECE_UNDERLINE}Hey", out_type = str)` will give `['H', 'e', 'y']` instead of `['▁He', 'y']`. + + Thus we always encode `f"{unk_token}text"` and strip the `unk_token`. Here is an example with `unk_token = ""` and `unk_token_length = 4`. + `self.tokenizer.sp_model.encode(" Hey", out_type = str)[4:]`. + """ + text = self.canonicalize_text(text, keep_punctuation_exact_string=None) + tokens = self.sp_model.encode(text, out_type=str) + + # 1. Encode string + prefix ex: " Hey" + tokens = self.sp_model.encode(self.unk_token + text, out_type=str) + # 2. Remove self.unk_token from ['<','unk','>', '▁Hey'] + return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._convert_token_to_id + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.sp_model.piece_to_id(token) + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._convert_id_to_token + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + token = self.sp_model.IdToPiece(index) + return token + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + current_sub_tokens = [] + out_string = "" + prev_is_special = False + for token in tokens: + # make sure that special tokens are not decoded using sentencepiece model + if token in self.all_special_tokens: + if not prev_is_special: + out_string += " " + out_string += self.sp_model.decode(current_sub_tokens) + token + prev_is_special = True + current_sub_tokens = [] + else: + current_sub_tokens.append(token) + prev_is_special = False + out_string += self.sp_model.decode(current_sub_tokens) + return out_string.strip() + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.save_vocabulary + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): + copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + + return (out_vocab_file,) + + +__all__ = ["SiglipTokenizer"] diff --git a/docs/transformers/src/transformers/models/siglip2/__init__.py b/docs/transformers/src/transformers/models/siglip2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fe5b732b951367cd130d7edd3299eda74f9dc267 --- /dev/null +++ b/docs/transformers/src/transformers/models/siglip2/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_siglip2 import * + from .image_processing_siglip2 import * + from .image_processing_siglip2_fast import * + from .modeling_siglip2 import * + from .processing_siglip2 import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/siglip2/configuration_siglip2.py b/docs/transformers/src/transformers/models/siglip2/configuration_siglip2.py new file mode 100644 index 0000000000000000000000000000000000000000..6cb379c670adeea5f1a511f8561006c89a21ae76 --- /dev/null +++ b/docs/transformers/src/transformers/models/siglip2/configuration_siglip2.py @@ -0,0 +1,277 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/siglip2/modular_siglip2.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_siglip2.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class Siglip2TextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Siglip2TextModel`]. It is used to instantiate a + Siglip2 text encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the text encoder of the Siglip2 + [google/siglip2-base-patch16-224](https://huggingface.co/google/siglip2-base-patch16-224) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the Siglip2 text model. Defines the number of different tokens that can be represented by + the `inputs_ids` passed when calling [`Siglip2Model`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + max_position_embeddings (`int`, *optional*, defaults to 64): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + pad_token_id (`int`, *optional*, defaults to 1): + The id of the padding token in the vocabulary. + bos_token_id (`int`, *optional*, defaults to 49406): + The id of the beginning-of-sequence token in the vocabulary. + eos_token_id (`int`, *optional*, defaults to 49407): + The id of the end-of-sequence token in the vocabulary. + projection_size (`int`, *optional*, defaults to `hidden_size`): + The size of the projection head. + + Example: + + ```python + >>> from transformers import Siglip2TextConfig, Siglip2TextModel + + >>> # Initializing a Siglip2TextConfig with google/siglip2-base-patch16-224 style configuration + >>> configuration = Siglip2TextConfig() + + >>> # Initializing a Siglip2TextModel (with random weights) from the google/siglip2-base-patch16-224 style configuration + >>> model = Siglip2TextModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "siglip2_text_model" + base_config_key = "text_config" + + def __init__( + self, + vocab_size=32000, + hidden_size=768, + intermediate_size=3072, + num_hidden_layers=12, + num_attention_heads=12, + max_position_embeddings=64, + hidden_act="gelu_pytorch_tanh", + layer_norm_eps=1e-6, + attention_dropout=0.0, + # This differs from `CLIPTokenizer`'s default and from openai/siglip2 + # See https://github.com/huggingface/transformers/pull/24773#issuecomment-1632287538 + pad_token_id=1, + bos_token_id=49406, + eos_token_id=49407, + projection_size=None, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.max_position_embeddings = max_position_embeddings + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + self.attention_dropout = attention_dropout + self.projection_size = projection_size if projection_size is not None else hidden_size + + +class Siglip2VisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Siglip2VisionModel`]. It is used to instantiate a + Siglip2 vision encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip2 + [google/siglip2-base-patch16-naflex](https://huggingface.co/google/siglip2-base-patch16-naflex) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + num_channels (`int`, *optional*, defaults to 3): + Number of channels in the input images. + num_patches (`int`, *optional*, defaults to 256): + The number of patches in the image with the size of (`patch_size`, `patch_size`). + The image is resized to fill maximum of this number of patches, and to preserve + the aspect ratio. In case the resulted number of patches is lower, the image is + padded in "patch" dimension. + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + + Example: + + ```python + >>> from transformers import Siglip2VisionConfig, Siglip2VisionModel + + >>> # Initializing a Siglip2VisionConfig with google/siglip2-base-patch16-naflex style configuration + >>> configuration = Siglip2VisionConfig() + + >>> # Initializing a Siglip2VisionModel (with random weights) from the google/siglip2-base-patch16-naflex style configuration + >>> model = Siglip2VisionModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "siglip2_vision_model" + base_config_key = "vision_config" + + def __init__( + self, + hidden_size=768, + intermediate_size=3072, + num_hidden_layers=12, + num_attention_heads=12, + num_channels=3, + num_patches=256, + patch_size=16, + hidden_act="gelu_pytorch_tanh", + layer_norm_eps=1e-6, + attention_dropout=0.0, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.attention_dropout = attention_dropout + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + self.num_patches = num_patches + + +class Siglip2Config(PretrainedConfig): + r""" + [`Siglip2Config`] is the configuration class to store the configuration of a [`Siglip2Model`]. It is used to + instantiate a Siglip2 model according to the specified arguments, defining the text model and vision model configs. + Instantiating a configuration with the defaults will yield a similar configuration to that of the Siglip2 + [google/siglip2-base-patch16-224](https://huggingface.co/google/siglip2-base-patch16-224) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + text_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`Siglip2TextConfig`]. + vision_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`Siglip2VisionConfig`]. + kwargs (*optional*): + Dictionary of keyword arguments. + + Example: + + ```python + >>> from transformers import Siglip2Config, Siglip2Model + + >>> # Initializing a Siglip2Config with google/siglip2-base-patch16-224 style configuration + >>> configuration = Siglip2Config() + + >>> # Initializing a Siglip2Model (with random weights) from the google/siglip2-base-patch16-224 style configuration + >>> model = Siglip2Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + + >>> # We can also initialize a Siglip2Config from a Siglip2TextConfig and a Siglip2VisionConfig + >>> from transformers import Siglip2TextConfig, Siglip2VisionConfig + + >>> # Initializing a Siglip2Text and Siglip2Vision configuration + >>> config_text = Siglip2TextConfig() + >>> config_vision = Siglip2VisionConfig() + + >>> config = Siglip2Config.from_text_vision_configs(config_text, config_vision) + ```""" + + model_type = "siglip2" + sub_configs = {"text_config": Siglip2TextConfig, "vision_config": Siglip2VisionConfig} + + def __init__(self, text_config=None, vision_config=None, **kwargs): + super().__init__(**kwargs) + + if text_config is None: + text_config = {} + logger.info("`text_config` is `None`. Initializing the `Siglip2TextConfig` with default values.") + + if vision_config is None: + vision_config = {} + logger.info("`vision_config` is `None`. initializing the `Siglip2VisionConfig` with default values.") + + self.text_config = Siglip2TextConfig(**text_config) + self.vision_config = Siglip2VisionConfig(**vision_config) + + self.initializer_factor = 1.0 + + @classmethod + def from_text_vision_configs(cls, text_config: Siglip2TextConfig, vision_config: Siglip2VisionConfig, **kwargs): + r""" + Instantiate a [`Siglip2Config`] (or a derived class) from siglip2 text model configuration and siglip2 vision + model configuration. + + Returns: + [`Siglip2Config`]: An instance of a configuration object + """ + + return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs) + + +__all__ = ["Siglip2Config", "Siglip2TextConfig", "Siglip2VisionConfig"] diff --git a/docs/transformers/src/transformers/models/siglip2/convert_siglip2_to_hf.py b/docs/transformers/src/transformers/models/siglip2/convert_siglip2_to_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..819596498996c1af9623f07e4cf2dea1f663ccef --- /dev/null +++ b/docs/transformers/src/transformers/models/siglip2/convert_siglip2_to_hf.py @@ -0,0 +1,438 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert Siglip2 checkpoints from the original repository. + +URL: https://github.com/google-research/big_vision/tree/main +""" + +import argparse +import collections +import os +import re + +import numpy as np +import torch +from huggingface_hub import hf_hub_download +from PIL import Image, ImageDraw + +from transformers import GemmaTokenizerFast, Siglip2Config, Siglip2ImageProcessorFast, Siglip2Model, Siglip2Processor +from transformers.utils import logging + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + + +COMMON_CONFIG_PARAMS = { + "base": { + "hidden_size": 768, + "intermediate_size": 3072, + "num_hidden_layers": 12, + "num_attention_heads": 12, + }, + "large": { + "hidden_size": 1024, + "intermediate_size": 4096, + "num_hidden_layers": 24, + "num_attention_heads": 16, + }, + "so400m": { + "hidden_size": 1152, + "intermediate_size": 4304, + "num_hidden_layers": 27, + "num_attention_heads": 16, + }, +} + +MODEL_NAME_TO_CHECKPOINT_PATH = { + # base checkpoints + "siglip2-base-patch16-naflex": "gv-hf/siglip2/siglip2_b16_naflex.npz", + "siglip2-so400m-patch16-naflex": "gv-hf/siglip2/siglip2_so400m16_naflex.npz", +} + +# fmt: off +EXPECTED_OUTPUTS = { + "siglip2-base-patch16-naflex": torch.tensor([ + [ 1.0195, -0.0280, -1.4468], + [ -4.5395, -6.2269, -1.5667], + [ 4.1757, 5.0358, 3.5159], + [ 9.4264, 10.1879, 6.3353], + [ 2.4409, 3.1058, 4.5491], + [-12.3230, -13.7355, -13.4632], + [ 1.1520, 1.1687, -1.9647], + ]), + "siglip2-so400m-patch16-naflex": torch.tensor([ + [ 0.9422, 0.5540, -2.4405], + [ -7.3522, -9.4931, -6.3499], + [ 5.7852, 6.7288, 7.7893], + [ 9.9881, 10.8136, 9.2121], + [ 5.3660, 5.7746, 8.4130], + [-12.7218, -14.2631, -13.6442], + [ 0.6384, 0.4278, -0.9022], + ]), +} +# fmt: on + +# fmt: off +ORIGINAL_TO_CONVERTED_KEY_MAPPING = { + # Vision embeddings + r"params/img/embedding/kernel": r"vision_model.embeddings.patch_embedding.weight", + r"params/img/embedding/bias": r"vision_model.embeddings.patch_embedding.bias", + r"params/img/pos_embedding": r"vision_model.embeddings.position_embedding.weight", + # Vision encoder + r"params/img/Transformer/encoderblock_(\d+)/LayerNorm_0/scale": r"vision_model.encoder.layers.\1.layer_norm1.weight", + r"params/img/Transformer/encoderblock_(\d+)/LayerNorm_0/bias": r"vision_model.encoder.layers.\1.layer_norm1.bias", + r"params/img/Transformer/encoderblock_(\d+)/LayerNorm_1/scale": r"vision_model.encoder.layers.\1.layer_norm2.weight", + r"params/img/Transformer/encoderblock_(\d+)/LayerNorm_1/bias": r"vision_model.encoder.layers.\1.layer_norm2.bias", + r"params/img/Transformer/encoderblock_(\d+)/MlpBlock_0/Dense_0/kernel": r"vision_model.encoder.layers.\1.mlp.fc1.weight", + r"params/img/Transformer/encoderblock_(\d+)/MlpBlock_0/Dense_0/bias": r"vision_model.encoder.layers.\1.mlp.fc1.bias", + r"params/img/Transformer/encoderblock_(\d+)/MlpBlock_0/Dense_1/kernel": r"vision_model.encoder.layers.\1.mlp.fc2.weight", + r"params/img/Transformer/encoderblock_(\d+)/MlpBlock_0/Dense_1/bias": r"vision_model.encoder.layers.\1.mlp.fc2.bias", + r"params/img/Transformer/encoderblock_(\d+)/MultiHeadDotProductAttention_0/(q|k|v|out)[a-z]*/kernel": r"vision_model.encoder.layers.\1.self_attn.\2_proj.weight", + r"params/img/Transformer/encoderblock_(\d+)/MultiHeadDotProductAttention_0/(q|k|v|out)[a-z]*/bias": r"vision_model.encoder.layers.\1.self_attn.\2_proj.bias", + # Vision norm + r"params/img/Transformer/encoder_norm/scale": r"vision_model.post_layernorm.weight", + r"params/img/Transformer/encoder_norm/bias": r"vision_model.post_layernorm.bias", + # Vision head + r"params/img/MAPHead_0/probe": r"vision_model.head.probe", + r"params/img/MAPHead_0/LayerNorm_0/scale": r"vision_model.head.layernorm.weight", + r"params/img/MAPHead_0/LayerNorm_0/bias": r"vision_model.head.layernorm.bias", + r"params/img/MAPHead_0/MlpBlock_0/Dense_0/kernel": r"vision_model.head.mlp.fc1.weight", + r"params/img/MAPHead_0/MlpBlock_0/Dense_0/bias": r"vision_model.head.mlp.fc1.bias", + r"params/img/MAPHead_0/MlpBlock_0/Dense_1/kernel": r"vision_model.head.mlp.fc2.weight", + r"params/img/MAPHead_0/MlpBlock_0/Dense_1/bias": r"vision_model.head.mlp.fc2.bias", + r"params/img/MAPHead_0/MultiHeadDotProductAttention_0/out/kernel": r"vision_model.head.attention.out_proj.weight", + r"params/img/MAPHead_0/MultiHeadDotProductAttention_0/out/bias": r"vision_model.head.attention.out_proj.bias", + r"params/img/MAPHead_0/MultiHeadDotProductAttention_0/qkv/kernel": r"vision_model.head.attention.in_proj_weight", + r"params/img/MAPHead_0/MultiHeadDotProductAttention_0/qkv/bias": r"vision_model.head.attention.in_proj_bias", + # Text embeddings + r"params/txt/Embed_0/embedding": r"text_model.embeddings.token_embedding.weight", + r"params/txt/pos_embedding": r"text_model.embeddings.position_embedding.weight", + # Text encoder + r"params/txt/Encoder_0/encoderblock_(\d+)/LayerNorm_0/scale": r"text_model.encoder.layers.\1.layer_norm1.weight", + r"params/txt/Encoder_0/encoderblock_(\d+)/LayerNorm_0/bias": r"text_model.encoder.layers.\1.layer_norm1.bias", + r"params/txt/Encoder_0/encoderblock_(\d+)/LayerNorm_1/scale": r"text_model.encoder.layers.\1.layer_norm2.weight", + r"params/txt/Encoder_0/encoderblock_(\d+)/LayerNorm_1/bias": r"text_model.encoder.layers.\1.layer_norm2.bias", + r"params/txt/Encoder_0/encoderblock_(\d+)/MlpBlock_0/Dense_0/kernel": r"text_model.encoder.layers.\1.mlp.fc1.weight", + r"params/txt/Encoder_0/encoderblock_(\d+)/MlpBlock_0/Dense_0/bias": r"text_model.encoder.layers.\1.mlp.fc1.bias", + r"params/txt/Encoder_0/encoderblock_(\d+)/MlpBlock_0/Dense_1/kernel": r"text_model.encoder.layers.\1.mlp.fc2.weight", + r"params/txt/Encoder_0/encoderblock_(\d+)/MlpBlock_0/Dense_1/bias": r"text_model.encoder.layers.\1.mlp.fc2.bias", + r"params/txt/Encoder_0/encoderblock_(\d+)/MultiHeadDotProductAttention_0/(q|k|v|out)[a-z]*/kernel": r"text_model.encoder.layers.\1.self_attn.\2_proj.weight", + r"params/txt/Encoder_0/encoderblock_(\d+)/MultiHeadDotProductAttention_0/(q|k|v|out)[a-z]*/bias": r"text_model.encoder.layers.\1.self_attn.\2_proj.bias", + # Text encoder norm and head + r"params/txt/Encoder_0/encoder_norm/scale": r"text_model.final_layer_norm.weight", + r"params/txt/Encoder_0/encoder_norm/bias": r"text_model.final_layer_norm.bias", + r"params/txt/head/kernel": r"text_model.head.weight", + r"params/txt/head/bias": r"text_model.head.bias", + # learned temperature and bias + r"params/t": r"logit_scale", + r"params/b": r"logit_bias", +} +# fmt: on + + +# -------------------------------------------------------------------------------------------- +# Model objects: configuration, tokenizer, image processor +# -------------------------------------------------------------------------------------------- + + +def get_siglip2_config(model_name: str) -> Siglip2Config: + """ + Create a configuration for the Siglip2 model based on the model name. + """ + + _, variant, patch, _ = model_name.split("-") + patch_size = int(patch[-2:]) + num_patches = 256 + + common_options = COMMON_CONFIG_PARAMS[variant] + vision_config = { + "patch_size": patch_size, + "num_patches": num_patches, + **common_options, + } + text_config = { + "vocab_size": 256_000, + **common_options, + } + config = Siglip2Config( + vision_config=vision_config, + text_config=text_config, + ) + return config + + +def get_siglip2_tokenizer() -> GemmaTokenizerFast: + # Load pretrained tokenizer + gemma_checkpoint = "google/gemma-7b" + tokenizer = GemmaTokenizerFast.from_pretrained( + gemma_checkpoint, + add_bos_token=False, + add_eos_token=True, + padding_side="right", + do_lower_case=True, + # important: make tokenizer NOT return attention_mask since original one doesn't require it + model_input_names=["input_ids"], + ) + return tokenizer + + +def get_siglip2_image_processor(patch_size: int, max_num_patches: int) -> Siglip2ImageProcessorFast: + image_processor = Siglip2ImageProcessorFast( + patch_size=patch_size, + max_num_patches=max_num_patches, + do_resize=True, + do_normalize=True, + image_mean=[0.5, 0.5, 0.5], + image_std=[0.5, 0.5, 0.5], + do_rescale=True, + rescale_factor=1 / 255, + resample=Image.Resampling.BILINEAR, + ) + return image_processor + + +# -------------------------------------------------------------------------------------------- +# Helper functions for state dict conversion +# -------------------------------------------------------------------------------------------- + + +def flatten_nested_dict(params: dict, parent_key: str = "", sep: str = "/") -> dict: + """ + Flatten a nested original checkpoint dictionary into a flat dictionary. + """ + items = [] + for k, v in params.items(): + new_key = parent_key + sep + k if parent_key else k + if isinstance(v, collections.abc.MutableMapping): + items.extend(flatten_nested_dict(v, new_key, sep=sep).items()) + else: + items.append((new_key, v)) + return dict(items) + + +def split_encoderblock_layers(state_dict: dict) -> dict: + """ + Split the encoderblock weight into layers. In some cases they are concatenated in + the original checkpoints. + """ + # Make shallow copy + state_dict = state_dict.copy() + # Split encoderblock weight into layers + keys = list(state_dict.keys()) + for key in keys: + if "/encoderblock/" in key: + weight = state_dict.pop(key) + for i, weight_i in enumerate(weight): + new_name = key.replace("encoderblock", f"encoderblock_{i}") + state_dict[new_name] = weight_i + return state_dict + + +def merge_qkv_for_head(state_dict: dict, config: Siglip2Config) -> dict: + """ + Merge the q/k/v weights and biases for the attention head. + """ + # Make shallow copy + state_dict = state_dict.copy() + # Read and process q/k/v weights and biases + qkv_weights, qkv_biases = [], [] + for name in ["query", "key", "value"]: + prefix = f"params/img/MAPHead_0/MultiHeadDotProductAttention_0/{name}" + weight = state_dict.pop(f"{prefix}/kernel").reshape(-1, config.vision_config.hidden_size) + bias = state_dict.pop(f"{prefix}/bias").reshape(-1) + qkv_weights.append(weight) + qkv_biases.append(bias) + + # Combine into single tensors + state_dict["params/img/MAPHead_0/MultiHeadDotProductAttention_0/qkv/kernel"] = np.concatenate(qkv_weights, axis=1) + state_dict["params/img/MAPHead_0/MultiHeadDotProductAttention_0/qkv/bias"] = np.concatenate(qkv_biases, axis=0) + return state_dict + + +def convert_old_keys_to_new_keys(state_dict_keys: list) -> dict: + """ + This function should be applied only once, on the concatenated keys to efficiently rename using + the key mappings. + """ + output_dict = {} + if state_dict_keys is not None: + old_text = "\n".join(state_dict_keys) + new_text = old_text + for pattern, replacement in ORIGINAL_TO_CONVERTED_KEY_MAPPING.items(): + if replacement is None: + new_text = re.sub(pattern, "", new_text) # an empty line + continue + new_text = re.sub(pattern, replacement, new_text) + output_dict = dict(zip(old_text.split("\n"), new_text.split("\n"))) + return output_dict + + +# -------------------------------------------------------------------------------------------- +# Helper functions for model verification +# -------------------------------------------------------------------------------------------- + + +def create_image(width, height): + """ + Helper function to create an image with a blue circle on a red background. + """ + image = Image.new("RGB", (width, height), color="red") + draw = ImageDraw.Draw(image) + center_x = image.width // 2 + center_y = image.height // 2 + radius = min(center_x, center_y) // 8 * 7 + draw.ellipse( + (center_x - radius, center_y - radius, center_x + radius, center_y + radius), + fill="blue", + outline="green", + width=image.width // 20, + ) + return image + + +def prepare_inputs(): + """ + Prepare inputs for the model. + """ + text = [ + "circle", + "ellipsoid", + "blue circle on red background", + "blue circle with green border on red background", + "green circle on red background", + "a dog", + "a blue dog with a green border on a red background", + ] + img224 = create_image(224, 224) + img1024 = create_image(1024, 1024) + img224_1024 = create_image(1024, 224) + + images = [img224, img1024, img224_1024] + return text, images + + +# -------------------------------------------------------------------------------------------- +# Convert model +# -------------------------------------------------------------------------------------------- + + +@torch.no_grad() +def convert_siglip2_checkpoint(model_name, pytorch_dump_folder_path, verify_logits=True, push_to_hub=False): + """ + Copy/paste/tweak model's weights to our Siglip2 structure. + """ + + # Define Siglip2 configuration + config = get_siglip2_config(model_name) + + checkpoint = MODEL_NAME_TO_CHECKPOINT_PATH[model_name] + if not os.path.exists(checkpoint): + org, repo_id, *filepath = checkpoint.split("/") + checkpoint = hf_hub_download(repo_id=f"{org}/{repo_id}", filename="/".join(filepath)) + + print(f"Loading checkpoint from {checkpoint}...") + data = np.load(checkpoint) + state_dict = flatten_nested_dict(data) + state_dict = split_encoderblock_layers(state_dict) + state_dict = merge_qkv_for_head(state_dict, config) + + # Rename and transform weights + print("Renaming and transforming weights...") + + original_keys = list(state_dict.keys()) + hf_keys = convert_old_keys_to_new_keys(original_keys) + + new_state_dict = {} + for original_key in original_keys: + new_key = hf_keys[original_key] + parameter = state_dict.pop(original_key) + + hidden_size = config.vision_config.hidden_size if "vision" in new_key else config.text_config.hidden_size + + if any(k in new_key for k in ("out_proj", "q_proj", "k_proj", "v_proj", "position_embedding")): + parameter = parameter.reshape(-1, hidden_size) + + # Transpose every weight except for position_embedding and token_embedding + if new_key.endswith("weight") and "position_embedding" not in new_key and "token_embedding" not in new_key: + parameter = parameter.T + + # Reshape every bias + if new_key.endswith("bias"): + parameter = parameter.reshape(-1) + + new_state_dict[new_key] = torch.from_numpy(parameter) + + # load HuggingFace model + print("Loading HuggingFace model...") + model = Siglip2Model(config).eval() + model.load_state_dict(new_state_dict) + + # Create processor + print("Creating processor...") + # TODO: update with more checkpoints + tokenizer = get_siglip2_tokenizer() + image_processor = get_siglip2_image_processor(config.vision_config.patch_size, max_num_patches=256) + processor = Siglip2Processor(image_processor=image_processor, tokenizer=tokenizer) + + # Verify logits + if verify_logits: + print(f"Verifying logits for {model_name}...") + text, images = prepare_inputs() + inputs = processor(text=text, images=images, padding="max_length", max_length=64, return_tensors="pt") + outputs = model(**inputs) + torch.testing.assert_close(outputs.logits_per_text, EXPECTED_OUTPUTS[model_name], atol=1e-3, rtol=1e-3) + + # Save model + if pytorch_dump_folder_path is not None: + dst_dir = os.path.join(pytorch_dump_folder_path, model_name) + print(f"Saving model {model_name} to {dst_dir}...") + model.save_pretrained(dst_dir) + print(f"Saving processor to {dst_dir}...") + processor.save_pretrained(dst_dir) + + if push_to_hub: + print(f"Pushing model and processor for {model_name} to the HuggingFace Hub...") + model.push_to_hub(f"qubvel-hf/{model_name}", private=True) + processor.push_to_hub(f"qubvel-hf/{model_name}", private=True) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--model_name", + default="siglip2-base-patch16-naflex", + type=str, + choices=MODEL_NAME_TO_CHECKPOINT_PATH.keys(), + help="Name of the model you'd like to convert.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", + default="checkpoints/", + type=str, + help="Path to the output PyTorch model directory.", + ) + parser.add_argument( + "--verify_logits", + action="store_true", + help="Whether to verify logits against the original implementation.", + ) + parser.add_argument( + "--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub." + ) + + args = parser.parse_args() + convert_siglip2_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.verify_logits, args.push_to_hub) diff --git a/docs/transformers/src/transformers/models/siglip2/image_processing_siglip2.py b/docs/transformers/src/transformers/models/siglip2/image_processing_siglip2.py new file mode 100644 index 0000000000000000000000000000000000000000..53d53797f2c060739aaf55d7aa4e084e1e2ae704 --- /dev/null +++ b/docs/transformers/src/transformers/models/siglip2/image_processing_siglip2.py @@ -0,0 +1,343 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for SigLIP2.""" + +import math +from functools import lru_cache +from typing import List, Optional, Tuple, Union + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature +from ...image_transforms import ( + convert_to_rgb, + resize, + to_channel_dimension_format, +) +from ...image_utils import ( + ChannelDimension, + ImageInput, + PILImageResampling, + infer_channel_dimension_format, + is_scaled_image, + make_flat_list_of_images, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging + + +logger = logging.get_logger(__name__) + + +if is_vision_available(): + from PIL import Image + + +@lru_cache(maxsize=256) +def get_image_size_for_max_num_patches( + image_height: int, image_width: int, patch_size: int, max_num_patches: int, eps: float = 1e-5 +) -> Tuple[int, int]: + """ + Determine image size based on max number of patches, ensure dimensions are divisible by patch size and image is at least 1 patch. + + Args: + image_height (`int`): + Original image height. + image_width (`int`): + Original image width. + patch_size (`int`): + Patch size for processing. + max_num_patches (`int`): + Maximum number of patches. + eps (`float`): + Small threshold for binary search. + + Returns: + Tuple: (target_height, target_width) + """ + + def get_scaled_image_size(scale: float, size: int, patch_size: int) -> int: + scaled_size = size * scale + scaled_size = math.ceil(scaled_size / patch_size) * patch_size # make divisible by patch_size + scaled_size = max(patch_size, scaled_size) # ensure at least 1 patch + return int(scaled_size) + + # Binary search for optimal scale + scale_min, scale_max = eps / 10, 100.0 + while (scale_max - scale_min) >= eps: + scale = (scale_min + scale_max) / 2 + target_height = get_scaled_image_size(scale, image_height, patch_size) + target_width = get_scaled_image_size(scale, image_width, patch_size) + num_patches = (target_height / patch_size) * (target_width / patch_size) + + if num_patches <= max_num_patches: + scale_min = scale + else: + scale_max = scale + + scale = scale_min + target_height = get_scaled_image_size(scale, image_height, patch_size) + target_width = get_scaled_image_size(scale, image_width, patch_size) + return target_height, target_width + + +def convert_image_to_patches(image: np.ndarray, patch_size: int) -> np.ndarray: + """ + Convert 3D array image of shape (image_height, image_width, num_channels) into 2D array of patches of shape + (num_patches_height * num_patches_width, patch_size * patch_size * num_channels). + """ + image_height, image_width, num_channels = image.shape + num_patches_height = image_height // patch_size + num_patches_width = image_width // patch_size + patched_image = image.reshape(num_patches_height, patch_size, num_patches_width, patch_size, num_channels) + patched_image = patched_image.transpose(0, 2, 1, 3, 4) + patched_image = patched_image.reshape(num_patches_height * num_patches_width, -1) + return patched_image + + +def pad_along_first_dim(array: np.ndarray, target_length: int, pad_value: int = 0) -> Tuple[np.ndarray, np.ndarray]: + """ + Pad the array along the first dimension. + """ + current_length = array.shape[0] + padding_length = target_length - current_length + mask = np.ones((target_length,), dtype=np.int32) + if padding_length > 0: + paddings = [(0, padding_length)] + [(0, 0)] * (array.ndim - 1) + array = np.pad(array, paddings, mode="constant", constant_values=pad_value) + mask[-padding_length:] = 0 + return array, mask + + +class Siglip2ImageProcessor(BaseImageProcessor): + r""" + Constructs a SigLIP2 image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's dimensions to fit `max_num_patches` according to given `patch_size`. + Can be overridden by `do_resize` in the `preprocess` method. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`): + Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in + the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess` + method. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image by the specified mean and standard deviation. Can be overridden by + `do_normalize` in the `preprocess` method. + image_mean (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + Can be overridden by the `image_std` parameter in the `preprocess` method. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch the image will be split to. + max_num_patches (`int`, *optional*, defaults to 256): + The image will be resized to have at most this number of patches, + and then padded in "patch" dimension to match this number exactly. + """ + + model_input_names = ["pixel_values", "pixel_attention_mask", "spatial_shapes"] + + def __init__( + self, + do_resize: bool = True, + resample: "PILImageResampling" = PILImageResampling.BILINEAR, + do_rescale: bool = True, + rescale_factor: float = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_convert_rgb: Optional[bool] = None, + patch_size: int = 16, + max_num_patches: int = 256, + **kwargs, + ): + super().__init__(**kwargs) + + image_mean = image_mean if image_mean is not None else [0.5, 0.5, 0.5] + image_std = image_std if image_std is not None else [0.5, 0.5, 0.5] + + self.do_resize = do_resize + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean + self.image_std = image_std + self.do_convert_rgb = do_convert_rgb + self.patch_size = patch_size + self.max_num_patches = max_num_patches + + @filter_out_non_signature_kwargs() + def preprocess( + self, + images: ImageInput, + do_resize: Optional[bool] = None, + resample: Optional["PILImageResampling"] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + do_convert_rgb: Optional[bool] = None, + patch_size: Optional[int] = None, + max_num_patches: Optional[int] = None, + ) -> "Image.Image": + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Size of the image after resizing. + resample (`int`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only + has an effect if `do_resize` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to + `True`. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + patch_size (`int`, *optional*, defaults to `self.patch_size`): + Patch size for processing, same as the patch size used in the model. + max_num_patches (`int`, *optional*, defaults to `self.max_num_patches`): + Maximum number of patches per image, the image will be resized to have at most this number of patches. + """ + do_resize = do_resize if do_resize is not None else self.do_resize + resample = resample if resample is not None else self.resample + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + patch_size = patch_size if patch_size is not None else self.patch_size + max_num_patches = max_num_patches if max_num_patches is not None else self.max_num_patches + + # Explicitly specify data format to be channels last for image preprocessing. + # Image processor does not support different output formats, because it returns patches. + data_format = ChannelDimension.LAST + + images = make_flat_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + ) + if do_convert_rgb: + images = [convert_to_rgb(image) for image in images] + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if do_rescale and is_scaled_image(images[0]): + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + pixel_masks = [] + pixel_values = [] + spatial_shapes = [] + + for image in images: + image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + + if do_resize: + height, width = get_image_size_for_max_num_patches( + image_height=image.shape[0], + image_width=image.shape[1], + patch_size=patch_size, + max_num_patches=max_num_patches, + ) + image = resize(image=image, size=(height, width), resample=resample, input_data_format=data_format) + + if do_rescale: + image = self.rescale(image=image, scale=rescale_factor, input_data_format=data_format) + + if do_normalize: + image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=data_format) + + patches = convert_image_to_patches(image, patch_size) + patches, mask = pad_along_first_dim(patches, max_num_patches) + num_patches_height = image.shape[0] // patch_size + num_patches_width = image.shape[1] // patch_size + + spatial_shapes.append((num_patches_height, num_patches_width)) + pixel_values.append(patches) + pixel_masks.append(mask) + + batch_feature = BatchFeature( + data={ + "pixel_values": pixel_values, + "pixel_attention_mask": pixel_masks, + "spatial_shapes": spatial_shapes, + }, + tensor_type=return_tensors, + ) + + return batch_feature + + +__all__ = ["Siglip2ImageProcessor"] diff --git a/docs/transformers/src/transformers/models/siglip2/image_processing_siglip2_fast.py b/docs/transformers/src/transformers/models/siglip2/image_processing_siglip2_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..dcd7cef6293e7e2e3fb42b39c5486e348b4c0efb --- /dev/null +++ b/docs/transformers/src/transformers/models/siglip2/image_processing_siglip2_fast.py @@ -0,0 +1,195 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for SigLIP2.""" + +from typing import List, Optional, Tuple, Union + +import torch + +from ...image_processing_utils import BatchFeature +from ...image_processing_utils_fast import ( + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS, + BaseImageProcessorFast, + DefaultFastImageProcessorKwargs, + SizeDict, +) +from ...image_utils import ( + ImageInput, + PILImageResampling, +) +from ...processing_utils import Unpack +from ...utils import ( + TensorType, + add_start_docstrings, + is_torch_available, + is_torchvision_available, + is_torchvision_v2_available, + logging, +) +from .image_processing_siglip2 import get_image_size_for_max_num_patches + + +if is_torch_available(): + import torch + +if is_torchvision_available(): + if is_torchvision_v2_available(): + from torchvision.transforms.v2 import functional as F + else: + from torchvision.transforms import functional as F + + +logger = logging.get_logger(__name__) + + +def convert_image_to_patches(image: "torch.Tensor", patch_size: int) -> "torch.Tensor": + """ + Convert 3D tensor image of shape (num_channels, image_height, image_width) into 2D tensor of patches of shape + (num_patches_height * num_patches_width, patch_size * patch_size * num_channels). + """ + num_channels, image_height, image_width = image.shape + num_patches_height = image_height // patch_size + num_patches_width = image_width // patch_size + patched_image = image.reshape(num_channels, num_patches_height, patch_size, num_patches_width, patch_size) + patched_image = patched_image.permute(1, 3, 2, 4, 0) + patched_image = patched_image.reshape(num_patches_height * num_patches_width, -1) + return patched_image + + +def pad_along_first_dim( + tensor: "torch.Tensor", target_length: int, pad_value: int = 0 +) -> Tuple["torch.Tensor", "torch.Tensor"]: + """ + Pad the tensor along the first dimension. + """ + current_length = tensor.shape[0] + padding_length = target_length - current_length + mask = torch.ones((target_length,), dtype=torch.int32) + if padding_length > 0: + padding = [0, 0] * (tensor.ndim - 1) + [0, padding_length] + tensor = torch.nn.functional.pad(tensor, padding, mode="constant", value=pad_value) + mask[-padding_length:] = 0 + return tensor, mask + + +class Siglip2FastImageProcessorKwargs(DefaultFastImageProcessorKwargs): + patch_size: Optional[int] + max_num_patches: Optional[int] + + +@add_start_docstrings( + r"Constructs a fast Siglip2 image processor.", + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, + """ + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch the image will be split to. + max_num_patches (`int`, *optional*, defaults to 256): + The image will be resized to have at most this number of patches, + and then padded in "patch" dimension to match this number exactly. + """, +) +class Siglip2ImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.BILINEAR + image_mean = [0.5, 0.5, 0.5] + image_std = [0.5, 0.5, 0.5] + do_resize = True + do_rescale = True + do_normalize = True + patch_size = 16 + max_num_patches = 256 + valid_kwargs = Siglip2FastImageProcessorKwargs + unused_kwargs = ["size", "do_center_crop", "crop_size"] + + def __init__(self, **kwargs: Unpack[Siglip2FastImageProcessorKwargs]): + super().__init__(**kwargs) + + def _validate_preprocess_kwargs(self, **kwargs) -> tuple: + # Remove do_resize from kwargs to not raise an error as size is None + kwargs.pop("do_resize", None) + return super()._validate_preprocess_kwargs(**kwargs) + + @add_start_docstrings( + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS, + """ + patch_size (`int`, *optional*, defaults to `self.patch_size`): + The size (resolution) of each patch the image will be split to. + max_num_patches (`int`, *optional*, defaults to `self.max_num_patches`): + The image will be resized to have at most this number of patches, + and then padded in "patch" dimension to match this number exactly. + """, + ) + def preprocess(self, images: ImageInput, **kwargs: Unpack[Siglip2FastImageProcessorKwargs]) -> BatchFeature: + return super().preprocess(images, **kwargs) + + def _preprocess( + self, + images: List["torch.Tensor"], + do_resize: bool, + patch_size: int, + max_num_patches: int, + interpolation: Optional["F.InterpolationMode"], + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + image_mean: Optional[Union[float, List[float]]], + image_std: Optional[Union[float, List[float]]], + return_tensors: Optional[Union[str, TensorType]], + **kwargs, + ) -> BatchFeature: + pixel_masks = [] + pixel_values = [] + spatial_shapes = [] + + for image in images: + if do_resize: + height, width = get_image_size_for_max_num_patches( + image_height=image.shape[1], + image_width=image.shape[2], + patch_size=patch_size, + max_num_patches=max_num_patches, + ) + side_dict = SizeDict(height=height, width=width) + image = self.resize(image=image, size=side_dict, interpolation=interpolation) + + image = self.rescale_and_normalize(image, do_rescale, rescale_factor, do_normalize, image_mean, image_std) + + # (num_channels, height, width) -> (num_patches, patch_size * patch_size * num_channels) + patches = convert_image_to_patches(image, patch_size) + patches, mask = pad_along_first_dim(patches, max_num_patches) + + num_patches_height = image.shape[1] // patch_size + num_patches_width = image.shape[2] // patch_size + + spatial_shapes.append((num_patches_height, num_patches_width)) + pixel_values.append(patches) + pixel_masks.append(mask) + + pixel_values = torch.stack(pixel_values) + pixel_masks = torch.stack(pixel_masks) + spatial_shapes = torch.tensor(spatial_shapes) + + batch_feature = BatchFeature( + data={ + "pixel_values": pixel_values, + "pixel_attention_mask": pixel_masks, + "spatial_shapes": spatial_shapes, + }, + tensor_type=return_tensors, + ) + return batch_feature + + +__all__ = ["Siglip2ImageProcessorFast"] diff --git a/docs/transformers/src/transformers/models/siglip2/modeling_siglip2.py b/docs/transformers/src/transformers/models/siglip2/modeling_siglip2.py new file mode 100644 index 0000000000000000000000000000000000000000..0999ff69845b1854b65203dfc8cf77f0ec35831b --- /dev/null +++ b/docs/transformers/src/transformers/models/siglip2/modeling_siglip2.py @@ -0,0 +1,1447 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/siglip2/modular_siglip2.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_siglip2.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +import warnings +from dataclasses import dataclass +from typing import Any, Callable, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from torch.nn.init import _calculate_fan_in_and_fan_out + +from ...activations import ACT2FN +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + can_return_tuple, + logging, + replace_return_docstrings, +) +from .configuration_siglip2 import Siglip2Config, Siglip2TextConfig, Siglip2VisionConfig + + +logger = logging.get_logger(__name__) + +# General docstring +_CONFIG_FOR_DOC = "Siglip2Config" + + +@dataclass +class Siglip2VisionOutput(ModelOutput): + """ + Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states. + + Args: + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The image embeddings obtained by applying the projection layer to the pooler_output. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + image_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class Siglip2TextOutput(ModelOutput): + """ + Base class for text model's outputs that also contains a pooling of the last hidden states. + + Args: + text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The text embeddings obtained by applying the projection layer to the pooler_output. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + text_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class Siglip2Output(ModelOutput): + """ + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`): + Contrastive loss for image-text similarity. + logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`): + The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text + similarity scores. + logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`): + The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image + similarity scores. + text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`): + The text embeddings obtained by applying the projection layer to the pooled output of [`Siglip2TextModel`]. + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`): + The image embeddings obtained by applying the projection layer to the pooled output of [`Siglip2VisionModel`]. + text_model_output (`BaseModelOutputWithPooling`): + The output of the [`Siglip2TextModel`]. + vision_model_output (`BaseModelOutputWithPooling`): + The output of the [`Siglip2VisionModel`]. + """ + + loss: Optional[torch.FloatTensor] = None + logits_per_image: Optional[torch.FloatTensor] = None + logits_per_text: Optional[torch.FloatTensor] = None + text_embeds: Optional[torch.FloatTensor] = None + image_embeds: Optional[torch.FloatTensor] = None + text_model_output: BaseModelOutputWithPooling = None + vision_model_output: BaseModelOutputWithPooling = None + + def to_tuple(self) -> Tuple[Any]: + return tuple( + self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() + for k in self.keys() + ) + + +class Siglip2VisionEmbeddings(nn.Module): + def __init__(self, config: Siglip2VisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.patch_size = config.patch_size + + self.patch_embedding = nn.Linear( + in_features=config.num_channels * self.patch_size * self.patch_size, + out_features=self.embed_dim, + ) + + self.num_patches = config.num_patches + self.position_embedding_size = int(self.num_patches**0.5) + self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim) + + @staticmethod + def resize_positional_embeddings( + positional_embeddings: torch.Tensor, + spatial_shapes: torch.LongTensor, + max_length: int, + ) -> torch.Tensor: + """ + Resize positional embeddings to image-specific size and pad to a fixed size. + + Args: + positional_embeddings (`torch.Tensor`): + Position embeddings of shape (height, width, embed_dim) + spatial_shapes (`torch.LongTensor`): + Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to + max_length (`int`): + Maximum length of the positional embeddings to pad resized positional embeddings to + + Returns: + `torch.Tensor`: Embeddings of shape (batch_size, max_length, embed_dim) + """ + batch_size = spatial_shapes.shape[0] + embed_dim = positional_embeddings.shape[-1] + source_dtype = positional_embeddings.dtype + + resulted_positional_embeddings = torch.empty( + (batch_size, max_length, embed_dim), + device=positional_embeddings.device, + dtype=source_dtype, + ) + + # (height, width, embed_dim) -> (1, embed_dim, height, width) for interpolation + positional_embeddings = positional_embeddings.permute(2, 0, 1).unsqueeze(0) + + # Upcast to float32 on CPU because antialias is not supported for bfloat16/float16 on CPU + if positional_embeddings.device.type == "cpu": + positional_embeddings = positional_embeddings.to(torch.float32) + + for i in range(batch_size): + # (1, dim, height, width) -> (1, dim, target_height, target_width) + height, width = spatial_shapes[i] + resized_embeddings = F.interpolate( + positional_embeddings, + size=(height, width), + mode="bilinear", + align_corners=False, + antialias=True, + ) + + # (1, dim, target_height, target_width) -> (target_height * target_width, dim) + resized_embeddings = resized_embeddings.reshape(embed_dim, height * width).transpose(0, 1) + + # Cast to original dtype + resized_embeddings = resized_embeddings.to(source_dtype) + + resulted_positional_embeddings[i, : height * width] = resized_embeddings + resulted_positional_embeddings[i, height * width :] = resized_embeddings[0] + + return resulted_positional_embeddings + + def forward(self, pixel_values: torch.FloatTensor, spatial_shapes: torch.LongTensor) -> torch.Tensor: + """ + Args: + pixel_values (`torch.FloatTensor`): + Pixel values of shape (batch_size, max_num_patches, num_channels * patch_size * patch_size) + spatial_shapes (`List[Tuple[int, int]]`): + Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to + """ + + # Apply patch embeddings to already patchified pixel values + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) + + # Get positional resized and padded positional embeddings + positional_embeddings = self.position_embedding.weight.reshape( + self.position_embedding_size, self.position_embedding_size, -1 + ) + resized_positional_embeddings = self.resize_positional_embeddings( + positional_embeddings, spatial_shapes, max_length=pixel_values.shape[1] + ) + + # Add positional embeddings to patch embeddings + embeddings = patch_embeds + resized_positional_embeddings + return embeddings + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class Siglip2Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Union[Siglip2VisionConfig, Siglip2TextConfig]): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + self.is_causal = False + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Input shape: Batch x Time x Channel""" + + batch_size, seq_length, embed_dim = hidden_states.shape + + queries = self.q_proj(hidden_states) + keys = self.k_proj(hidden_states) + values = self.v_proj(hidden_states) + + queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and output_attentions: + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + queries, + keys, + values, + attention_mask, + is_causal=self.is_causal, + scaling=self.scale, + dropout=0.0 if not self.training else self.dropout, + ) + + attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous() + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights + + +class Siglip2MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class Siglip2EncoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Union[Siglip2VisionConfig, Siglip2TextConfig]): + super().__init__() + self.embed_dim = config.hidden_size + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.self_attn = Siglip2Attention(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = Siglip2MLP(config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor]: + """ + Args: + hidden_states (`torch.FloatTensor`): + Input to the layer of shape `(batch, seq_len, embed_dim)`. + attention_mask (`torch.FloatTensor`): + Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class Siglip2Encoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`Siglip2EncoderLayer`]. + + Args: + config: Siglip2Config + """ + + def __init__(self, config: Siglip2Config): + super().__init__() + self.config = config + self.layers = nn.ModuleList([Siglip2EncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + # Ignore copy + @can_return_tuple + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> BaseModelOutput: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for encoder_layer in self.layers: + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=encoder_states, + attentions=all_attentions, + ) + + +SIGLIP2_VISION_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): + Whether to interpolate the pre-trained position encodings. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class Siglip2VisionTransformer(nn.Module): + def __init__(self, config: Siglip2VisionConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + + self.embeddings = Siglip2VisionEmbeddings(config) + self.encoder = Siglip2Encoder(config) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.use_head = True if not hasattr(config, "vision_use_head") else config.vision_use_head + if self.use_head: + self.head = Siglip2MultiheadAttentionPoolingHead(config) + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + + @can_return_tuple + @add_start_docstrings_to_model_forward(SIGLIP2_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=Siglip2VisionConfig) + def forward( + self, + pixel_values: torch.FloatTensor, + attention_mask: torch.Tensor, + spatial_shapes: torch.LongTensor, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> BaseModelOutputWithPooling: + r""" + Returns: + + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + hidden_states = self.embeddings(pixel_values, spatial_shapes) + + if attention_mask is not None and not self._use_flash_attention_2: + # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) + else: + encoder_attention_mask = attention_mask + + encoder_outputs: BaseModelOutput = self.encoder( + inputs_embeds=hidden_states, + attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + last_hidden_state = encoder_outputs.last_hidden_state + last_hidden_state = self.post_layernorm(last_hidden_state) + + pooler_output = self.head(last_hidden_state, attention_mask) if self.use_head else None + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooler_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class Siglip2TextEmbeddings(nn.Module): + def __init__(self, config: Siglip2TextConfig): + super().__init__() + embed_dim = config.hidden_size + + self.token_embedding = nn.Embedding(config.vocab_size, embed_dim) + self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ) -> torch.Tensor: + seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] + max_position_embedding = self.position_embedding.weight.shape[0] + + if seq_length > max_position_embedding: + raise ValueError( + f"Sequence length must be less than max_position_embeddings (got `sequence length`: " + f"{seq_length} and max_position_embeddings: {max_position_embedding}" + ) + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + if inputs_embeds is None: + inputs_embeds = self.token_embedding(input_ids) + + position_embeddings = self.position_embedding(position_ids) + embeddings = inputs_embeds + position_embeddings + + return embeddings + + +def _trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2, + ) + + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + + +def trunc_normal_tf_( + tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0 +) -> torch.Tensor: + """Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \\leq \text{mean} \\leq b`. + + NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the + bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0 + and the result is subsequently scaled and shifted by the mean and std args. + + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + """ + with torch.no_grad(): + _trunc_normal_(tensor, 0, 1.0, a, b) + tensor.mul_(std).add_(mean) + + +def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + if mode == "fan_in": + denom = fan_in + elif mode == "fan_out": + denom = fan_out + elif mode == "fan_avg": + denom = (fan_in + fan_out) / 2 + + variance = scale / denom + + if distribution == "truncated_normal": + # constant is stddev of standard normal truncated to (-2, 2) + trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978) + elif distribution == "normal": + with torch.no_grad(): + tensor.normal_(std=math.sqrt(variance)) + elif distribution == "uniform": + bound = math.sqrt(3 * variance) + with torch.no_grad(): + tensor.uniform_(-bound, bound) + else: + raise ValueError(f"invalid distribution {distribution}") + + +def lecun_normal_(tensor): + variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal") + + +def default_flax_embed_init(tensor): + variance_scaling_(tensor, mode="fan_in", distribution="normal") + + +SIGLIP2_TEXT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class Siglip2TextTransformer(nn.Module): + def __init__(self, config: Siglip2TextConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + self.embeddings = Siglip2TextEmbeddings(config) + self.encoder = Siglip2Encoder(config) + self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + self.head = nn.Linear(embed_dim, config.projection_size) + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + + @can_return_tuple + @add_start_docstrings_to_model_forward(SIGLIP2_TEXT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=Siglip2TextConfig) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> BaseModelOutputWithPooling: + r""" + Returns: + + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + if input_ids is None: + raise ValueError("You have to specify input_ids") + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + + hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids) + + # note: Siglip2's text model does not use a causal mask, unlike the original CLIP model. + # expand attention_mask + if attention_mask is not None and not self._use_flash_attention_2: + # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) + + encoder_outputs: BaseModelOutput = self.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + last_hidden_state = encoder_outputs.last_hidden_state + last_hidden_state = self.final_layer_norm(last_hidden_state) + + # Assuming "sticky" EOS tokenization, last token is always EOS. + pooled_output = last_hidden_state[:, -1, :] + pooled_output = self.head(pooled_output) + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +SIGLIP2_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Siglip2Config`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +SIGLIP2_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. + return_loss (`bool`, *optional*): + Whether or not to return the contrastive loss. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): + Whether to interpolate the pre-trained position encodings. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class Siglip2PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = Siglip2Config + base_model_prefix = "siglip2" + supports_gradient_checkpointing = True + + _no_split_modules = [ + "Siglip2TextEmbeddings", + "Siglip2EncoderLayer", + "Siglip2VisionEmbeddings", + "Siglip2EncoderLayer", + "Siglip2MultiheadAttentionPoolingHead", + ] + _supports_flash_attn_2 = True + _supports_sdpa = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, Siglip2VisionEmbeddings): + width = ( + self.config.vision_config.hidden_size + if isinstance(self.config, Siglip2Config) + else self.config.hidden_size + ) + nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width)) + elif isinstance(module, nn.Embedding): + default_flax_embed_init(module.weight) + elif isinstance(module, Siglip2Attention): + nn.init.xavier_uniform_(module.q_proj.weight) + nn.init.xavier_uniform_(module.k_proj.weight) + nn.init.xavier_uniform_(module.v_proj.weight) + nn.init.xavier_uniform_(module.out_proj.weight) + nn.init.zeros_(module.q_proj.bias) + nn.init.zeros_(module.k_proj.bias) + nn.init.zeros_(module.v_proj.bias) + nn.init.zeros_(module.out_proj.bias) + elif isinstance(module, Siglip2MLP): + nn.init.xavier_uniform_(module.fc1.weight) + nn.init.xavier_uniform_(module.fc2.weight) + nn.init.normal_(module.fc1.bias, std=1e-6) + nn.init.normal_(module.fc2.bias, std=1e-6) + elif isinstance(module, Siglip2MultiheadAttentionPoolingHead): + nn.init.xavier_uniform_(module.probe.data) + nn.init.xavier_uniform_(module.attention.in_proj_weight.data) + nn.init.zeros_(module.attention.in_proj_bias.data) + elif isinstance(module, Siglip2Model): + logit_scale_init = torch.log(torch.tensor(1.0)) + module.logit_scale.data.fill_(logit_scale_init) + module.logit_bias.data.zero_() + elif isinstance(module, Siglip2ForImageClassification): + nn.init.normal_( + module.classifier.weight, + std=self.config.vision_config.hidden_size**-0.5 * self.config.initializer_factor, + ) + elif isinstance(module, (nn.Linear, nn.Conv2d)): + lecun_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +@add_start_docstrings( + """The text model from Siglip2 without any head or projection on top.""", + SIGLIP2_START_DOCSTRING, +) +class Siglip2TextModel(Siglip2PreTrainedModel): + config_class = Siglip2TextConfig + + def __init__(self, config: Siglip2TextConfig): + super().__init__(config) + self.text_model = Siglip2TextTransformer(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.text_model.embeddings.token_embedding + + def set_input_embeddings(self, value): + self.text_model.embeddings.token_embedding = value + + @can_return_tuple + @add_start_docstrings_to_model_forward(SIGLIP2_TEXT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=Siglip2TextConfig) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> BaseModelOutputWithPooling: + r""" + Returns: + + Examples: + + ```python + >>> from transformers import AutoTokenizer, Siglip2TextModel + + >>> model = Siglip2TextModel.from_pretrained("google/siglip2-base-patch16-224") + >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip2-base-patch16-224") + + >>> # important: make sure to set padding="max_length" as that's how the model was trained + >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled (EOS token) states + ```""" + + return self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + +class Siglip2MultiheadAttentionPoolingHead(nn.Module): + """Multihead Attention Pooling.""" + + def __init__(self, config: Siglip2VisionConfig): + super().__init__() + + self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size)) + self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True) + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.mlp = Siglip2MLP(config) + self.num_heads = config.num_attention_heads + + def forward(self, hidden_state: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + batch_size = hidden_state.shape[0] + probe = self.probe.repeat(batch_size, 1, 1) + + if attention_mask is not None: + target_len, source_len = probe.shape[1], hidden_state.shape[1] + attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_state.dtype, target_len) + attention_mask = attention_mask.repeat(1, self.num_heads, target_len, 1) + attention_mask = attention_mask.reshape(-1, target_len, source_len) + + hidden_state = self.attention(probe, hidden_state, hidden_state, attn_mask=attention_mask)[0] + + residual = hidden_state + hidden_state = self.layernorm(hidden_state) + hidden_state = residual + self.mlp(hidden_state) + + return hidden_state[:, 0] + + +@add_start_docstrings( + """The vision model from Siglip2 without any head or projection on top.""", + SIGLIP2_START_DOCSTRING, +) +class Siglip2VisionModel(Siglip2PreTrainedModel): + config_class = Siglip2VisionConfig + main_input_name = "pixel_values" + + def __init__(self, config: Siglip2VisionConfig): + super().__init__(config) + + self.vision_model = Siglip2VisionTransformer(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.vision_model.embeddings.patch_embedding + + @can_return_tuple + @add_start_docstrings_to_model_forward(SIGLIP2_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=Siglip2VisionConfig) + def forward( + self, + pixel_values: torch.FloatTensor, + pixel_attention_mask: torch.Tensor, + spatial_shapes: torch.LongTensor, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> BaseModelOutputWithPooling: + r""" + Returns: + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Siglip2VisionModel + + >>> model = Siglip2VisionModel.from_pretrained("google/siglip2-base-patch16-224") + >>> processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-224") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled features + ```""" + return self.vision_model( + pixel_values=pixel_values, + attention_mask=pixel_attention_mask, + spatial_shapes=spatial_shapes, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + +@add_start_docstrings(SIGLIP2_START_DOCSTRING) +class Siglip2Model(Siglip2PreTrainedModel): + config_class = Siglip2Config + + def __init__(self, config: Siglip2Config): + super().__init__(config) + + if not isinstance(config.text_config, Siglip2TextConfig): + raise TypeError( + "config.text_config is expected to be of type Siglip2TextConfig but is of type" + f" {type(config.text_config)}." + ) + + if not isinstance(config.vision_config, Siglip2VisionConfig): + raise TypeError( + "config.vision_config is expected to be of type Siglip2VisionConfig but is of type" + f" {type(config.vision_config)}." + ) + + text_config = config.text_config + vision_config = config.vision_config + + # First, initialize the text and vision models with proper attention implementation + text_model = Siglip2TextModel._from_config(text_config) + vision_model = Siglip2VisionModel._from_config(vision_config) + + # Second, get the text and vision submodules (for backward compatibility) + self.text_model = text_model.text_model + self.vision_model = vision_model.vision_model + + self.logit_scale = nn.Parameter(torch.randn(1)) + self.logit_bias = nn.Parameter(torch.randn(1)) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(SIGLIP2_TEXT_INPUTS_DOCSTRING) + def get_text_features( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> torch.FloatTensor: + r""" + Returns: + text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by + applying the projection layer to the pooled output of [`Siglip2TextModel`]. + + Examples: + + ```python + >>> from transformers import AutoTokenizer, AutoModel + >>> import torch + + >>> model = AutoModel.from_pretrained("google/siglip2-base-patch16-224") + >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip2-base-patch16-224") + + >>> # important: make sure to set padding="max_length" as that's how the model was trained + >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt") + >>> with torch.no_grad(): + ... text_features = model.get_text_features(**inputs) + ```""" + # Use Siglip2 model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + text_outputs: BaseModelOutputWithPooling = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + pooled_output = text_outputs.pooler_output + + return pooled_output + + @add_start_docstrings_to_model_forward(SIGLIP2_VISION_INPUTS_DOCSTRING) + def get_image_features( + self, + pixel_values: Optional[torch.FloatTensor] = None, + pixel_attention_mask: Optional[torch.Tensor] = None, + spatial_shapes: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> torch.FloatTensor: + r""" + Returns: + image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by + applying the projection layer to the pooled output of [`Siglip2VisionModel`]. + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, AutoModel + >>> import torch + + >>> model = AutoModel.from_pretrained("google/siglip2-base-patch16-224") + >>> processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-224") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> with torch.no_grad(): + ... image_features = model.get_image_features(**inputs) + ```""" + # Use Siglip2Model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + vision_outputs: BaseModelOutputWithPooling = self.vision_model( + pixel_values=pixel_values, + attention_mask=pixel_attention_mask, + spatial_shapes=spatial_shapes, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + pooled_output = vision_outputs.pooler_output + + return pooled_output + + @can_return_tuple + @add_start_docstrings_to_model_forward(SIGLIP2_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Siglip2Output, config_class=Siglip2Config) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + pixel_attention_mask: Optional[torch.Tensor] = None, + spatial_shapes: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + return_loss: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> Siglip2Output: + r""" + Returns: + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, AutoModel + >>> import torch + + >>> model = AutoModel.from_pretrained("google/siglip2-base-patch16-224") + >>> processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-224") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> texts = ["a photo of 2 cats", "a photo of 2 dogs"] + >>> # important: we pass `padding=max_length` since the model was trained with this + >>> inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt") + + >>> with torch.no_grad(): + ... outputs = model(**inputs) + + >>> logits_per_image = outputs.logits_per_image + >>> probs = torch.sigmoid(logits_per_image) # these are the probabilities + >>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'") + 31.9% that image 0 is 'a photo of 2 cats' + ```""" + # Use Siglip2 model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + vision_outputs: BaseModelOutputWithPooling = self.vision_model( + pixel_values=pixel_values, + attention_mask=pixel_attention_mask, + spatial_shapes=spatial_shapes, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + text_outputs: BaseModelOutputWithPooling = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + image_embeds = vision_outputs.pooler_output + text_embeds = text_outputs.pooler_output + + # normalized features + image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) + text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) + + # cosine similarity as logits + logits_per_text = torch.matmul(text_embeds, image_embeds.t().to(text_embeds.device)) + + logit_scale, logit_bias = self.logit_scale.to(text_embeds.device), self.logit_bias.to(text_embeds.device) + logits_per_text = logits_per_text * logit_scale.exp() + logit_bias + + logits_per_image = logits_per_text.t() + + loss = None + if return_loss: + # Adapted from https://github.com/google-research/big_vision/blob/01edb81a4716f93a48be43b3a4af14e29cdb3a7f/big_vision/trainers/proj/image_text/siglip2.py#L287 + eye = torch.eye(logits_per_text.size(0), device=logits_per_text.device) + m1_diag1 = -torch.ones_like(logits_per_text) + 2 * eye + loglik = torch.nn.functional.logsigmoid(m1_diag1 * logits_per_text) + nll = -torch.sum(loglik, dim=-1) + loss = nll.mean() + + return Siglip2Output( + loss=loss, + logits_per_image=logits_per_image, + logits_per_text=logits_per_text, + text_embeds=text_embeds, + image_embeds=image_embeds, + text_model_output=text_outputs, + vision_model_output=vision_outputs, + ) + + +@add_start_docstrings( + """ + Siglip2 vision encoder with an image classification head on top (a linear layer on top of the pooled final hidden states of + the patch tokens) e.g. for ImageNet. + """, + SIGLIP2_START_DOCSTRING, +) +class Siglip2ForImageClassification(Siglip2PreTrainedModel): + main_input_name = "pixel_values" + + def __init__(self, config: Siglip2Config) -> None: + super().__init__(config) + + self.num_labels = config.num_labels + + # Create the vision model with proper attention + # and take only vision_model submodule (for backward compatibility) + vision_model = Siglip2VisionModel._from_config(config.vision_config) + self.vision_model = vision_model.vision_model + + # Classifier head + self.classifier = ( + nn.Linear(config.vision_config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity() + ) + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @add_start_docstrings_to_model_forward(SIGLIP2_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=ImageClassifierOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + pixel_attention_mask: Optional[torch.Tensor] = None, + spatial_shapes: Optional[torch.LongTensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> ImageClassifierOutput: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + Returns: + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, Siglip2ForImageClassification + >>> import torch + >>> from PIL import Image + >>> import requests + + >>> torch.manual_seed(3) # doctest: +IGNORE_RESULT + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> # note: we are loading a `Siglip2Model` from the hub here, + >>> # so the head will be randomly initialized, hence the predictions will be random if seed is not set above. + >>> image_processor = AutoImageProcessor.from_pretrained("google/siglip2-base-patch16-224") + >>> model = Siglip2ForImageClassification.from_pretrained("google/siglip2-base-patch16-224") + + >>> inputs = image_processor(images=image, return_tensors="pt") + >>> outputs = model(**inputs) + >>> logits = outputs.logits + >>> # model predicts one of the two classes + >>> predicted_class_idx = logits.argmax(-1).item() + >>> print("Predicted class:", model.config.id2label[predicted_class_idx]) + Predicted class: LABEL_1 + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + outputs: BaseModelOutputWithPooling = self.vision_model( + pixel_values, + attention_mask=pixel_attention_mask, + spatial_shapes=spatial_shapes, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + sequence_output = outputs.last_hidden_state + + # average pool the patch tokens + if pixel_attention_mask is not None: + pool_mask = pixel_attention_mask[..., None].to(sequence_output.device) + sequence_output = torch.sum(sequence_output * pool_mask, dim=1) / torch.sum(pool_mask, dim=1) + else: + sequence_output = torch.mean(sequence_output, dim=1) + + # apply classifier + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "Siglip2Model", + "Siglip2PreTrainedModel", + "Siglip2TextModel", + "Siglip2VisionModel", + "Siglip2ForImageClassification", +] diff --git a/docs/transformers/src/transformers/models/siglip2/modular_siglip2.py b/docs/transformers/src/transformers/models/siglip2/modular_siglip2.py new file mode 100644 index 0000000000000000000000000000000000000000..23df3b0413d9c2719925a2cb93418ddc70d11f5a --- /dev/null +++ b/docs/transformers/src/transformers/models/siglip2/modular_siglip2.py @@ -0,0 +1,511 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.models.siglip.configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig +from transformers.models.siglip.modeling_siglip import ( + BaseModelOutput, + BaseModelOutputWithPooling, + ImageClassifierOutput, + SiglipForImageClassification, + SiglipModel, + SiglipMultiheadAttentionPoolingHead, + SiglipOutput, + SiglipPreTrainedModel, + SiglipTextModel, + SiglipTextModelOutput, + SiglipVisionModel, + SiglipVisionModelOutput, + SiglipVisionTransformer, +) + +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask + + +class Siglip2TextConfig(SiglipTextConfig): + pass + + +class Siglip2VisionConfig(SiglipVisionConfig): + r""" + This is the configuration class to store the configuration of a [`Siglip2VisionModel`]. It is used to instantiate a + Siglip2 vision encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip2 + [google/siglip2-base-patch16-naflex](https://huggingface.co/google/siglip2-base-patch16-naflex) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + num_channels (`int`, *optional*, defaults to 3): + Number of channels in the input images. + num_patches (`int`, *optional*, defaults to 256): + The number of patches in the image with the size of (`patch_size`, `patch_size`). + The image is resized to fill maximum of this number of patches, and to preserve + the aspect ratio. In case the resulted number of patches is lower, the image is + padded in "patch" dimension. + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + + Example: + + ```python + >>> from transformers import Siglip2VisionConfig, Siglip2VisionModel + + >>> # Initializing a Siglip2VisionConfig with google/siglip2-base-patch16-naflex style configuration + >>> configuration = Siglip2VisionConfig() + + >>> # Initializing a Siglip2VisionModel (with random weights) from the google/siglip2-base-patch16-naflex style configuration + >>> model = Siglip2VisionModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + def __init__( + self, + hidden_size=768, + intermediate_size=3072, + num_hidden_layers=12, + num_attention_heads=12, + num_channels=3, + num_patches=256, + patch_size=16, + hidden_act="gelu_pytorch_tanh", + layer_norm_eps=1e-6, + attention_dropout=0.0, + **kwargs, + ): + super().__init__(**kwargs) + self.num_patches = num_patches + del self.image_size + + +class Siglip2Config(SiglipConfig): + pass + + +class Siglip2VisionOutput(SiglipVisionModelOutput): + pass + + +class Siglip2TextOutput(SiglipTextModelOutput): + pass + + +class Siglip2Output(SiglipOutput): + pass + + +class Siglip2VisionEmbeddings(nn.Module): + def __init__(self, config: Siglip2VisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.patch_size = config.patch_size + + self.patch_embedding = nn.Linear( + in_features=config.num_channels * self.patch_size * self.patch_size, + out_features=self.embed_dim, + ) + + self.num_patches = config.num_patches + self.position_embedding_size = int(self.num_patches**0.5) + self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim) + + @staticmethod + def resize_positional_embeddings( + positional_embeddings: torch.Tensor, + spatial_shapes: torch.LongTensor, + max_length: int, + ) -> torch.Tensor: + """ + Resize positional embeddings to image-specific size and pad to a fixed size. + + Args: + positional_embeddings (`torch.Tensor`): + Position embeddings of shape (height, width, embed_dim) + spatial_shapes (`torch.LongTensor`): + Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to + max_length (`int`): + Maximum length of the positional embeddings to pad resized positional embeddings to + + Returns: + `torch.Tensor`: Embeddings of shape (batch_size, max_length, embed_dim) + """ + batch_size = spatial_shapes.shape[0] + embed_dim = positional_embeddings.shape[-1] + source_dtype = positional_embeddings.dtype + + resulted_positional_embeddings = torch.empty( + (batch_size, max_length, embed_dim), + device=positional_embeddings.device, + dtype=source_dtype, + ) + + # (height, width, embed_dim) -> (1, embed_dim, height, width) for interpolation + positional_embeddings = positional_embeddings.permute(2, 0, 1).unsqueeze(0) + + # Upcast to float32 on CPU because antialias is not supported for bfloat16/float16 on CPU + if positional_embeddings.device.type == "cpu": + positional_embeddings = positional_embeddings.to(torch.float32) + + for i in range(batch_size): + # (1, dim, height, width) -> (1, dim, target_height, target_width) + height, width = spatial_shapes[i] + resized_embeddings = F.interpolate( + positional_embeddings, + size=(height, width), + mode="bilinear", + align_corners=False, + antialias=True, + ) + + # (1, dim, target_height, target_width) -> (target_height * target_width, dim) + resized_embeddings = resized_embeddings.reshape(embed_dim, height * width).transpose(0, 1) + + # Cast to original dtype + resized_embeddings = resized_embeddings.to(source_dtype) + + resulted_positional_embeddings[i, : height * width] = resized_embeddings + resulted_positional_embeddings[i, height * width :] = resized_embeddings[0] + + return resulted_positional_embeddings + + def forward(self, pixel_values: torch.FloatTensor, spatial_shapes: torch.LongTensor) -> torch.Tensor: + """ + Args: + pixel_values (`torch.FloatTensor`): + Pixel values of shape (batch_size, max_num_patches, num_channels * patch_size * patch_size) + spatial_shapes (`List[Tuple[int, int]]`): + Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to + """ + + # Apply patch embeddings to already patchified pixel values + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) + + # Get positional resized and padded positional embeddings + positional_embeddings = self.position_embedding.weight.reshape( + self.position_embedding_size, self.position_embedding_size, -1 + ) + resized_positional_embeddings = self.resize_positional_embeddings( + positional_embeddings, spatial_shapes, max_length=pixel_values.shape[1] + ) + + # Add positional embeddings to patch embeddings + embeddings = patch_embeds + resized_positional_embeddings + return embeddings + + +class Siglip2VisionTransformer(SiglipVisionTransformer): + def __init__(self, config: Siglip2VisionConfig): + super().__init__() + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + + # Update: add `spatial_shapes` and `attention_mask` + def forward( + self, + pixel_values: torch.FloatTensor, + attention_mask: torch.Tensor, + spatial_shapes: torch.LongTensor, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> BaseModelOutputWithPooling: + r""" + Returns: + + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + hidden_states = self.embeddings(pixel_values, spatial_shapes) + + if attention_mask is not None and not self._use_flash_attention_2: + # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) + else: + encoder_attention_mask = attention_mask + + encoder_outputs: BaseModelOutput = self.encoder( + inputs_embeds=hidden_states, + attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + last_hidden_state = encoder_outputs.last_hidden_state + last_hidden_state = self.post_layernorm(last_hidden_state) + + pooler_output = self.head(last_hidden_state, attention_mask) if self.use_head else None + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooler_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class Siglip2PreTrainedModel(SiglipPreTrainedModel): + pass + + +class Siglip2TextModel(SiglipTextModel): + pass + + +class Siglip2MultiheadAttentionPoolingHead(SiglipMultiheadAttentionPoolingHead): + def __init__(self, config: Siglip2VisionConfig): + super().__init__(config) + self.num_heads = config.num_attention_heads + + def forward(self, hidden_state: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + batch_size = hidden_state.shape[0] + probe = self.probe.repeat(batch_size, 1, 1) + + if attention_mask is not None: + target_len, source_len = probe.shape[1], hidden_state.shape[1] + attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_state.dtype, target_len) + attention_mask = attention_mask.repeat(1, self.num_heads, target_len, 1) + attention_mask = attention_mask.reshape(-1, target_len, source_len) + + hidden_state = self.attention(probe, hidden_state, hidden_state, attn_mask=attention_mask)[0] + + residual = hidden_state + hidden_state = self.layernorm(hidden_state) + hidden_state = residual + self.mlp(hidden_state) + + return hidden_state[:, 0] + + +class Siglip2VisionModel(SiglipVisionModel): + # Update: add `spatial_shapes` and `pixel_attention_mask` + def forward( + self, + pixel_values: torch.FloatTensor, + pixel_attention_mask: torch.Tensor, + spatial_shapes: torch.LongTensor, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> BaseModelOutputWithPooling: + return self.vision_model( + pixel_values=pixel_values, + attention_mask=pixel_attention_mask, + spatial_shapes=spatial_shapes, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + +class Siglip2Model(SiglipModel): + # Update: add `spatial_shapes` and `pixel_attention_mask` + def get_image_features( + self, + pixel_values: Optional[torch.FloatTensor] = None, + pixel_attention_mask: Optional[torch.Tensor] = None, + spatial_shapes: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> torch.FloatTensor: + # Use Siglip2Model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + vision_outputs: BaseModelOutputWithPooling = self.vision_model( + pixel_values=pixel_values, + attention_mask=pixel_attention_mask, + spatial_shapes=spatial_shapes, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + pooled_output = vision_outputs.pooler_output + + return pooled_output + + # Update: add `spatial_shapes` and `pixel_attention_mask` + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + pixel_attention_mask: Optional[torch.Tensor] = None, + spatial_shapes: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + return_loss: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> Siglip2Output: + # Use Siglip2 model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + vision_outputs: BaseModelOutputWithPooling = self.vision_model( + pixel_values=pixel_values, + attention_mask=pixel_attention_mask, + spatial_shapes=spatial_shapes, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + text_outputs: BaseModelOutputWithPooling = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + image_embeds = vision_outputs.pooler_output + text_embeds = text_outputs.pooler_output + + # normalized features + image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) + text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) + + # cosine similarity as logits + logits_per_text = torch.matmul(text_embeds, image_embeds.t().to(text_embeds.device)) + + logit_scale, logit_bias = self.logit_scale.to(text_embeds.device), self.logit_bias.to(text_embeds.device) + logits_per_text = logits_per_text * logit_scale.exp() + logit_bias + + logits_per_image = logits_per_text.t() + + loss = None + if return_loss: + # Adapted from https://github.com/google-research/big_vision/blob/01edb81a4716f93a48be43b3a4af14e29cdb3a7f/big_vision/trainers/proj/image_text/siglip2.py#L287 + eye = torch.eye(logits_per_text.size(0), device=logits_per_text.device) + m1_diag1 = -torch.ones_like(logits_per_text) + 2 * eye + loglik = torch.nn.functional.logsigmoid(m1_diag1 * logits_per_text) + nll = -torch.sum(loglik, dim=-1) + loss = nll.mean() + + return Siglip2Output( + loss=loss, + logits_per_image=logits_per_image, + logits_per_text=logits_per_text, + text_embeds=text_embeds, + image_embeds=image_embeds, + text_model_output=text_outputs, + vision_model_output=vision_outputs, + ) + + +class Siglip2ForImageClassification(SiglipForImageClassification): + # Update: add `spatial_shapes` and `pixel_attention_mask` + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + pixel_attention_mask: Optional[torch.Tensor] = None, + spatial_shapes: Optional[torch.LongTensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> ImageClassifierOutput: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + outputs: BaseModelOutputWithPooling = self.vision_model( + pixel_values, + attention_mask=pixel_attention_mask, + spatial_shapes=spatial_shapes, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + sequence_output = outputs.last_hidden_state + + # average pool the patch tokens + if pixel_attention_mask is not None: + pool_mask = pixel_attention_mask[..., None].to(sequence_output.device) + sequence_output = torch.sum(sequence_output * pool_mask, dim=1) / torch.sum(pool_mask, dim=1) + else: + sequence_output = torch.mean(sequence_output, dim=1) + + # apply classifier + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "Siglip2Config", + "Siglip2TextConfig", + "Siglip2VisionConfig", + "Siglip2Model", + "Siglip2PreTrainedModel", + "Siglip2TextModel", + "Siglip2VisionModel", + "Siglip2ForImageClassification", +] diff --git a/docs/transformers/src/transformers/models/siglip2/processing_siglip2.py b/docs/transformers/src/transformers/models/siglip2/processing_siglip2.py new file mode 100644 index 0000000000000000000000000000000000000000..16e0ea1b6b8df5ef74dc9f1472d779407dc16c08 --- /dev/null +++ b/docs/transformers/src/transformers/models/siglip2/processing_siglip2.py @@ -0,0 +1,171 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Image/Text processor class for SigLIP2. +""" + +from typing import List, Optional, Union + +from ...feature_extraction_utils import BatchFeature +from ...image_utils import ImageInput +from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack +from ...tokenization_utils_base import PreTokenizedInput, TextInput + + +class Siglip2ImagesKwargs(ImagesKwargs, total=False): + max_num_patches: Optional[int] + patch_size: Optional[int] + + +class Siglip2ProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs: Siglip2ImagesKwargs + + _defaults = { + "text_kwargs": { + "padding": "max_length", + "truncation": True, + "max_length": 64, + }, + "images_kwargs": { + "max_num_patches": 256, + "patch_size": 16, + }, + } + + +class Siglip2Processor(ProcessorMixin): + r""" + Constructs a Siglip2 processor which wraps a Siglip2 image processor and a Gemma tokenizer into a single processor. + + [`Siglip2Processor`] offers all the functionalities of [`Siglip2ImageProcessor`] and [`GemmaTokenizerFast`]. See the + [`~Siglip2Processor.__call__`] and [`~Siglip2Processor.decode`] for more information. + + Args: + image_processor ([`Siglip2ImageProcessor`]): + The image processor is a required input. + tokenizer ([`GemmaTokenizerFast`]): + The tokenizer is a required input. + """ + + attributes = ["image_processor", "tokenizer"] + + image_processor_class = "AutoImageProcessor" + tokenizer_class = "AutoTokenizer" + + def __init__(self, image_processor, tokenizer): + super().__init__(image_processor, tokenizer) + + def __call__( + self, + images: Optional[Union[ImageInput, List[ImageInput], List[List[ImageInput]]]] = None, + text: Optional[Union[TextInput, "PreTokenizedInput", List[TextInput], List["PreTokenizedInput"]]] = None, + audio=None, + videos=None, + **kwargs: Unpack[Siglip2ProcessorKwargs], + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` + and `kwargs` arguments to GemmaTokenizerFast's [`~GemmaTokenizerFast.__call__`] if `text` is not `None` to encode + the text. To prepare the image(s), this method forwards the `images` argument to + Siglip2ImageProcessor's [`~Siglip2ImageProcessor.__call__`] if `images` is not `None`. Please refer to the docstring + of the above two methods for more information. + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `max_length`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding + index) among: + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + max_length (`int`, *optional*, defaults to 64): + Maximum length of the returned list and optionally padding length (see above). + truncation (`bool`, *optional*, defaults to `True`): + Activates truncation to cut input sequences longer than `max_length` to `max_length`. + return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'pt'`): + If set, will return tensors of a particular framework. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + - **pixel_attention_mask** -- Attention mask for the pixel values. Returned when `images` is not `None`. + - **spatial_shapes** -- The number of horizontal and vertical patches per image. + Returned when `images` is not `None`. + """ + output_kwargs = self._merge_kwargs( + Siglip2ProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + + if text is None and images is None: + raise ValueError("You have to specify either text or images. Both cannot be none.") + + if text is not None: + encoding = self.tokenizer(text, **output_kwargs["text_kwargs"]) + + if images is not None: + image_features = self.image_processor(images, **output_kwargs["images_kwargs"]) + + if text is not None and images is not None: + encoding.update(image_features) + return encoding + elif text is not None: + return encoding + else: + return_tensors = output_kwargs["common_kwargs"]["return_tensors"] + return BatchFeature(data=dict(**image_features), tensor_type=return_tensors) + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to Siglip2Tokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to Siglip2Tokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) + + +__all__ = ["Siglip2Processor"] diff --git a/docs/transformers/src/transformers/models/smolvlm/__init__.py b/docs/transformers/src/transformers/models/smolvlm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..706c9928300f408dbae61df4ccd17879d1f53fdf --- /dev/null +++ b/docs/transformers/src/transformers/models/smolvlm/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_smolvlm import * + from .image_processing_smolvlm import * + from .modeling_smolvlm import * + from .processing_smolvlm import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/smolvlm/configuration_smolvlm.py b/docs/transformers/src/transformers/models/smolvlm/configuration_smolvlm.py new file mode 100644 index 0000000000000000000000000000000000000000..cd8544156832ece63f5eeac3dc815e33d9ed8c20 --- /dev/null +++ b/docs/transformers/src/transformers/models/smolvlm/configuration_smolvlm.py @@ -0,0 +1,196 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/smolvlm/modular_smolvlm.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_smolvlm.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 the HuggingFace Inc. team. All rights reserved. +# Written by Orr Zohar +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ..auto import CONFIG_MAPPING, AutoConfig + + +logger = logging.get_logger(__name__) + + +class SmolVLMVisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`SmolVLMVisionModel`]. It is used to instantiate a + SmolVLM vision encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the SigLIP checkpoint + [google/siglip-so400m-patch14-384](https://huggingface.co/google/siglip-so400m-patch14-384) used in SmolVLM + [HuggingFaceTB/SmolVLM2-2.2B-Instruct](https://huggingface.co/HuggingFaceTB/SmolVLM2-2.2B-Instruct). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 1152): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + num_channels (`int`, *optional*, defaults to 3): + Number of channels in the input images. + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 32): + The size (resolution) of each patch. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + + Example: + + ```python + >>> from transformers.models.smolvlm.modeling_smolvlm import SmolVLMVisionTransformer + >>> from transformers.models.smolvlm.configuration_smolvlm import SmolVLMVisionConfig + + >>> # Initializing a SmolVLMVisionConfig with google/siglip-so400m-patch14-384 style configuration + >>> configuration = SmolVLMVisionConfig() + + >>> # Initializing a SmolVLMVisionTransformer (with random weights) from the google/siglip-so400m-patch14-384 style configuration + >>> model = SmolVLMVisionTransformer(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "smolvlm_vision" + base_config_key = "vision_config" + + def __init__( + self, + hidden_size=1152, + intermediate_size=3072, + num_hidden_layers=12, + num_attention_heads=16, + num_channels=3, + image_size=224, + patch_size=32, + hidden_act="gelu_pytorch_tanh", + layer_norm_eps=1e-6, + attention_dropout=0.0, + initializer_range=0.02, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.image_size = image_size + self.attention_dropout = attention_dropout + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + self.initializer_range = initializer_range + + +class SmolVLMConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`SmolVLMModel`]. It is used to instantiate a + SmolVLM model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the model of the SmolVLM + [HuggingFaceTB/SmolVLM2-2.2B-Instruct](https://huggingface.co/HuggingFaceTB/SmolVLM2-2.2B-Instruct) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should cache the key/value pairs of the attention mechanism. Only + relevant if `config.is_decoder=True`. + image_token_id (`int`, *optional*, defaults to 128257): + The id of the "image" token. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether or not to tie the word embeddings with the token embeddings. + vision_config (`IdeficsVisionConfig` or `dict`, *optional*, defaults to `IdeficsVisionConfig`): + Custom vision config or dict for the vision tower + text_config (`PretrainedConfig` or `dict`, *optional*, defaults to `LlamaConfig`): + Custom text config or dict for the text model + scale_factor (`int`, *optional*, defaults to 2): + The scale factor for the image encoder. + pad_token_id (`int`, *optional*, defaults to 128002): + The id of the padding token. + + Example: + ```python + >>> from transformers import SmolVLMModel, SmolVLMConfig + >>> # Initializing configuration + >>> configuration = SmolVLMConfig() + >>> # Initializing a model from the configuration + >>> model = SmolVLMModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "smolvlm" + sub_configs = {"text_config": AutoConfig, "vision_config": SmolVLMVisionConfig} + + def __init__( + self, + use_cache=True, + image_token_id=128257, + tie_word_embeddings=False, + vision_config=None, + text_config=None, + scale_factor=2, + pad_token_id=128_002, + **kwargs, + ): + self.image_token_id = image_token_id + self.use_cache = use_cache + self.tie_word_embeddings = tie_word_embeddings + + if vision_config is None: + self.vision_config = SmolVLMVisionConfig() + logger.info("vision_config is None, using default vision config") + elif isinstance(vision_config, dict): + self.vision_config = SmolVLMVisionConfig(**vision_config) + elif isinstance(vision_config, SmolVLMVisionConfig): + self.vision_config = vision_config + + if isinstance(text_config, dict): + text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama" + text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) + elif text_config is None: + logger.info("text_config is None, using default text config") + text_config = CONFIG_MAPPING["llama"]( + rms_norm_eps=1e-5, + pad_token_id=pad_token_id, + tie_word_embeddings=False, + ) + + self.text_config = text_config + self.scale_factor = scale_factor + + super().__init__(**kwargs, pad_token_id=pad_token_id, tie_word_embeddings=tie_word_embeddings) + + +__all__ = ["SmolVLMVisionConfig", "SmolVLMConfig"] diff --git a/docs/transformers/src/transformers/models/smolvlm/image_processing_smolvlm.py b/docs/transformers/src/transformers/models/smolvlm/image_processing_smolvlm.py new file mode 100644 index 0000000000000000000000000000000000000000..abf9353fd4ea8c919f337c46dadbc305039c96f5 --- /dev/null +++ b/docs/transformers/src/transformers/models/smolvlm/image_processing_smolvlm.py @@ -0,0 +1,851 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/smolvlm/modular_smolvlm.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_smolvlm.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 the HuggingFace Inc. team. All rights reserved. +# Written by Orr Zohar +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import Dict, Iterable, List, Optional, Tuple, Union + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature +from ...image_transforms import PaddingMode, pad, to_channel_dimension_format, to_pil_image +from ...image_utils import ( + IMAGENET_STANDARD_MEAN, + IMAGENET_STANDARD_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + get_image_size, + infer_channel_dimension_format, + is_scaled_image, + make_nested_list_of_images, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...utils import TensorType, is_vision_available, logging + + +if is_vision_available(): + import PIL + from PIL import Image + + +logger = logging.get_logger(__name__) +MAX_IMAGE_SIZE = 4096 # 4k resolution as absolute maximum + + +def _resize_output_size_rescale_to_max_len( + height: int, width: int, min_len: Optional[int] = 1, max_len: Optional[int] = None +) -> Tuple[int, int]: + """ + Get the output size of the image after resizing given a dictionary specifying the max and min sizes. + Args: + height (`int`): + Height of the input image. + width (`int`): + Width of the input image. + min_len (`int`, *optional*, defaults to 1): + Minimum size of the output image. + max_len (`int`, *optional*, defaults to the maximum size of the image): + Maximum size of the output image. + Returns: + The output size of the image after resizing. + """ + max_len = max(height, width) if max_len is None else max_len + aspect_ratio = width / height + + if width >= height: + width = max_len + height = int(width / aspect_ratio) + if height % 2 != 0: + height += 1 + elif height > width: + height = max_len + width = int(height * aspect_ratio) + if width % 2 != 0: + width += 1 + + # Avoid resizing to a size smaller than min_len + height = max(height, min_len) + width = max(width, min_len) + return height, width + + +def _resize_output_size_scale_below_upper_bound( + height: int, width: int, max_len: Optional[Dict[str, int]] = None +) -> Tuple[int, int]: + """ + Get the output size of the image after resizing given a dictionary specifying the max and min sizes. + Args: + height (`int`): + Height of the input image. + width (`int`): + Width of the input image. + max_len (`Dict[str, int]`, *optional*, defaults to the maximum size of the image): + Defines the maximum dimensions of the image. + Returns: + The output size of the image after resizing. + """ + max_len = max(height, width) if max_len is None else max_len + + aspect_ratio = width / height + if width >= height and width > max_len: + width = max_len + height = int(width / aspect_ratio) + elif height > width and height > max_len: + height = max_len + width = int(height * aspect_ratio) + + # Avoid resizing to a size smaller than 1 + height = max(height, 1) + width = max(width, 1) + return height, width + + +def get_resize_output_image_size( + image, + resolution_max_side: int, + input_data_format: Optional[Union[str, ChannelDimension]] = None, +) -> Tuple[int, int]: + """ + Get the output size of the image after resizing given a dictionary specifying the max and min sizes. + Args: + image (`np.ndarray`): + Image to resize. + resolution_max_side (`int`): + The longest edge of the image will be resized to this value. The shortest edge will be resized to keep the + input aspect ratio. + input_data_format (`ChannelDimension` or `str`): + The channel dimension format of the input image. + Returns: + The output size of the image after resizing. + """ + height, width = get_image_size(image, channel_dim=input_data_format) + + # Find the output size, when rescaling the longest edge to max_len and preserving the aspect ratio + height, width = _resize_output_size_rescale_to_max_len(height, width, max_len=resolution_max_side) + # Find the output size when scaling the image to be below the MAX_IMAGE_SIZE + height, width = _resize_output_size_scale_below_upper_bound(height, width, max_len=MAX_IMAGE_SIZE) + return height, width + + +def get_max_height_width( + images_list: List[List[np.ndarray]], input_data_format: Optional[Union[str, ChannelDimension]] = None +) -> List[int]: + """ + Get the maximum height and width across all images in a batch. + """ + if input_data_format is None: + input_data_format = infer_channel_dimension_format(images_list[0][0], num_channels=(1, 3, 4)) + + max_height = max_width = float("-inf") + for images in images_list: + for image in images: + height, width = get_image_size(image, channel_dim=input_data_format) + max_height = max(height, max_height) + max_width = max(width, max_width) + return (max_height, max_width) + + +def make_pixel_mask( + image: np.ndarray, output_size: Tuple[int, int], input_data_format: Optional[Union[str, ChannelDimension]] = None +) -> np.ndarray: + """ + Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding. + Args: + image (`np.ndarray`): + Image to make the pixel mask for. + output_size (`Tuple[int, int]`): + Output size of the mask. + """ + input_height, input_width = get_image_size(image, channel_dim=input_data_format) + mask = np.zeros(output_size, dtype=np.int64) + mask[:input_height, :input_width] = 1 + return mask + + +def convert_to_rgb( + image: np.ndarray, + palette: Optional[PIL.ImagePalette.ImagePalette] = None, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, +) -> ImageInput: + """ + Converts an image to RGB format. + Args: + image (`np.ndarray`): + The image to convert. + palette (List[int], *optional*): + The palette to use if given. + data_format (ChannelDimension or str, *optional*): + The channel dimension format for the output image. If not provided, it will be the same as the input image. + input_data_format (ChannelDimension or str, *optional*): + The channel dimension format of the input image. + """ + if input_data_format is None: + input_data_format = infer_channel_dimension_format(image, num_channels=(1, 3, 4)) + + # For all transformations, we want to keep the same data format as the input image unless otherwise specified. + # The resized image from PIL will always have channels last, so find the input format first. + data_format = input_data_format if data_format is None else data_format + + mode = "P" if palette is not None else None + image = to_pil_image(image, image_mode=mode, input_data_format=input_data_format) + if image.mode == "P" and palette is not None: + image.putpalette(palette) + + image_rgba = image.convert("RGBA") + background = Image.new("RGBA", image_rgba.size, (255, 255, 255)) + alpha_composite = Image.alpha_composite(background, image_rgba) + alpha_composite = alpha_composite.convert("RGB") + + output_array = np.array(alpha_composite) + # The image is always in channels last format after converting from a PIL image + output_array = to_channel_dimension_format(output_array, data_format, input_channel_dim=ChannelDimension.LAST) + return output_array + + +# FIXME Amy: make a more general crop function that isn't just centre crop +def _crop( + image: np.ndarray, + w1: int, + h1: int, + w2: int, + h2: int, + data_format: Optional[Union[str, ChannelDimension]] = None, +) -> np.ndarray: + if data_format is None: + data_format = infer_channel_dimension_format(image, num_channels=(1, 3, 4)) + + if data_format == ChannelDimension.FIRST: + image = image[:, h1:h2, w1:w2] + elif data_format == ChannelDimension.LAST: + image = image[h1:h2, w1:w2, :] + else: + raise ValueError("Invalid channel dimension format.") + + return image + + +class SmolVLMImageProcessor(BaseImageProcessor): + r""" + Constructs a SmolVLM image processor. + Args: + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. This is useful if the input image is of a different format e.g. RGBA. + Only has an effect if the input image is in the PIL format. + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image. The longest edge of the image is resized to be <= `size["longest_edge"]`, with the + shortest edge resized to keep the input aspect ratio. + size (`Dict`, *optional*, defaults to `{"longest_edge": 4 * 364}`): + Controls the size of the output image. This is a dictionary containing the key "longest_edge". + The image will be resized such that the longest edge is <= `size["longest_edge"]` and the shortest edge is resized + to keep the input aspect ratio. + resample (`Resampling`, *optional*, defaults to `Resampling.LANCZOS`): + Resampling filter to use when resizing the image. + do_image_splitting (`bool`, *optional*, defaults to `True`): + Whether to split the image into sub-images concatenated with the original image. They are split into patches + such that each patch has a size of `max_image_size["height"]` x `max_image_size["width"]`. + max_image_size (`Dict`, *optional*, defaults to `{"longest_edge": 364}`): + Maximum resolution of the patches of images accepted by the model. This is a dictionary containing the key "longest_edge". + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image. If set to `True`, the image is rescaled to have pixel values between 0 and 1. + rescale_factor (`float`, *optional*, defaults to `1/255`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. If set to `True`, the image is normalized to have a mean of `image_mean` and + a standard deviation of `image_std`. + image_mean (`float` or `List[float]`, *optional*, defaults to `IDEFICS_STANDARD_MEAN`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be + overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*, defaults to `IDEFICS_STANDARD_STD`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + Can be overridden by the `image_std` parameter in the `preprocess` method. + do_pad (`bool`, *optional*, defaults to `True`): + Whether or not to pad the images to the largest height and width in the batch and number of images per + sample in the batch, such that the returned tensor is of shape (batch_size, max_num_images, num_channels, max_height, max_width). + """ + + model_input_names = ["pixel_values", "pixel_attention_mask"] + + def __init__( + self, + do_convert_rgb: bool = True, + do_resize: bool = True, + size: Dict[str, int] = None, + resample: PILImageResampling = PILImageResampling.LANCZOS, + do_image_splitting: bool = True, + max_image_size: Dict[str, int] = None, + do_rescale: bool = True, + rescale_factor: float = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_pad: bool = True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.do_convert_rgb = do_convert_rgb + self.do_resize = do_resize + self.size = size if size is not None else {"longest_edge": 4 * 364} + self.resample = resample + self.do_image_splitting = do_image_splitting + self.max_image_size = max_image_size if max_image_size is not None else {"longest_edge": 364} + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN + self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD + self.do_pad = do_pad + + def resize( + self, + image: np.ndarray, + size: Dict[str, int], + resample: PILImageResampling = PILImageResampling.LANCZOS, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize an image. The longest edge of the image is resized to size["longest_edge"], with the shortest edge + resized to keep the input aspect ratio. Can also be used with size["height"] and size["width"]. + Args: + image (`np.ndarray`): + Image to resize. + size (`Dict[str, int]`): + Size of the output image. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.LANCZOS`): + Resampling filter to use when resizing the image. + data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the output image. If not provided, it will be the same as the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + if input_data_format is None: + input_data_format = infer_channel_dimension_format(image, num_channels=(1, 3, 4)) + + # For all transformations, we want to keep the same data format as the input image unless otherwise specified. + # The resized image from PIL will always have channels last, so find the input format first. + data_format = input_data_format if data_format is None else data_format + + if "longest_edge" in size: + size = get_resize_output_image_size( + image, resolution_max_side=size["longest_edge"], input_data_format=input_data_format + ) + elif "height" in size and "width" in size: + size = (size["height"], size["width"]) + else: + raise ValueError("size must be a dictionary with key 'longest_edge' or 'height' and 'width'.") + + image_mode = None + if image.ndim == 2 or image.shape[-1] == 1: + image_mode = "P" + image = to_pil_image(image, image_mode=image_mode, input_data_format=input_data_format) + + resized_image = image.resize((size[1], size[0]), resample=resample) + resized_image = np.array(resized_image) + + # If the input image channel dimension was of size 1, then it is dropped when converting to a PIL image + # so we need to add it back if necessary. + resized_image = np.expand_dims(resized_image, axis=-1) if resized_image.ndim == 2 else resized_image + # The image is always in channels last format after converting from a PIL image + resized_image = to_channel_dimension_format( + resized_image, data_format, input_channel_dim=ChannelDimension.LAST + ) + return resized_image + + def split_image( + self, + image, + max_image_size: Dict[str, int], + resample: PILImageResampling = PILImageResampling.LANCZOS, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + """ + Split an image into squares of side max_image_size and the original image resized to max_image_size. + That means that a single image becomes a sequence of images. + This is a "trick" to spend more compute on each image with no changes in the vision encoder. + 1) If one side of the original image is larger than `max_image_size`, resize it to `max_image_size` while preserving the aspect ratio. + 2) Divide the resulting image into `ceil(height / max_image_size)` x `ceil(width / max_image_size)` + sub-images of the same size each (image_size, image_size). Typically, 364x364. + 3) Returns the list of the crops and the original image, in addition to the number of splits for the height and the width. + Args: + image (`np.ndarray`): + Images to split. + max_image_size (`Dict[str, int]`): + Maximum size of the output image. If the image is larger than this size, it will be split into + patches of this size, and the original image will be concatenated with the patches, resized to max_size. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.LANCZOS`): + Resampling filter to use when resizing the image. + data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the output image. If not provided, it will be the same as the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + height, width = get_image_size(image, channel_dim=input_data_format) + max_height = max_width = max_image_size["longest_edge"] + + frames = [] + if height > max_height or width > max_width: + # Calculate the number of splits + num_splits_h = math.ceil(height / max_height) + num_splits_w = math.ceil(width / max_width) + # Calculate the optimal width and height for the sub-images + optimal_height = math.ceil(height / num_splits_h) + optimal_width = math.ceil(width / num_splits_w) + + # Iterate through each row and column + for r in range(num_splits_h): + for c in range(num_splits_w): + # Calculate the starting point of the crop + start_x = c * optimal_width + start_y = r * optimal_height + + # Calculate the ending point of the crop + end_x = min(start_x + optimal_width, width) + end_y = min(start_y + optimal_height, height) + + # Crop the image + cropped_image = _crop( + image, + start_x, + start_y, + end_x, + end_y, + data_format=data_format, + ) + frames.append(cropped_image) + + # For the global image at the end, we resize it to match the max_image_size, for cpu memory efficiency + global_image_height, global_image_width = max_height, max_width + if height != global_image_height or width != global_image_width: + image = self.resize( + image, + {"height": global_image_height, "width": global_image_width}, + resample=resample, + input_data_format=data_format, + ) + else: + num_splits_h, num_splits_w = 0, 0 + + frames.append(image) + + return frames, num_splits_h, num_splits_w + + def resize_for_vision_encoder( + self, + image: np.ndarray, + vision_encoder_max_size: int, + resample: PILImageResampling = PILImageResampling.LANCZOS, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + """ + Resize images to be multiples of `vision_encoder_max_size` while preserving the aspect ratio. + Args: + image (`np.ndarray`): + Images to resize. + vision_encoder_max_size (`int`): + Maximum size of the output image. If the image is larger than this size, it will be split into + patches of this size, and the original image will be concatenated with the patches, resized to max_size. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.LANCZOS`): + Resampling filter to use when resizing the image. + data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the output image. If not provided, it will be the same as the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred + """ + height, width = get_image_size(image, channel_dim=input_data_format) + + aspect_ratio = width / height + if width >= height: + width = math.ceil(width / vision_encoder_max_size) * vision_encoder_max_size + height = int(width / aspect_ratio) + height = math.ceil(height / vision_encoder_max_size) * vision_encoder_max_size + elif height > width: + height = math.ceil(height / vision_encoder_max_size) * vision_encoder_max_size + width = int(height * aspect_ratio) + width = math.ceil(width / vision_encoder_max_size) * vision_encoder_max_size + new_size = {"height": height, "width": width} + return self.resize( + image, size=new_size, resample=resample, input_data_format=input_data_format, data_format=data_format + ) + + def _pad_image( + self, + image: np.ndarray, + output_size: Tuple[int, int], + constant_values: Union[float, Iterable[float]] = 0, + data_format: Optional[ChannelDimension] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> np.ndarray: + """ + Pad an image with zeros to the given size. + """ + input_height, input_width = get_image_size(image, channel_dim=input_data_format) + output_height, output_width = output_size + + pad_bottom = output_height - input_height + pad_right = output_width - input_width + padding = ((0, pad_bottom), (0, pad_right)) + padded_image = pad( + image, + padding, + mode=PaddingMode.CONSTANT, + constant_values=constant_values, + data_format=data_format, + input_data_format=input_data_format, + ) + return padded_image + + def pad( + self, + images: List[np.ndarray], + constant_values: Union[float, Iterable[float]] = 0, + return_pixel_mask: bool = True, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Optional[ChannelDimension] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> BatchFeature: + """ + For a list of images, for each images, pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width. + For each sample in the batch, pads the sample with empty images to the max_number of images per sample in the batch. Optionally returns a pixel mask. + Args: + images (`List[np.ndarray]`): + List of list of images to pad. Pads to the largest height and width in the batch. + constant_values (`float` or `Iterable[float]`, *optional*): + The value to use for the padding if `mode` is `"constant"`. + return_pixel_mask (`bool`, *optional*, defaults to `True`): + Whether to return a pixel mask. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be the same as the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + pad_size = get_max_height_width(images, input_data_format=input_data_format) + + batch_size = len(images) + max_num_images = max(len(images_) for images_ in images) + input_data_format = ( + infer_channel_dimension_format(images[0][0], num_channels=(1, 3, 4)) + if input_data_format is None + else input_data_format + ) + data_format = input_data_format if data_format is None else data_format + + if input_data_format == ChannelDimension.FIRST: + n_channels = images[0][0].shape[0] + elif input_data_format == ChannelDimension.LAST: + n_channels = images[0][0].shape[-1] + else: + raise ValueError("Invalid channel dimension format.") + + def empty_image(size, input_data_format): + if input_data_format == ChannelDimension.FIRST: + return np.zeros((n_channels, *size), dtype=np.uint8) + elif input_data_format == ChannelDimension.LAST: + return np.zeros((*size, n_channels), dtype=np.uint8) + + padded_images_list = [ + [empty_image(pad_size, data_format) for _ in range(max_num_images)] for _ in range(batch_size) + ] + padded_masks = [[np.zeros(pad_size, dtype=np.int64) for _ in range(max_num_images)] for _ in range(batch_size)] + + for batch_idx in range(batch_size): + for sample_idx, image in enumerate(images[batch_idx]): + padded_images_list[batch_idx][sample_idx] = self._pad_image( + image, + pad_size, + constant_values=constant_values, + data_format=data_format, + input_data_format=input_data_format, + ) + padded_masks[batch_idx][sample_idx] = make_pixel_mask( + image, output_size=pad_size, input_data_format=input_data_format + ) + + padded_masks = padded_masks if return_pixel_mask else None + return padded_images_list, padded_masks + + def preprocess( + self, + images: ImageInput, + do_convert_rgb: Optional[bool] = None, + do_resize: Optional[bool] = None, + size: Optional[Dict[str, int]] = None, + resample: PILImageResampling = None, + do_image_splitting: Optional[bool] = None, + do_rescale: Optional[bool] = None, + max_image_size: Optional[Dict[str, int]] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_pad: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_row_col_info: bool = False, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + """ + Preprocess a batch of images. + Args: + images (`ImageInput`): + A list of images to preprocess. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Size of the image after resizing. With the longest edge resized to keep the input aspect ratio. + resample (`int`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only + has an effect if `do_resize` is set to `True`. + do_image_splitting (`bool`, *optional*, defaults to `self.do_image_splitting`): + Whether to split the image into sub-images concatenated with the original image. They are split into patches + such that each patch has a size of `max_image_size["height"]` x `max_image_size["width"]`. + max_image_size (`Dict`, *optional*, defaults to `self.max_image_size`): + Maximum resolution of the images. If the image is larger than this size, the image is split into patches. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to + `True`. + do_pad (`bool`, *optional*, defaults to `self.do_pad`): + Whether or not to pad the images to the largest height and width in the batch. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + return_row_col_info (`bool`, *optional*, default to `False`): + Whether to return the number of rows and columns of the split images. This is used for the + `SmolVLMProcessor` to generate prompt strings based on the number of rows and columns. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + do_resize = do_resize if do_resize is not None else self.do_resize + size = size if size is not None else self.size + resample = resample if resample is not None else self.resample + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_image_splitting = do_image_splitting if do_image_splitting is not None else self.do_image_splitting + max_image_size = max_image_size if max_image_size is not None else self.max_image_size + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + do_pad = do_pad if do_pad is not None else self.do_pad + + images_list = make_nested_list_of_images(images) + + if not valid_images(images_list[0]): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_resize=do_resize, + size=size, + resample=resample, + ) + + # save the palettes for conversion to RGB + palettes_list = [ + [im.getpalette() if isinstance(im, Image.Image) and im.mode == "P" else None for im in images] + for images in images_list + ] + + # All transformations expect numpy arrays. + images_list = [[to_numpy_array(image) for image in images] for images in images_list] + + # Extra channel dimension for grayscale images + if input_data_format in [ChannelDimension.LAST, None]: + images_list = [ + [np.expand_dims(img, axis=-1) if img.ndim == 2 else img for img in images] for images in images_list + ] + elif input_data_format == ChannelDimension.FIRST: + images_list = [ + [np.expand_dims(img, axis=0) if img.ndim == 2 else img for img in images] for images in images_list + ] + + if do_rescale and is_scaled_image(images_list[0][0]): + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + # We assume that all images have the same channel dimension format. + if input_data_format is None: + input_data_format = infer_channel_dimension_format(images_list[0][0], num_channels=(1, 3, 4)) + + if do_resize: + images_list = [ + [ + self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format) + for image in images + ] + for images in images_list + ] + + if do_image_splitting: + # We first resize both height and width of each image to the nearest max_image_size multiple, disregarding the aspect ratio + # for size=(10, max_image_size) -> rescaled_size=(max_image_size, max_image_size) + # for size=(11, max_image_size+1) -> rescaled_size=(max_image_size, max_image_size*2) + images_list = [ + [ + self.resize_for_vision_encoder( + image, max_image_size["longest_edge"], resample=resample, input_data_format=input_data_format + ) + for image in images + ] + for images in images_list + ] + images_list_split_arrays = [] + palettes_list_split_arrays = [] + images_list_rows = [] + images_list_cols = [] + for images, palettes in zip(images_list, palettes_list): + split_image_arrays = [] + split_palettes_arrays = [] + image_rows = [] + image_cols = [] + for image, palette in zip(images, palettes): + split_image_array, rows, cols = self.split_image( + image, + max_image_size=max_image_size, + input_data_format=input_data_format, + ) + split_image_arrays.extend(split_image_array) + split_palettes_arrays.extend([palette] * len(split_image_array)) + image_rows.append(rows) + image_cols.append(cols) + images_list_split_arrays.append(split_image_arrays) + palettes_list_split_arrays.append(split_palettes_arrays) + images_list_rows.append(image_rows) + images_list_cols.append(image_cols) + images_list = images_list_split_arrays + palettes_list = palettes_list_split_arrays + else: + # We square the images to max_image_size + images_list = [ + [ + self.resize( + image=image, + size={"height": max_image_size["longest_edge"], "width": max_image_size["longest_edge"]}, + resample=resample, + input_data_format=input_data_format, + ) + for image in images + ] + for images in images_list + ] + images_list_rows = [[0] * len(images) for images in images_list] + images_list_cols = [[0] * len(images) for images in images_list] + + if do_convert_rgb: + images_list = [ + [convert_to_rgb(img, palette) for img, palette in zip(images, palettes)] + for images, palettes in zip(images_list, palettes_list) + ] + + if do_rescale: + images_list = [ + [self.rescale(image, rescale_factor, input_data_format=input_data_format) for image in images] + for images in images_list + ] + + if do_normalize: + images_list = [ + [ + self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) + for image in images + ] + for images in images_list + ] + + pixel_attention_mask = None + if do_pad: + images_list, pixel_attention_mask = self.pad( + images_list, return_pixel_mask=True, return_tensors=return_tensors, input_data_format=input_data_format + ) + + if data_format is not None: + images_list = [ + [ + to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + for image in images + ] + for images in images_list + ] + + # Faster tensor conversion + data = {"pixel_values": np.array(images_list) if do_pad and return_tensors is not None else images_list} + if pixel_attention_mask is not None: + data["pixel_attention_mask"] = ( + np.array(pixel_attention_mask) if do_pad and return_tensors is not None else pixel_attention_mask + ) + + encoding = BatchFeature(data=data, tensor_type=return_tensors) + + # This is needed for generating correct text inputs in the processor - we don't pad to the max number of images + if return_row_col_info: + encoding["rows"] = images_list_rows + encoding["cols"] = images_list_cols + + return encoding + + +__all__ = ["SmolVLMImageProcessor"] diff --git a/docs/transformers/src/transformers/models/smolvlm/modeling_smolvlm.py b/docs/transformers/src/transformers/models/smolvlm/modeling_smolvlm.py new file mode 100644 index 0000000000000000000000000000000000000000..6c33479e9134949703fb9838dbcf669f2b348c79 --- /dev/null +++ b/docs/transformers/src/transformers/models/smolvlm/modeling_smolvlm.py @@ -0,0 +1,1161 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/smolvlm/modular_smolvlm.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_smolvlm.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 the HuggingFace Inc. team. All rights reserved. +# Written by Orr Zohar +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Callable, List, Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss + +from ...activations import ACT2FN +from ...cache_utils import DynamicCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask +from ...modeling_outputs import BaseModelOutput, ModelOutput +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from ..auto import AutoModel +from .configuration_smolvlm import SmolVLMConfig, SmolVLMVisionConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "SmolVLMConfig" + + +SMOLVLM_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`SmolVLMConfig`] or [`SmolVLMVisionConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare SmolVLM Model outputting raw hidden-states without any specific head on top.", + SMOLVLM_START_DOCSTRING, +) +class SmolVLMPreTrainedModel(PreTrainedModel): + config_class = SmolVLMConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["SmolVLMVisionAttention", "SmolVLMDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_cache_class = True + + def _init_weights(self, module): + std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) + + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + + +class SmolVLMVisionEmbeddings(nn.Module): + """ + This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings` to enable images of variable + resolution. + + The modifications are adapted from [Patch n' Pack: NaViT, a Vision Transformer for any Aspect Ratio and Resolution](https://arxiv.org/abs/2307.06304) + which allows treating images in their native aspect ratio and without the need to resize them to the same + fixed size. In particular, we start from the original pre-trained SigLIP model + (which uses images of fixed-size square images) and adapt it by training on images of variable resolutions. + """ + + def __init__(self, config: SmolVLMVisionConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + ) + + self.num_patches_per_side = self.image_size // self.patch_size + self.num_patches = self.num_patches_per_side**2 + self.num_positions = self.num_patches + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + + def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor) -> torch.Tensor: + batch_size, _, max_im_h, max_im_w = pixel_values.shape + + patch_embeds = self.patch_embedding(pixel_values) + embeddings = patch_embeds.flatten(2).transpose(1, 2) + + max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size + boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side) + position_ids = torch.full(size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0) + + for batch_idx, p_attn_mask in enumerate(patch_attention_mask): + nb_patches_h = p_attn_mask[:, 0].sum() + nb_patches_w = p_attn_mask[0].sum() + + fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h) + fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w) + + bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True) + bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True) + + pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w).flatten() + position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids + + position_ids = position_ids.to(self.position_embedding.weight.device) + embeddings = embeddings + self.position_embedding(position_ids) + return embeddings + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class SmolVLMVisionAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + + # Ignore copy + self.is_causal = False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Input shape: Batch x Time x Channel""" + + batch_size, seq_length, embed_dim = hidden_states.shape + + queries = self.q_proj(hidden_states) + keys = self.k_proj(hidden_states) + values = self.v_proj(hidden_states) + + queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and output_attentions: + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + queries, + keys, + values, + attention_mask, + is_causal=self.is_causal, + scaling=self.scale, + dropout=0.0 if not self.training else self.dropout, + ) + + attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous() + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights + + +class SmolVLMVisionMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class SmolVLMEncoderLayer(nn.Module): + def __init__(self, config: SmolVLMVisionConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = SmolVLMVisionAttention(config) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = SmolVLMVisionMLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor]: + """ + Args: + hidden_states (`torch.FloatTensor`): + Input to the layer of shape `(batch, seq_len, embed_dim)`. + attention_mask (`torch.FloatTensor`): + Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class SmolVLMEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`SmolVLMEncoderLayer`]. + + Args: + config: SmolVLMConfig + """ + + def __init__(self, config: SmolVLMConfig): + super().__init__() + self.config = config + self.layers = nn.ModuleList([SmolVLMEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + # Ignore copy + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for encoder_layer in self.layers: + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + attention_mask, + output_attentions, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +SMOLVLM_VISION_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`SmolVLMVisionConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The SmolVLM Vision Transformer Model outputting raw image embedding.", + SMOLVLM_VISION_START_DOCSTRING, +) +class SmolVLMVisionTransformer(SmolVLMPreTrainedModel): + config_class = SmolVLMVisionConfig + _supports_sdpa = True + _supports_flash_attention_2 = True + _supports_flex_attn = True + + def __init__(self, config: SmolVLMVisionConfig): + super().__init__(config) + embed_dim = config.hidden_size + + self.embeddings = SmolVLMVisionEmbeddings(config) + self.encoder = SmolVLMEncoder(config) + self.patch_size = config.patch_size + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, value): + self.embeddings = value + + def forward( + self, + pixel_values, + patch_attention_mask: Optional[torch.BoolTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size = pixel_values.size(0) + if patch_attention_mask is None: + patch_size = self.patch_size + patch_attention_mask = torch.ones( + ( + batch_size, + pixel_values.size(2) // patch_size, + pixel_values.size(3) // patch_size, + ) + ) + patch_attention_mask = patch_attention_mask.to(dtype=torch.bool, device=pixel_values.device) + + hidden_states = self.embeddings(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask) + + patch_attention_mask = patch_attention_mask.view(batch_size, -1) + # The call to `_upad_input` in `_flash_attention_forward` is expensive + # So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence), + # avoiding passing the attention_mask, which is equivalent to attending to the full sequence + if not torch.any(~patch_attention_mask): + patch_attention_mask = None + elif not self._use_flash_attention_2: + patch_attention_mask = _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + attention_mask=patch_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + last_hidden_state = self.post_layernorm(last_hidden_state) + + if not return_dict: + return (last_hidden_state,) + encoder_outputs[1:] + + return BaseModelOutput( + last_hidden_state=last_hidden_state, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@dataclass +class SmolVLMBaseModelOutputWithPast(ModelOutput): + """ + Base class for SmolVLM model's outputs that may also contain a past key/values (to speed up sequential decoding). + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, + hidden_size)` is output. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if + `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, + encoder_sequence_length, embed_size_per_head)`. + Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if + `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` + input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + image_hidden_states (`tuple(torch.FloatTensor)`, *optional*): + Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images, + sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + + +class SmolVLMSimpleMLP(nn.Module): + def __init__(self, config): + super().__init__() + input_size = config.vision_config.hidden_size * (config.scale_factor**2) + output_size = config.text_config.hidden_size + self.proj = nn.Linear(input_size, output_size, bias=False) + + def forward(self, x): + return self.proj(x) + + +class SmolVLMConnector(nn.Module): + def __init__(self, config): + super().__init__() + self.scale_factor = config.scale_factor + self.modality_projection = SmolVLMSimpleMLP(config) + + def pixel_shuffle(self, x, scale_factor=2): + bsz, seq, embed_dim = x.size() + height = width = int(seq**0.5) + x = x.view(bsz, height, width, embed_dim) + x = x.view(bsz, height, int(width / scale_factor), embed_dim * scale_factor) + x = x.permute(0, 2, 1, 3) + x = x.reshape(bsz, int(width / scale_factor), int(height / scale_factor), embed_dim * (scale_factor**2)) + x = x.permute(0, 2, 1, 3) + x = x.reshape(bsz, int(seq / (scale_factor**2)), embed_dim * (scale_factor**2)) + return x + + def forward(self, image_hidden_states): + image_hidden_states = self.pixel_shuffle(image_hidden_states, self.scale_factor) + image_hidden_states = self.modality_projection(image_hidden_states) + return image_hidden_states + + +SMOLVLM_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)): + The tensors corresponding to the input images. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details ([]`LlavaProcessor`] uses + [`CLIPImageProcessor`] for processing images). + pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*): + Mask to avoid performing attention on padding pixel indices. + image_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The hidden states of the image encoder after modality projection. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + """SmolVLM model consisting of a SIGLIP vision encoder and Llama3 language decoder""", + SMOLVLM_START_DOCSTRING, +) +class SmolVLMModel(SmolVLMPreTrainedModel): + """ + A subclass of Idefics3Model. We do *not* remove or block the call to inputs_merger + in forward. Instead, we override inputs_merger here with custom logic. + """ + + def __init__(self, config: SmolVLMConfig): + super().__init__(config) + self.padding_idx = self.config.text_config.pad_token_id + self.vocab_size = self.config.text_config.vocab_size + + self.vision_model = SmolVLMVisionTransformer._from_config(config.vision_config) + self.connector = SmolVLMConnector(config) + self.text_model = AutoModel.from_config(config.text_config) + + self.image_seq_len = int( + ((config.vision_config.image_size // config.vision_config.patch_size) ** 2) / (config.scale_factor**2) + ) + self.image_token_id = self.config.image_token_id + + self._use_flash_attention_2 = config.text_config._attn_implementation == "flash_attention_2" + + self.post_init() + + def enable_input_require_grads(self): + """ + Enables the gradients for the input embeddings. + + This is useful for lora when using gradient checkpointing. + c.f. https://github.com/huggingface/peft/issues/1402#issuecomment-1913675032 + + Override to set output.requires_grad = True for both the decoder's and vision model's embeddings. + """ + + def get_lowest_module(module): + if len(list(module.children())) == 0: + # If the module has no children, it is a leaf module (e.g., Linear, Conv2d, etc.) + return module + else: + # Recursively call the function on each child module + return get_lowest_module(list(module.children())[0]) + + def make_inputs_require_grads(module, input, output): + output.requires_grad_(True) + + self._text_require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads) + self._vision_require_grads_hook = get_lowest_module(self.vision_model).register_forward_hook( + make_inputs_require_grads + ) + + def disable_input_require_grads(self): + self._text_require_grads_hook.remove() + self._vision_require_grads_hook.remove() + + def get_input_embeddings(self): + return self.text_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.text_model.set_input_embeddings(value) + + def inputs_merger( + self, input_ids: torch.LongTensor, inputs_embeds: torch.Tensor, image_hidden_states: torch.Tensor + ): + """ + This method aims at merging the token embeddings with the image hidden states into one single sequence of vectors that are fed to the transformer LM. + The merging happens as follows: + - The text token sequence is: `tok_1 tok_2 tok_3 ... tok_4`. + - We get the image hidden states for the image through the vision encoder and that hidden state, after a pixel shuffle operation, is then projected into the text embedding space. + We thus have a sequence of image hidden states of size (1, image_seq_len, hidden_dim), where 1 is for batch_size of 1 image and hidden_dim is the hidden_dim of the LM transformer. + - The merging happens so that we obtain the following sequence: `vector_tok_1 vector_tok_2 vector_tok_3 vector_fake_tok_around_image {sequence of image_seq_len image hidden states} vector_fake_toke_around_image vector_tok_4`. That sequence is fed to the LM. + - To fit the format of that sequence, `input_ids`, `input_embeds`, `attention_mask` are all 3 adapted to insert the image hidden states. + """ + _, patch_size, _ = image_hidden_states.shape + + image_mask = input_ids == self.image_token_id + num_image_tokens = image_mask.sum(dim=1) + if not torch.all(num_image_tokens % patch_size == 0): + raise ValueError("At least one sample has tokens not divisible by patch_size.") + + blocks_per_sample = num_image_tokens // patch_size + + offsets = torch.nn.functional.pad(blocks_per_sample.cumsum(dim=0), (1, 0), value=0) + block_offset = offsets[:-1] + row_cum = image_mask.cumsum(dim=-1) + chunk_idx = (row_cum - 1) // patch_size + local_idx = (row_cum - 1) % patch_size + block_idx = block_offset.unsqueeze(1) + chunk_idx + + image_embeds = torch.zeros_like(inputs_embeds) + image_embeds[image_mask] = image_hidden_states[block_idx[image_mask], local_idx[image_mask], :] + + merged_embeds = torch.where(image_mask.unsqueeze(-1), image_embeds, inputs_embeds) + return merged_embeds + + @add_start_docstrings_to_model_forward( + """ + Inputs fed to the model can have an arbitrary number of images. To account for this, pixel_values fed to + the model have image padding -> (batch_size, max_num_images, 3, max_heights, max_widths) where + max_num_images is the maximum number of images among the batch_size samples in the batch. + Padding images are not needed beyond padding the pixel_values at the entrance of the model. + For efficiency, we only pass through the vision_model's forward the real images by + discarding the padding images i.e. pixel_values of size (image_batch_size, 3, height, width) where + image_batch_size would be 7 when num_images_per_sample=[1, 3, 1, 2] and max_num_images would be 3. + """, + SMOLVLM_INPUTS_DOCSTRING, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + pixel_attention_mask: Optional[torch.BoolTensor] = None, + image_hidden_states: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, SmolVLMBaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.training and self.text_model.gradient_checkpointing and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # retrieve input_ids and inputs_embeds + if input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + past_seen_tokens = 0 + if use_cache: + if past_key_values is None: + past_key_values = DynamicCache() + past_seen_tokens = past_key_values.get_seq_length() + + if inputs_embeds is not None and input_ids is None and past_seen_tokens == 0: + raise ValueError("When first calling the model, if input_embeds are passed, input_ids should not be None.") + + if inputs_embeds is None: + inputs_embeds = self.text_model.get_input_embeddings()(input_ids).to(input_ids.device) + + # START VISUAL INPUTS INTEGRATION + if pixel_values is not None and image_hidden_states is not None: + raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time") + elif pixel_values is not None: + batch_size, num_images, num_channels, height, width = pixel_values.shape + pixel_values = pixel_values + pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:]) + + # Remove padding images - padding images are full 0. + nb_values_per_image = pixel_values.shape[1:].numel() + real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image + + if not any(real_images_inds): + # no images, leave one empty image. + real_images_inds[0] = True + + pixel_values = pixel_values[real_images_inds].contiguous() + + # Handle the vision attention mask + if pixel_attention_mask is None: + pixel_attention_mask = torch.ones( + size=[pixel_values.shape[i] for i in (0, 2, 3)], + dtype=torch.bool, + device=pixel_values.device, + ) + else: + # Remove padding images from the mask + pixel_attention_mask = pixel_attention_mask.view( + batch_size * num_images, *pixel_attention_mask.shape[2:] + ) + pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous() + + patch_size = self.config.vision_config.patch_size + patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size) + patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size) + patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() + + # Get sequence from the vision encoder + image_hidden_states = self.vision_model( + pixel_values=pixel_values, + patch_attention_mask=patch_attention_mask, + ).last_hidden_state + + # Modality projection & resampling + image_hidden_states = self.connector(image_hidden_states) + + elif image_hidden_states is not None: + image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device) + + if inputs_embeds is not None and image_hidden_states is not None: + # When we generate, we don't want to replace the potential image_token_id that we generated by images + # that simply don't exist + inputs_embeds = self.inputs_merger( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + image_hidden_states=image_hidden_states, + ) + + outputs = self.text_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + if not return_dict: + return tuple(v for v in [*outputs, image_hidden_states] if v is not None) + + return SmolVLMBaseModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_hidden_states, + ) + + +@dataclass +class SmolVLMCausalLMOutputWithPast(ModelOutput): + """ + Base class for Idefics causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + image_hidden_states (`tuple(torch.FloatTensor)`, *optional*): + Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images, + sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + past_key_values: Optional[List[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + + +@add_start_docstrings( + """The SmolVLM Model with a language modeling head. It is made up a SigLIP vision encoder, with a language modeling head on top. """, + SMOLVLM_START_DOCSTRING, +) +class SmolVLMForConditionalGeneration(SmolVLMPreTrainedModel, GenerationMixin): + """ + A subclass of Idefics3ForConditionalGeneration that uses SmolVLMModel + instead of the default Idefics3Model. + """ + + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = SmolVLMModel(config) + self.image_token_id = self.config.image_token_id + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.vocab_size = config.text_config.vocab_size + + # Initialize weights and apply final processing + self.post_init() + + def enable_input_require_grads(self): + """ + Enables the gradients for the input embeddings. This is useful for fine-tuning adapter weights while keeping + the model weights fixed. + """ + + def make_inputs_require_grads(module, input, output): + output.requires_grad_(True) + + self._text_require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads) + self._vision_require_grads_hook = self.model.vision_model.get_input_embeddings().register_forward_hook( + make_inputs_require_grads + ) + + def disable_input_require_grads(self): + self._text_require_grads_hook.remove() + self._vision_require_grads_hook.remove() + + def get_input_embeddings(self): + return self.model.text_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.text_model.set_input_embeddings(value) + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + @add_start_docstrings_to_model_forward(SMOLVLM_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=SmolVLMCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + pixel_attention_mask: Optional[torch.BoolTensor] = None, + image_hidden_states: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + return_dict: Optional[bool] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + ) -> Union[Tuple, SmolVLMCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or `model.image_token_id` (where `model` is your instance of `SmolVLMForConditionalGeneration`). + Tokens with indices set to `model.image_token_id` are ignored (masked), the loss is only + computed for the tokens with labels in `[0, ..., config.vocab_size]`. + Returns: + + Example: + + ```python + >>> import requests + >>> import torch + >>> from PIL import Image + >>> from io import BytesIO + + >>> from transformers import AutoProcessor, AutoModelForImageTextToText + >>> from transformers.image_utils import load_image + + >>> # Note that passing the image urls (instead of the actual pil images) to the processor is also possible + >>> image1 = load_image("https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg") + >>> image2 = load_image("https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg") + >>> image3 = load_image("https://cdn.britannica.com/68/170868-050-8DDE8263/Golden-Gate-Bridge-San-Francisco.jpg") + + >>> processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM2-2.2B-Instruct") + >>> model = AutoModelForImageTextToText.from_pretrained("HuggingFaceTB/SmolVLM2-2.2B-Instruct", torch_dtype=torch.bfloat16, device_map="auto") + + >>> # Create inputs + >>> messages = [ + ... { + ... "role": "user", + ... "content": [ + ... {"type": "video", "path": path/to/video}, + ... {"type": "text", "text": "What is happening in this video?"}, + ... ] + ... } + ... ] + + >>> inputs = processor.apply_chat_template([messages], add_generation_prompt=True) + + >>> # Generate + >>> generated_ids = model.generate(**inputs, max_new_tokens=256) + >>> generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True) + + >>> print(generated_texts) + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + pixel_values=pixel_values, + pixel_attention_mask=pixel_attention_mask, + image_hidden_states=image_hidden_states, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + labels = labels.to(logits.device) + # Shift so that tokens < n predict n + if attention_mask is not None: + # we use the input attention mask to shift the logits and labels, because it is 2D. + # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft + shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device) + shift_logits = logits[..., :-1, :][shift_attention_mask != 0].contiguous() + shift_labels = labels[..., 1:][shift_attention_mask != 0].contiguous() + else: + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return SmolVLMCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=outputs.image_hidden_states, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + pixel_values=None, + pixel_attention_mask=None, + image_hidden_states=None, + logits_to_keep=None, + **kwargs, + ): + # Overwritten -- there are mutually exclusive inputs (if the logic to make `image_hidden_states` take + # precedence is moved to the model, we can remove this fn) + + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + pixel_values=pixel_values, + pixel_attention_mask=pixel_attention_mask, + image_hidden_states=image_hidden_states, + logits_to_keep=logits_to_keep, + **kwargs, + ) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + # but IDEFICS requires both ids and embeds to be present + if inputs_embeds is not None and cache_position[0] == 0: + model_inputs["input_ids"] = input_ids + + if image_hidden_states is not None: + model_inputs["pixel_values"] = None + model_inputs["pixel_attention_mask"] = None + + return model_inputs + + def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder, **kwargs): + model_kwargs = super()._update_model_kwargs_for_generation( + outputs=outputs, + model_kwargs=model_kwargs, + is_encoder_decoder=is_encoder_decoder, + **kwargs, + ) + # Get the precomputed image_hidden_states + model_kwargs["image_hidden_states"] = outputs.image_hidden_states + return model_kwargs + + +__all__ = ["SmolVLMForConditionalGeneration", "SmolVLMPreTrainedModel", "SmolVLMModel", "SmolVLMVisionTransformer"] diff --git a/docs/transformers/src/transformers/models/smolvlm/modular_smolvlm.py b/docs/transformers/src/transformers/models/smolvlm/modular_smolvlm.py new file mode 100644 index 0000000000000000000000000000000000000000..4745fe30dad29b00b233311a71da3edece691820 --- /dev/null +++ b/docs/transformers/src/transformers/models/smolvlm/modular_smolvlm.py @@ -0,0 +1,400 @@ +# coding=utf-8 +# Copyright 2025 the HuggingFace Inc. team. All rights reserved. +# Written by Orr Zohar +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn + +from ...cache_utils import DynamicCache +from ...utils import ( + logging, +) +from ..idefics3.configuration_idefics3 import Idefics3Config, Idefics3VisionConfig +from ..idefics3.image_processing_idefics3 import Idefics3ImageProcessor +from ..idefics3.modeling_idefics3 import ( + Idefics3BaseModelOutputWithPast, + Idefics3ForConditionalGeneration, + Idefics3Model, + Idefics3PreTrainedModel, + Idefics3VisionTransformer, +) + + +logger = logging.get_logger(__name__) + + +class SmolVLMVisionConfig(Idefics3VisionConfig): + r""" + This is the configuration class to store the configuration of a [`SmolVLMVisionModel`]. It is used to instantiate a + SmolVLM vision encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the SigLIP checkpoint + [google/siglip-so400m-patch14-384](https://huggingface.co/google/siglip-so400m-patch14-384) used in SmolVLM + [HuggingFaceTB/SmolVLM2-2.2B-Instruct](https://huggingface.co/HuggingFaceTB/SmolVLM2-2.2B-Instruct). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 1152): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + num_channels (`int`, *optional*, defaults to 3): + Number of channels in the input images. + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 32): + The size (resolution) of each patch. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + + Example: + + ```python + >>> from transformers.models.smolvlm.modeling_smolvlm import SmolVLMVisionTransformer + >>> from transformers.models.smolvlm.configuration_smolvlm import SmolVLMVisionConfig + + >>> # Initializing a SmolVLMVisionConfig with google/siglip-so400m-patch14-384 style configuration + >>> configuration = SmolVLMVisionConfig() + + >>> # Initializing a SmolVLMVisionTransformer (with random weights) from the google/siglip-so400m-patch14-384 style configuration + >>> model = SmolVLMVisionTransformer(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "smolvlm_vision" + pass + + +class SmolVLMPreTrainedModel(Idefics3PreTrainedModel): + def _init_weights(self, module): + std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) + + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + + +class SmolVLMVisionTransformer(Idefics3VisionTransformer): + pass + + +class SmolVLMConfig(Idefics3Config): + r""" + This is the configuration class to store the configuration of a [`SmolVLMModel`]. It is used to instantiate a + SmolVLM model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the model of the SmolVLM + [HuggingFaceTB/SmolVLM2-2.2B-Instruct](https://huggingface.co/HuggingFaceTB/SmolVLM2-2.2B-Instruct) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should cache the key/value pairs of the attention mechanism. Only + relevant if `config.is_decoder=True`. + image_token_id (`int`, *optional*, defaults to 128257): + The id of the "image" token. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether or not to tie the word embeddings with the token embeddings. + vision_config (`IdeficsVisionConfig` or `dict`, *optional*, defaults to `IdeficsVisionConfig`): + Custom vision config or dict for the vision tower + text_config (`PretrainedConfig` or `dict`, *optional*, defaults to `LlamaConfig`): + Custom text config or dict for the text model + scale_factor (`int`, *optional*, defaults to 2): + The scale factor for the image encoder. + pad_token_id (`int`, *optional*, defaults to 128002): + The id of the padding token. + + Example: + ```python + >>> from transformers import SmolVLMModel, SmolVLMConfig + >>> # Initializing configuration + >>> configuration = SmolVLMConfig() + >>> # Initializing a model from the configuration + >>> model = SmolVLMModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "smolvlm" + pass + + +class SmolVLMImageProcessor(Idefics3ImageProcessor): + pass + + +class SmolVLMBaseModelOutputWithPast(Idefics3BaseModelOutputWithPast): + pass + + +class SmolVLMModel(Idefics3Model): + """ + A subclass of Idefics3Model. We do *not* remove or block the call to inputs_merger + in forward. Instead, we override inputs_merger here with custom logic. + """ + + def inputs_merger( + self, input_ids: torch.LongTensor, inputs_embeds: torch.Tensor, image_hidden_states: torch.Tensor + ): + _, patch_size, _ = image_hidden_states.shape + + image_mask = input_ids == self.image_token_id + num_image_tokens = image_mask.sum(dim=1) + if not torch.all(num_image_tokens % patch_size == 0): + raise ValueError("At least one sample has tokens not divisible by patch_size.") + + blocks_per_sample = num_image_tokens // patch_size + + offsets = torch.nn.functional.pad(blocks_per_sample.cumsum(dim=0), (1, 0), value=0) + block_offset = offsets[:-1] + row_cum = image_mask.cumsum(dim=-1) + chunk_idx = (row_cum - 1) // patch_size + local_idx = (row_cum - 1) % patch_size + block_idx = block_offset.unsqueeze(1) + chunk_idx + + image_embeds = torch.zeros_like(inputs_embeds) + image_embeds[image_mask] = image_hidden_states[block_idx[image_mask], local_idx[image_mask], :] + + merged_embeds = torch.where(image_mask.unsqueeze(-1), image_embeds, inputs_embeds) + return merged_embeds + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + pixel_attention_mask: Optional[torch.BoolTensor] = None, + image_hidden_states: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, SmolVLMBaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.training and self.text_model.gradient_checkpointing and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # retrieve input_ids and inputs_embeds + if input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + past_seen_tokens = 0 + if use_cache: + if past_key_values is None: + past_key_values = DynamicCache() + past_seen_tokens = past_key_values.get_seq_length() + + if inputs_embeds is not None and input_ids is None and past_seen_tokens == 0: + raise ValueError("When first calling the model, if input_embeds are passed, input_ids should not be None.") + + if inputs_embeds is None: + inputs_embeds = self.text_model.get_input_embeddings()(input_ids).to(input_ids.device) + + # START VISUAL INPUTS INTEGRATION + if pixel_values is not None and image_hidden_states is not None: + raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time") + elif pixel_values is not None: + batch_size, num_images, num_channels, height, width = pixel_values.shape + pixel_values = pixel_values + pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:]) + + # Remove padding images - padding images are full 0. + nb_values_per_image = pixel_values.shape[1:].numel() + real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image + + if not any(real_images_inds): + # no images, leave one empty image. + real_images_inds[0] = True + + pixel_values = pixel_values[real_images_inds].contiguous() + + # Handle the vision attention mask + if pixel_attention_mask is None: + pixel_attention_mask = torch.ones( + size=[pixel_values.shape[i] for i in (0, 2, 3)], + dtype=torch.bool, + device=pixel_values.device, + ) + else: + # Remove padding images from the mask + pixel_attention_mask = pixel_attention_mask.view( + batch_size * num_images, *pixel_attention_mask.shape[2:] + ) + pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous() + + patch_size = self.config.vision_config.patch_size + patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size) + patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size) + patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() + + # Get sequence from the vision encoder + image_hidden_states = self.vision_model( + pixel_values=pixel_values, + patch_attention_mask=patch_attention_mask, + ).last_hidden_state + + # Modality projection & resampling + image_hidden_states = self.connector(image_hidden_states) + + elif image_hidden_states is not None: + image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device) + + if inputs_embeds is not None and image_hidden_states is not None: + # When we generate, we don't want to replace the potential image_token_id that we generated by images + # that simply don't exist + inputs_embeds = self.inputs_merger( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + image_hidden_states=image_hidden_states, + ) + + outputs = self.text_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + if not return_dict: + return tuple(v for v in [*outputs, image_hidden_states] if v is not None) + + return SmolVLMBaseModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_hidden_states, + ) + + +class SmolVLMForConditionalGeneration(Idefics3ForConditionalGeneration): + """ + A subclass of Idefics3ForConditionalGeneration that uses SmolVLMModel + instead of the default Idefics3Model. + """ + + def __init__(self, config): + super().__init__(config) + self.model = SmolVLMModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def forward(self, **super_kwargs): + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or `model.image_token_id` (where `model` is your instance of `SmolVLMForConditionalGeneration`). + Tokens with indices set to `model.image_token_id` are ignored (masked), the loss is only + computed for the tokens with labels in `[0, ..., config.vocab_size]`. + Returns: + + Example: + + ```python + >>> import requests + >>> import torch + >>> from PIL import Image + >>> from io import BytesIO + + >>> from transformers import AutoProcessor, AutoModelForImageTextToText + >>> from transformers.image_utils import load_image + + >>> # Note that passing the image urls (instead of the actual pil images) to the processor is also possible + >>> image1 = load_image("https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg") + >>> image2 = load_image("https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg") + >>> image3 = load_image("https://cdn.britannica.com/68/170868-050-8DDE8263/Golden-Gate-Bridge-San-Francisco.jpg") + + >>> processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM2-2.2B-Instruct") + >>> model = AutoModelForImageTextToText.from_pretrained("HuggingFaceTB/SmolVLM2-2.2B-Instruct", torch_dtype=torch.bfloat16, device_map="auto") + + >>> # Create inputs + >>> messages = [ + ... { + ... "role": "user", + ... "content": [ + ... {"type": "video", "path": path/to/video}, + ... {"type": "text", "text": "What is happening in this video?"}, + ... ] + ... } + ... ] + + >>> inputs = processor.apply_chat_template([messages], add_generation_prompt=True) + + >>> # Generate + >>> generated_ids = model.generate(**inputs, max_new_tokens=256) + >>> generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True) + + >>> print(generated_texts) + ```""" + super().forward(**super_kwargs) + + +__all__ = [ + "SmolVLMVisionConfig", + "SmolVLMConfig", + "SmolVLMImageProcessor", + "SmolVLMForConditionalGeneration", + "SmolVLMPreTrainedModel", + "SmolVLMModel", + "SmolVLMVisionTransformer", +] diff --git a/docs/transformers/src/transformers/models/smolvlm/processing_smolvlm.py b/docs/transformers/src/transformers/models/smolvlm/processing_smolvlm.py new file mode 100644 index 0000000000000000000000000000000000000000..3a2357a4cc1d052a4b235a8547b41058abc84eb9 --- /dev/null +++ b/docs/transformers/src/transformers/models/smolvlm/processing_smolvlm.py @@ -0,0 +1,476 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Processor class for SmolVLM. +""" + +import copy +from datetime import timedelta +from typing import TYPE_CHECKING, Dict, List, Optional, Union + +import numpy as np + +from ...feature_extraction_utils import BatchFeature +from ...image_utils import ( + ImageInput, + VideoInput, + load_video, + make_batched_videos, + make_nested_list_of_images, +) +from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack +from ...tokenization_utils_base import BatchEncoding, TextInput +from ...utils import is_num2words_available, logging +from .video_processing_smolvlm import ( + DEFAULT_MEDIA_OUTTRO, + DEFAULT_VIDEO_INTRO, + FRAME_TIMESTAMP_MESSAGE, + smolvlm_sample_indices_fn, +) + + +if TYPE_CHECKING: + from ...tokenization_utils_base import PreTokenizedInput + +logger = logging.get_logger(__name__) + + +if is_num2words_available(): + from num2words import num2words +else: + num2words = None + + +def _prompt_split_image( + image_seq_len, image_rows, image_cols, fake_token_around_image, image_token, global_image_token +): + """Prompt with expanded image tokens for when the image is split into patches.""" + text_split_images = "" + for n_h in range(image_rows): + for n_w in range(image_cols): + text_split_images += ( + f"{fake_token_around_image}" + f"" + f"{image_token}" * image_seq_len + ) + text_split_images += "\n" + + text_split_images += ( + f"\n{fake_token_around_image}" + + f"{global_image_token}" + + f"{image_token}" * image_seq_len + + f"{fake_token_around_image}" + ) + return text_split_images + + +def _prompt_single_image(image_seq_len, fake_token_around_image, image_token, global_image_token): + """Prompt with expanded image tokens for a single image.""" + return ( + f"{fake_token_around_image}" + + f"{global_image_token}" + + f"{image_token}" * image_seq_len + + f"{fake_token_around_image}" + ) + + +def get_image_prompt_string( + image_rows, image_cols, image_seq_len, fake_token_around_image, image_token, global_image_token +): + if image_rows == 0 and image_cols == 0: + return _prompt_single_image( + image_seq_len, + fake_token_around_image=fake_token_around_image, + image_token=image_token, + global_image_token=global_image_token, + ) + return _prompt_split_image( + image_seq_len, image_rows, image_cols, fake_token_around_image, image_token, global_image_token + ) + + +class SmolVLMImagesKwargs(ImagesKwargs, total=False): + return_row_col_info: Optional[bool] + max_image_size: Optional[Dict[str, int]] + + +class SmolVLMProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs: SmolVLMImagesKwargs + + _defaults = { + "text_kwargs": { + "add_special_tokens": True, + "padding": False, + "is_split_into_words": False, + }, + "images_kwargs": { + "return_row_col_info": True, + }, + } + + +class SmolVLMProcessor(ProcessorMixin): + r""" + Constructs a SmolVLM processor which wraps a LLama tokenizer and SmolVLM image processor into a single processor. + + [`SmolVLMProcessor`] offers all the functionalities of [`SmolVLMImageProcessor`] and [`SmolVLMTokenizerFast`]. See + the docstring of [`~IdeficsProcessor.__call__`] and [`~IdeficsProcessor.decode`] for more information. + + Args: + image_processor (`SmolVLMImageProcessor`): + An instance of [`SmolVLMImageProcessor`]. The image processor is a required input. + tokenizer (`PreTrainedTokenizerBase`, *optional*): + An instance of [`PreTrainedTokenizerBase`]. This should correspond with the model's text model. The tokenizer is a required input. + image_seq_len (`int`, *optional*, defaults to 169): + The length of the image sequence i.e. the number of tokens per image in the input. + This parameter is used to build the string from the input prompt and image tokens and should match the + value the model used. It is computed as: image_seq_len = int(((image_size // patch_size) ** 2) / (scale_factor**2)) + chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages + in a chat into a tokenizable string. + """ + + attributes = ["image_processor", "tokenizer"] + valid_kwargs = ["image_seq_len", "chat_template"] + image_processor_class = "SmolVLMImageProcessor" + tokenizer_class = "AutoTokenizer" + + def __init__( + self, image_processor, tokenizer=None, image_seq_len: int = 169, chat_template: Optional[str] = None, **kwargs + ): + self.fake_image_token = getattr(tokenizer, "fake_image_token", "") + self.image_token = getattr(tokenizer, "image_token", "") + self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token) + self.end_of_utterance_token = getattr(tokenizer, "end_of_utterance_token", "") + self.global_image_token = getattr(tokenizer, "global_image_token", "") + self.image_seq_len = image_seq_len + + self.video_size = image_processor.video_sampling["video_size"] + self.image_size = image_processor.size + + self.do_image_splitting = image_processor.do_image_splitting + self.do_video_splitting = image_processor.video_sampling.get("do_image_splitting", False) + + self.default_max_frames = image_processor.video_sampling["max_frames"] + self.default_fps = image_processor.video_sampling["fps"] + # Matches one or more occurrences of tags (where x and y are digits, optionally surrounded by newline characters + # self._regex_to_remove_extra_special_tokens = re.compile(r"(\n?)+") + + if not num2words: + raise ImportError( + "Package `num2words` is required to run SmolVLM processor. Install it with `pip install num2words`." + ) + + super().__init__(image_processor, tokenizer, chat_template=chat_template, **kwargs) + + def process_vision(self, text, images, output_kwargs, do_image_splitting=False, image_processor_size=None): + if text is not None: + n_images_in_text = [sample.count(self.image_token) for sample in text] + + n_images_in_images = [len(sublist) for sublist in images] + image_inputs = self.image_processor( + images, do_image_splitting=do_image_splitting, size=image_processor_size, **output_kwargs["images_kwargs"] + ) + + if text is None: + return None, image_inputs + + if n_images_in_images != n_images_in_text: + raise ValueError( + f"The number of images in the text {n_images_in_text} and images {n_images_in_images} should be the same." + ) + image_rows = image_inputs.pop("rows", [[0] * len(text)]) + image_cols = image_inputs.pop("cols", [[0] * len(text)]) + + prompt_strings = [] + for sample, sample_rows, sample_cols in zip(text, image_rows, image_cols): + # Replace the image token with fake tokens around the expanded image token sequence of length `image_seq_len` + image_prompt_strings = [] + for n_rows, n_cols in zip(sample_rows, sample_cols): + image_prompt_string = get_image_prompt_string( + n_rows, + n_cols, + self.image_seq_len, + image_token=self.image_token, + fake_token_around_image=self.fake_image_token, + global_image_token=self.global_image_token, + ) + image_prompt_strings.append(image_prompt_string) + + split_sample = sample.split(self.image_token) + if len(split_sample) == 0: + raise ValueError("The image token should be present in the text.") + + # Place in the image prompt strings where the image tokens are + sample = split_sample[0] + for i, image_prompt_string in enumerate(image_prompt_strings): + sample += image_prompt_string + split_sample[i + 1] + prompt_strings.append(sample) + + return prompt_strings, image_inputs + + def __call__( + self, + images: Union[ImageInput, List[ImageInput], List[List[ImageInput]]] = None, + text: Union[TextInput, "PreTokenizedInput", List[TextInput], List["PreTokenizedInput"]] = None, + audio=None, + videos: VideoInput = None, + **kwargs: Unpack[SmolVLMProcessorKwargs], + ) -> BatchEncoding: + """ + Processes the input prompts and returns a BatchEncoding. + + Example: + + ```python + >>> import requests + >>> from transformers import SmolVLMProcessor + >>> from transformers.image_utils import load_image + + >>> processor = SmolVLMProcessor.from_pretrained("HuggingFaceM4/SmolVLM2-256M-Video-Instruct") + >>> processor.image_processor.do_image_splitting = False # Force as False to simplify the example + + >>> url1 = "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg" + >>> url2 = "https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg" + + >>> image1, image2 = load_image(url1), load_image(url2) + >>> images = [[image1], [image2]] + + >>> text = [ + ... "In this image, we see", + ... "bla bla bla", + ... ] + >>> outputs = processor(images=images, text=text, return_tensors="pt", padding=True) + >>> input_ids = outputs.input_ids + >>> input_tokens = processor.tokenizer.batch_decode(input_ids) + >>> print(input_tokens) + ['<|begin_of_text|>(()*169) In this image, we see', '<|reserved_special_token_0|><|reserved_special_token_0|><|reserved_special_token_0|><|begin_of_text|>bla bla bla(()*169)'] + ``` + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`, *optional*): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. If is of type `List[ImageInput]`, it's assumed that this is for a single prompt i.e. of batch size 1. + text (`Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]`, *optional*): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + Wherever an image token, `` is encountered it is expanded to + `` + `` + `` * `image_seq_len` * `. + return_tensors (`Union[str, TensorType]`, *optional*): + If set, will return tensors of a particular framework. See [`PreTrainedTokenizerFast.__call__`] for more + information. + """ + if text is None and images is None and videos is None: + raise ValueError("You must provide one of `text`, `images` or `videos'.") + + if text is None and ((images is None) ^ (videos is not None)): + raise ValueError("You must specify exactly one of `images` or `videos`") + + output_kwargs = self._merge_kwargs( + SmolVLMProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + + if text is not None: + if isinstance(text, str): + text = [text] + elif not isinstance(text, list) and not isinstance(text[0], str): + raise ValueError("Invalid input text. Please provide a string, or a list of strings") + n_images_in_text = sum([sample.count(self.image_token) for sample in text]) + if n_images_in_text > 0 and (images is None and videos is None): + raise ValueError(f"We detected {n_images_in_text} tokens in the text but no images/videos were passed") + + inputs = {} + # Images and videos are mutually exclusive, so process one which is present + if images is not None: + images = make_nested_list_of_images(images) + text, vision_inputs = self.process_vision( + text, + images, + output_kwargs, + do_image_splitting=self.do_image_splitting, + image_processor_size=self.image_size, + ) + inputs.update(vision_inputs) + elif videos is not None: + videos = make_batched_videos(videos) + text, vision_inputs = self.process_vision( + text, + videos, + output_kwargs, + do_image_splitting=self.do_image_splitting, + image_processor_size=self.video_size, + ) + inputs.update(vision_inputs) + + return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) + + if text is not None: + text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) + self._check_special_mm_tokens(text, text_inputs, modalities=["image"]) + inputs.update(text_inputs) + + return BatchFeature(inputs, tensor_type=return_tensors) + + def _process_messages_for_chat_template( + self, + conversations: List[List[Dict[str, str]]], + batch_images: List[ImageInput], + batch_videos: List[VideoInput], + batch_video_metadata: List[List[Dict[str, any]]], + **chat_template_kwargs, + ): + """ + Used within `apply_chat_template` when a model has special way to process conversation history. For example, + video models might want to specify in the prompt the duration of video or which frame indices at which timestamps + were sampled. This information cannot be accessed before the video is loaded. + For most models it is a no-op, must be overriden by model processors which require special processing. + Args: + conversation (`List[Dict, str, str]`): + The conversation to process. Always comes in batched format. + batch_images (`List[List[ImageInput]]`): + Batch of images that were loaded from url/path defined in the conversation. The images + are ordered in the same way as in the conversation. Comes in nested list format, one list of `PIL` images + per batch. + batch_videos (`List[List[ImageInput]]`): + Batch of videos that were loaded from url/path defined in the conversation. The videos + are ordered in the same way as in the conversation. Comes in nested list format, one list of 4D video arrays + per batch. + batch_video_metadata (`List[List[Dict[[str, any]]]]`): + Batch of metadata returned from loading videos. That includes video fps, duration and total number of framer in original video. + Metadata are ordered in the same way as `batch_videos`. Comes in nested list format, one list of 4D video arrays + per batch. + """ + # We don't want to modify in-place the messages passed by user + # The user might want to add new turn on conv and continue generation + conversations = copy.deepcopy(conversations) + batch_num_frames, batch_timestamps = [], [] + for metadata_list, video_list in zip(batch_video_metadata, batch_videos): + for metadata, video in zip(metadata_list, video_list): + duration_sec = getattr(metadata, "duration") + frames_idx = getattr(metadata, "frames_indices") + fps = getattr(metadata, "fps") + + timestamps = [] + for idx, frame_np in zip(frames_idx, video): + sec = idx / fps + mm = int(sec // 60) + ss = int(sec % 60) + timestamps.append(f"{mm:02d}:{ss:02d}") + batch_timestamps.append(timestamps) + batch_num_frames.append(len(video)) + + for conversation in conversations: + # For each message, scan content for {"type": "video"} + for msg in conversation: + if "content" not in msg: + continue + + new_content = [] + for block in msg["content"]: + if block.get("type") == "video": + curr_timestamps = batch_timestamps.pop(0) + curr_num_frames = batch_num_frames.pop(0) + + # Build the video intro texts + td = timedelta(seconds=int(duration_sec)) + new_content.append( + { + "type": "text", + "text": DEFAULT_VIDEO_INTRO.format( + frame_count=num2words(curr_num_frames), video_duration=str(td) + ), + } + ) + + # 2) Insert per-frame lines: "Frame from {timestamp}:", then an "image" block + for i, ts in enumerate(curr_timestamps): + new_content.append({"type": "text", "text": FRAME_TIMESTAMP_MESSAGE.format(timestamp=ts)}) + new_content.append({"type": "image"}) + + # 3) Optionally add an outro (e.g. "Now answer the question:") + new_content.append({"type": "text", "text": DEFAULT_MEDIA_OUTTRO}) + # Do NOT add the original block => we skip it (since we've replaced it) + else: + # keep original block + new_content.append(block) + + # update the content + msg["content"] = new_content + return conversations + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to SmolVLMTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + batched_decode_output = self.tokenizer.batch_decode(*args, **kwargs) + return batched_decode_output + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to SmolVLMTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + decode_output = self.tokenizer.decode(*args, **kwargs) + return decode_output + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + return list(dict.fromkeys(image_processor_input_names + tokenizer_input_names)) + + # TODO: raushan, has to be public method under `VideoProcessorBase` when API is added + def _load_video_for_model( + self, + video: Union[str, "VideoInput"], + num_frames: Optional[int] = None, + fps: Optional[int] = None, + backend: str = "opencv", + skip_secs: int = 0.0, + **kwargs, + ) -> np.array: + """ + Loads `video` to a numpy array. + + Args: + video (`str` or `VideoInput`): + The video to convert to the numpy array format. Can be a link to video or local path. + num_frames (`int`, *optional*): + Number of frames to sample uniformly. If not passed, the whole video is loaded. + fps (`int`, *optional*): + Number of frames to sample per second. Should be passed only when `num_frames=None`. + If not specified and `num_frames==None`, all frames are sampled. + backend (`str`, *optional*, defaults to `"opencv"`): + The backend to use when loading the video. Can be any of ["decord", "pyav", "opencv", "torchvision"]. Defaults to "opencv". + + Returns: + Tuple[`np.array`, Dict]: A tuple containing: + - Numpy array of frames in RGB (shape: [num_frames, height, width, 3]). + - Metadata dictionary. + """ + max_frames = self.default_max_frames if num_frames is None else num_frames + target_fps = self.default_fps if fps is None else fps + + def sample_indices_fn_func(metadata, **fn_kwargs): + return smolvlm_sample_indices_fn( + metadata, max_frames=max_frames, target_fps=target_fps, skip_secs=skip_secs, **fn_kwargs + ) + + video, metadata = load_video(video, backend=backend, sample_indices_fn=sample_indices_fn_func) + return video, metadata + + +__all__ = ["SmolVLMProcessor"] diff --git a/docs/transformers/src/transformers/models/smolvlm/video_processing_smolvlm.py b/docs/transformers/src/transformers/models/smolvlm/video_processing_smolvlm.py new file mode 100644 index 0000000000000000000000000000000000000000..4adde6be578b8b9a67c6b0ad3e4d9c95e667bfee --- /dev/null +++ b/docs/transformers/src/transformers/models/smolvlm/video_processing_smolvlm.py @@ -0,0 +1,90 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy as np + +# Make sure these are imported from your library +from ...utils import logging + + +logger = logging.get_logger(__name__) + +DEFAULT_SYSTEM_MESSAGE = "You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language." +DEFAULT_VIDEO_INTRO = ( + "You are provided the following series of {frame_count} frames from a {video_duration} [H:MM:SS] video.\n" +) +DEFAULT_MEDIA_OUTTRO = "\n\n" +FRAME_TIMESTAMP_MESSAGE = "\nFrame from {timestamp}:" + + +def smolvlm_sample_indices_fn(metadata, max_frames, target_fps, skip_secs=0): + """ + Example sampling function which: + - Uses `max_frames` (if provided) or calculates it from `fps` and metadata. + - Applies a basic center-skip if fewer frames than available, otherwise + optionally skips `skip_secs` from both the start and end. + - Uniformly samples the desired number of frames between the start and end indices. + + Args: + max_frames (`int`): + Maximum number of frames to sample. + target_fps (`int`): + Target frames to sample per second. + metadata (`dict`): + Contains video metadata such as "n_frames" and "video_fps". + skip_secs (`float`, *optional*, defaults to 1.0): + Number of seconds to skip from the start and end if the video is long enough. + + Returns: + numpy.ndarray: + An array of unique frame indices to sample. + """ + + total_num_frames = getattr(metadata, "total_num_frames", 0) + if total_num_frames <= 0: + raise ValueError(f"Invalid total_num_frames={total_num_frames} in metadata.") + + native_fps = getattr(metadata, "fps", 30.0) + duration_seconds = getattr(metadata, "duration", 0) + + if duration_seconds <= 0: + raise ValueError(f"Invalid duration_seconds={duration_seconds} in metadata.") + + # Step 1) Estimate how many frames we'd sample at `target_fps`, fallback if target_fps <= 0 + estimated_frames = int(round(target_fps * duration_seconds)) + + # Step 2) desired_frames + desired_frames = min(estimated_frames, max_frames) + if desired_frames < 1: + desired_frames = 1 + + # Step 3) center skip logic + start_idx = 0 + end_idx = total_num_frames - 1 + + if skip_secs > 0 and (duration_seconds - 2 * skip_secs) > (max_frames * target_fps): + start_idx = int(skip_secs * native_fps) + end_idx = int(total_num_frames - skip_secs * native_fps) + + start_idx = max(0, start_idx) + end_idx = min(end_idx, total_num_frames - 1) + if start_idx >= end_idx: + start_idx, end_idx = 0, total_num_frames - 1 + + indices = np.linspace(start_idx, end_idx, desired_frames, dtype=int) + indices = np.unique(indices) + + return indices diff --git a/docs/transformers/src/transformers/models/speech_encoder_decoder/__init__.py b/docs/transformers/src/transformers/models/speech_encoder_decoder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4e07844d45c2d0f3c5bfba69b330c3beede0eab1 --- /dev/null +++ b/docs/transformers/src/transformers/models/speech_encoder_decoder/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_speech_encoder_decoder import * + from .modeling_flax_speech_encoder_decoder import * + from .modeling_speech_encoder_decoder import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/speech_encoder_decoder/configuration_speech_encoder_decoder.py b/docs/transformers/src/transformers/models/speech_encoder_decoder/configuration_speech_encoder_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..14dfab7eaa64e28f981536157c4af7d077acc835 --- /dev/null +++ b/docs/transformers/src/transformers/models/speech_encoder_decoder/configuration_speech_encoder_decoder.py @@ -0,0 +1,112 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ..auto.configuration_auto import AutoConfig + + +logger = logging.get_logger(__name__) + + +class SpeechEncoderDecoderConfig(PretrainedConfig): + r""" + [`SpeechEncoderDecoderConfig`] is the configuration class to store the configuration of a + [`SpeechEncoderDecoderModel`]. It is used to instantiate an Encoder Decoder model according to the specified + arguments, defining the encoder and decoder configs. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + kwargs (*optional*): + Dictionary of keyword arguments. Notably: + + - **encoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that defines + the encoder config. + - **decoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that defines + the decoder config. + + Examples: + + ```python + >>> from transformers import BertConfig, Wav2Vec2Config, SpeechEncoderDecoderConfig, SpeechEncoderDecoderModel + + >>> # Initializing a Wav2Vec2 & BERT style configuration + >>> config_encoder = Wav2Vec2Config() + >>> config_decoder = BertConfig() + + >>> config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder) + + >>> # Initializing a Wav2Vec2Bert model from a Wav2Vec2 & google-bert/bert-base-uncased style configurations + >>> model = SpeechEncoderDecoderModel(config=config) + + >>> # Accessing the model configuration + >>> config_encoder = model.config.encoder + >>> config_decoder = model.config.decoder + >>> # set decoder config to causal lm + >>> config_decoder.is_decoder = True + >>> config_decoder.add_cross_attention = True + + >>> # Saving the model, including its configuration + >>> model.save_pretrained("my-model") + + >>> # loading model and config from pretrained folder + >>> encoder_decoder_config = SpeechEncoderDecoderConfig.from_pretrained("my-model") + >>> model = SpeechEncoderDecoderModel.from_pretrained("my-model", config=encoder_decoder_config) + ```""" + + model_type = "speech-encoder-decoder" + sub_configs = {"encoder": AutoConfig, "decoder": AutoConfig} + has_no_defaults_at_init = True + + def __init__(self, **kwargs): + super().__init__(**kwargs) + if "encoder" not in kwargs or "decoder" not in kwargs: + raise ValueError( + f"A configuraton of type {self.model_type} cannot be instantiated because not both `encoder` and" + f" `decoder` sub-configurations are passed, but only {kwargs}" + ) + + encoder_config = kwargs.pop("encoder") + encoder_model_type = encoder_config.pop("model_type") + decoder_config = kwargs.pop("decoder") + decoder_model_type = decoder_config.pop("model_type") + + self.encoder = AutoConfig.for_model(encoder_model_type, **encoder_config) + self.decoder = AutoConfig.for_model(decoder_model_type, **decoder_config) + self.is_encoder_decoder = True + + @classmethod + def from_encoder_decoder_configs( + cls, encoder_config: PretrainedConfig, decoder_config: PretrainedConfig, **kwargs + ) -> PretrainedConfig: + r""" + Instantiate a [`SpeechEncoderDecoderConfig`] (or a derived class) from a pre-trained encoder model + configuration and decoder model configuration. + + Returns: + [`SpeechEncoderDecoderConfig`]: An instance of a configuration object + """ + logger.info("Setting `config.is_decoder=True` and `config.add_cross_attention=True` for decoder_config") + decoder_config.is_decoder = True + decoder_config.add_cross_attention = True + + return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict(), **kwargs) + + +__all__ = ["SpeechEncoderDecoderConfig"] diff --git a/docs/transformers/src/transformers/models/speech_encoder_decoder/convert_mbart_wav2vec2_seq2seq_original_to_pytorch.py b/docs/transformers/src/transformers/models/speech_encoder_decoder/convert_mbart_wav2vec2_seq2seq_original_to_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..bdd57c84f544dc58031b0cc00b6e4f6498c997fd --- /dev/null +++ b/docs/transformers/src/transformers/models/speech_encoder_decoder/convert_mbart_wav2vec2_seq2seq_original_to_pytorch.py @@ -0,0 +1,357 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert Wav2Vec2 checkpoint.""" + +import argparse + +import fairseq +import torch +from torch import nn + +from transformers import ( + MBart50Tokenizer, + MBartConfig, + MBartForCausalLM, + SpeechEncoderDecoderConfig, + SpeechEncoderDecoderModel, + Wav2Vec2Config, + Wav2Vec2FeatureExtractor, + Wav2Vec2Model, + logging, +) + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + +MAPPING = { + "post_extract_proj": "feature_projection.projection", + "encoder.pos_conv.0": "encoder.pos_conv_embed.conv", + "self_attn.k_proj": "encoder.layers.*.attention.k_proj", + "self_attn.v_proj": "encoder.layers.*.attention.v_proj", + "self_attn.q_proj": "encoder.layers.*.attention.q_proj", + "self_attn.out_proj": "encoder.layers.*.attention.out_proj", + "self_attn_layer_norm": "encoder.layers.*.layer_norm", + "fc1": "encoder.layers.*.feed_forward.intermediate_dense", + "fc2": "encoder.layers.*.feed_forward.output_dense", + "final_layer_norm": "encoder.layers.*.final_layer_norm", + "encoder.layer_norm": "encoder.layer_norm", + "w2v_model.layer_norm": "feature_projection.layer_norm", + "quantizer.weight_proj": "quantizer.weight_proj", + "quantizer.vars": "quantizer.codevectors", + "project_q": "project_q", + "final_proj": "project_hid", + "w2v_encoder.proj": "lm_head", + "mask_emb": "masked_spec_embed", +} +TOP_LEVEL_KEYS = [ + "lm_head", + "quantizer.weight_proj", + "quantizer.codevectors", + "project_q", + "project_hid", +] + + +def set_recursively(hf_pointer, key, value, full_name, weight_type): + for attribute in key.split("."): + hf_pointer = getattr(hf_pointer, attribute) + + if weight_type is not None: + hf_shape = getattr(hf_pointer, weight_type).shape + else: + hf_shape = hf_pointer.shape + + assert hf_shape == value.shape, ( + f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be" + f" {value.shape} for {full_name}" + ) + + if weight_type == "weight": + hf_pointer.weight.data = value + elif weight_type == "weight_g": + hf_pointer.weight_g.data = value + elif weight_type == "weight_v": + hf_pointer.weight_v.data = value + elif weight_type == "bias": + hf_pointer.bias.data = value + else: + hf_pointer.data = value + + logger.info(f"{key + '.' + weight_type if weight_type is not None else ''} was initialized from {full_name}.") + + +def recursively_load_weights_wav2vec2(fairseq_model, hf_model): + unused_weights = [] + fairseq_dict = fairseq_model.state_dict() + + feature_extractor = hf_model.feature_extractor + adapter = hf_model.adapter + + for name, value in fairseq_dict.items(): + is_used = False + if "conv_layers" in name: + load_conv_layer( + name, + value, + feature_extractor, + unused_weights, + hf_model.config.feat_extract_norm == "group", + ) + is_used = True + elif any(x in name for x in ["adaptor", "w2v_encoder.proj.", "w2v_proj_ln."]): + load_adapter(name, value, adapter, unused_weights) + is_used = True + else: + for key, mapped_key in MAPPING.items(): + if key in name or key.split("w2v_model.")[-1] == name.split(".")[0]: + is_used = True + if "*" in mapped_key: + layer_index = name.split(key)[0].split(".")[-2] + mapped_key = mapped_key.replace("*", layer_index) + if "weight_g" in name: + weight_type = "weight_g" + elif "weight_v" in name: + weight_type = "weight_v" + elif "bias" in name: + weight_type = "bias" + elif "weight" in name: + weight_type = "weight" + else: + weight_type = None + set_recursively(hf_model, mapped_key, value, name, weight_type) + continue + if not is_used: + unused_weights.append(name) + + logger.warning(f"Unused weights: {unused_weights}") + + +def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_group_norm): + name = full_name.split("conv_layers.")[-1] + items = name.split(".") + layer_id = int(items[0]) + type_id = int(items[1]) + + if type_id == 0: + if "bias" in name: + assert value.shape == feature_extractor.conv_layers[layer_id].conv.bias.data.shape, ( + f"{full_name} has size {value.shape}, but" + f" {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found." + ) + feature_extractor.conv_layers[layer_id].conv.bias.data = value + logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.") + elif "weight" in name: + assert value.shape == feature_extractor.conv_layers[layer_id].conv.weight.data.shape, ( + f"{full_name} has size {value.shape}, but" + f" {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found." + ) + feature_extractor.conv_layers[layer_id].conv.weight.data = value + logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.") + elif (type_id == 2 and not use_group_norm) or (type_id == 2 and layer_id == 0 and use_group_norm): + if "bias" in name: + assert value.shape == feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape, ( + f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.bias.data.shape} was" + " found." + ) + feature_extractor.conv_layers[layer_id].layer_norm.bias.data = value + logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.") + elif "weight" in name: + assert value.shape == feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape, ( + f"{full_name} has size {value.shape}, but" + f" {feature_extractor[layer_id].layer_norm.weight.data.shape} was found." + ) + feature_extractor.conv_layers[layer_id].layer_norm.weight.data = value + logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.") + else: + unused_weights.append(full_name) + + +def load_adapter(full_name, value, adapter, unused_weights): + name = full_name.split("adaptor.")[-1] + items = name.split(".") + + if items[1].isdigit(): + layer_id = int(items[1]) + else: + layer_id = None + + if "adaptor" not in full_name: + if "proj_ln" in full_name: + # has to be layer norm + if "bias" in name: + assert value.shape == adapter.proj_layer_norm.bias.data.shape, ( + f"{full_name} has size {value.shape}, but {adapter.proj_layer_norm.bias.data.shape} was found." + ) + adapter.proj_layer_norm.bias.data = value + logger.info(f"Adapter proj layer norm bias was initialized from {full_name}.") + if "weight" in name: + assert value.shape == adapter.proj_layer_norm.weight.data.shape, ( + f"{full_name} has size {value.shape}, but {adapter.proj_layer_norm.weight.data.shape} was found." + ) + adapter.proj_layer_norm.weight.data = value + else: + # has to be projection layer + if "bias" in name: + assert value.shape == adapter.proj.bias.data.shape, ( + f"{full_name} has size {value.shape}, but {adapter.proj.bias.data.shape} was found." + ) + adapter.proj.bias.data = value + logger.info(f"Adapter proj layer bias was initialized from {full_name}.") + if "weight" in name: + assert value.shape == adapter.proj.weight.data.shape, ( + f"{full_name} has size {value.shape}, but {adapter.proj.weight.data.shape} was found." + ) + adapter.proj.weight.data = value + logger.info(f"Adapter proj layer weight was initialized from {full_name}.") + elif isinstance(layer_id, int): + if "bias" in name: + assert value.shape == adapter.layers[layer_id].conv.bias.data.shape, ( + f"{full_name} has size {value.shape}, but {adapter.layers[layer_id].conv.bias.data.shape} was found." + ) + adapter.layers[layer_id].conv.bias.data = value + logger.info(f"Adapter layer {layer_id} bias was initialized from {full_name}.") + elif "weight" in name: + assert value.shape == adapter.layers[layer_id].conv.weight.data.shape, ( + f"{full_name} has size {value.shape}, but {adapter.layers[layer_id].conv.weight.data.shape} was found." + ) + adapter.layers[layer_id].conv.weight.data = value + logger.info(f"Adapter layer {layer_id} bias was initialized from {full_name}.") + else: + unused_weights.append(full_name) + + +def make_linear_from_emb(emb): + vocab_size, emb_size = emb.weight.shape + lin_layer = nn.Linear(vocab_size, emb_size, bias=False) + lin_layer.weight.data = emb.weight.data + return lin_layer + + +@torch.no_grad() +def convert_wav2vec2_checkpoint( + checkpoint_path, + pytorch_dump_folder_path, + dict_path, + config_yaml_path, + encoder_config_path, + decoder_config_path, + add_adapter, + adapter_kernel_size, + adapter_stride, + decoder_start_token_id, + encoder_output_dim, +): + """ + Copy/paste/tweak model's weights to transformers design. + """ + # load configs + encoder_config = Wav2Vec2Config.from_pretrained( + encoder_config_path, + add_adapter=True, + adapter_stride=adapter_stride, + adapter_kernel_size=adapter_kernel_size, + token_token=True, + output_hidden_size=encoder_output_dim, + ) + decoder_config = MBartConfig.from_pretrained(decoder_config_path) + + # load model + model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task( + [checkpoint_path], + arg_overrides={ + "config_yaml": config_yaml_path, + "data": "/".join(dict_path.split("/")[:-1]), + "w2v_path": checkpoint_path, + "load_pretrained_decoder_from": None, + }, + ) + model = model[0].eval() + + # load feature extractor + feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(encoder_config_path, token_token=True) + + # set weights for wav2vec2 encoder + hf_encoder = Wav2Vec2Model(encoder_config) + + recursively_load_weights_wav2vec2(model.encoder, hf_encoder) + + # load decoder weights + hf_decoder = MBartForCausalLM(decoder_config) + missing_keys, unexpected_keys = hf_decoder.model.decoder.load_state_dict(model.decoder.state_dict(), strict=False) + logger.warning(f"The following keys are missing when loading the decoder weights: {missing_keys}") + logger.warning(f"The following keys are unexpected when loading the decoder weights: {unexpected_keys}") + + hf_wav2vec = SpeechEncoderDecoderModel(encoder=hf_encoder, decoder=hf_decoder) + hf_wav2vec.config.tie_word_embeddings = False + + tokenizer = MBart50Tokenizer(dict_path) + tokenizer.save_pretrained(pytorch_dump_folder_path) + + config = hf_wav2vec.config.to_dict() + config["pad_token_id"] = tokenizer.pad_token_id + config["bos_token_id"] = tokenizer.bos_token_id + config["eos_token_id"] = tokenizer.eos_token_id + config["tokenizer_class"] = "mbart50" + config["feature_extractor_type"] = "wav2vec2" + + config["decoder_start_token_id"] = tokenizer.eos_token_id + config["forced_bos_token_id"] = 250004 + config["forced_eos_token_id"] = tokenizer.eos_token_id + + hf_wav2vec.config = SpeechEncoderDecoderConfig.from_dict(config) + + hf_wav2vec.save_pretrained(pytorch_dump_folder_path) + feature_extractor.save_pretrained(pytorch_dump_folder_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.") + parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to fairseq checkpoint") + parser.add_argument("--dict_path", default=None, type=str, help="Path to dict of fine-tuned model") + parser.add_argument("--config_yaml_path", default=None, type=str, help="Path to yaml file of fine-tuned model") + parser.add_argument( + "--encoder_config_path", + default="facebook/wav2vec2-xls-r-1b", + type=str, + help="Path to hf encoder wav2vec2 checkpoint config", + ) + parser.add_argument( + "--decoder_config_path", + default="facebook/mbart-large-50-one-to-many-mmt", + type=str, + help="Path to hf decoder checkpoint config", + ) + parser.add_argument("--add_adapter", default=True, type=bool, help="whethere to add model adapter layers") + parser.add_argument("--adapter_stride", default=2, type=int, help="stride of adapter layers") + parser.add_argument("--adapter_kernel_size", default=3, type=int, help="kernel size of adapter layers") + parser.add_argument("--encoder_output_dim", default=1024, type=int, help="encoder output dim") + parser.add_argument("--start_token_id", default=250004, type=int, help="`decoder_start_token_id` of model config") + + args = parser.parse_args() + convert_wav2vec2_checkpoint( + args.checkpoint_path, + args.pytorch_dump_folder_path, + args.dict_path, + args.config_yaml_path, + encoder_config_path=args.encoder_config_path, + decoder_config_path=args.decoder_config_path, + add_adapter=args.add_adapter, + adapter_kernel_size=args.adapter_kernel_size, + adapter_stride=args.adapter_stride, + decoder_start_token_id=args.start_token_id, + encoder_output_dim=args.encoder_output_dim, + ) diff --git a/docs/transformers/src/transformers/models/speech_encoder_decoder/convert_speech_to_text_wav2vec2_seq2seq_original_to_pytorch.py b/docs/transformers/src/transformers/models/speech_encoder_decoder/convert_speech_to_text_wav2vec2_seq2seq_original_to_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..377288982087bacefac6ac35aa3c2cbb126c3388 --- /dev/null +++ b/docs/transformers/src/transformers/models/speech_encoder_decoder/convert_speech_to_text_wav2vec2_seq2seq_original_to_pytorch.py @@ -0,0 +1,316 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert Wav2Vec2 checkpoint.""" + +import argparse +import json +import os + +import fairseq +import torch +from torch import nn + +from transformers import ( + Speech2Text2Config, + Speech2Text2ForCausalLM, + Speech2Text2Tokenizer, + SpeechEncoderDecoderConfig, + SpeechEncoderDecoderModel, + Wav2Vec2Config, + Wav2Vec2FeatureExtractor, + Wav2Vec2Model, + logging, +) + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + +MAPPING = { + "post_extract_proj": "feature_projection.projection", + "encoder.pos_conv.0": "encoder.pos_conv_embed.conv", + "self_attn.k_proj": "encoder.layers.*.attention.k_proj", + "self_attn.v_proj": "encoder.layers.*.attention.v_proj", + "self_attn.q_proj": "encoder.layers.*.attention.q_proj", + "self_attn.out_proj": "encoder.layers.*.attention.out_proj", + "self_attn_layer_norm": "encoder.layers.*.layer_norm", + "fc1": "encoder.layers.*.feed_forward.intermediate_dense", + "fc2": "encoder.layers.*.feed_forward.output_dense", + "final_layer_norm": "encoder.layers.*.final_layer_norm", + "encoder.layer_norm": "encoder.layer_norm", + "w2v_model.layer_norm": "feature_projection.layer_norm", + "quantizer.weight_proj": "quantizer.weight_proj", + "quantizer.vars": "quantizer.codevectors", + "project_q": "project_q", + "final_proj": "project_hid", + "w2v_encoder.proj": "lm_head", + "mask_emb": "masked_spec_embed", +} +TOP_LEVEL_KEYS = [ + "lm_head", + "quantizer.weight_proj", + "quantizer.codevectors", + "project_q", + "project_hid", +] + + +def set_recursively(hf_pointer, key, value, full_name, weight_type): + for attribute in key.split("."): + hf_pointer = getattr(hf_pointer, attribute) + + if weight_type is not None: + hf_shape = getattr(hf_pointer, weight_type).shape + else: + hf_shape = hf_pointer.shape + + assert hf_shape == value.shape, ( + f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be" + f" {value.shape} for {full_name}" + ) + + if weight_type == "weight": + hf_pointer.weight.data = value + elif weight_type == "weight_g": + hf_pointer.weight_g.data = value + elif weight_type == "weight_v": + hf_pointer.weight_v.data = value + elif weight_type == "bias": + hf_pointer.bias.data = value + else: + hf_pointer.data = value + + logger.info(f"{key + '.' + weight_type if weight_type is not None else ''} was initialized from {full_name}.") + + +def recursively_load_weights_wav2vec2(fairseq_model, hf_model): + unused_weights = [] + fairseq_dict = fairseq_model.state_dict() + + feature_extractor = hf_model.feature_extractor + + # if encoder has different dim to decoder -> use proj_weight + proj_weight = None + + for name, value in fairseq_dict.items(): + is_used = False + if "conv_layers" in name: + load_conv_layer( + name, + value, + feature_extractor, + unused_weights, + hf_model.config.feat_extract_norm == "group", + ) + is_used = True + elif name.split(".")[0] == "proj": + proj_weight = fairseq_model.proj + is_used = True + else: + for key, mapped_key in MAPPING.items(): + if key in name or key.split("w2v_model.")[-1] == name.split(".")[0]: + is_used = True + if "*" in mapped_key: + layer_index = name.split(key)[0].split(".")[-2] + mapped_key = mapped_key.replace("*", layer_index) + if "weight_g" in name: + weight_type = "weight_g" + elif "weight_v" in name: + weight_type = "weight_v" + elif "bias" in name: + weight_type = "bias" + elif "weight" in name: + weight_type = "weight" + else: + weight_type = None + set_recursively(hf_model, mapped_key, value, name, weight_type) + continue + if not is_used: + unused_weights.append(name) + + logger.warning(f"Unused weights: {unused_weights}") + + return proj_weight + + +def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_group_norm): + name = full_name.split("conv_layers.")[-1] + items = name.split(".") + layer_id = int(items[0]) + type_id = int(items[1]) + + if type_id == 0: + if "bias" in name: + assert value.shape == feature_extractor.conv_layers[layer_id].conv.bias.data.shape, ( + f"{full_name} has size {value.shape}, but" + f" {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found." + ) + feature_extractor.conv_layers[layer_id].conv.bias.data = value + logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.") + elif "weight" in name: + assert value.shape == feature_extractor.conv_layers[layer_id].conv.weight.data.shape, ( + f"{full_name} has size {value.shape}, but" + f" {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found." + ) + feature_extractor.conv_layers[layer_id].conv.weight.data = value + logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.") + elif (type_id == 2 and not use_group_norm) or (type_id == 2 and layer_id == 0 and use_group_norm): + if "bias" in name: + assert value.shape == feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape, ( + f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.bias.data.shape} was" + " found." + ) + feature_extractor.conv_layers[layer_id].layer_norm.bias.data = value + logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.") + elif "weight" in name: + assert value.shape == feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape, ( + f"{full_name} has size {value.shape}, but" + f" {feature_extractor[layer_id].layer_norm.weight.data.shape} was found." + ) + feature_extractor.conv_layers[layer_id].layer_norm.weight.data = value + logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.") + else: + unused_weights.append(full_name) + + +def make_linear_from_emb(emb): + vocab_size, emb_size = emb.weight.shape + lin_layer = nn.Linear(vocab_size, emb_size, bias=False) + lin_layer.weight.data = emb.weight.data + return lin_layer + + +def create_vocab_dict(dict_path): + with open(dict_path, "r", encoding="utf-8") as f: + lines = f.readlines() + words = [line.split(" ")[0] for line in lines] + + num_words = len(words) + + vocab_dict = { + "": 0, + "": 1, + "": 2, + "": 3, + } + + vocab_dict.update(dict(zip(words, range(4, num_words + 4)))) + return vocab_dict + + +@torch.no_grad() +def convert_wav2vec2_checkpoint( + checkpoint_path, + pytorch_dump_folder_path, + dict_path, + encoder_config_path, + decoder_config_path, + vocab_size, + num_decoder_layers, +): + """ + Copy/paste/tweak model's weights to transformers design. + """ + encoder_config = Wav2Vec2Config.from_pretrained(encoder_config_path) + decoder_config = Speech2Text2Config.from_pretrained( + decoder_config_path, vocab_size=vocab_size, decoder_layers=num_decoder_layers, do_stable_layer_norm=True + ) + + feature_extractor = Wav2Vec2FeatureExtractor( + feature_size=1, + sampling_rate=16000, + padding_value=0, + do_normalize=True, + return_attention_mask=True, + ) + + model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task( + [checkpoint_path], arg_overrides={"data": "/".join(dict_path.split("/")[:-1])} + ) + model = model[0].eval() + + # set weights for wav2vec2 encoder + hf_encoder = Wav2Vec2Model(encoder_config) + projection_layer = recursively_load_weights_wav2vec2(model.encoder, hf_encoder) + + hf_decoder = Speech2Text2ForCausalLM(decoder_config) + missing_keys, unexpected_keys = hf_decoder.model.decoder.load_state_dict(model.decoder.state_dict(), strict=False) + + # set output linear layer + unexpected_keys.remove("embed_out") + hf_decoder.lm_head.weight = nn.Parameter(model.decoder.embed_out.detach()) + + # layer norm is init to identity matrix so leaving it is fine + logger.warning(f"The following keys are missing when loading the decoder weights: {missing_keys}") + logger.warning(f"The following keys are unexpected when loading the decoder weights: {unexpected_keys}") + + hf_wav2vec = SpeechEncoderDecoderModel(encoder=hf_encoder, decoder=hf_decoder) + hf_wav2vec.config.tie_word_embeddings = False + + # add projection layer + hf_wav2vec.enc_to_dec_proj.weight = nn.Parameter(projection_layer.weight) + hf_wav2vec.enc_to_dec_proj.bias = nn.Parameter(projection_layer.bias) + + vocab_dict = create_vocab_dict(dict_path) + + with open(os.path.join(pytorch_dump_folder_path, "vocab.json"), "w") as fp: + json.dump(vocab_dict, fp) + + tokenizer = Speech2Text2Tokenizer(os.path.join(pytorch_dump_folder_path, "vocab.json")) + tokenizer.save_pretrained(pytorch_dump_folder_path) + + config = hf_wav2vec.config.to_dict() + config["pad_token_id"] = tokenizer.pad_token_id + config["bos_token_id"] = tokenizer.bos_token_id + config["eos_token_id"] = tokenizer.eos_token_id + config["tokenizer_class"] = "speech_to_text_2" + config["feature_extractor_type"] = "wav2vec2" + + hf_wav2vec.config = SpeechEncoderDecoderConfig.from_dict(config) + + hf_wav2vec.save_pretrained(pytorch_dump_folder_path) + feature_extractor.save_pretrained(pytorch_dump_folder_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.") + parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to fairseq checkpoint") + parser.add_argument("--dict_path", default=None, type=str, help="Path to dict of fine-tuned model") + parser.add_argument( + "--encoder_config_path", + default="facebook/wav2vec2-large-lv60", + type=str, + help="Path to hf encoder wav2vec2 checkpoint config", + ) + parser.add_argument( + "--decoder_config_path", + default="facebook/s2t-small-mustc-en-fr-st", + type=str, + help="Path to hf decoder s2t checkpoint config", + ) + parser.add_argument("--vocab_size", default=10224, type=int, help="Vocab size of decoder") + parser.add_argument("--num_decoder_layers", default=7, type=int, help="Number of decoder layers") + + args = parser.parse_args() + convert_wav2vec2_checkpoint( + args.checkpoint_path, + args.pytorch_dump_folder_path, + args.dict_path, + encoder_config_path=args.encoder_config_path, + decoder_config_path=args.decoder_config_path, + vocab_size=args.vocab_size, + num_decoder_layers=args.num_decoder_layers, + ) diff --git a/docs/transformers/src/transformers/models/speech_encoder_decoder/modeling_flax_speech_encoder_decoder.py b/docs/transformers/src/transformers/models/speech_encoder_decoder/modeling_flax_speech_encoder_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..fd837146d5ca7d80a2b7bbd3a86369a36b28a68a --- /dev/null +++ b/docs/transformers/src/transformers/models/speech_encoder_decoder/modeling_flax_speech_encoder_decoder.py @@ -0,0 +1,929 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Classes to support Flax Speech-Encoder-Decoder architectures""" + +import os +from typing import Optional, Tuple, Union + +import flax.linen as nn +import jax +import jax.numpy as jnp +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze +from flax.traverse_util import flatten_dict, unflatten_dict +from jax import lax +from jax.random import PRNGKey + +from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutputWithCrossAttentions, FlaxSeq2SeqLMOutput +from ...modeling_flax_utils import FlaxPreTrainedModel +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from ..auto.configuration_auto import AutoConfig +from ..auto.modeling_flax_auto import FlaxAutoModel, FlaxAutoModelForCausalLM +from .configuration_speech_encoder_decoder import SpeechEncoderDecoderConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "SpeechEncoderDecoderConfig" + +SPEECH_ENCODER_DECODER_START_DOCSTRING = r""" + This class can be used to initialize a speech-sequence-to-text-sequence model with any pretrained speech + autoencoding model as the encoder and any pretrained text autoregressive model as the decoder. The encoder is + loaded via [`~AutoModel.from_pretrained`] function and the decoder is loaded via + [`~AutoModelForCausalLM.from_pretrained`] function. Cross-attention layers are automatically added to the decoder + and should be fine-tuned on a downstream generative task, like summarization. + + The effectiveness of initializing sequence-to-sequence models with pretrained checkpoints for sequence generation + tasks was shown in [Leveraging Pre-trained Checkpoints for Sequence Generation + Tasks](https://arxiv.org/abs/1907.12461) by Sascha Rothe, Shashi Narayan, Aliaksei Severyn. Michael Matena, Yanqi + Zhou, Wei Li, Peter J. Liu. + + Additionally, in [Large-Scale Self- and Semi-Supervised Learning for Speech + Translation](https://arxiv.org/abs/2104.06678) it is shown how leveraging large pretrained speech models for speech + translation yields a significant performance improvement. + + After such an Speech-Encoder Decoder model has been trained/fine-tuned, it can be saved/loaded just like any other + models (see the examples for more information). + + This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a Flax Linen + [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a + regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior. + + Parameters: + config ([`SpeechEncoderDecoderConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. + dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): + The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and + `jax.numpy.bfloat16` (on TPUs). + + This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If + specified all the computation will be performed with the given `dtype`. + + **Note that this only specifies the dtype of the computation and does not influence the dtype of model + parameters.** + + If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and + [`~FlaxPreTrainedModel.to_bf16`]. +""" + +SPEECH_ENCODER_DECODER_INPUTS_DOCSTRING = r""" + Args: + inputs (`jnp.ndarray` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, feature_dim)`, *optional*): + Float values of input raw speech waveform or speech features. Values can be obtained by loading a `.flac` + or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile + library (`pip install soundfile`). To prepare the array into `inputs`, either the [`Wav2Vec2Processor`] or + [`Speech2TextProcessor`] should be used for padding and conversion into a tensor of type + `torch.FloatTensor`. + attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + For sequence to sequence training, `decoder_input_ids` should be provided. `decoder_input_ids` should be + created outside of the model by shifting the `labels` to the right, replacing -100 by the `pad_token_id` + and prepending them with the `decoder_start_token_id`. + decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the + range `[0, config.decoder.max_position_embeddings - 1]`. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + If set to `True`, the model will return a [`~utils.FlaxSeq2SeqLMOutput`] instead of a plain tuple. +""" + +SPEECH_ENCODER_DECODER_ENCODE_INPUTS_DOCSTRING = r""" + Args: + inputs (`jnp.ndarray` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, feature_dim)`, *optional*): + Float values of input raw speech waveform or speech features. Values can be obtained by loading a *.flac* + or *.wav* audio file into an array of type *List[float]* or a *numpy.ndarray*, *e.g.* via the soundfile + library (*pip install soundfile*). To prepare the array into *inputs*, either the [`Wav2Vec2Processor`] or + [`Speech2TextProcessor`] should be used for padding and conversion into a tensor of type + *torch.FloatTensor*. + attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + If set to `True`, the model will return a [`~utils.FlaxBaseModelOutput`] instead of a plain tuple. +""" + +SPEECH_ENCODER_DECODER_DECODE_INPUTS_DOCSTRING = r""" + Args: + decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + For sequence to sequence training, `decoder_input_ids` should be provided. `decoder_input_ids` should be + created outside of the model by shifting the `labels` to the right, replacing -100 by the `pad_token_id` + and prepending them with the `decoder_start_token_id`. + encoder_outputs (`tuple(tuple(jnp.ndarray)`): + Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of + hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. + encoder_attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the + range `[0, config.decoder.max_position_embeddings - 1]`. + past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`): + Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast + auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + If set to `True`, the model will return a [`~utils.FlaxCausalLMOutputWithCrossAttentions`] instead of a + plain tuple. +""" + + +class FlaxSpeechEncoderDecoderModule(nn.Module): + config: SpeechEncoderDecoderConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + encoder_config = self.config.encoder + decoder_config = self.config.decoder + + # Copied from `modeling_hybrid_clip.py` with modifications. + from ...models.auto.modeling_flax_auto import FLAX_MODEL_FOR_CAUSAL_LM_MAPPING, FLAX_MODEL_MAPPING + + encoder_module = FLAX_MODEL_MAPPING[encoder_config.__class__].module_class + decoder_module = FLAX_MODEL_FOR_CAUSAL_LM_MAPPING[decoder_config.__class__].module_class + + self.encoder = encoder_module(encoder_config, dtype=self.dtype) + self.decoder = decoder_module(decoder_config, dtype=self.dtype) + + # encoder outputs might need to be projected to different dimension for decoder + if ( + self.encoder.config.hidden_size != self.decoder.config.hidden_size + and self.decoder.config.cross_attention_hidden_size is None + ): + self.enc_to_dec_proj = nn.Dense( + self.decoder.config.hidden_size, + kernel_init=jax.nn.initializers.normal(self.decoder.config.initializer_range), + dtype=self.dtype, + ) + else: + self.enc_to_dec_proj = None + + def _get_feat_extract_output_lengths( + self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None + ): + """ + Computes the output length of the convolutional layers + """ + + add_adapter = self.config.encoder.add_adapter if add_adapter is None else add_adapter + + def _conv_out_length(input_length, kernel_size, stride): + # 1D convolutional layer output length formula taken + # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html + return (input_length - kernel_size) // stride + 1 + + for kernel_size, stride in zip(self.config.encoder.conv_kernel, self.config.encoder.conv_stride): + input_lengths = _conv_out_length(input_lengths, kernel_size, stride) + + if add_adapter: + for _ in range(self.config.encoder.num_adapter_layers): + input_lengths = _conv_out_length(input_lengths, 1, self.config.encoder.adapter_stride) + + return input_lengths + + def _get_encoder_module(self): + return self.encoder + + def _get_projection_module(self): + return self.enc_to_dec_proj + + def _get_decoder_module(self): + return self.decoder + + def __call__( + self, + inputs, + attention_mask, + decoder_input_ids, + decoder_attention_mask, + decoder_position_ids, + encoder_outputs=None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + freeze_feature_encoder: bool = False, + ): + if encoder_outputs is None: + encoder_outputs = self.encoder( + inputs, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + freeze_feature_encoder=freeze_feature_encoder, + ) + + encoder_hidden_states = encoder_outputs[0] + + # optionally project encoder_hidden_states + if self.enc_to_dec_proj is not None: + encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states) + + # compute correct encoder attention mask + if attention_mask is not None: + encoder_attention_mask = self.encoder._get_feature_vector_attention_mask( + encoder_hidden_states.shape[1], attention_mask + ) + else: + encoder_attention_mask = None + + # flax script modeling_flax_wav2vec2.py + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + position_ids=decoder_position_ids, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return FlaxSeq2SeqLMOutput( + logits=decoder_outputs.logits, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_hidden_states, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings(SPEECH_ENCODER_DECODER_START_DOCSTRING) +class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel): + r""" + [`FlaxSpeechEncoderDecoderModel`] is a generic model class that will be instantiated as a transformer architecture + with the module (flax.nn.Module) of one of the base model classes of the library as encoder module and another one + as decoder module when created with the :meth*~transformers.FlaxAutoModel.from_pretrained* class method for the + encoder and :meth*~transformers.FlaxAutoModelForCausalLM.from_pretrained* class method for the decoder. + """ + + config_class = SpeechEncoderDecoderConfig + base_model_prefix: str = "speech_encoder_decoder" + module_class = FlaxSpeechEncoderDecoderModule + + def __init__( + self, + config: SpeechEncoderDecoderConfig, + input_shape: Optional[Tuple] = None, + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, + **kwargs, + ): + if not _do_init: + raise ValueError( + "`FlaxSpeechEncoderDecoderModel` cannot be created without initializing, `_do_init` must be `True`." + ) + + if config.decoder.cross_attention_hidden_size is not None: + # Raise ValueError or option to project enc to dec hidden_size (eg EncAdapterLayer) + if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size: + raise ValueError( + "If `cross_attention_hidden_size` is specified in the decoder's configuration, it has to be equal" + f" to the encoder's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for" + f" `config.decoder.cross_attention_hidden_size` and {config.encoder.hidden_size} for" + " `config.encoder.hidden_size`." + ) + + # make sure input & output embeddings are not tied + config.tie_word_embeddings = False + module = self.module_class(config=config, dtype=dtype, **kwargs) + + if input_shape is None: + # speech encoders almost always downsample the sequence length dimension + encoder_input_length = 1024 + decoder_input_length = module._get_feat_extract_output_lengths(encoder_input_length) + input_shape = ((1, encoder_input_length), (1, decoder_input_length)) + + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: + encoder_input_shape, decoder_input_shape = input_shape + + # init input DeviceArrays + inputs = jnp.zeros(encoder_input_shape, dtype="f4") + attention_mask = jnp.ones_like(inputs, dtype="i4") + decoder_input_ids = jnp.zeros(decoder_input_shape, dtype="i4") + decoder_attention_mask = jnp.ones_like(decoder_input_ids) + + batch_size, sequence_length = inputs.shape + + decoder_batch_size, decoder_sequence_length = decoder_input_ids.shape + if not decoder_batch_size == batch_size: + raise ValueError( + f"The inputs of encoder and decoder should have the same batch size, but got {batch_size} for encoder" + f" and {decoder_batch_size} for decoder." + ) + decoder_position_ids = jnp.broadcast_to( + jnp.arange(decoder_sequence_length)[None, :], (decoder_batch_size, decoder_sequence_length) + ) + + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + random_params = self.module.init( + rngs, + inputs, + attention_mask, + decoder_input_ids, + decoder_attention_mask, + decoder_position_ids, + )["params"] + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + + def init_cache(self, batch_size, max_length, encoder_outputs): + r""" + Args: + batch_size (`int`): + batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. + max_length (`int`): + maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized + cache. + encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`): + `encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: + `attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) + is a sequence of hidden-states at the output of the last layer of the encoder. Used in the + cross-attention of the decoder. + """ + # init input variables to retrieve cache + decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4") + decoder_attention_mask = jnp.ones_like(decoder_input_ids) + decoder_position_ids = jnp.broadcast_to( + jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape + ) + + def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): + decoder_module = module._get_decoder_module() + return decoder_module( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + position_ids=decoder_position_ids, + **kwargs, + ) + + init_variables = self.module.init( + jax.random.PRNGKey(0), + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + decoder_position_ids=decoder_position_ids, + encoder_hidden_states=encoder_outputs[0], + init_cache=True, + method=_decoder_forward, # we only need to call the decoder to init the cache + ) + return unfreeze(init_variables["cache"]) + + def _get_feat_extract_output_lengths( + self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None + ): + return self.module._get_feat_extract_output_lengths(input_lengths, add_adapter=add_adapter) + + @add_start_docstrings(SPEECH_ENCODER_DECODER_ENCODE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=_CONFIG_FOR_DOC) + def encode( + self, + inputs: jnp.ndarray, + attention_mask: Optional[jnp.ndarray] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + freeze_feature_encoder: bool = False, + params: dict = None, + dropout_rng: PRNGKey = None, + ): + r""" + Returns: + + Example: + + ```python + >>> from transformers import FlaxSpeechEncoderDecoderModel + + >>> # initialize a wav2vec2-2-bart from pretrained wav2vec2 and bart models. Note that the cross-attention layers will be randomly initialized + >>> model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained( + ... "facebook/wav2vec2-large-lv60", "facebook/bart-large" + ... ) + + >>> inputs = jnp.ones((2, 5000), dtype=jnp.float32) + >>> encoder_outputs = model.encode(inputs) + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + if attention_mask is None: + attention_mask = jnp.ones_like(inputs, dtype="i4") + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + def _encoder_forward(module, inputs, attention_mask, **kwargs): + encode_module = module._get_encoder_module() + return encode_module(inputs, attention_mask, **kwargs) + + outputs = self.module.apply( + {"params": params or self.params}, + inputs=jnp.array(inputs, dtype="f4"), + attention_mask=jnp.array(attention_mask, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + freeze_feature_encoder=freeze_feature_encoder, + rngs=rngs, + method=_encoder_forward, + ) + + if return_dict: + outputs = FlaxBaseModelOutput( + last_hidden_state=outputs.last_hidden_state, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + return outputs + + @add_start_docstrings(SPEECH_ENCODER_DECODER_DECODE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) + def decode( + self, + decoder_input_ids, + encoder_outputs, + encoder_attention_mask: Optional[jnp.ndarray] = None, + decoder_attention_mask: Optional[jnp.ndarray] = None, + decoder_position_ids: Optional[jnp.ndarray] = None, + past_key_values: dict = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: dict = None, + dropout_rng: PRNGKey = None, + ): + r""" + Returns: + + Example: + + ```python + >>> from transformers import FlaxSpeechEncoderDecoderModel + >>> import jax.numpy as jnp + + >>> # initialize a wav2vec2-2-bart from pretrained wav2vec2 and bart models. Note that the cross-attention layers will be randomly initialized + >>> model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained( + ... "facebook/wav2vec2-large-lv60", "facebook/bart-large" + ... ) + + >>> inputs = jnp.ones((2, 5000), dtype=jnp.float32) + >>> encoder_outputs = model.encode(inputs) + + >>> decoder_start_token_id = model.config.decoder.bos_token_id + >>> decoder_input_ids = jnp.ones((inputs.shape[0], 1), dtype="i4") * decoder_start_token_id + + >>> outputs = model.decode(decoder_input_ids, encoder_outputs) + >>> logits = outputs.logits + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + encoder_hidden_states = encoder_outputs[0] + if encoder_attention_mask is None: + batch_size, sequence_length = encoder_hidden_states.shape[:2] + encoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + batch_size, sequence_length = decoder_input_ids.shape + if decoder_attention_mask is None: + decoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + if decoder_position_ids is None: + if past_key_values is not None: + raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.") + + decoder_position_ids = jnp.broadcast_to( + jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) + ) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + params = {"params": params or self.params} + + # if past_key_values are passed then cache is already initialized a private flag init_cache has to be + # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that + # it can be changed by FlaxBartAttention module + if past_key_values: + params["cache"] = past_key_values + mutable = ["cache"] + else: + mutable = False + + def _decoder_forward( + module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, encoder_hidden_states, **kwargs + ): + projection_module = module._get_projection_module() + decoder_module = module._get_decoder_module() + + # optionally project encoder_hidden_states + if projection_module is not None: + encoder_hidden_states = projection_module(encoder_hidden_states) + + return decoder_module( + decoder_input_ids, + decoder_attention_mask, + decoder_position_ids, + encoder_hidden_states=encoder_hidden_states, + **kwargs, + ) + + outputs = self.module.apply( + params, + decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), + decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), + decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + mutable=mutable, + method=_decoder_forward, + ) + + # add updated cache to model output + if past_key_values is not None and return_dict: + outputs, past = outputs + outputs["past_key_values"] = unfreeze(past["cache"]) + return outputs + elif past_key_values is not None and not return_dict: + outputs, past = outputs + outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] + + return outputs + + @add_start_docstrings_to_model_forward(SPEECH_ENCODER_DECODER_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + def __call__( + self, + inputs: jnp.ndarray, + attention_mask: Optional[jnp.ndarray] = None, + decoder_input_ids: Optional[jnp.ndarray] = None, + decoder_attention_mask: Optional[jnp.ndarray] = None, + decoder_position_ids: Optional[jnp.ndarray] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + freeze_feature_encoder: bool = False, + params: dict = None, + dropout_rng: PRNGKey = None, + ): + r""" + Returns: + + Examples: + + ```python + >>> from transformers import FlaxSpeechEncoderDecoderModel, AutoTokenizer + + >>> # load a fine-tuned wav2vec2-2-bart model + >>> model = FlaxSpeechEncoderDecoderModel.from_pretrained("patrickvonplaten/wav2vec2-2-bart-large") + >>> # load output tokenizer + >>> tokenizer_output = AutoTokenizer.from_pretrained("facebook/bart-large") + + >>> inputs = jnp.ones((2, 5000), dtype=jnp.float32) + + >>> # use bart's special bos, pad and eos tokens + >>> model.config.decoder_start_token_id = model.decoder.config.bos_token_id + >>> model.config.pad_token_id = model.decoder.config.pad_token_id + >>> model.config.eos_token_id = model.decoder.config.eos_token_id + + >>> outputs = model.generate(inputs) + # Assert something? More interesting input? dtype correct? + ``` + """ + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + # prepare encoder inputs + if attention_mask is None: + attention_mask = jnp.ones_like(inputs, dtype="i4") + + # prepare decoder inputs + if decoder_input_ids is None: + raise ValueError( + "`decoder_input_ids` cannot be `None`. For sequence to sequence training, `decoder_position_ids` must" + " be specified as an input argument." + ) + if decoder_attention_mask is None: + decoder_attention_mask = jnp.ones_like(decoder_input_ids) + if decoder_position_ids is None: + batch_size, sequence_length = decoder_input_ids.shape + decoder_position_ids = jnp.broadcast_to( + jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) + ) + + # Handle any PRNG if needed + rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} + + return self.module.apply( + {"params": params or self.params}, + inputs=jnp.array(inputs, dtype="f4"), + attention_mask=jnp.array(attention_mask, dtype="i4"), + decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), + decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), + decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + freeze_feature_encoder=freeze_feature_encoder, + rngs=rngs, + ) + + def prepare_inputs_for_generation( + self, + decoder_input_ids, + max_length, + attention_mask: Optional[jax.Array] = None, + decoder_attention_mask: Optional[jax.Array] = None, + encoder_outputs=None, + **kwargs, + ): + # initializing the cache + batch_size, seq_length = decoder_input_ids.shape + + past_key_values = self.init_cache(batch_size, max_length, encoder_outputs) + # Note that usually one would have to put 0's in the attention_mask for x > input.shape[-1] and x < cache_length. + # But since the decoder uses a causal mask, those positions are masked anyways. + # Thus we can create a single static attention_mask here, which is more efficient for compilation + extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") + if decoder_attention_mask is not None: + decoder_position_ids = decoder_attention_mask.cumsum(axis=-1) - 1 + extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0)) + else: + decoder_position_ids = jnp.broadcast_to( + jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length) + ) + + return { + "past_key_values": past_key_values, + "encoder_outputs": encoder_outputs, + "encoder_attention_mask": attention_mask, + "decoder_attention_mask": extended_attention_mask, + "decoder_position_ids": decoder_position_ids, + } + + def update_inputs_for_generation(self, model_outputs, model_kwargs): + model_kwargs["past_key_values"] = model_outputs.past_key_values + model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1 + return model_kwargs + + @classmethod + def from_encoder_decoder_pretrained( + cls, + encoder_pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, + decoder_pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, + *model_args, + **kwargs, + ) -> FlaxPreTrainedModel: + r""" + Instantiate an encoder and a decoder from one or two base classes of the library from pretrained model + checkpoints. + + Params: + encoder_pretrained_model_name_or_path (`Union[str, os.PathLike]`, *optional*): + Information necessary to initiate the encoder. Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + - A path to a *directory* containing model weights saved using + [`~FlaxPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. + + decoder_pretrained_model_name_or_path (`Union[str, os.PathLike]`, *optional*, defaults to `None`): + Information necessary to initiate the decoder. Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + - A path to a *directory* containing model weights saved using + [`~FlaxPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. + + model_args (remaining positional arguments, *optional*): + All remaning positional arguments will be passed to the underlying model's `__init__` method. + + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., + `output_attentions=True`). + + - To update the encoder configuration, use the prefix *encoder_* for each configuration parameter. + - To update the decoder configuration, use the prefix *decoder_* for each configuration parameter. + - To update the parent model configuration, do not use a prefix for each configuration parameter. + + Behaves differently depending on whether a `config` is provided or automatically loaded. + + Example: + + ```python + >>> from transformers import FlaxSpeechEncoderDecoderModel + + >>> # initialize a wav2vec2-2-bart from pretrained wav2vec2 and bart models. Note that the cross-attention layers will be randomly initialized + >>> model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained( + ... "facebook/wav2vec2-large-lv60", "facebook/bart-large" + ... ) + >>> # saving model after fine-tuning + >>> model.save_pretrained("./wav2vec2-2-bart-large") + >>> # load fine-tuned model + >>> model = FlaxSpeechEncoderDecoderModel.from_pretrained("./wav2vec2-2-bart-large") + ```""" + + kwargs_encoder = { + argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_") + } + + kwargs_decoder = { + argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") + } + + # remove encoder, decoder kwargs from kwargs + for key in kwargs_encoder.keys(): + del kwargs["encoder_" + key] + for key in kwargs_decoder.keys(): + del kwargs["decoder_" + key] + + # Load and initialize the encoder and decoder + # The distinction between encoder and decoder at the model level is made + # by the value of the flag `is_decoder` that we need to set correctly. + encoder = kwargs_encoder.pop("model", None) + if encoder is None: + if encoder_pretrained_model_name_or_path is None: + raise ValueError( + "If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has " + "to be defined." + ) + + if "config" not in kwargs_encoder: + encoder_config, kwargs_encoder = AutoConfig.from_pretrained( + encoder_pretrained_model_name_or_path, **kwargs_encoder, return_unused_kwargs=True + ) + if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True: + logger.info( + f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model " + "from a decoder model. Cross-attention and casual mask are disabled." + ) + encoder_config.is_decoder = False + encoder_config.add_cross_attention = False + + kwargs_encoder["config"] = encoder_config + + encoder = FlaxAutoModel.from_pretrained( + encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder + ) + + decoder = kwargs_decoder.pop("model", None) + if decoder is None: + if decoder_pretrained_model_name_or_path is None: + raise ValueError( + "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has " + "to be defined." + ) + + if "config" not in kwargs_decoder: + decoder_config, kwargs_decoder = AutoConfig.from_pretrained( + decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True + ) + if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False: + logger.info( + f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention" + f" layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if" + f" {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers." + ) + decoder_config.is_decoder = True + decoder_config.add_cross_attention = True + + kwargs_decoder["config"] = decoder_config + + if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False: + logger.warning( + f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. " + f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, " + "make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` " + "passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a " + "`decoder_config` to `.from_encoder_decoder_pretrained(...)`" + ) + + decoder = FlaxAutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder) + + # instantiate config with corresponding kwargs + dtype = kwargs.pop("dtype", jnp.float32) + config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs) + + # make sure input & output word embeddings are not tied + config.tie_word_embeddings = False + + # init model + model = cls(config, dtype=dtype) + model.params["encoder"] = encoder.params + model.params["decoder"] = decoder.params + + return model + + +__all__ = ["FlaxSpeechEncoderDecoderModel"] diff --git a/docs/transformers/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py b/docs/transformers/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..425c3f7d5b35e9fbd362386f7b09a85531595299 --- /dev/null +++ b/docs/transformers/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py @@ -0,0 +1,587 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Classes to support Speech-Encoder-Text-Decoder architectures""" + +from typing import Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss + +from ...configuration_utils import PretrainedConfig +from ...generation import GenerationMixin +from ...modeling_outputs import BaseModelOutput, Seq2SeqLMOutput +from ...modeling_utils import PreTrainedModel +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from ..auto.configuration_auto import AutoConfig +from ..auto.modeling_auto import AutoModel, AutoModelForCausalLM +from .configuration_speech_encoder_decoder import SpeechEncoderDecoderConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "SpeechEncoderDecoderConfig" + +SPEECH_ENCODER_DECODER_START_DOCSTRING = r""" + This class can be used to initialize a speech-sequence-to-text-sequence model with any pretrained speech + autoencoding model as the encoder and any pretrained text autoregressive model as the decoder. The encoder is + loaded via [`~AutoModel.from_pretrained`] function and the decoder is loaded via + [`~AutoModelForCausalLM.from_pretrained`] function. Cross-attention layers are automatically added to the decoder + and should be fine-tuned on a downstream generative task, like summarization. + + The effectiveness of initializing sequence-to-sequence models with pretrained checkpoints for sequence generation + tasks was shown in [Leveraging Pre-trained Checkpoints for Sequence Generation + Tasks](https://arxiv.org/abs/1907.12461) by Sascha Rothe, Shashi Narayan, Aliaksei Severyn. Michael Matena, Yanqi + Zhou, Wei Li, Peter J. Liu. + + Additionally, in [Large-Scale Self- and Semi-Supervised Learning for Speech + Translation](https://arxiv.org/abs/2104.06678) it is shown how leveraging large pretrained speech models for speech + translation yields a significant performance improvement. + + After such an Speech-Encoder Decoder model has been trained/fine-tuned, it can be saved/loaded just like any other + models (see the examples for more information). + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`SpeechEncoderDecoderConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +SPEECH_ENCODER_DECODER_INPUTS_DOCSTRING = r""" + Args: + inputs (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, feature_dim)`, *optional*): + Float values of input raw speech waveform or speech features. Values can be obtained by loading a `.flac` + or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile + library (`pip install soundfile`). To prepare the array into `inputs`, either the [`Wav2Vec2Processor`] or + [`Speech2TextProcessor`] should be used for padding and conversion into a tensor of type + `torch.FloatTensor`. + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + For training, `decoder_input_ids` are automatically created by the model by shifting the `labels` to the + right, replacing -100 by the `pad_token_id` and prepending them with the `decoder_start_token_id`. + decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + encoder_outputs (`tuple(torch.FloatTensor)`, *optional*): + This tuple must consist of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) + `last_hidden_state` (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`) is a tensor + of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the + decoder. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded + representation. This is useful if you want more control over how to convert `decoder_input_ids` indices + into associated vectors than the model's internal embedding lookup matrix. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss for the decoder. Indices should be in `[-100, 0, + ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Float values of input raw speech waveform. Values can be obtained by loading a *.flac* or *.wav* audio file + into an array of type *List[float]* or a *numpy.ndarray*, *e.g.* via the soundfile library (*pip install + soundfile*). To prepare the array into *input_values*, the [`Wav2Vec2Processor`] should be used for padding + and conversion into a tensor of type *torch.FloatTensor*. See [`Wav2Vec2Processor.__call__`] for details. + input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, feature_size)`, *optional*): + Float values of fbank features extracted from the raw speech waveform. Raw speech waveform can be obtained + by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* + via the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the + [`Speech2TextFeatureExtractor`] should be used for extracting the fbank features, padding and conversion + into a tensor of type `torch.FloatTensor`. See [`~Speech2TextFeatureExtractor.__call__`] + return_dict (`bool`, *optional*): + If set to `True`, the model will return a [`~utils.Seq2SeqLMOutput`] instead of a plain tuple. + kwargs (*optional*): Remaining dictionary of keyword arguments. Keyword arguments come in two flavors: + + - Without a prefix which will be input as `**encoder_kwargs` for the encoder forward function. + - With a *decoder_* prefix which will be input as `**decoder_kwargs` for the decoder forward function. +""" + + +# Copied from transformers.models.encoder_decoder.modeling_encoder_decoder.shift_tokens_right +def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): + """ + Shift input ids one token to the right. + """ + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() + if decoder_start_token_id is None: + raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.") + shifted_input_ids[:, 0] = decoder_start_token_id + + if pad_token_id is None: + raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.") + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +@add_start_docstrings(SPEECH_ENCODER_DECODER_START_DOCSTRING) +class SpeechEncoderDecoderModel(PreTrainedModel, GenerationMixin): + r""" + [`SpeechEncoderDecoderModel`] is a generic model class that will be instantiated as a transformer architecture with + one of the base model classes of the library as encoder and another one as decoder when created with the + :meth*~transformers.AutoModel.from_pretrained* class method for the encoder and + :meth*~transformers.AutoModelForCausalLM.from_pretrained* class method for the decoder. + """ + + config_class = SpeechEncoderDecoderConfig + base_model_prefix = "speech_encoder_decoder" + main_input_name = "inputs" + supports_gradient_checkpointing = True + _supports_param_buffer_assignment = False + _supports_flash_attn_2 = True + _supports_sdpa = True + + def __init__( + self, + config: Optional[PretrainedConfig] = None, + encoder: Optional[PreTrainedModel] = None, + decoder: Optional[PreTrainedModel] = None, + ): + if config is None and (encoder is None or decoder is None): + raise ValueError("Either a configuration or an encoder and a decoder has to be provided.") + if config is None: + config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config) + else: + if not isinstance(config, self.config_class): + raise ValueError(f"Config: {config} has to be of type {self.config_class}") + + if config.decoder.cross_attention_hidden_size is not None: + if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size: + raise ValueError( + "If `cross_attention_hidden_size` is specified in the decoder's configuration, it has to be equal" + f" to the encoder's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for" + f" `config.decoder.cross_attention_hidden_size` and {config.encoder.hidden_size} for" + " `config.encoder.hidden_size`." + ) + + # initialize with config + # make sure input & output embeddings is not tied + config.tie_word_embeddings = False + super().__init__(config) + + if encoder is None: + encoder = AutoModel.from_config(config.encoder) + + if decoder is None: + decoder = AutoModelForCausalLM.from_config(config.decoder) + + self.encoder = encoder + self.decoder = decoder + + if self.encoder.config.to_dict() != self.config.encoder.to_dict(): + logger.warning( + f"Config of the encoder: {self.encoder.__class__} is overwritten by shared encoder config:" + f" {self.config.encoder}" + ) + if self.decoder.config.to_dict() != self.config.decoder.to_dict(): + logger.warning( + f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config:" + f" {self.config.decoder}" + ) + + # make sure that the individual model's config refers to the shared config + # so that the updates to the config will be synced + self.config.encoder._attn_implementation = self.encoder.config._attn_implementation + self.config.decoder._attn_implementation = self.decoder.config._attn_implementation + self.encoder.config = self.config.encoder + self.decoder.config = self.config.decoder + + # get encoder output hidden size + self.encoder_output_dim = getattr(config.encoder, "output_hidden_size", config.encoder.hidden_size) + if ( + self.encoder_output_dim != self.decoder.config.hidden_size + and self.decoder.config.cross_attention_hidden_size is None + ): + # encoder outputs might need to be projected to different dimension for decoder + self.enc_to_dec_proj = nn.Linear(self.encoder.config.hidden_size, self.decoder.config.hidden_size) + + if self.encoder.get_output_embeddings() is not None: + raise ValueError( + f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head" + ) + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + def get_input_embeddings(self): + return self.decoder.get_input_embeddings() + + def get_output_embeddings(self): + return self.decoder.get_output_embeddings() + + def set_output_embeddings(self, new_embeddings): + return self.decoder.set_output_embeddings(new_embeddings) + + def freeze_feature_encoder(self): + """ + Calling this function will disable the gradient computation for the feature encoder of the speech encoder so + that its parameters will not be updated during training. + """ + self.encoder.freeze_feature_encoder() + + @classmethod + def from_encoder_decoder_pretrained( + cls, + encoder_pretrained_model_name_or_path: Optional[str] = None, + decoder_pretrained_model_name_or_path: Optional[str] = None, + *model_args, + **kwargs, + ) -> PreTrainedModel: + r""" + Instantiate an encoder and a decoder from one or two base classes of the library from pretrained model + checkpoints. + + + The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train + the model, you need to first set it back in training mode with `model.train()`. + + Params: + encoder_pretrained_model_name_or_path (`str`, *optional*): + Information necessary to initiate the encoder. Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + - A path to a *directory* containing model weights saved using + [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. + - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In + this case, `from_tf` should be set to `True` and a configuration object should be provided as + `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a + PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. + + decoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `None`): + Information necessary to initiate the decoder. Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + - A path to a *directory* containing model weights saved using + [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. + - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In + this case, `from_tf` should be set to `True` and a configuration object should be provided as + `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a + PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. + + model_args (remaining positional arguments, *optional*): + All remaning positional arguments will be passed to the underlying model's `__init__` method. + + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., + `output_attentions=True`). + + - To update the encoder configuration, use the prefix *encoder_* for each configuration parameter. + - To update the decoder configuration, use the prefix *decoder_* for each configuration parameter. + - To update the parent model configuration, do not use a prefix for each configuration parameter. + + Behaves differently depending on whether a `config` is provided or automatically loaded. + + Example: + + ```python + >>> from transformers import SpeechEncoderDecoderModel + + >>> # initialize a wav2vec2bert from a pretrained Wav2Vec2 and a pretrained BERT model. Note that the cross-attention layers will be randomly initialized + >>> model = SpeechEncoderDecoderModel.from_encoder_decoder_pretrained( + ... "facebook/wav2vec2-base-960h", "google-bert/bert-base-uncased" + ... ) + >>> # saving model after fine-tuning + >>> model.save_pretrained("./wav2vec2bert") + >>> # load fine-tuned model + >>> model = SpeechEncoderDecoderModel.from_pretrained("./wav2vec2bert") + ```""" + + kwargs_encoder = { + argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_") + } + + kwargs_decoder = { + argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") + } + + # remove encoder, decoder kwargs from kwargs + for key in kwargs_encoder.keys(): + del kwargs["encoder_" + key] + for key in kwargs_decoder.keys(): + del kwargs["decoder_" + key] + + # Load and initialize the encoder and decoder + # The distinction between encoder and decoder at the model level is made + # by the value of the flag `is_decoder` that we need to set correctly. + encoder = kwargs_encoder.pop("model", None) + if encoder is None: + if encoder_pretrained_model_name_or_path is None: + raise ValueError( + "If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has " + "to be defined." + ) + + if "config" not in kwargs_encoder: + encoder_config, kwargs_encoder = AutoConfig.from_pretrained( + encoder_pretrained_model_name_or_path, **kwargs_encoder, return_unused_kwargs=True + ) + + if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True: + logger.info( + f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model " + "from a decoder model. Cross-attention and casual mask are disabled." + ) + encoder_config.is_decoder = False + encoder_config.add_cross_attention = False + + kwargs_encoder["config"] = encoder_config + + encoder = AutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder) + + decoder = kwargs_decoder.pop("model", None) + if decoder is None: + if decoder_pretrained_model_name_or_path is None: + raise ValueError( + "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has " + "to be defined." + ) + + if "config" not in kwargs_decoder: + decoder_config, kwargs_decoder = AutoConfig.from_pretrained( + decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True + ) + + if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False: + logger.info( + f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention" + f" layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if" + f" {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers." + ) + decoder_config.is_decoder = True + decoder_config.add_cross_attention = True + + kwargs_decoder["config"] = decoder_config + + if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False: + logger.warning( + f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. " + f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, " + "make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` " + "passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a " + "`decoder_config` to `.from_encoder_decoder_pretrained(...)`" + ) + + decoder = AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder) + + # instantiate config with corresponding kwargs + config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs) + + # make sure input & output embeddings is not tied + config.tie_word_embeddings = False + return cls(encoder=encoder, decoder=decoder, config=config) + + @add_start_docstrings_to_model_forward(SPEECH_ENCODER_DECODER_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + inputs: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + input_values: Optional[torch.FloatTensor] = None, + input_features: Optional[torch.FloatTensor] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: + r""" + Returns: + + Examples: + + ```python + >>> from transformers import SpeechEncoderDecoderModel, AutoProcessor + >>> from datasets import load_dataset + >>> import torch + + >>> processor = AutoProcessor.from_pretrained("facebook/wav2vec2-xls-r-300m-en-to-15") + >>> model = SpeechEncoderDecoderModel.from_pretrained("facebook/wav2vec2-xls-r-300m-en-to-15") + + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + + >>> input_values = processor(ds[0]["audio"]["array"], return_tensors="pt").input_values + >>> # Inference: Translate English speech to German + >>> generated = model.generate(input_values) + >>> decoded = processor.batch_decode(generated, skip_special_tokens=True)[0] + >>> decoded + 'Mr. Quilter ist der Apostel der Mittelschicht und wir freuen uns, sein Evangelium willkommen heißen zu können.' + + >>> # Training: Train model on English transcription + >>> labels = processor(text=ds[0]["text"], return_tensors="pt").input_ids + + >>> loss = model(input_values, labels=labels).loss + >>> loss.backward() + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")} + + kwargs_decoder = { + argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") + } + if "num_items_in_batch" in kwargs_encoder: + kwargs_decoder["num_items_in_batch"] = kwargs_encoder.pop("num_items_in_batch", None) + + if encoder_outputs is None: + if inputs is None: + if input_values is not None and input_features is not None: + raise ValueError("You cannot specify both input_values and input_features at the same time") + elif input_values is not None: + inputs = input_values + elif input_features is not None: + inputs = input_features + else: + raise ValueError("You have to specify either input_values or input_features") + + encoder_outputs = self.encoder( + inputs, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **kwargs_encoder, + ) + elif isinstance(encoder_outputs, tuple): + encoder_outputs = BaseModelOutput(*encoder_outputs) + + encoder_hidden_states = encoder_outputs[0] + + # optionally project encoder_hidden_states + if ( + self.encoder_output_dim != self.decoder.config.hidden_size + and self.decoder.config.cross_attention_hidden_size is None + ): + encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states) + + # compute correct encoder attention mask + if attention_mask is not None: + encoder_attention_mask = self.encoder._get_feature_vector_attention_mask( + encoder_hidden_states.shape[1], attention_mask + ) + else: + encoder_attention_mask = None + + if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None): + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + use_cache=use_cache, + past_key_values=past_key_values, + return_dict=return_dict, + **kwargs_decoder, + ) + + # Compute loss independent from decoder (as some shift the logits inside them) + loss = None + if labels is not None: + logits = decoder_outputs.logits if return_dict else decoder_outputs[0] + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.reshape(-1)) + + if not return_dict: + if loss is not None: + return (loss,) + decoder_outputs + encoder_outputs + else: + return decoder_outputs + encoder_outputs + + return Seq2SeqLMOutput( + loss=loss, + logits=decoder_outputs.logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_hidden_states, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) + + def resize_token_embeddings(self, *args, **kwargs): + raise NotImplementedError( + "Resizing the embedding layers via the SpeechEncoderDecoderModel directly is not supported. Please use the" + " respective methods of the wrapped decoder object (model.decoder.resize_token_embeddings(...))" + ) + + def _reorder_cache(self, past_key_values, beam_idx): + # apply decoder cache reordering here + return self.decoder._reorder_cache(past_key_values, beam_idx) + + +__all__ = ["SpeechEncoderDecoderModel"] diff --git a/docs/transformers/src/transformers/models/speech_to_text/__init__.py b/docs/transformers/src/transformers/models/speech_to_text/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ec094769d4aea356895c219ee6a9b64f347890b8 --- /dev/null +++ b/docs/transformers/src/transformers/models/speech_to_text/__init__.py @@ -0,0 +1,31 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_speech_to_text import * + from .feature_extraction_speech_to_text import * + from .modeling_speech_to_text import * + from .modeling_tf_speech_to_text import * + from .processing_speech_to_text import * + from .tokenization_speech_to_text import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/speech_to_text/configuration_speech_to_text.py b/docs/transformers/src/transformers/models/speech_to_text/configuration_speech_to_text.py new file mode 100644 index 0000000000000000000000000000000000000000..fef4069e4e4170421f3b1e634f5c63632d7a74c7 --- /dev/null +++ b/docs/transformers/src/transformers/models/speech_to_text/configuration_speech_to_text.py @@ -0,0 +1,199 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Speech2Text model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class Speech2TextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Speech2TextModel`]. It is used to instantiate a + Speech2Text model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the Speech2Text + [facebook/s2t-small-librispeech-asr](https://huggingface.co/facebook/s2t-small-librispeech-asr) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 10000): + Vocabulary size of the Speech2Text model. Defines the number of different tokens that can be represented by + the `inputs_ids` passed when calling [`Speech2TextModel`] + encoder_layers (`int`, *optional*, defaults to 12): + Number of encoder layers. + encoder_ffn_dim (`int`, *optional*, defaults to 2048): + Dimensionality of the "intermediate" (often named feed-forward) layer in encoder. + encoder_attention_heads (`int`, *optional*, defaults to 4): + Number of attention heads for each attention layer in the Transformer encoder. + decoder_layers (`int`, *optional*, defaults to 6): + Number of decoder layers. + decoder_ffn_dim (`int`, *optional*, defaults to 2048): + Dimensionality of the "intermediate" (often named feed-forward) layer in decoder. + decoder_attention_heads (`int`, *optional*, defaults to 4): + Number of attention heads for each attention layer in the Transformer decoder. + encoder_layerdrop (`float`, *optional*, defaults to 0.0): + The LayerDrop probability for the encoder. See the [LayerDrop paper](https://arxiv.org/abs/1909.11556) for + more details. + decoder_layerdrop (`float`, *optional*, defaults to 0.0): + The LayerDrop probability for the decoder. See the [LayerDrop paper](https://arxiv.org/abs/1909.11556) for + more details. + use_cache (`bool`, *optional*, defaults to `True`): + Whether the model should return the last key/values attentions (not used by all models). + is_encoder_decoder (`bool`, *optional*, defaults to `True`): + Whether the model is set up as an encoder-decoder architecture for sequence-to-sequence tasks. + activation_function (`str` or `function`, *optional*, defaults to `"relu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + d_model (`int`, *optional*, defaults to 256): + Dimensionality of the layers and the pooler layer. + dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + activation_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for activations inside the fully connected layer. + init_std (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + decoder_start_token_id (`int`, *optional*, defaults to 2): + The initial token ID of the decoder when decoding sequences. + scale_embedding (`bool`, *optional*, defaults to `True`): + Whether the embeddings are scaled by the square root of `d_model`. + pad_token_id (`int`, *optional*, defaults to 1): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 0): + The id of the beginning-of-sequence token. + eos_token_id (`int`, *optional*, defaults to 2): + The id of the end-of-sequence token. + max_source_positions (`int`, *optional*, defaults to 6000): + The maximum sequence length of log-mel filter-bank features that this model might ever be used with. + max_target_positions (`int`, *optional*, defaults to 1024): + The maximum sequence length that this model might ever be used with. Typically, set this to something large + just in case (e.g., 512 or 1024 or 2048). + num_conv_layers (`int`, *optional*, defaults to 2): + Number of 1D convolutional layers in the conv module. + conv_kernel_sizes (`Tuple[int]`, *optional*, defaults to `(5, 5)`): + A tuple of integers defining the kernel size of each 1D convolutional layer in the conv module. The length + of `conv_kernel_sizes` has to match `num_conv_layers`. + conv_channels (`int`, *optional*, defaults to 1024): + An integer defining the number of output channels of each convolution layers except the final one in the + conv module. + input_feat_per_channel (`int`, *optional*, defaults to 80): + An integer specifying the size of feature vector. This is also the dimensions of log-mel filter-bank + features. + input_channels (`int`, *optional*, defaults to 1): + An integer specifying number of input channels of the input feature vector. + + Example: + + ```python + >>> from transformers import Speech2TextConfig, Speech2TextModel + + >>> # Initializing a Speech2Text s2t_transformer_s style configuration + >>> configuration = Speech2TextConfig() + + >>> # Initializing a model (with random weights) from the s2t_transformer_s style configuration + >>> model = Speech2TextModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "speech_to_text" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"} + + def __init__( + self, + vocab_size=10000, + encoder_layers=12, + encoder_ffn_dim=2048, + encoder_attention_heads=4, + decoder_layers=6, + decoder_ffn_dim=2048, + decoder_attention_heads=4, + encoder_layerdrop=0.0, + decoder_layerdrop=0.0, + use_cache=True, + is_encoder_decoder=True, + activation_function="relu", + d_model=256, + dropout=0.1, + attention_dropout=0.0, + activation_dropout=0.0, + init_std=0.02, + decoder_start_token_id=2, + scale_embedding=True, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + max_source_positions=6000, + max_target_positions=1024, + num_conv_layers=2, + conv_kernel_sizes=(5, 5), + conv_channels=1024, + input_feat_per_channel=80, + input_channels=1, + **kwargs, + ): + self.vocab_size = vocab_size + self.d_model = d_model + self.encoder_ffn_dim = encoder_ffn_dim + self.encoder_layers = encoder_layers + self.encoder_attention_heads = encoder_attention_heads + self.decoder_ffn_dim = decoder_ffn_dim + self.decoder_layers = decoder_layers + self.decoder_attention_heads = decoder_attention_heads + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.activation_function = activation_function + self.init_std = init_std + self.encoder_layerdrop = encoder_layerdrop + self.decoder_layerdrop = decoder_layerdrop + self.use_cache = use_cache + self.num_hidden_layers = encoder_layers + self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True + self.max_source_positions = max_source_positions + self.max_target_positions = max_target_positions + self.num_conv_layers = num_conv_layers + self.conv_kernel_sizes = list(conv_kernel_sizes) + self.conv_channels = conv_channels + self.input_feat_per_channel = input_feat_per_channel + self.input_channels = input_channels + + if len(self.conv_kernel_sizes) != self.num_conv_layers: + raise ValueError( + "Configuration for convolutional module is incorrect. " + "It is required that `len(config.conv_kernel_sizes)` == `config.num_conv_layers` " + f"but is `len(config.conv_kernel_sizes) = {len(self.conv_kernel_sizes)}`, " + f"`config.num_conv_layers = {self.num_conv_layers}`." + ) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + is_encoder_decoder=is_encoder_decoder, + decoder_start_token_id=decoder_start_token_id, + **kwargs, + ) + + +__all__ = ["Speech2TextConfig"] diff --git a/docs/transformers/src/transformers/models/speech_to_text/convert_s2t_fairseq_to_tfms.py b/docs/transformers/src/transformers/models/speech_to_text/convert_s2t_fairseq_to_tfms.py new file mode 100644 index 0000000000000000000000000000000000000000..9286fae776fd857e4f45b77e7c24f55f5e05ac3b --- /dev/null +++ b/docs/transformers/src/transformers/models/speech_to_text/convert_s2t_fairseq_to_tfms.py @@ -0,0 +1,121 @@ +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +import torch +from torch import nn + +from transformers import Speech2TextConfig, Speech2TextForConditionalGeneration + + +def remove_ignore_keys_(state_dict): + ignore_keys = [ + "encoder.version", + "decoder.version", + "model.encoder.version", + "model.decoder.version", + "decoder.output_projection.weight", + "_float_tensor", + "encoder.embed_positions._float_tensor", + "decoder.embed_positions._float_tensor", + ] + for k in ignore_keys: + state_dict.pop(k, None) + + +def rename_keys(s_dict): + keys = list(s_dict.keys()) + for key in keys: + if "transformer_layers" in key: + s_dict[key.replace("transformer_layers", "layers")] = s_dict.pop(key) + elif "subsample" in key: + s_dict[key.replace("subsample", "conv")] = s_dict.pop(key) + + +def make_linear_from_emb(emb): + vocab_size, emb_size = emb.weight.shape + lin_layer = nn.Linear(vocab_size, emb_size, bias=False) + lin_layer.weight.data = emb.weight.data + return lin_layer + + +def convert_fairseq_s2t_checkpoint_to_tfms(checkpoint_path, pytorch_dump_folder_path): + m2m_100 = torch.load(checkpoint_path, map_location="cpu", weights_only=True) + args = m2m_100["args"] + state_dict = m2m_100["model"] + lm_head_weights = state_dict["decoder.output_projection.weight"] + + remove_ignore_keys_(state_dict) + rename_keys(state_dict) + + vocab_size = state_dict["decoder.embed_tokens.weight"].shape[0] + + tie_embeds = args.share_decoder_input_output_embed + + conv_kernel_sizes = [int(i) for i in args.conv_kernel_sizes.split(",")] + config = Speech2TextConfig( + vocab_size=vocab_size, + max_source_positions=args.max_source_positions, + max_target_positions=args.max_target_positions, + encoder_layers=args.encoder_layers, + decoder_layers=args.decoder_layers, + encoder_attention_heads=args.encoder_attention_heads, + decoder_attention_heads=args.decoder_attention_heads, + encoder_ffn_dim=args.encoder_ffn_embed_dim, + decoder_ffn_dim=args.decoder_ffn_embed_dim, + d_model=args.encoder_embed_dim, + dropout=args.dropout, + attention_dropout=args.attention_dropout, + activation_dropout=args.activation_dropout, + activation_function="relu", + num_conv_layers=len(conv_kernel_sizes), + conv_channels=args.conv_channels, + conv_kernel_sizes=conv_kernel_sizes, + input_feat_per_channel=args.input_feat_per_channel, + input_channels=args.input_channels, + tie_word_embeddings=tie_embeds, + num_beams=5, + max_length=200, + use_cache=True, + decoder_start_token_id=2, + early_stopping=True, + ) + + model = Speech2TextForConditionalGeneration(config) + missing, unexpected = model.model.load_state_dict(state_dict, strict=False) + if len(missing) > 0 and not set(missing) <= { + "encoder.embed_positions.weights", + "decoder.embed_positions.weights", + }: + raise ValueError( + "Only `encoder.embed_positions.weights` and `decoder.embed_positions.weights` are allowed to be missing," + f" but all the following weights are missing {missing}" + ) + + if tie_embeds: + model.lm_head = make_linear_from_emb(model.model.decoder.embed_tokens) + else: + model.lm_head.weight.data = lm_head_weights + + model.save_pretrained(pytorch_dump_folder_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument("--fairseq_path", type=str, help="Path to the fairseq model (.pt) file.") + parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.") + args = parser.parse_args() + convert_fairseq_s2t_checkpoint_to_tfms(args.fairseq_path, args.pytorch_dump_folder_path) diff --git a/docs/transformers/src/transformers/models/speech_to_text/feature_extraction_speech_to_text.py b/docs/transformers/src/transformers/models/speech_to_text/feature_extraction_speech_to_text.py new file mode 100644 index 0000000000000000000000000000000000000000..3abbfacb8de0a584360952e2674ede12ffc2750a --- /dev/null +++ b/docs/transformers/src/transformers/models/speech_to_text/feature_extraction_speech_to_text.py @@ -0,0 +1,315 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Feature extractor class for Speech2Text +""" + +from typing import List, Optional, Union + +import numpy as np + +from ...audio_utils import mel_filter_bank, spectrogram, window_function +from ...feature_extraction_sequence_utils import SequenceFeatureExtractor +from ...feature_extraction_utils import BatchFeature +from ...utils import PaddingStrategy, TensorType, is_speech_available, logging + + +if is_speech_available(): + import torch + import torchaudio.compliance.kaldi as ta_kaldi + +logger = logging.get_logger(__name__) + + +class Speech2TextFeatureExtractor(SequenceFeatureExtractor): + r""" + Constructs a Speech2Text feature extractor. + + This feature extractor inherits from [`Speech2TextFeatureExtractor`] which contains most of the main methods. Users + should refer to this superclass for more information regarding those methods. + + This class extracts mel-filter bank features from raw speech using TorchAudio if installed or using numpy + otherwise, and applies utterance-level cepstral mean and variance normalization to the extracted features. + + Args: + feature_size (`int`, *optional*, defaults to 80): + The feature dimension of the extracted features. + sampling_rate (`int`, *optional*, defaults to 16000): + The sampling rate at which the audio files should be digitalized expressed in hertz (Hz). + num_mel_bins (`int`, *optional*, defaults to 80): + Number of Mel-frequency bins. + padding_value (`float`, *optional*, defaults to 0.0): + The value that is used to fill the padding vectors. + dither (`float`, *optional*, defaults to 0.0): + Adds dithering. In other words, adds a small Gaussian noise to each frame. + E.g. use 4.0 to add dithering with a normal distribution centered + around 0.0 with standard deviation 4.0 (assuming [-32k,+32k] range of kaldi waveform). + The value 0.0 means no dithering. + Dithering has similar effect as `mel_floor`. It reduces the high log_mel_fbank + values for signals with hard-zero sections, when VAD cutoff is present in the signal. + do_ceptral_normalize (`bool`, *optional*, defaults to `True`): + Whether or not to apply utterance-level cepstral mean and variance normalization to extracted features. + normalize_means (`bool`, *optional*, defaults to `True`): + Whether or not to zero-mean normalize the extracted features. + normalize_vars (`bool`, *optional*, defaults to `True`): + Whether or not to unit-variance normalize the extracted features. + """ + + model_input_names = ["input_features", "attention_mask"] + + def __init__( + self, + feature_size=80, + sampling_rate=16000, + num_mel_bins=80, + padding_value=0.0, + dither=0.0, + do_ceptral_normalize=True, + normalize_means=True, + normalize_vars=True, + **kwargs, + ): + super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs) + self.num_mel_bins = num_mel_bins + self.dither = dither + self.do_ceptral_normalize = do_ceptral_normalize + self.normalize_means = normalize_means + self.normalize_vars = normalize_vars + self.return_attention_mask = True + + if not is_speech_available(): + mel_filters = mel_filter_bank( + num_frequency_bins=257, + num_mel_filters=self.num_mel_bins, + min_frequency=20, + max_frequency=sampling_rate // 2, + sampling_rate=sampling_rate, + norm=None, + mel_scale="kaldi", + triangularize_in_mel_space=True, + ) + + self.mel_filters = mel_filters + self.window = window_function(400, "povey", periodic=False) + + def _extract_fbank_features( + self, + waveform: np.ndarray, + ) -> np.ndarray: + """ + Get mel-filter bank features using TorchAudio. Note that TorchAudio requires 16-bit signed integers as inputs + and hence the waveform should not be normalized before feature extraction. + """ + waveform = waveform * (2**15) # Kaldi compliance: 16-bit signed integers + if is_speech_available(): + waveform = torch.from_numpy(waveform).unsqueeze(0) + features = ta_kaldi.fbank( + waveform, + dither=self.dither, + num_mel_bins=self.num_mel_bins, + sample_frequency=self.sampling_rate, + ) + features = features.numpy() + else: + waveform = np.squeeze(waveform) + features = spectrogram( + waveform, + self.window, + frame_length=400, + hop_length=160, + fft_length=512, + power=2.0, + center=False, + dither=self.dither, + preemphasis=0.97, + mel_filters=self.mel_filters, + log_mel="log", + mel_floor=1.192092955078125e-07, + remove_dc_offset=True, + ).T + return features + + @staticmethod + def utterance_cmvn( + x: np.ndarray, + input_length: int, + normalize_means: Optional[bool] = True, + normalize_vars: Optional[bool] = True, + padding_value: float = 0.0, + ) -> np.ndarray: + # make sure we normalize float32 arrays + if normalize_means: + mean = x[:input_length].mean(axis=0) + x = np.subtract(x, mean) + if normalize_vars: + std = x[:input_length].std(axis=0) + x = np.divide(x, std) + + if input_length < x.shape[0]: + x[input_length:] = padding_value + + # make sure array is in float32 + x = x.astype(np.float32) + + return x + + def normalize( + self, input_features: List[np.ndarray], attention_mask: Optional[np.ndarray] = None + ) -> List[np.ndarray]: + lengths = attention_mask.sum(-1) if attention_mask is not None else [x.shape[0] for x in input_features] + return [ + self.utterance_cmvn(x, n, self.normalize_means, self.normalize_vars, self.padding_value) + for x, n in zip(input_features, lengths) + ] + + def __call__( + self, + raw_speech: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]], + padding: Union[bool, str, PaddingStrategy] = False, + max_length: Optional[int] = None, + truncation: bool = False, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + sampling_rate: Optional[int] = None, + return_attention_mask: Optional[bool] = None, + **kwargs, + ) -> BatchFeature: + """ + Main method to featurize and prepare for the model one or several sequence(s). + + Args: + raw_speech (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`): + The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float + values, a list of numpy arrays or a list of list of float values. Must be mono channel audio, not + stereo, i.e. single float per timestep. + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding + index) among: + + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + max_length (`int`, *optional*): + Maximum length of the returned list and optionally padding length (see above). + truncation (`bool`): + Activates truncation to cut input sequences longer than *max_length* to *max_length*. + pad_to_multiple_of (`int`, *optional*): + If set will pad the sequence to a multiple of the provided value. + + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability + `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. + return_attention_mask (`bool`, *optional*): + Whether to return the attention mask. If left to the default, will return the attention mask according + to the specific feature_extractor's default. + + [What are attention masks?](../glossary#attention-mask) + + + + For Speech2TextTransformer models, `attention_mask` should always be passed for batched inference, to + avoid subtle bugs. + + + + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors instead of list of python integers. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return Numpy `np.ndarray` objects. + sampling_rate (`int`, *optional*): + The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass + `sampling_rate` at the forward call to prevent silent errors. + padding_value (`float`, *optional*, defaults to 0.0): + The value that is used to fill the padding values / vectors. + """ + + if sampling_rate is not None: + if sampling_rate != self.sampling_rate: + raise ValueError( + f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of" + f" {self.sampling_rate}. Please make sure that the provided `raw_speech` input was sampled with" + f" {self.sampling_rate} and not {sampling_rate}." + ) + else: + logger.warning( + f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. " + "Failing to do so can result in silent errors that might be hard to debug." + ) + + is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1 + if is_batched_numpy and len(raw_speech.shape) > 2: + raise ValueError(f"Only mono-channel audio is supported for input to {self}") + is_batched = is_batched_numpy or ( + isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list))) + ) + + if is_batched: + raw_speech = [np.asarray(speech, dtype=np.float32) for speech in raw_speech] + elif not is_batched and not isinstance(raw_speech, np.ndarray): + raw_speech = np.asarray(raw_speech, dtype=np.float32) + elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(np.float64): + raw_speech = raw_speech.astype(np.float32) + + # always return batch + if not is_batched: + raw_speech = [raw_speech] + + # extract fbank features + features = [self._extract_fbank_features(waveform) for waveform in raw_speech] + + # convert into correct format for padding + encoded_inputs = BatchFeature({"input_features": features}) + + padded_inputs = self.pad( + encoded_inputs, + padding=padding, + max_length=max_length, + truncation=truncation, + pad_to_multiple_of=pad_to_multiple_of, + return_attention_mask=return_attention_mask, + **kwargs, + ) + + # make sure list is in array format + input_features = padded_inputs.get("input_features") + if isinstance(input_features[0], list): + padded_inputs["input_features"] = [np.asarray(feature, dtype=np.float32) for feature in input_features] + + attention_mask = padded_inputs.get("attention_mask") + if attention_mask is not None: + padded_inputs["attention_mask"] = [np.asarray(array, dtype=np.int32) for array in attention_mask] + + # Utterance-level cepstral mean and variance normalization + if self.do_ceptral_normalize: + attention_mask = ( + np.array(attention_mask, dtype=np.int32) + if self._get_padding_strategies(padding, max_length=max_length) is not PaddingStrategy.DO_NOT_PAD + else None + ) + padded_inputs["input_features"] = self.normalize( + padded_inputs["input_features"], attention_mask=attention_mask + ) + + if return_tensors is not None: + padded_inputs = padded_inputs.convert_to_tensors(return_tensors) + + return padded_inputs + + +__all__ = ["Speech2TextFeatureExtractor"] diff --git a/docs/transformers/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/docs/transformers/src/transformers/models/speech_to_text/modeling_speech_to_text.py new file mode 100644 index 0000000000000000000000000000000000000000..fd52e64a1065c54050b55a72c6b9fa6b9491892f --- /dev/null +++ b/docs/transformers/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -0,0 +1,1342 @@ +# coding=utf-8 +# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Speech2Text model.""" + +import math +from typing import Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss + +from ...activations import ACT2FN +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, +) +from ...modeling_utils import PreTrainedModel +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from .configuration_speech_to_text import Speech2TextConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "Speech2TextConfig" + + +# Copied from transformers.models.bart.modeling_bart.shift_tokens_right +def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): + """ + Shift input ids one token to the right. + """ + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() + shifted_input_ids[:, 0] = decoder_start_token_id + + if pad_token_id is None: + raise ValueError("self.model.config.pad_token_id has to be defined.") + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +class Conv1dSubsampler(nn.Module): + """ + Convolutional subsampler: a stack of 1D convolution (along temporal dimension) followed by non-linear activation + via gated linear units (https://arxiv.org/abs/1911.08460) + """ + + def __init__(self, config): + super(Conv1dSubsampler, self).__init__() + self.config = config + self.num_layers = config.num_conv_layers + self.in_channels = config.input_feat_per_channel * config.input_channels + self.mid_channels = config.conv_channels + self.out_channels = config.d_model + self.kernel_sizes = config.conv_kernel_sizes + + self.conv_layers = nn.ModuleList( + nn.Conv1d( + self.in_channels if i == 0 else self.mid_channels // 2, + self.mid_channels if i < self.num_layers - 1 else self.out_channels * 2, + kernel_size=k, + stride=2, + padding=k // 2, + ) + for i, k in enumerate(self.kernel_sizes) + ) + + def forward(self, input_features): + hidden_states = input_features.transpose(1, 2).contiguous() # -> B x (C x D) x T + for conv in self.conv_layers: + hidden_states = conv(hidden_states) + hidden_states = nn.functional.glu(hidden_states, dim=1) + hidden_states = hidden_states.transpose(1, 2).contiguous() # -> T x B x (C x D) + return hidden_states + + +class Speech2TextSinusoidalPositionalEmbedding(nn.Module): + """This module produces sinusoidal positional embeddings of any length.""" + + def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None): + super().__init__() + self.offset = 2 + self.embedding_dim = embedding_dim + self.padding_idx = padding_idx + self.make_weights(num_positions + self.offset, embedding_dim, padding_idx) + + def make_weights(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None): + emb_weights = self.get_embedding(num_embeddings, embedding_dim, padding_idx) + if hasattr(self, "weights"): + # in forward put the weights on the correct dtype and device of the param + emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device) + + self.weights = nn.Parameter(emb_weights) + self.weights.requires_grad = False + self.weights.detach_() + + @staticmethod + def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None): + """ + Build sinusoidal embeddings. This matches the implementation in tensor2tensor, but differs slightly from the + description in Section 3.5 of "Attention Is All You Need". + """ + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb) + emb = torch.arange(num_embeddings, dtype=torch.int64).float().unsqueeze(1) * emb.unsqueeze(0) + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1) + if embedding_dim % 2 == 1: + # zero pad + emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) + if padding_idx is not None: + emb[padding_idx, :] = 0 + return emb.to(torch.get_default_dtype()) + + @torch.no_grad() + def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0): + bsz, seq_len = input_ids.size() + # Create the position ids from the input token ids. Any padded tokens remain padded. + position_ids = self.create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length).to( + input_ids.device + ) + + # expand embeddings if needed + max_pos = self.padding_idx + 1 + seq_len + if max_pos > self.weights.size(0): + self.make_weights(max_pos + self.offset, self.embedding_dim, self.padding_idx) + + return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, -1).detach() + + def create_position_ids_from_input_ids( + self, input_ids: torch.Tensor, padding_idx: int, past_key_values_length: Optional[int] = 0 + ): + """ + Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding + symbols are ignored. This is modified from fairseq's `utils.make_positions`. + + Args: + x: torch.Tensor x: + Returns: torch.Tensor + """ + # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. + mask = input_ids.ne(padding_idx).int() + incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask + return incremental_indices.long() + padding_idx + + +# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->Speech2Text +class Speech2TextAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + is_causal: bool = False, + config: Optional[Speech2TextConfig] = None, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + self.config = config + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + self.is_causal = is_causal + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.reshape(*proj_shape) + value_states = value_states.reshape(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value + + +SPEECH_TO_TEXT_ATTENTION_CLASSES = {"eager": Speech2TextAttention} + + +# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Speech2Text, MBART->SPEECH_TO_TEXT +class Speech2TextEncoderLayer(nn.Module): + def __init__(self, config: Speech2TextConfig): + super().__init__() + self.embed_dim = config.d_model + + self.self_attn = SPEECH_TO_TEXT_ATTENTION_CLASSES[config._attn_implementation]( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + config=config, + ) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + layer_head_mask: torch.Tensor, + output_attentions: bool = False, + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states, attn_weights, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + if hidden_states.dtype == torch.float16: + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +# Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Speech2Text, MBART->SPEECH_TO_TEXT +class Speech2TextDecoderLayer(nn.Module): + def __init__(self, config: Speech2TextConfig): + super().__init__() + self.embed_dim = config.d_model + + self.self_attn = SPEECH_TO_TEXT_ATTENTION_CLASSES[config._attn_implementation]( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + is_causal=True, + config=config, + ) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.encoder_attn = SPEECH_TO_TEXT_ATTENTION_CLASSES[config._attn_implementation]( + self.embed_dim, + config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + config=config, + ) + self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) + self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + cross_attn_layer_head_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = True, + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + encoder_hidden_states (`torch.FloatTensor`): + cross attention input to the layer of shape `(batch, seq_len, embed_dim)` + encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of + size `(decoder_attention_heads,)`. + past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + # add present self-attn cache to positions 1,2 of present_key_value tuple + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=self_attn_past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + # Cross-Attention Block + cross_attn_present_key_value = None + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=cross_attn_past_key_value, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + # add cross-attn to positions 3,4 of present_key_value tuple + present_key_value = present_key_value + cross_attn_present_key_value + + # Fully Connected + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class Speech2TextPreTrainedModel(PreTrainedModel): + config_class = Speech2TextConfig + base_model_prefix = "model" + main_input_name = "input_features" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + std = self.config.init_std + if isinstance(module, (nn.Linear, nn.Conv1d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): + """ + Computes the output length of the convolutional layers + """ + for i in range(self.config.num_conv_layers): + input_lengths = (input_lengths - 1) // 2 + 1 + + return input_lengths + + def _get_feature_vector_attention_mask(self, feature_vector_length, attention_mask): + # generate creates 3D attention mask, because of the shape of input_features + # convert it to 2D if thats the case + if len(attention_mask.shape) > 2: + attention_mask = attention_mask[:, :, -1] + + subsampled_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)) + bsz = attention_mask.size()[0] + attention_mask = torch.zeros( + (bsz, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device + ) + + # these two operations makes sure that all values + # before the output lengths indices are attended to + attention_mask[(torch.arange(bsz, device=attention_mask.device), subsampled_lengths - 1)] = 1 + attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).long() + return attention_mask + + +SPEECH_TO_TEXT_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Speech2TextConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +SPEECH_TO_TEXT_INPUTS_DOCSTRING = r""" + Args: + input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, feature_size)`): + Float values of fbank features extracted from the raw speech waveform. Raw speech waveform can be obtained + by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* + via the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the + [`AutoFeatureExtractor`] should be used for extracting the fbank features, padding and conversion into a + tensor of type `torch.FloatTensor`. See [`~Speech2TextFeatureExtractor.__call__`] + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0, + 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`SpeechToTextTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + SpeechToText uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If + `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + + If you want to change padding behavior, you should read + [`modeling_speech_to_text._prepare_decoder_attention_mask`] and modify to your needs. See diagram 1 in [the + paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy. + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + decoder_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): + Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of + hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded + representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be + input (see `past_key_values`). This is useful if you want more control over how to convert + `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class Speech2TextEncoder(Speech2TextPreTrainedModel): + """ + Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a + [`Speech2TextEncoderLayer`]. + + Args: + config: Speech2TextConfig + embed_tokens (nn.Embedding): output embedding + """ + + def __init__(self, config: Speech2TextConfig): + super().__init__(config) + + self.dropout = config.dropout + self.layerdrop = config.encoder_layerdrop + + embed_dim = config.d_model + self.padding_idx = config.pad_token_id + self.max_source_positions = config.max_source_positions + self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + + self.conv = Conv1dSubsampler(config) + + self.embed_positions = Speech2TextSinusoidalPositionalEmbedding( + self.max_source_positions, + embed_dim, + self.padding_idx, + ) + self.layers = nn.ModuleList([Speech2TextEncoderLayer(config) for _ in range(config.encoder_layers)]) + self.layer_norm = nn.LayerNorm(config.d_model) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_features, + attention_mask=None, + head_mask=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + input_features (`torch.LongTensor` of shape `(batch_size, sequence_length, feature_size)`): + Float values of fbank features extracted from the raw speech waveform. Raw speech waveform can be + obtained by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a + `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into + `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the fbank features, + padding and conversion into a tensor of type `torch.FloatTensor`. See + [`~Speech2TextFeatureExtractor.__call__`] + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing convolution and attention on padding token indices. Mask values selected in + `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + inputs_embeds = self.conv(input_features) + inputs_embeds = self.embed_scale * inputs_embeds + + # subsample attention mask if necessary + if attention_mask is not None: + attention_mask = self._get_feature_vector_attention_mask(inputs_embeds.shape[1], attention_mask) + padding_mask = attention_mask.ne(1).long() + else: + padding_mask = torch.zeros(inputs_embeds.shape[:2], dtype=torch.long, device=inputs_embeds.device) + + embed_pos = self.embed_positions(padding_mask) + + hidden_states = inputs_embeds + embed_pos + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + # expand attention_mask + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + assert head_mask.size()[0] == (len(self.layers)), ( + f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." + ) + + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + to_drop = False + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: # skip the layer + to_drop = True + + if to_drop: + layer_outputs = (None, None) + else: + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + attention_mask, + (head_mask[idx] if head_mask is not None else None), + output_attentions, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + hidden_states = self.layer_norm(hidden_states) + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +class Speech2TextDecoder(Speech2TextPreTrainedModel): + """ + Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`Speech2TextDecoderLayer`] + + Args: + config: Speech2TextConfig + embed_tokens (nn.Embedding): output embedding + """ + + def __init__(self, config: Speech2TextConfig): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.decoder_layerdrop + self.padding_idx = config.pad_token_id + self.max_target_positions = config.max_target_positions + self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + + self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) + + self.embed_positions = Speech2TextSinusoidalPositionalEmbedding( + self.max_target_positions, + config.d_model, + self.padding_idx, + ) + + self.layers = nn.ModuleList([Speech2TextDecoderLayer(config) for _ in range(config.decoder_layers)]) + + self.layer_norm = nn.LayerNorm(config.d_model) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + head_mask=None, + cross_attn_head_mask=None, + past_key_values=None, + inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`Speech2TextTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): + Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values + selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention + on hidden heads. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + # embed positions + positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length) + + hidden_states = inputs_embeds + positions + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + next_decoder_cache = () if use_cache else None + + # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): + if attn_mask is not None: + assert attn_mask.size()[0] == (len(self.layers)), ( + f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: + continue + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + head_mask[idx] if head_mask is not None else None, + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, + None, + output_attentions, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=( + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None + ), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + hidden_states = self.layer_norm(hidden_states) + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + +@add_start_docstrings( + "The bare Speech2Text Model outputting raw hidden-states without any specific head on top.", + SPEECH_TO_TEXT_START_DOCSTRING, +) +class Speech2TextModel(Speech2TextPreTrainedModel): + def __init__(self, config: Speech2TextConfig): + super().__init__(config) + + self.encoder = Speech2TextEncoder(config) + self.decoder = Speech2TextDecoder(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.decoder.embed_tokens = value + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + @add_start_docstrings_to_model_forward(SPEECH_TO_TEXT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_features: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: + r""" + Returns: + + Example: + + ```python + >>> import torch + >>> from transformers import Speech2TextModel, AutoFeatureExtractor + >>> from datasets import load_dataset + + >>> model = Speech2TextModel.from_pretrained("facebook/s2t-small-librispeech-asr") + >>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/s2t-small-librispeech-asr") + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + >>> inputs = feature_extractor( + ... ds[0]["audio"]["array"], sampling_rate=ds[0]["audio"]["sampling_rate"], return_tensors="pt" + ... ) + >>> input_features = inputs.input_features + >>> decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id + >>> last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state + >>> list(last_hidden_state.shape) + [1, 2, 256] + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_features, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + # downsample encoder attention mask + if attention_mask is not None: + encoder_attention_mask = self._get_feature_vector_attention_mask( + encoder_outputs[0].shape[1], attention_mask + ) + else: + encoder_attention_mask = None + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=encoder_attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + "The Speech2Text Model with a language modeling head. Can be used for summarization.", + SPEECH_TO_TEXT_START_DOCSTRING, +) +class Speech2TextForConditionalGeneration(Speech2TextPreTrainedModel, GenerationMixin): + base_model_prefix = "model" + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config: Speech2TextConfig): + super().__init__(config) + self.model = Speech2TextModel(config) + self.lm_head = nn.Linear(config.d_model, self.config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_encoder(self): + return self.model.get_encoder() + + def get_decoder(self): + return self.model.get_decoder() + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + @add_start_docstrings_to_model_forward(SPEECH_TO_TEXT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_features: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` + or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is + only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> import torch + >>> from transformers import Speech2TextProcessor, Speech2TextForConditionalGeneration + >>> from datasets import load_dataset + + >>> model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-small-librispeech-asr") + >>> processor = Speech2TextProcessor.from_pretrained("facebook/s2t-small-librispeech-asr") + + + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + + >>> inputs = processor( + ... ds[0]["audio"]["array"], sampling_rate=ds[0]["audio"]["sampling_rate"], return_tensors="pt" + ... ) + >>> input_features = inputs.input_features + + >>> generated_ids = model.generate(inputs=input_features) + + >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + >>> transcription + 'mister quilter is the apostle of the middle classes and we are glad to welcome his gospel' + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None: + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + outputs = self.model( + input_features, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + encoder_outputs=encoder_outputs, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + lm_logits = self.lm_head(outputs[0]) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (lm_logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return Seq2SeqLMOutput( + loss=loss, + logits=lm_logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + + +__all__ = ["Speech2TextForConditionalGeneration", "Speech2TextModel", "Speech2TextPreTrainedModel"] diff --git a/docs/transformers/src/transformers/models/speech_to_text/modeling_tf_speech_to_text.py b/docs/transformers/src/transformers/models/speech_to_text/modeling_tf_speech_to_text.py new file mode 100644 index 0000000000000000000000000000000000000000..7a20e462279a90c87c78630b595835c1558dd6fa --- /dev/null +++ b/docs/transformers/src/transformers/models/speech_to_text/modeling_tf_speech_to_text.py @@ -0,0 +1,1606 @@ +# coding=utf-8 +# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TensorFlow Speech2Text model.""" + +from __future__ import annotations + +import random +from typing import Optional, Tuple, Union + +import numpy as np +import tensorflow as tf + +from ...activations_tf import get_tf_activation, glu +from ...modeling_tf_outputs import ( + TFBaseModelOutput, + TFBaseModelOutputWithPastAndCrossAttentions, + TFSeq2SeqLMOutput, + TFSeq2SeqModelOutput, +) +from ...modeling_tf_utils import ( + TFCausalLanguageModelingLoss, + TFModelInputType, + TFPreTrainedModel, + TFSharedEmbeddings, + keras, + keras_serializable, + unpack_inputs, +) +from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax +from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from .configuration_speech_to_text import Speech2TextConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "Speech2TextConfig" +_CHECKPOINT_FOR_DOC = "facebook/s2t-small-librispeech-asr" + + +LARGE_NEGATIVE = -1e8 + + +# Copied from transformers.models.bart.modeling_tf_bart.shift_tokens_right +def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int): + pad_token_id = tf.cast(pad_token_id, input_ids.dtype) + decoder_start_token_id = tf.cast(decoder_start_token_id, input_ids.dtype) + start_tokens = tf.fill( + (shape_list(input_ids)[0], 1), tf.convert_to_tensor(decoder_start_token_id, input_ids.dtype) + ) + shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1) + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids = tf.where( + shifted_input_ids == -100, + tf.fill(shape_list(shifted_input_ids), tf.convert_to_tensor(pad_token_id, input_ids.dtype)), + shifted_input_ids, + ) + + # "Verify that `labels` has only positive values and -100" + assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype)) + + # Make sure the assertion op is called by wrapping the result in an identity no-op + with tf.control_dependencies([assert_gte0]): + shifted_input_ids = tf.identity(shifted_input_ids) + + return shifted_input_ids + + +# Copied from transformers.models.bart.modeling_tf_bart._make_causal_mask +def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: int = 0): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz = input_ids_shape[0] + tgt_len = input_ids_shape[1] + mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE + mask_cond = tf.range(shape_list(mask)[-1]) + + mask = tf.where(mask_cond < tf.reshape(mask_cond + 1, (shape_list(mask)[-1], 1)), 0.0, mask) + + if past_key_values_length > 0: + mask = tf.concat([tf.zeros((tgt_len, past_key_values_length)), mask], axis=-1) + + return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1)) + + +# Copied from transformers.models.bart.modeling_tf_bart._expand_mask +def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + src_len = shape_list(mask)[1] + tgt_len = tgt_len if tgt_len is not None else src_len + one_cst = tf.constant(1.0) + mask = tf.cast(mask, dtype=one_cst.dtype) + expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1)) + + return (one_cst - expanded_mask) * LARGE_NEGATIVE + + +class TFConv1dSubsampler(keras.layers.Layer): + """ + Convolutional subsampler: a stack of 1D convolution (along temporal dimension) followed by non-linear activation + via gated linear units (https://arxiv.org/abs/1911.08460) + """ + + def __init__(self, config: Speech2TextConfig, **kwargs): + super().__init__(**kwargs) + self.config = config + self.num_layers = config.num_conv_layers + self.in_channels = config.input_feat_per_channel * config.input_channels + self.mid_channels = config.conv_channels + self.out_channels = config.d_model + self.kernel_sizes = config.conv_kernel_sizes + + self.conv_layers = [ + keras.layers.Conv1D( + filters=self.mid_channels if i < self.num_layers - 1 else self.out_channels * 2, + kernel_size=k, + strides=2, + name=f"conv_layers.{i}", + ) + for i, k in enumerate(self.kernel_sizes) + ] + + def call(self, input_features: tf.Tensor) -> tf.Tensor: + # TF Conv1D assumes Batch x Time x Channels, same as the input + hidden_states = tf.cast(input_features, tf.float32) + for i, conv in enumerate(self.conv_layers): + # equivalent to `padding=k // 2` on PT's `nn.Conv1d` + pad_len = self.kernel_sizes[i] // 2 + hidden_shapes = shape_list(hidden_states) + hidden_states = tf.concat( + ( + tf.zeros((hidden_shapes[0], pad_len, hidden_shapes[2])), + hidden_states, + tf.zeros((hidden_shapes[0], pad_len, hidden_shapes[2])), + ), + axis=1, + ) + + hidden_states = conv(hidden_states) + hidden_states = glu(hidden_states, axis=2) # GLU over the Channel dimension + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "conv_layers", None) is not None: + for i, layer in enumerate(self.conv_layers): + with tf.name_scope(layer.name): + layer.build([None, None, self.in_channels] if i == 0 else [None, None, self.mid_channels // 2]) + + +class TFSpeech2TextSinusoidalPositionalEmbedding(keras.layers.Layer): + """This module produces sinusoidal positional embeddings of any length.""" + + def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None, **kwargs): + super().__init__(**kwargs) + self.offset = 2 + self.embedding_dim = embedding_dim + self.padding_idx = padding_idx + self.embedding_weights = self._get_embedding(num_positions + self.offset, embedding_dim, padding_idx) + + @staticmethod + def _get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None) -> tf.Tensor: + """ + Build sinusoidal embeddings. This matches the implementation in tensor2tensor, but differs slightly from the + description in Section 3.5 of "Attention Is All You Need". + """ + half_dim = embedding_dim // 2 + emb = tf.math.log(10000.0) / (half_dim - 1) + emb = tf.math.exp(tf.range(half_dim, dtype=tf.float32) * -emb) + emb = tf.expand_dims(tf.range(num_embeddings, dtype=tf.float32), axis=1) * tf.expand_dims(emb, axis=0) + emb = tf.reshape(tf.concat([tf.math.sin(emb), tf.math.cos(emb)], axis=1), shape=[num_embeddings, -1]) + if embedding_dim % 2 == 1: + # zero pad + emb = tf.concat([emb, tf.zeros(num_embeddings, 1)], axis=1) + if padding_idx is not None: + emb = tf.concat([emb[:padding_idx, :], tf.zeros((1, tf.shape(emb)[1])), emb[padding_idx + 1 :, :]], axis=0) + return emb + + def call(self, input_ids: tf.Tensor, past_key_values_length: int = 0) -> tf.Tensor: + bsz, seq_len = shape_list(input_ids) + # Create the position ids from the input token ids. Any padded tokens remain padded. + position_ids = self.create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length) + + # Matt: The PyTorch code does a lot of work to cache the embeddings, setting the cached values as a + # model attribute in the forward pass. This is extremely forbidden in TF, which wants forward calls to be + # idempotent. TF doesn't need that caching anyway, since it can just store constants during compilation, + # so we just remove all of that code. + embeddings = self._get_embedding( + self.padding_idx + 1 + seq_len + self.offset + past_key_values_length, self.embedding_dim, self.padding_idx + ) + return tf.reshape(tf.gather(embeddings, tf.reshape(position_ids, (-1,)), axis=0), (bsz, seq_len, -1)) + + @staticmethod + def create_position_ids_from_input_ids( + input_ids: tf.Tensor, padding_idx: int, past_key_values_length: Optional[int] = 0 + ) -> tf.Tensor: + """ + Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding + symbols are ignored. This is modified from fairseq's `utils.make_positions`. + + Args: + x: tf.Tensor x: + Returns: tf.Tensor + """ + mask = tf.cast(tf.math.not_equal(input_ids, padding_idx), dtype=tf.int32) + incremental_indices = (tf.math.cumsum(mask, axis=1) + past_key_values_length) * mask + return tf.cast(incremental_indices, dtype=tf.int64) + padding_idx + + +# Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention with Bart->Speech2Text +class TFSpeech2TextAttention(keras.layers.Layer): + """Multi-headed attention from "Attention Is All You Need""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + **kwargs, + ): + super().__init__(**kwargs) + self.embed_dim = embed_dim + + self.num_heads = num_heads + self.dropout = keras.layers.Dropout(dropout) + self.head_dim = embed_dim // num_heads + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + + self.k_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="k_proj") + self.q_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="q_proj") + self.v_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="v_proj") + self.out_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="out_proj") + + def _shape(self, tensor: tf.Tensor, seq_len: int, bsz: int): + return tf.transpose(tf.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)), (0, 2, 1, 3)) + + def call( + self, + hidden_states: tf.Tensor, + key_value_states: tf.Tensor | None = None, + past_key_value: Tuple[Tuple[tf.Tensor]] | None = None, + attention_mask: tf.Tensor | None = None, + layer_head_mask: tf.Tensor | None = None, + training: Optional[bool] = False, + ) -> Tuple[tf.Tensor, tf.Tensor | None]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + bsz, tgt_len, embed_dim = shape_list(hidden_states) + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = tf.concat([past_key_value[0], key_states], axis=2) + value_states = tf.concat([past_key_value[1], value_states], axis=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = tf.reshape(self._shape(query_states, tgt_len, bsz), proj_shape) + key_states = tf.reshape(key_states, proj_shape) + value_states = tf.reshape(value_states, proj_shape) + + src_len = shape_list(key_states)[1] + attn_weights = tf.matmul(query_states, key_states, transpose_b=True) + + tf.debugging.assert_equal( + shape_list(attn_weights), + [bsz * self.num_heads, tgt_len, src_len], + message=( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {shape_list(attn_weights)}" + ), + ) + + if attention_mask is not None: + tf.debugging.assert_equal( + shape_list(attention_mask), + [bsz, 1, tgt_len, src_len], + message=( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" + f" {shape_list(attention_mask)}" + ), + ) + + attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype) + attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask + attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) + + attn_weights = stable_softmax(attn_weights, axis=-1) + + if layer_head_mask is not None: + tf.debugging.assert_equal( + shape_list(layer_head_mask), + [self.num_heads], + message=( + f"Head mask for a single layer should be of size {(self.num_heads)}, but is" + f" {shape_list(layer_head_mask)}" + ), + ) + + attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape( + attn_weights, (bsz, self.num_heads, tgt_len, src_len) + ) + attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) + + attn_probs = self.dropout(attn_weights, training=training) + attn_output = tf.matmul(attn_probs, value_states) + + tf.debugging.assert_equal( + shape_list(attn_output), + [bsz * self.num_heads, tgt_len, self.head_dim], + message=( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {shape_list(attn_output)}" + ), + ) + + attn_output = tf.transpose( + tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3) + ) + attn_output = tf.reshape(attn_output, (bsz, tgt_len, embed_dim)) + + attn_output = self.out_proj(attn_output) + attn_weights: tf.Tensor = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + + return attn_output, attn_weights, past_key_value + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "k_proj", None) is not None: + with tf.name_scope(self.k_proj.name): + self.k_proj.build([None, None, self.embed_dim]) + if getattr(self, "q_proj", None) is not None: + with tf.name_scope(self.q_proj.name): + self.q_proj.build([None, None, self.embed_dim]) + if getattr(self, "v_proj", None) is not None: + with tf.name_scope(self.v_proj.name): + self.v_proj.build([None, None, self.embed_dim]) + if getattr(self, "out_proj", None) is not None: + with tf.name_scope(self.out_proj.name): + self.out_proj.build([None, None, self.embed_dim]) + + +class TFSpeech2TextEncoderLayer(keras.layers.Layer): + def __init__(self, config: Speech2TextConfig, **kwargs): + super().__init__(**kwargs) + self.embed_dim = config.d_model + self.self_attn = TFSpeech2TextAttention( + self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout, name="self_attn" + ) + self.self_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm") + self.dropout = keras.layers.Dropout(config.dropout) + self.activation_fn = get_tf_activation(config.activation_function) + self.activation_dropout = keras.layers.Dropout(config.activation_dropout) + self.fc1 = keras.layers.Dense(config.encoder_ffn_dim, name="fc1") + self.fc2 = keras.layers.Dense(self.embed_dim, name="fc2") + self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm") + self.config = config + + def call( + self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, layer_head_mask: tf.Tensor, training: bool = False + ): + """ + Args: + hidden_states (`tf.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`tf.Tensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`tf.Tensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)` + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states, self_attn_weights, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + training=training, + ) + + tf.debugging.assert_equal( + shape_list(hidden_states), + shape_list(residual), + message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}", + ) + + hidden_states = self.dropout(hidden_states, training=training) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = self.activation_dropout(hidden_states, training=training) + hidden_states = self.fc2(hidden_states) + hidden_states = self.dropout(hidden_states, training=training) + hidden_states = residual + hidden_states + + return hidden_states, self_attn_weights + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "self_attn", None) is not None: + with tf.name_scope(self.self_attn.name): + self.self_attn.build(None) + if getattr(self, "self_attn_layer_norm", None) is not None: + with tf.name_scope(self.self_attn_layer_norm.name): + self.self_attn_layer_norm.build([None, None, self.embed_dim]) + if getattr(self, "fc1", None) is not None: + with tf.name_scope(self.fc1.name): + self.fc1.build([None, None, self.embed_dim]) + if getattr(self, "fc2", None) is not None: + with tf.name_scope(self.fc2.name): + self.fc2.build([None, None, self.config.encoder_ffn_dim]) + if getattr(self, "final_layer_norm", None) is not None: + with tf.name_scope(self.final_layer_norm.name): + self.final_layer_norm.build([None, None, self.embed_dim]) + + +class TFSpeech2TextDecoderLayer(keras.layers.Layer): + def __init__(self, config: Speech2TextConfig, **kwargs): + super().__init__(**kwargs) + self.embed_dim = config.d_model + + self.self_attn = TFSpeech2TextAttention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + name="self_attn", + is_decoder=True, + ) + self.dropout = keras.layers.Dropout(config.dropout) + self.activation_fn = get_tf_activation(config.activation_function) + self.activation_dropout = keras.layers.Dropout(config.activation_dropout) + + self.self_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm") + self.encoder_attn = TFSpeech2TextAttention( + self.embed_dim, + config.decoder_attention_heads, + dropout=config.attention_dropout, + name="encoder_attn", + is_decoder=True, + ) + self.encoder_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="encoder_attn_layer_norm") + self.fc1 = keras.layers.Dense(config.decoder_ffn_dim, name="fc1") + self.fc2 = keras.layers.Dense(self.embed_dim, name="fc2") + self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm") + self.config = config + + def call( + self, + hidden_states, + attention_mask: tf.Tensor | None = None, + encoder_hidden_states: tf.Tensor | None = None, + encoder_attention_mask: tf.Tensor | None = None, + layer_head_mask: tf.Tensor | None = None, + cross_attn_layer_head_mask: tf.Tensor | None = None, + past_key_value: Tuple[tf.Tensor] | None = None, + training=False, + ) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]: + """ + Args: + hidden_states (`tf.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`tf.Tensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + encoder_hidden_states (`tf.Tensor`): + cross attention input to the layer of shape `(batch, seq_len, embed_dim)` + encoder_attention_mask (`tf.Tensor`): encoder attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`tf.Tensor`): mask for attention heads in a given layer of size + `(decoder_attention_heads,)` + cross_attn_layer_head_mask (`tf.Tensor`): mask for heads of the cross-attention module. + `(decoder_attention_heads,)` + past_key_value (`Tuple(tf.Tensor)`): cached past key and value projection states + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + # add present self-attn cache to positions 1,2 of present_key_value tuple + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=self_attn_past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + training=training, + ) + hidden_states = self.dropout(hidden_states, training=training) + hidden_states = residual + hidden_states + + # Cross-Attention Block + cross_attn_present_key_value = None + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=cross_attn_past_key_value, + training=training, + ) + hidden_states = self.dropout(hidden_states, training=training) + hidden_states = residual + hidden_states + + # add cross-attn to positions 3,4 of present_key_value tuple + present_key_value = present_key_value + cross_attn_present_key_value + + # Fully Connected + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = self.activation_dropout(hidden_states, training=training) + hidden_states = self.fc2(hidden_states) + hidden_states = self.dropout(hidden_states, training=training) + hidden_states = residual + hidden_states + + return ( + hidden_states, + self_attn_weights, + cross_attn_weights, + present_key_value, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "self_attn", None) is not None: + with tf.name_scope(self.self_attn.name): + self.self_attn.build(None) + if getattr(self, "self_attn_layer_norm", None) is not None: + with tf.name_scope(self.self_attn_layer_norm.name): + self.self_attn_layer_norm.build([None, None, self.embed_dim]) + if getattr(self, "encoder_attn", None) is not None: + with tf.name_scope(self.encoder_attn.name): + self.encoder_attn.build(None) + if getattr(self, "encoder_attn_layer_norm", None) is not None: + with tf.name_scope(self.encoder_attn_layer_norm.name): + self.encoder_attn_layer_norm.build([None, None, self.embed_dim]) + if getattr(self, "fc1", None) is not None: + with tf.name_scope(self.fc1.name): + self.fc1.build([None, None, self.embed_dim]) + if getattr(self, "fc2", None) is not None: + with tf.name_scope(self.fc2.name): + self.fc2.build([None, None, self.config.decoder_ffn_dim]) + if getattr(self, "final_layer_norm", None) is not None: + with tf.name_scope(self.final_layer_norm.name): + self.final_layer_norm.build([None, None, self.embed_dim]) + + +class TFSpeech2TextPreTrainedModel(TFPreTrainedModel): + config_class = Speech2TextConfig + base_model_prefix = "model" + main_input_name = "input_features" + _keys_to_ignore_on_load_unexpected = [r"encoder.embed_positions.weights"] + + def _get_feat_extract_output_lengths(self, input_lengths: tf.Tensor): + """ + Computes the output length of the convolutional layers + """ + for _ in range(self.config.num_conv_layers): + input_lengths = (input_lengths - 1) // 2 + 1 + + return input_lengths + + @property + def input_signature(self): + return { + "input_features": tf.TensorSpec( + (None, None, self.config.input_feat_per_channel * self.config.input_channels), + tf.float32, + name="input_features", + ), + "attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"), + "decoder_input_ids": tf.TensorSpec((None, None), tf.int32, name="decoder_input_ids"), + "decoder_attention_mask": tf.TensorSpec((None, None), tf.int32, name="decoder_attention_mask"), + } + + +SPEECH_TO_TEXT_START_DOCSTRING = r""" + This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it + as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and + behavior. + + + + TensorFlow models and layers in `transformers` accept two formats as input: + + - having all inputs as keyword arguments (like PyTorch models), or + - having all inputs as a list, tuple or dict in the first positional argument. + + The reason the second format is supported is that Keras methods prefer this format when passing inputs to models + and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just + pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second + format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with + the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first + positional argument: + + - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` + - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: + `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` + - a dictionary with one or several input Tensors associated to the input names given in the docstring: + `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` + + Note that when creating models and layers with + [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry + about any of this, as you can just pass inputs like you would to any other Python function! + + + + Parameters: + config ([`Speech2TextConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +SPEECH_TO_TEXT_INPUTS_DOCSTRING = r""" + Args: + input_features (`tf.Tensor` of shape `(batch_size, sequence_length, feature_size)`): + Float values of fbank features extracted from the raw speech waveform. Raw speech waveform can be obtained + by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* + via the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the + [`AutoFeatureExtractor`] should be used for extracting the fbank features, padding and conversion into a + tensor of floats. See [`~Speech2TextFeatureExtractor.__call__`] + attention_mask (`tf.Tensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_input_ids (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`Speech2TextTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + SpeechToText uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If + `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + For translation and summarization training, `decoder_input_ids` should be provided. If no + `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right + for denoising pre-training following the paper. + decoder_attention_mask (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*): + will be made by default and ignore pad tokens. It is not recommended to set this for most use cases. + head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + decoder_head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + encoder_outputs (`tf.FloatTensor`, *optional*): + hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. + of shape `(batch_size, sequence_length, hidden_size)` is a sequence of + past_key_values (`Tuple[Tuple[tf.Tensor]]` of length `config.n_layers`) + contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + decoder_inputs_embeds (`tf.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded + representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be + input (see `past_key_values`). This is useful if you want more control over how to convert + `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the + config will be used instead. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. This argument can be used only in eager mode, in graph mode the value in the config will be + used instead. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in + eager mode, in graph mode the value will always be set to True. + training (`bool`, *optional*, defaults to `False`): + Whether or not to use the model in training mode (some modules like dropout modules have different + behaviors between training and evaluation). +""" + + +@keras_serializable +class TFSpeech2TextEncoder(keras.layers.Layer): + config_class = Speech2TextConfig + """ + Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a + [`TFSpeech2TextEncoderLayer`]. + + Args: + config: Speech2TextConfig + """ + + def __init__(self, config: Speech2TextConfig, **kwargs): + super().__init__(**kwargs) + self.config = config + + self.dropout = keras.layers.Dropout(config.dropout) + self.layerdrop = config.encoder_layerdrop + + embed_dim = config.d_model + self.padding_idx = config.pad_token_id + self.max_source_positions = config.max_source_positions + self.embed_scale = tf.math.sqrt(float(embed_dim)) if config.scale_embedding else 1.0 + + self.conv = TFConv1dSubsampler(config, name="conv") + + self.embed_positions = TFSpeech2TextSinusoidalPositionalEmbedding( + num_positions=config.max_source_positions, + embedding_dim=embed_dim, + padding_idx=self.padding_idx, + name="embed_positions", + ) + self.layers = [TFSpeech2TextEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)] + self.layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm") + + def _get_feat_extract_output_lengths(self, input_lengths: tf.Tensor): + """ + Computes the output length of the convolutional layers + """ + for _ in range(self.config.num_conv_layers): + input_lengths = (input_lengths - 1) // 2 + 1 + + return input_lengths + + def _get_feature_vector_attention_mask(self, feature_vector_length, attention_mask): + # generate creates 3D attention mask, because of the shape of input_features + # convert it to 2D if thats the case + if len(attention_mask.shape) > 2: + attention_mask = attention_mask[:, :, -1] + + subsampled_lengths = self._get_feat_extract_output_lengths(tf.math.reduce_sum(attention_mask, -1)) + bsz = shape_list(attention_mask)[0] + indices = tf.concat( + ( + tf.expand_dims(tf.range(bsz, dtype=attention_mask.dtype), -1), + tf.expand_dims(subsampled_lengths - 1, -1), + ), + axis=-1, + ) + attention_mask = tf.scatter_nd(indices=indices, updates=tf.ones(bsz), shape=[bsz, feature_vector_length]) + attention_mask = tf.cast(tf.reverse(tf.math.cumsum(tf.reverse(attention_mask, [-1]), -1), [-1]), tf.int64) + return attention_mask + + @unpack_inputs + def call( + self, + input_features=None, + attention_mask=None, + head_mask=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + training=False, + ): + """ + Args: + input_features (`tf.Tensor` of shape `(batch_size, sequence_length, feature_size)`): + Float values of fbank features extracted from the raw speech waveform. Raw speech waveform can be + obtained by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a + `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into + `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the fbank features, + padding and conversion into a tensor of floats. See [`~Speech2TextFeatureExtractor.__call__`] + attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, `optional): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + if input_features is None: + raise ValueError("You have to specify input_features") + + inputs_embeds = self.conv(input_features) + inputs_embeds = self.embed_scale * inputs_embeds + + # subsample attention mask if necessary + if attention_mask is not None: + attention_mask = self._get_feature_vector_attention_mask(tf.shape(inputs_embeds)[1], attention_mask) + padding_mask = tf.cast(tf.math.not_equal(attention_mask, 1), tf.int64) + else: + padding_mask = tf.zeros(tf.shape(inputs_embeds)[:-1], dtype=tf.int64) + + embed_pos = self.embed_positions(padding_mask) + + hidden_states = inputs_embeds + embed_pos + hidden_states = self.dropout(hidden_states, training=training) + + # check attention mask and invert + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _expand_mask(attention_mask) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + tf.debugging.assert_equal( + shape_list(head_mask)[0], + len(self.layers), + message=( + f"The head_mask should be specified for {len(self.layers)} layers, but it is for" + f" {shape_list(head_mask)[0]}." + ), + ) + + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = random.uniform(0, 1) + if training and (dropout_probability < self.layerdrop): # skip the layer + continue + + hidden_states, attn = encoder_layer( + hidden_states, + attention_mask, + head_mask[idx] if head_mask is not None else None, + training=training, + ) + + if output_attentions: + all_attentions += (attn,) + + hidden_states = self.layer_norm(hidden_states) + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return TFBaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "conv", None) is not None: + with tf.name_scope(self.conv.name): + self.conv.build(None) + if getattr(self, "embed_positions", None) is not None: + with tf.name_scope(self.embed_positions.name): + self.embed_positions.build(None) + if getattr(self, "layer_norm", None) is not None: + with tf.name_scope(self.layer_norm.name): + self.layer_norm.build([None, None, self.config.d_model]) + if getattr(self, "layers", None) is not None: + for layer in self.layers: + with tf.name_scope(layer.name): + layer.build(None) + + +@keras_serializable +class TFSpeech2TextDecoder(keras.layers.Layer): + config_class = Speech2TextConfig + """ + Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`TFSpeech2TextDecoderLayer`] + + Args: + config: Speech2TextConfig + """ + + def __init__(self, config: Speech2TextConfig, **kwargs): + super().__init__(**kwargs) + self.config = config + self.layerdrop = config.decoder_layerdrop + self.padding_idx = config.pad_token_id + self.max_target_positions = config.max_target_positions + self.embed_scale = tf.math.sqrt(float(config.d_model)) if config.scale_embedding else 1.0 + + self.embed_tokens = TFSharedEmbeddings(config.vocab_size, config.d_model, name="embed_tokens") + + self.embed_positions = TFSpeech2TextSinusoidalPositionalEmbedding( + num_positions=config.max_target_positions, + embedding_dim=config.d_model, + padding_idx=self.padding_idx, + name="embed_positions", + ) + + self.layers = [TFSpeech2TextDecoderLayer(config, name=f"layers.{i}") for i in range(config.decoder_layers)] + self.layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm") + + self.dropout = keras.layers.Dropout(config.dropout) + + def get_embed_tokens(self): + return self.embed_tokens + + def set_embed_tokens(self, embed_tokens): + self.embed_tokens = embed_tokens + + @unpack_inputs + def call( + self, + input_ids=None, + inputs_embeds=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + head_mask=None, + cross_attn_head_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + training=False, + ): + r""" + Args: + input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`Speech2TextTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + encoder_hidden_states (`tf.Tensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + encoder_attention_mask (`tf.Tensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): + Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values + selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`Tuple[Tuple[tf.Tensor]]` of length `config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up + decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input_shape = shape_list(input_ids) + elif inputs_embeds is not None: + input_shape = shape_list(inputs_embeds)[:-1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + # past_key_values_length + past_key_values_length = shape_list(past_key_values[0][0])[2] if past_key_values is not None else 0 + + if inputs_embeds is None: + check_embeddings_within_bounds(input_ids, self.embed_tokens.vocab_size) + inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + else: + inputs_embeds = inputs_embeds + + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length) + else: + combined_attention_mask = _expand_mask( + tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1] + ) + + if attention_mask is not None: + combined_attention_mask = combined_attention_mask + _expand_mask(attention_mask, tgt_len=input_shape[-1]) + + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _expand_mask(encoder_attention_mask, tgt_len=input_shape[-1]) + + # embed positions + positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length) + + hidden_states = inputs_embeds + positions + hidden_states = self.dropout(hidden_states, training=training) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attns = () if (output_attentions and encoder_hidden_states is not None) else None + next_decoder_cache = () if use_cache else None + + # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired + for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]: + if attn_mask is not None: + tf.debugging.assert_equal( + shape_list(attn_mask)[0], + len(self.layers), + message=( + f"The {attn_mask_name} should be specified for {len(self.layers)} layers, but it is for" + f" {shape_list(attn_mask)[0]}." + ), + ) + + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + dropout_probability = random.uniform(0, 1) + if training and (dropout_probability < self.layerdrop): + continue + + past_key_value = past_key_values[idx] if past_key_values is not None else None + cross_attn_layer_head_mask = cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None + + hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer( + hidden_states, + attention_mask=combined_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=head_mask[idx] if head_mask is not None else None, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, + past_key_value=past_key_value, + ) + + if use_cache: + next_decoder_cache += (present_key_value,) + + if output_attentions: + all_self_attns += (layer_self_attn,) + + if encoder_hidden_states is not None: + all_cross_attns += (layer_cross_attn,) + + hidden_states = self.layer_norm(hidden_states) + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + + if not return_dict: + return hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attns + else: + return TFBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attns, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "embed_tokens", None) is not None: + with tf.name_scope(self.embed_tokens.name): + self.embed_tokens.build(None) + if getattr(self, "embed_positions", None) is not None: + with tf.name_scope(self.embed_positions.name): + self.embed_positions.build(None) + if getattr(self, "layer_norm", None) is not None: + with tf.name_scope(self.layer_norm.name): + self.layer_norm.build([None, None, self.config.d_model]) + if getattr(self, "layers", None) is not None: + for layer in self.layers: + with tf.name_scope(layer.name): + layer.build(None) + + +@keras_serializable +class TFSpeech2TextMainLayer(keras.layers.Layer): + config_class = Speech2TextConfig + + def __init__(self, config: Speech2TextConfig, **kwargs): + super().__init__(**kwargs) + self.config = config + + self.encoder = TFSpeech2TextEncoder(config, name="encoder") + self.decoder = TFSpeech2TextDecoder(config, name="decoder") + + def get_input_embeddings(self): + return self.decoder.embed_tokens + + def set_input_embeddings(self, new_embeddings): + self.decoder.embed_tokens = new_embeddings + + @unpack_inputs + def call( + self, + input_features=None, + attention_mask=None, + decoder_input_ids=None, + decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + encoder_outputs=None, + past_key_values=None, + decoder_inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + training=False, + **kwargs, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_features=input_features, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a TFBaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, TFBaseModelOutput): + encoder_outputs = TFBaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + # If the user passed a TFBaseModelOutput for encoder_outputs, we wrap it in a tuple when return_dict=False + elif not return_dict and not isinstance(encoder_outputs, tuple): + encoder_outputs = encoder_outputs.to_tuple() + + # downsample encoder attention mask + if attention_mask is not None: + encoder_attention_mask = self.encoder._get_feature_vector_attention_mask( + tf.shape(encoder_outputs[0])[1], attention_mask + ) + else: + encoder_attention_mask = None + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=encoder_attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return TFSeq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "encoder", None) is not None: + with tf.name_scope(self.encoder.name): + self.encoder.build(None) + if getattr(self, "decoder", None) is not None: + with tf.name_scope(self.decoder.name): + self.decoder.build(None) + + +@add_start_docstrings( + "The bare Speech2Text Model outputting raw hidden-states without any specific head on top.", + SPEECH_TO_TEXT_START_DOCSTRING, +) +class TFSpeech2TextModel(TFSpeech2TextPreTrainedModel): + def __init__(self, config: Speech2TextConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.model = TFSpeech2TextMainLayer(config, name="model") + + def get_encoder(self): + return self.model.encoder + + def get_decoder(self): + return self.model.decoder + + @unpack_inputs + @add_start_docstrings_to_model_forward(SPEECH_TO_TEXT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFSeq2SeqModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_features: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + decoder_input_ids: np.ndarray | tf.Tensor | None = None, + decoder_attention_mask: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + decoder_head_mask: np.ndarray | tf.Tensor | None = None, + cross_attn_head_mask: np.ndarray | tf.Tensor | None = None, + encoder_outputs: np.ndarray | tf.Tensor | None = None, + past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None, + decoder_inputs_embeds: np.ndarray | tf.Tensor | None = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: bool = False, + **kwargs, + ) -> Union[Tuple, TFSeq2SeqModelOutput]: + outputs = self.model( + input_features=input_features, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + encoder_outputs=encoder_outputs, + past_key_values=past_key_values, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + return outputs + + def serving_output(self, output): + pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None + dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None + dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None + cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None + enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None + enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None + + return TFSeq2SeqModelOutput( + last_hidden_state=output.last_hidden_state, + past_key_values=pkv, + decoder_hidden_states=dec_hs, + decoder_attentions=dec_attns, + cross_attentions=cross_attns, + encoder_last_hidden_state=output.encoder_last_hidden_state, + encoder_hidden_states=enc_hs, + encoder_attentions=enc_attns, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "model", None) is not None: + with tf.name_scope(self.model.name): + self.model.build(None) + + +@add_start_docstrings( + "The Speech2Text Model with a language modeling head. Can be used for summarization.", + SPEECH_TO_TEXT_START_DOCSTRING, +) +class TFSpeech2TextForConditionalGeneration(TFSpeech2TextPreTrainedModel, TFCausalLanguageModelingLoss): + def __init__(self, config: Speech2TextConfig): + super().__init__(config) + self.model = TFSpeech2TextMainLayer(config, name="model") + self.lm_head = keras.layers.Dense(self.config.vocab_size, use_bias=False, name="lm_head") + # TODO (Joao): investigate why Speech2Text has numerical issues in XLA generate + self.supports_xla_generation = False + self.config = config + + def get_encoder(self): + return self.model.encoder + + def get_decoder(self): + return self.model.decoder + + def resize_token_embeddings(self, new_num_tokens: int) -> tf.Variable: + new_embeddings = super().resize_token_embeddings(new_num_tokens) + return new_embeddings + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + @unpack_inputs + @add_start_docstrings_to_model_forward(SPEECH_TO_TEXT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + def call( + self, + input_features: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + decoder_input_ids: np.ndarray | tf.Tensor | None = None, + decoder_attention_mask: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + decoder_head_mask: np.ndarray | tf.Tensor | None = None, + cross_attn_head_mask: np.ndarray | tf.Tensor | None = None, + encoder_outputs: np.ndarray | tf.Tensor | None = None, + past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None, + decoder_inputs_embeds: np.ndarray | tf.Tensor | None = None, + labels: np.ndarray | tf.Tensor | None = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: Optional[bool] = False, + **kwargs, + ) -> Union[Tuple, TFSeq2SeqLMOutput]: + r""" + labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> import tensorflow as tf + >>> from transformers import Speech2TextProcessor, TFSpeech2TextForConditionalGeneration + >>> from datasets import load_dataset + >>> import soundfile as sf + + >>> model = TFSpeech2TextForConditionalGeneration.from_pretrained( + ... "facebook/s2t-small-librispeech-asr", from_pt=True + ... ) + >>> processor = Speech2TextProcessor.from_pretrained("facebook/s2t-small-librispeech-asr") + + + >>> def map_to_array(batch): + ... speech, _ = sf.read(batch["file"]) + ... batch["speech"] = speech + ... return batch + + + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + >>> ds = ds.map(map_to_array) + >>> ds.set_format(type="tf") + + >>> input_features = processor( + ... ds["speech"][0], sampling_rate=16000, return_tensors="tf" + ... ).input_features # Batch size 1 + >>> generated_ids = model.generate(input_features) + + >>> transcription = processor.batch_decode(generated_ids) + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None: + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + outputs = self.model( + input_features=input_features, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + encoder_outputs=encoder_outputs, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + lm_logits = self.lm_head(outputs[0]) + masked_lm_loss = None if labels is None else self.hf_compute_loss(labels, lm_logits) + + if not return_dict: + output = (lm_logits,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return TFSeq2SeqLMOutput( + loss=masked_lm_loss, + logits=lm_logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + def serving_output(self, output): + pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None + dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None + dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None + cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None + enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None + enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None + + return TFSeq2SeqLMOutput( + logits=output.logits, + past_key_values=pkv, + decoder_hidden_states=dec_hs, + decoder_attentions=dec_attns, + cross_attentions=cross_attns, + encoder_last_hidden_state=output.encoder_last_hidden_state, + encoder_hidden_states=enc_hs, + encoder_attentions=enc_attns, + ) + + def prepare_inputs_for_generation( + self, + decoder_input_ids, + past_key_values=None, + attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs, + ): + # cut decoder_input_ids if past is used + if past_key_values is not None: + decoder_input_ids = decoder_input_ids[:, -1:] + + return { + "input_features": None, # needs to be passed to make Keras.layer.__call__ happy + "encoder_outputs": encoder_outputs, + "past_key_values": past_key_values, + "decoder_input_ids": decoder_input_ids, + "attention_mask": attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "use_cache": use_cache, # change this to avoid caching (presumably for debugging) + } + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "model", None) is not None: + with tf.name_scope(self.model.name): + self.model.build(None) + if getattr(self, "lm_head", None) is not None: + with tf.name_scope(self.lm_head.name): + self.lm_head.build([None, None, self.config.d_model]) + + def tf_to_pt_weight_rename(self, tf_weight): + if tf_weight == "lm_head.weight": + return tf_weight, "model.decoder.embed_tokens.weight" + else: + return (tf_weight,) + + +__all__ = ["TFSpeech2TextForConditionalGeneration", "TFSpeech2TextModel", "TFSpeech2TextPreTrainedModel"] diff --git a/docs/transformers/src/transformers/models/speech_to_text/processing_speech_to_text.py b/docs/transformers/src/transformers/models/speech_to_text/processing_speech_to_text.py new file mode 100644 index 0000000000000000000000000000000000000000..724c0a6ed0f77877b9de500cc212c8da5c7fcc25 --- /dev/null +++ b/docs/transformers/src/transformers/models/speech_to_text/processing_speech_to_text.py @@ -0,0 +1,120 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Speech processor class for Speech2Text +""" + +import warnings +from contextlib import contextmanager + +from ...processing_utils import ProcessorMixin + + +class Speech2TextProcessor(ProcessorMixin): + r""" + Constructs a Speech2Text processor which wraps a Speech2Text feature extractor and a Speech2Text tokenizer into a + single processor. + + [`Speech2TextProcessor`] offers all the functionalities of [`Speech2TextFeatureExtractor`] and + [`Speech2TextTokenizer`]. See the [`~Speech2TextProcessor.__call__`] and [`~Speech2TextProcessor.decode`] for more + information. + + Args: + feature_extractor (`Speech2TextFeatureExtractor`): + An instance of [`Speech2TextFeatureExtractor`]. The feature extractor is a required input. + tokenizer (`Speech2TextTokenizer`): + An instance of [`Speech2TextTokenizer`]. The tokenizer is a required input. + """ + + feature_extractor_class = "Speech2TextFeatureExtractor" + tokenizer_class = "Speech2TextTokenizer" + + def __init__(self, feature_extractor, tokenizer): + super().__init__(feature_extractor, tokenizer) + self.current_processor = self.feature_extractor + self._in_target_context_manager = False + + def __call__(self, *args, **kwargs): + """ + When used in normal mode, this method forwards all its arguments to Speech2TextFeatureExtractor's + [`~Speech2TextFeatureExtractor.__call__`] and returns its output. If used in the context + [`~Speech2TextProcessor.as_target_processor`] this method forwards all its arguments to Speech2TextTokenizer's + [`~Speech2TextTokenizer.__call__`]. Please refer to the docstring of the above two methods for more + information. + """ + # For backward compatibility + if self._in_target_context_manager: + return self.current_processor(*args, **kwargs) + + if "raw_speech" in kwargs: + warnings.warn("Using `raw_speech` as a keyword argument is deprecated. Use `audio` instead.") + audio = kwargs.pop("raw_speech") + else: + audio = kwargs.pop("audio", None) + sampling_rate = kwargs.pop("sampling_rate", None) + text = kwargs.pop("text", None) + if len(args) > 0: + audio = args[0] + args = args[1:] + + if audio is None and text is None: + raise ValueError("You need to specify either an `audio` or `text` input to process.") + + if audio is not None: + inputs = self.feature_extractor(audio, *args, sampling_rate=sampling_rate, **kwargs) + if text is not None: + encodings = self.tokenizer(text, **kwargs) + + if text is None: + return inputs + elif audio is None: + return encodings + else: + inputs["labels"] = encodings["input_ids"] + return inputs + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to Speech2TextTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to Speech2TextTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer + to the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + @contextmanager + def as_target_processor(self): + """ + Temporarily sets the tokenizer for processing the input. Useful for encoding the labels when fine-tuning + Speech2Text. + """ + warnings.warn( + "`as_target_processor` is deprecated and will be removed in v5 of Transformers. You can process your " + "labels by using the argument `text` of the regular `__call__` method (either in the same call as " + "your audio inputs, or in a separate call." + ) + self._in_target_context_manager = True + self.current_processor = self.tokenizer + yield + self.current_processor = self.feature_extractor + self._in_target_context_manager = False + + +__all__ = ["Speech2TextProcessor"] diff --git a/docs/transformers/src/transformers/models/speech_to_text/tokenization_speech_to_text.py b/docs/transformers/src/transformers/models/speech_to_text/tokenization_speech_to_text.py new file mode 100644 index 0000000000000000000000000000000000000000..abb6dd5da7a5ca3c1505b870e9030d1ac33900f1 --- /dev/null +++ b/docs/transformers/src/transformers/models/speech_to_text/tokenization_speech_to_text.py @@ -0,0 +1,295 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for Speech2Text.""" + +import json +import os +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple, Union + +import sentencepiece + +from ...tokenization_utils import PreTrainedTokenizer +from ...utils import logging +from ...utils.import_utils import requires + + +logger = logging.get_logger(__name__) + +SPIECE_UNDERLINE = "▁" + +VOCAB_FILES_NAMES = { + "vocab_file": "vocab.json", + "spm_file": "sentencepiece.bpe.model", +} + + +MAX_MODEL_INPUT_SIZES = { + "facebook/s2t-small-librispeech-asr": 1024, +} + +MUSTC_LANGS = ["pt", "fr", "ru", "nl", "ro", "it", "es", "de"] + +LANGUAGES = {"mustc": MUSTC_LANGS} + + +@requires(backends=("sentencepiece",)) +class Speech2TextTokenizer(PreTrainedTokenizer): + """ + Construct an Speech2Text tokenizer. + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains some of the main methods. Users should refer to + the superclass for more information regarding such methods. + + Args: + vocab_file (`str`): + File containing the vocabulary. + spm_file (`str`): + Path to the [SentencePiece](https://github.com/google/sentencepiece) model file + bos_token (`str`, *optional*, defaults to `""`): + The beginning of sentence token. + eos_token (`str`, *optional*, defaults to `""`): + The end of sentence token. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + do_upper_case (`bool`, *optional*, defaults to `False`): + Whether or not to uppercase the output when decoding. + do_lower_case (`bool`, *optional*, defaults to `False`): + Whether or not to lowercase the input when tokenizing. + tgt_lang (`str`, *optional*): + A string representing the target language. + sp_model_kwargs (`dict`, *optional*): + Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for + SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things, + to set: + + - `enable_sampling`: Enable subword regularization. + - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout. + + - `nbest_size = {0,1}`: No sampling is performed. + - `nbest_size > 1`: samples from the nbest_size results. + - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice) + using forward-filtering-and-backward-sampling algorithm. + + - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for + BPE-dropout. + + **kwargs + Additional keyword arguments passed along to [`PreTrainedTokenizer`] + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + + prefix_tokens: List[int] = [] + + def __init__( + self, + vocab_file, + spm_file, + bos_token="", + eos_token="", + pad_token="", + unk_token="", + do_upper_case=False, + do_lower_case=False, + tgt_lang=None, + lang_codes=None, + additional_special_tokens=None, + sp_model_kwargs: Optional[Dict[str, Any]] = None, + **kwargs, + ) -> None: + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + + self.do_upper_case = do_upper_case + self.do_lower_case = do_lower_case + + self.encoder = load_json(vocab_file) + self.decoder = {v: k for k, v in self.encoder.items()} + self.spm_file = spm_file + self.sp_model = load_spm(spm_file, self.sp_model_kwargs) + + if lang_codes is not None: + self.lang_codes = lang_codes + self.langs = LANGUAGES[lang_codes] + self.lang_tokens = [f"" for lang in self.langs] + self.lang_code_to_id = {lang: self.sp_model.PieceToId(f"") for lang in self.langs} + if additional_special_tokens is not None: + additional_special_tokens = self.lang_tokens + additional_special_tokens + else: + additional_special_tokens = self.lang_tokens + self._tgt_lang = tgt_lang if tgt_lang is not None else self.langs[0] + + self.set_tgt_lang_special_tokens(self._tgt_lang) + else: + self.lang_code_to_id = {} + + super().__init__( + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + pad_token=pad_token, + do_upper_case=do_upper_case, + do_lower_case=do_lower_case, + tgt_lang=tgt_lang, + lang_codes=lang_codes, + sp_model_kwargs=self.sp_model_kwargs, + additional_special_tokens=additional_special_tokens, + **kwargs, + ) + + @property + def vocab_size(self) -> int: + return len(self.encoder) + + def get_vocab(self) -> Dict: + vocab = self.encoder.copy() + vocab.update(self.added_tokens_encoder) + return vocab + + @property + def tgt_lang(self) -> str: + return self._tgt_lang + + @tgt_lang.setter + def tgt_lang(self, new_tgt_lang) -> None: + self._tgt_lang = new_tgt_lang + self.set_tgt_lang_special_tokens(new_tgt_lang) + + def set_tgt_lang_special_tokens(self, tgt_lang: str) -> None: + """Reset the special tokens to the target language setting. prefix=[eos, tgt_lang_code] and suffix=[eos].""" + lang_code_id = self.lang_code_to_id[tgt_lang] + self.prefix_tokens = [lang_code_id] + + def _tokenize(self, text: str) -> List[str]: + return self.sp_model.encode(text, out_type=str) + + def _convert_token_to_id(self, token): + return self.encoder.get(token, self.encoder[self.unk_token]) + + def _convert_id_to_token(self, index: int) -> str: + """Converts an index (integer) in a token (str) using the decoder.""" + return self.decoder.get(index, self.unk_token) + + def convert_tokens_to_string(self, tokens: List[str]) -> str: + """Converts a sequence of tokens (strings for sub-words) in a single string.""" + current_sub_tokens = [] + out_string = "" + for token in tokens: + # make sure that special tokens are not decoded using sentencepiece model + if token in self.all_special_tokens: + decoded = self.sp_model.decode(current_sub_tokens) + out_string += (decoded.upper() if self.do_upper_case else decoded) + token + " " + current_sub_tokens = [] + else: + current_sub_tokens.append(token) + decoded = self.sp_model.decode(current_sub_tokens) + out_string += decoded.upper() if self.do_upper_case else decoded + return out_string.strip() + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]: + """Build model inputs from a sequence by appending eos_token_id.""" + if token_ids_1 is None: + return self.prefix_tokens + token_ids_0 + [self.eos_token_id] + # We don't expect to process pairs, but leave the pair logic for API consistency + return self.prefix_tokens + token_ids_0 + token_ids_1 + [self.eos_token_id] + + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + prefix_ones = [1] * len(self.prefix_tokens) + suffix_ones = [1] + if token_ids_1 is None: + return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones + return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones + + def __getstate__(self) -> Dict: + state = self.__dict__.copy() + state["sp_model"] = None + return state + + def __setstate__(self, d: Dict) -> None: + self.__dict__ = d + + # for backward compatibility + if not hasattr(self, "sp_model_kwargs"): + self.sp_model_kwargs = {} + + self.sp_model = load_spm(self.spm_file, self.sp_model_kwargs) + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + save_dir = Path(save_directory) + assert save_dir.is_dir(), f"{save_directory} should be a directory" + vocab_save_path = save_dir / ( + (filename_prefix + "-" if filename_prefix else "") + self.vocab_files_names["vocab_file"] + ) + spm_save_path = save_dir / ( + (filename_prefix + "-" if filename_prefix else "") + self.vocab_files_names["spm_file"] + ) + + save_json(self.encoder, vocab_save_path) + + if os.path.abspath(self.spm_file) != os.path.abspath(spm_save_path) and os.path.isfile(self.spm_file): + copyfile(self.spm_file, spm_save_path) + elif not os.path.isfile(self.spm_file): + with open(spm_save_path, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + + return (str(vocab_save_path), str(spm_save_path)) + + +def load_spm(path: str, sp_model_kwargs: Dict[str, Any]) -> sentencepiece.SentencePieceProcessor: + spm = sentencepiece.SentencePieceProcessor(**sp_model_kwargs) + spm.Load(str(path)) + return spm + + +def load_json(path: str) -> Union[Dict, List]: + with open(path, "r") as f: + return json.load(f) + + +def save_json(data, path: str) -> None: + with open(path, "w") as f: + json.dump(data, f, indent=2) + + +__all__ = ["Speech2TextTokenizer"] diff --git a/docs/transformers/src/transformers/models/speecht5/__init__.py b/docs/transformers/src/transformers/models/speecht5/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..52fee59ae9b96d14361a771cbc58089d5e5811f2 --- /dev/null +++ b/docs/transformers/src/transformers/models/speecht5/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_speecht5 import * + from .feature_extraction_speecht5 import * + from .modeling_speecht5 import * + from .processing_speecht5 import * + from .tokenization_speecht5 import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/speecht5/configuration_speecht5.py b/docs/transformers/src/transformers/models/speecht5/configuration_speecht5.py new file mode 100644 index 0000000000000000000000000000000000000000..6f79bdbc61d3cd7a8ed2470499c629bd703dab90 --- /dev/null +++ b/docs/transformers/src/transformers/models/speecht5/configuration_speecht5.py @@ -0,0 +1,422 @@ +# coding=utf-8 +# Copyright 2023 The Fairseq Authors, Microsoft Research, and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""SpeechT5 model configuration""" + +import functools +import operator + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class SpeechT5Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`SpeechT5Model`]. It is used to instantiate a + SpeechT5 model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the SpeechT5 + [microsoft/speecht5_asr](https://huggingface.co/microsoft/speecht5_asr) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 81): + Vocabulary size of the SpeechT5 model. Defines the number of different tokens that can be represented by + the `inputs_ids` passed to the forward method of [`SpeechT5Model`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + encoder_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + encoder_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + encoder_ffn_dim (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + encoder_layerdrop (`float`, *optional*, defaults to 0.1): + The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) + for more details. + decoder_layers (`int`, *optional*, defaults to 6): + Number of hidden layers in the Transformer decoder. + decoder_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer decoder. + decoder_ffn_dim (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer decoder. + decoder_layerdrop (`float`, *optional*, defaults to 0.1): + The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) + for more details. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + positional_dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for the text position encoding layers. + hidden_dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_dropout (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + activation_dropout (`float`, *optional*, defaults to 0.1): + The dropout ratio for activations inside the fully connected layer. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-5): + The epsilon used by the layer normalization layers. + scale_embedding (`bool`, *optional*, defaults to `False`): + Scale embeddings by diving by sqrt(d_model). + feat_extract_norm (`str`, *optional*, defaults to `"group"`): + The norm to be applied to 1D convolutional layers in the speech encoder pre-net. One of `"group"` for group + normalization of only the first 1D convolutional layer or `"layer"` for layer normalization of all 1D + convolutional layers. + feat_proj_dropout (`float`, *optional*, defaults to 0.0): + The dropout probability for output of the speech encoder pre-net. + feat_extract_activation (`str, `optional`, defaults to `"gelu"`): + The non-linear activation function (function or string) in the 1D convolutional layers of the feature + extractor. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` are supported. + conv_dim (`Tuple[int]` or `List[int]`, *optional*, defaults to `(512, 512, 512, 512, 512, 512, 512)`): + A tuple of integers defining the number of input and output channels of each 1D convolutional layer in the + speech encoder pre-net. The length of *conv_dim* defines the number of 1D convolutional layers. + conv_stride (`Tuple[int]` or `List[int]`, *optional*, defaults to `(5, 2, 2, 2, 2, 2, 2)`): + A tuple of integers defining the stride of each 1D convolutional layer in the speech encoder pre-net. The + length of *conv_stride* defines the number of convolutional layers and has to match the length of + *conv_dim*. + conv_kernel (`Tuple[int]` or `List[int]`, *optional*, defaults to `(10, 3, 3, 3, 3, 3, 3)`): + A tuple of integers defining the kernel size of each 1D convolutional layer in the speech encoder pre-net. + The length of *conv_kernel* defines the number of convolutional layers and has to match the length of + *conv_dim*. + conv_bias (`bool`, *optional*, defaults to `False`): + Whether the 1D convolutional layers have a bias. + num_conv_pos_embeddings (`int`, *optional*, defaults to 128): + Number of convolutional positional embeddings. Defines the kernel size of 1D convolutional positional + embeddings layer. + num_conv_pos_embedding_groups (`int`, *optional*, defaults to 16): + Number of groups of 1D convolutional positional embeddings layer. + apply_spec_augment (`bool`, *optional*, defaults to `True`): + Whether to apply *SpecAugment* data augmentation to the outputs of the speech encoder pre-net. For + reference see [SpecAugment: A Simple Data Augmentation Method for Automatic Speech + Recognition](https://arxiv.org/abs/1904.08779). + mask_time_prob (`float`, *optional*, defaults to 0.05): + Percentage (between 0 and 1) of all feature vectors along the time axis which will be masked. The masking + procecure generates ''mask_time_prob*len(time_axis)/mask_time_length'' independent masks over the axis. If + reasoning from the propability of each feature vector to be chosen as the start of the vector span to be + masked, *mask_time_prob* should be `prob_vector_start*mask_time_length`. Note that overlap may decrease the + actual percentage of masked vectors. This is only relevant if `apply_spec_augment is True`. + mask_time_length (`int`, *optional*, defaults to 10): + Length of vector span along the time axis. + mask_time_min_masks (`int`, *optional*, defaults to 2),: + The minimum number of masks of length `mask_feature_length` generated along the time axis, each time step, + irrespectively of `mask_feature_prob`. Only relevant if ''mask_time_prob*len(time_axis)/mask_time_length < + mask_time_min_masks'' + mask_feature_prob (`float`, *optional*, defaults to 0.0): + Percentage (between 0 and 1) of all feature vectors along the feature axis which will be masked. The + masking procecure generates ''mask_feature_prob*len(feature_axis)/mask_time_length'' independent masks over + the axis. If reasoning from the propability of each feature vector to be chosen as the start of the vector + span to be masked, *mask_feature_prob* should be `prob_vector_start*mask_feature_length`. Note that overlap + may decrease the actual percentage of masked vectors. This is only relevant if `apply_spec_augment is + True`. + mask_feature_length (`int`, *optional*, defaults to 10): + Length of vector span along the feature axis. + mask_feature_min_masks (`int`, *optional*, defaults to 0),: + The minimum number of masks of length `mask_feature_length` generated along the feature axis, each time + step, irrespectively of `mask_feature_prob`. Only relevant if + ''mask_feature_prob*len(feature_axis)/mask_feature_length < mask_feature_min_masks'' + num_mel_bins (`int`, *optional*, defaults to 80): + Number of mel features used per input features. Used by the speech decoder pre-net. Should correspond to + the value used in the [`SpeechT5Processor`] class. + speech_decoder_prenet_layers (`int`, *optional*, defaults to 2): + Number of layers in the speech decoder pre-net. + speech_decoder_prenet_units (`int`, *optional*, defaults to 256): + Dimensionality of the layers in the speech decoder pre-net. + speech_decoder_prenet_dropout (`float`, *optional*, defaults to 0.5): + The dropout probability for the speech decoder pre-net layers. + speaker_embedding_dim (`int`, *optional*, defaults to 512): + Dimensionality of the *XVector* embedding vectors. + speech_decoder_postnet_layers (`int`, *optional*, defaults to 5): + Number of layers in the speech decoder post-net. + speech_decoder_postnet_units (`int`, *optional*, defaults to 256): + Dimensionality of the layers in the speech decoder post-net. + speech_decoder_postnet_kernel (`int`, *optional*, defaults to 5): + Number of convolutional filter channels in the speech decoder post-net. + speech_decoder_postnet_dropout (`float`, *optional*, defaults to 0.5): + The dropout probability for the speech decoder post-net layers. + reduction_factor (`int`, *optional*, defaults to 2): + Spectrogram length reduction factor for the speech decoder inputs. + max_speech_positions (`int`, *optional*, defaults to 4000): + The maximum sequence length of speech features that this model might ever be used with. + max_text_positions (`int`, *optional*, defaults to 450): + The maximum sequence length of text features that this model might ever be used with. + encoder_max_relative_position (`int`, *optional*, defaults to 160): + Maximum distance for relative position embedding in the encoder. + use_guided_attention_loss (`bool`, *optional*, defaults to `True`): + Whether to apply guided attention loss while training the TTS model. + guided_attention_loss_num_heads (`int`, *optional*, defaults to 2): + Number of attention heads the guided attention loss will be applied to. Use -1 to apply this loss to all + attention heads. + guided_attention_loss_sigma (`float`, *optional*, defaults to 0.4): + Standard deviation for guided attention loss. + guided_attention_loss_scale (`float`, *optional*, defaults to 10.0): + Scaling coefficient for guided attention loss (also known as lambda). + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + + Example: + + ```python + >>> from transformers import SpeechT5Model, SpeechT5Config + + >>> # Initializing a "microsoft/speecht5_asr" style configuration + >>> configuration = SpeechT5Config() + + >>> # Initializing a model (with random weights) from the "microsoft/speecht5_asr" style configuration + >>> model = SpeechT5Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "speecht5" + attribute_map = {"num_attention_heads": "encoder_attention_heads", "num_hidden_layers": "encoder_layers"} + + def __init__( + self, + vocab_size=81, + hidden_size=768, + encoder_layers=12, + encoder_attention_heads=12, + encoder_ffn_dim=3072, + encoder_layerdrop=0.1, + decoder_layers=6, + decoder_ffn_dim=3072, + decoder_attention_heads=12, + decoder_layerdrop=0.1, + hidden_act="gelu", + positional_dropout=0.1, + hidden_dropout=0.1, + attention_dropout=0.1, + activation_dropout=0.1, + initializer_range=0.02, + layer_norm_eps=1e-5, + scale_embedding=False, + feat_extract_norm="group", + feat_proj_dropout=0.0, + feat_extract_activation="gelu", + conv_dim=(512, 512, 512, 512, 512, 512, 512), + conv_stride=(5, 2, 2, 2, 2, 2, 2), + conv_kernel=(10, 3, 3, 3, 3, 2, 2), + conv_bias=False, + num_conv_pos_embeddings=128, + num_conv_pos_embedding_groups=16, + apply_spec_augment=True, + mask_time_prob=0.05, + mask_time_length=10, + mask_time_min_masks=2, + mask_feature_prob=0.0, + mask_feature_length=10, + mask_feature_min_masks=0, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + decoder_start_token_id=2, + num_mel_bins=80, + speech_decoder_prenet_layers=2, + speech_decoder_prenet_units=256, + speech_decoder_prenet_dropout=0.5, + speaker_embedding_dim=512, + speech_decoder_postnet_layers=5, + speech_decoder_postnet_units=256, + speech_decoder_postnet_kernel=5, + speech_decoder_postnet_dropout=0.5, + reduction_factor=2, + max_speech_positions=4000, + max_text_positions=450, + encoder_max_relative_position=160, + use_guided_attention_loss=True, + guided_attention_loss_num_heads=2, + guided_attention_loss_sigma=0.4, + guided_attention_loss_scale=10.0, + use_cache=True, + is_encoder_decoder=True, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.encoder_layers = encoder_layers + self.encoder_ffn_dim = encoder_ffn_dim + self.encoder_attention_heads = encoder_attention_heads + self.encoder_layerdrop = encoder_layerdrop + self.decoder_layers = decoder_layers + self.decoder_ffn_dim = decoder_ffn_dim + self.decoder_attention_heads = decoder_attention_heads + self.decoder_layerdrop = decoder_layerdrop + self.hidden_act = hidden_act + self.positional_dropout = positional_dropout + self.hidden_dropout = hidden_dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.scale_embedding = scale_embedding + + self.feat_extract_norm = feat_extract_norm + self.feat_proj_dropout = feat_proj_dropout + self.feat_extract_activation = feat_extract_activation + self.conv_dim = list(conv_dim) + self.conv_stride = list(conv_stride) + self.conv_kernel = list(conv_kernel) + self.conv_bias = conv_bias + self.num_conv_pos_embeddings = num_conv_pos_embeddings + self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups + self.num_feat_extract_layers = len(self.conv_dim) + + if ( + (len(self.conv_stride) != self.num_feat_extract_layers) + or (len(self.conv_kernel) != self.num_feat_extract_layers) + or (len(self.conv_dim) != self.num_feat_extract_layers) + ): + raise ValueError( + "Configuration for convolutional layers is incorrect. It is required that `len(config.conv_dim)` ==" + " `len(config.conv_stride)` == `len(config.conv_kernel)`, but is `len(config.conv_dim) =" + f" {len(self.conv_dim)}`, `len(config.conv_stride) = {len(self.conv_stride)}`," + f" `len(config.conv_kernel) = {len(self.conv_kernel)}`." + ) + + # fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779 + self.apply_spec_augment = apply_spec_augment + self.mask_time_prob = mask_time_prob + self.mask_time_length = mask_time_length + self.mask_time_min_masks = mask_time_min_masks + self.mask_feature_prob = mask_feature_prob + self.mask_feature_length = mask_feature_length + self.mask_feature_min_masks = mask_feature_min_masks + + self.num_mel_bins = num_mel_bins + self.speech_decoder_prenet_layers = speech_decoder_prenet_layers + self.speech_decoder_prenet_units = speech_decoder_prenet_units + self.speech_decoder_prenet_dropout = speech_decoder_prenet_dropout + self.speaker_embedding_dim = speaker_embedding_dim + + self.speech_decoder_postnet_layers = speech_decoder_postnet_layers + self.speech_decoder_postnet_units = speech_decoder_postnet_units + self.speech_decoder_postnet_kernel = speech_decoder_postnet_kernel + self.speech_decoder_postnet_dropout = speech_decoder_postnet_dropout + self.reduction_factor = reduction_factor + + self.max_speech_positions = max_speech_positions + self.max_text_positions = max_text_positions + self.encoder_max_relative_position = encoder_max_relative_position + + self.use_guided_attention_loss = use_guided_attention_loss + self.guided_attention_loss_num_heads = guided_attention_loss_num_heads + self.guided_attention_loss_sigma = guided_attention_loss_sigma + self.guided_attention_loss_scale = guided_attention_loss_scale + + self.use_cache = use_cache + self.is_encoder_decoder = is_encoder_decoder + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + is_encoder_decoder=is_encoder_decoder, + decoder_start_token_id=decoder_start_token_id, + **kwargs, + ) + + def inputs_to_logits_ratio(self): + return functools.reduce(operator.mul, self.conv_stride, 1) + + +class SpeechT5HifiGanConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`SpeechT5HifiGanModel`]. It is used to instantiate + a SpeechT5 HiFi-GAN vocoder model according to the specified arguments, defining the model architecture. + Instantiating a configuration with the defaults will yield a similar configuration to that of the SpeechT5 + [microsoft/speecht5_hifigan](https://huggingface.co/microsoft/speecht5_hifigan) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + model_in_dim (`int`, *optional*, defaults to 80): + The number of frequency bins in the input log-mel spectrogram. + sampling_rate (`int`, *optional*, defaults to 16000): + The sampling rate at which the output audio will be generated, expressed in hertz (Hz). + upsample_initial_channel (`int`, *optional*, defaults to 512): + The number of input channels into the upsampling network. + upsample_rates (`Tuple[int]` or `List[int]`, *optional*, defaults to `[4, 4, 4, 4]`): + A tuple of integers defining the stride of each 1D convolutional layer in the upsampling network. The + length of *upsample_rates* defines the number of convolutional layers and has to match the length of + *upsample_kernel_sizes*. + upsample_kernel_sizes (`Tuple[int]` or `List[int]`, *optional*, defaults to `[8, 8, 8, 8]`): + A tuple of integers defining the kernel size of each 1D convolutional layer in the upsampling network. The + length of *upsample_kernel_sizes* defines the number of convolutional layers and has to match the length of + *upsample_rates*. + resblock_kernel_sizes (`Tuple[int]` or `List[int]`, *optional*, defaults to `[3, 7, 11]`): + A tuple of integers defining the kernel sizes of the 1D convolutional layers in the multi-receptive field + fusion (MRF) module. + resblock_dilation_sizes (`Tuple[Tuple[int]]` or `List[List[int]]`, *optional*, defaults to `[[1, 3, 5], [1, 3, 5], [1, 3, 5]]`): + A nested tuple of integers defining the dilation rates of the dilated 1D convolutional layers in the + multi-receptive field fusion (MRF) module. + initializer_range (`float`, *optional*, defaults to 0.01): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + leaky_relu_slope (`float`, *optional*, defaults to 0.1): + The angle of the negative slope used by the leaky ReLU activation. + normalize_before (`bool`, *optional*, defaults to `True`): + Whether or not to normalize the spectrogram before vocoding using the vocoder's learned mean and variance. + + Example: + + ```python + >>> from transformers import SpeechT5HifiGan, SpeechT5HifiGanConfig + + >>> # Initializing a "microsoft/speecht5_hifigan" style configuration + >>> configuration = SpeechT5HifiGanConfig() + + >>> # Initializing a model (with random weights) from the "microsoft/speecht5_hifigan" style configuration + >>> model = SpeechT5HifiGan(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "hifigan" + + def __init__( + self, + model_in_dim=80, + sampling_rate=16000, + upsample_initial_channel=512, + upsample_rates=[4, 4, 4, 4], + upsample_kernel_sizes=[8, 8, 8, 8], + resblock_kernel_sizes=[3, 7, 11], + resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], + initializer_range=0.01, + leaky_relu_slope=0.1, + normalize_before=True, + **kwargs, + ): + self.model_in_dim = model_in_dim + self.sampling_rate = sampling_rate + self.upsample_initial_channel = upsample_initial_channel + self.upsample_rates = upsample_rates + self.upsample_kernel_sizes = upsample_kernel_sizes + self.resblock_kernel_sizes = resblock_kernel_sizes + self.resblock_dilation_sizes = resblock_dilation_sizes + self.initializer_range = initializer_range + self.leaky_relu_slope = leaky_relu_slope + self.normalize_before = normalize_before + super().__init__(**kwargs) + + +__all__ = ["SpeechT5Config", "SpeechT5HifiGanConfig"] diff --git a/docs/transformers/src/transformers/models/speecht5/convert_hifigan.py b/docs/transformers/src/transformers/models/speecht5/convert_hifigan.py new file mode 100644 index 0000000000000000000000000000000000000000..b39012f8e251cef0e164629dcbca1cecac26ac9e --- /dev/null +++ b/docs/transformers/src/transformers/models/speecht5/convert_hifigan.py @@ -0,0 +1,108 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert SpeechT5 HiFi-GAN checkpoint.""" + +import argparse + +import numpy as np +import torch + +from transformers import SpeechT5HifiGan, SpeechT5HifiGanConfig, logging + + +logging.set_verbosity_info() +logger = logging.get_logger("transformers.models.speecht5") + + +def load_weights(checkpoint, hf_model, config): + hf_model.apply_weight_norm() + + hf_model.conv_pre.weight_g.data = checkpoint["input_conv.weight_g"] + hf_model.conv_pre.weight_v.data = checkpoint["input_conv.weight_v"] + hf_model.conv_pre.bias.data = checkpoint["input_conv.bias"] + + for i in range(len(config.upsample_rates)): + hf_model.upsampler[i].weight_g.data = checkpoint[f"upsamples.{i}.1.weight_g"] + hf_model.upsampler[i].weight_v.data = checkpoint[f"upsamples.{i}.1.weight_v"] + hf_model.upsampler[i].bias.data = checkpoint[f"upsamples.{i}.1.bias"] + + for i in range(len(config.upsample_rates) * len(config.resblock_kernel_sizes)): + for j in range(len(config.resblock_dilation_sizes)): + hf_model.resblocks[i].convs1[j].weight_g.data = checkpoint[f"blocks.{i}.convs1.{j}.1.weight_g"] + hf_model.resblocks[i].convs1[j].weight_v.data = checkpoint[f"blocks.{i}.convs1.{j}.1.weight_v"] + hf_model.resblocks[i].convs1[j].bias.data = checkpoint[f"blocks.{i}.convs1.{j}.1.bias"] + + hf_model.resblocks[i].convs2[j].weight_g.data = checkpoint[f"blocks.{i}.convs2.{j}.1.weight_g"] + hf_model.resblocks[i].convs2[j].weight_v.data = checkpoint[f"blocks.{i}.convs2.{j}.1.weight_v"] + hf_model.resblocks[i].convs2[j].bias.data = checkpoint[f"blocks.{i}.convs2.{j}.1.bias"] + + hf_model.conv_post.weight_g.data = checkpoint["output_conv.1.weight_g"] + hf_model.conv_post.weight_v.data = checkpoint["output_conv.1.weight_v"] + hf_model.conv_post.bias.data = checkpoint["output_conv.1.bias"] + + hf_model.remove_weight_norm() + + +@torch.no_grad() +def convert_hifigan_checkpoint( + checkpoint_path, + stats_path, + pytorch_dump_folder_path, + config_path=None, + repo_id=None, +): + if config_path is not None: + config = SpeechT5HifiGanConfig.from_pretrained(config_path) + else: + config = SpeechT5HifiGanConfig() + + model = SpeechT5HifiGan(config) + + orig_checkpoint = torch.load(checkpoint_path, weights_only=True) + load_weights(orig_checkpoint["model"]["generator"], model, config) + + stats = np.load(stats_path) + mean = stats[0].reshape(-1) + scale = stats[1].reshape(-1) + model.mean = torch.from_numpy(mean).float() + model.scale = torch.from_numpy(scale).float() + + model.save_pretrained(pytorch_dump_folder_path) + + if repo_id: + print("Pushing to the hub...") + model.push_to_hub(repo_id) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--checkpoint_path", required=True, default=None, type=str, help="Path to original checkpoint") + parser.add_argument("--stats_path", required=True, default=None, type=str, help="Path to stats.npy file") + parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert") + parser.add_argument( + "--pytorch_dump_folder_path", required=True, default=None, type=str, help="Path to the output PyTorch model." + ) + parser.add_argument( + "--push_to_hub", default=None, type=str, help="Where to upload the converted model on the 🤗 hub." + ) + + args = parser.parse_args() + convert_hifigan_checkpoint( + args.checkpoint_path, + args.stats_path, + args.pytorch_dump_folder_path, + args.config_path, + args.push_to_hub, + ) diff --git a/docs/transformers/src/transformers/models/speecht5/convert_speecht5_original_pytorch_checkpoint_to_pytorch.py b/docs/transformers/src/transformers/models/speecht5/convert_speecht5_original_pytorch_checkpoint_to_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..c16e11d2b250f041f6cf36b0ea71d77632e97e5c --- /dev/null +++ b/docs/transformers/src/transformers/models/speecht5/convert_speecht5_original_pytorch_checkpoint_to_pytorch.py @@ -0,0 +1,401 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert SpeechT5 checkpoint.""" + +import argparse + +import torch + +from transformers import ( + SpeechT5Config, + SpeechT5FeatureExtractor, + SpeechT5ForSpeechToSpeech, + SpeechT5ForSpeechToText, + SpeechT5ForTextToSpeech, + SpeechT5Processor, + SpeechT5Tokenizer, + logging, +) +from transformers.tokenization_utils import AddedToken + + +logging.set_verbosity_info() +logger = logging.get_logger("transformers.models.speecht5") + +MAPPING_SPEECH_ENCODER_PRENET = { + "speech_encoder_prenet.layer_norm": "speecht5.encoder.prenet.feature_projection.layer_norm", + "speech_encoder_prenet.post_extract_proj": "speecht5.encoder.prenet.feature_projection.projection", + "speech_encoder_prenet.pos_conv.0": "speecht5.encoder.prenet.pos_conv_embed.conv", + "speech_encoder_prenet.mask_emb": "speecht5.encoder.prenet.masked_spec_embed", +} +MAPPING_TEXT_ENCODER_PRENET = { + "text_encoder_prenet.encoder_prenet.0": "speecht5.encoder.prenet.embed_tokens", + "text_encoder_prenet.encoder_prenet.1.alpha": "speecht5.encoder.prenet.encode_positions.alpha", +} +MAPPING_SPEECH_DECODER_PRENET = { + "speech_decoder_prenet.decoder_prenet.0.0.prenet.0.0": "speecht5.decoder.prenet.layers.0", + "speech_decoder_prenet.decoder_prenet.0.0.prenet.1.0": "speecht5.decoder.prenet.layers.1", + "speech_decoder_prenet.decoder_prenet.0.1": "speecht5.decoder.prenet.final_layer", + "speech_decoder_prenet.decoder_prenet.1.alpha": "speecht5.decoder.prenet.encode_positions.alpha", + "speech_decoder_prenet.spkembs_layer.0": "speecht5.decoder.prenet.speaker_embeds_layer", +} +MAPPING_SPEECH_DECODER_POSTNET = { + "speech_decoder_postnet.feat_out": "speech_decoder_postnet.feat_out", + "speech_decoder_postnet.prob_out": "speech_decoder_postnet.prob_out", + "speech_decoder_postnet.postnet.postnet.0.0": "speech_decoder_postnet.layers.0.conv", + "speech_decoder_postnet.postnet.postnet.0.1": "speech_decoder_postnet.layers.0.batch_norm", + "speech_decoder_postnet.postnet.postnet.1.0": "speech_decoder_postnet.layers.1.conv", + "speech_decoder_postnet.postnet.postnet.1.1": "speech_decoder_postnet.layers.1.batch_norm", + "speech_decoder_postnet.postnet.postnet.2.0": "speech_decoder_postnet.layers.2.conv", + "speech_decoder_postnet.postnet.postnet.2.1": "speech_decoder_postnet.layers.2.batch_norm", + "speech_decoder_postnet.postnet.postnet.3.0": "speech_decoder_postnet.layers.3.conv", + "speech_decoder_postnet.postnet.postnet.3.1": "speech_decoder_postnet.layers.3.batch_norm", + "speech_decoder_postnet.postnet.postnet.4.0": "speech_decoder_postnet.layers.4.conv", + "speech_decoder_postnet.postnet.postnet.4.1": "speech_decoder_postnet.layers.4.batch_norm", +} +MAPPING_TEXT_DECODER_PRENET = { + "text_decoder_prenet.embed_tokens": "speecht5.decoder.prenet.embed_tokens", +} +MAPPING_TEXT_DECODER_POSTNET = { + "text_decoder_postnet.output_projection": "text_decoder_postnet.lm_head", +} +MAPPING_ENCODER = { + "encoder.layers.*.self_attn.k_proj": "speecht5.encoder.wrapped_encoder.layers.*.attention.k_proj", + "encoder.layers.*.self_attn.v_proj": "speecht5.encoder.wrapped_encoder.layers.*.attention.v_proj", + "encoder.layers.*.self_attn.q_proj": "speecht5.encoder.wrapped_encoder.layers.*.attention.q_proj", + "encoder.layers.*.self_attn.out_proj": "speecht5.encoder.wrapped_encoder.layers.*.attention.out_proj", + "encoder.layers.*.self_attn_layer_norm": "speecht5.encoder.wrapped_encoder.layers.*.layer_norm", + "encoder.layers.*.fc1": "speecht5.encoder.wrapped_encoder.layers.*.feed_forward.intermediate_dense", + "encoder.layers.*.fc2": "speecht5.encoder.wrapped_encoder.layers.*.feed_forward.output_dense", + "encoder.layers.*.final_layer_norm": "speecht5.encoder.wrapped_encoder.layers.*.final_layer_norm", + "encoder.layer_norm": "speecht5.encoder.wrapped_encoder.layer_norm", + "encoder.pos_emb.pe_k": "speecht5.encoder.wrapped_encoder.embed_positions.pe_k", +} +MAPPING_DECODER = { + "decoder.layers.*.self_attn.k_proj": "speecht5.decoder.wrapped_decoder.layers.*.self_attn.k_proj", + "decoder.layers.*.self_attn.v_proj": "speecht5.decoder.wrapped_decoder.layers.*.self_attn.v_proj", + "decoder.layers.*.self_attn.q_proj": "speecht5.decoder.wrapped_decoder.layers.*.self_attn.q_proj", + "decoder.layers.*.self_attn.out_proj": "speecht5.decoder.wrapped_decoder.layers.*.self_attn.out_proj", + "decoder.layers.*.self_attn_layer_norm": "speecht5.decoder.wrapped_decoder.layers.*.self_attn_layer_norm", + "decoder.layers.*.encoder_attn.k_proj": "speecht5.decoder.wrapped_decoder.layers.*.encoder_attn.k_proj", + "decoder.layers.*.encoder_attn.v_proj": "speecht5.decoder.wrapped_decoder.layers.*.encoder_attn.v_proj", + "decoder.layers.*.encoder_attn.q_proj": "speecht5.decoder.wrapped_decoder.layers.*.encoder_attn.q_proj", + "decoder.layers.*.encoder_attn.out_proj": "speecht5.decoder.wrapped_decoder.layers.*.encoder_attn.out_proj", + "decoder.layers.*.encoder_attn_layer_norm": "speecht5.decoder.wrapped_decoder.layers.*.encoder_attn_layer_norm", + "decoder.layers.*.fc1": "speecht5.decoder.wrapped_decoder.layers.*.feed_forward.intermediate_dense", + "decoder.layers.*.fc2": "speecht5.decoder.wrapped_decoder.layers.*.feed_forward.output_dense", + "decoder.layers.*.final_layer_norm": "speecht5.decoder.wrapped_decoder.layers.*.final_layer_norm", +} +MAPPING_S2T = { + **MAPPING_SPEECH_ENCODER_PRENET, + **MAPPING_ENCODER, + **MAPPING_DECODER, + **MAPPING_TEXT_DECODER_PRENET, + **MAPPING_TEXT_DECODER_POSTNET, +} +MAPPING_T2S = { + **MAPPING_TEXT_ENCODER_PRENET, + **MAPPING_ENCODER, + **MAPPING_DECODER, + **MAPPING_SPEECH_DECODER_PRENET, + **MAPPING_SPEECH_DECODER_POSTNET, +} +MAPPING_S2S = { + **MAPPING_SPEECH_ENCODER_PRENET, + **MAPPING_ENCODER, + **MAPPING_DECODER, + **MAPPING_SPEECH_DECODER_PRENET, + **MAPPING_SPEECH_DECODER_POSTNET, +} +TOP_LEVEL_KEYS = [] +IGNORE_KEYS = [ + "encoder.version", + "encoder.layers.*.norm_k.weight", + "encoder.layers.*.norm_k.bias", + "decoder.version", + "decoder.layers.*.norm_k.weight", + "decoder.layers.*.norm_k.bias", + "decoder.pos_emb.pe_k", + "speech_encoder_prenet.embed_positions._float_tensor", + "text_decoder_prenet.embed_positions._float_tensor", +] +IGNORE_KEYS_S2T = IGNORE_KEYS + [ + "encoder.proj", + "text_encoder_prenet.*", + "speech_decoder_prenet.*", + "speech_decoder_postnet.*", +] +IGNORE_KEYS_T2S = IGNORE_KEYS + [ + "encoder.proj", + "speech_encoder_prenet.*", + "text_decoder_prenet.*", + "text_decoder_postnet.*", +] +IGNORE_KEYS_S2S = IGNORE_KEYS + [ + "encoder.proj", + "text_encoder_prenet.*", + "text_decoder_prenet.*", + "text_decoder_postnet.*", +] + + +def set_recursively(hf_pointer, key, value, full_name, weight_type): + for attribute in key.split("."): + hf_pointer = getattr(hf_pointer, attribute) + + if weight_type is not None: + hf_shape = getattr(hf_pointer, weight_type).shape + else: + hf_shape = hf_pointer.shape + + if hf_shape != value.shape: + raise ValueError( + f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be" + f" {value.shape} for {full_name}" + ) + + if weight_type == "weight": + hf_pointer.weight.data = value + elif weight_type == "weight_g": + hf_pointer.weight_g.data = value + elif weight_type == "weight_v": + hf_pointer.weight_v.data = value + elif weight_type == "bias": + hf_pointer.bias.data = value + elif weight_type == "running_mean": + hf_pointer.running_mean.data = value + elif weight_type == "running_var": + hf_pointer.running_var.data = value + elif weight_type == "num_batches_tracked": + hf_pointer.num_batches_tracked.data = value + else: + hf_pointer.data = value + + logger.info(f"{key + ('.' + weight_type if weight_type is not None else '')} was initialized from {full_name}.") + + +def should_ignore(name, ignore_keys): + for key in ignore_keys: + if key.endswith(".*"): + if name.startswith(key[:-1]): + return True + elif ".*." in key: + prefix, suffix = key.split(".*.") + if prefix in name and suffix in name: + return True + elif key in name: + return True + return False + + +def recursively_load_weights(fairseq_dict, hf_model, task): + unused_weights = [] + + if task == "s2t": + feature_encoder = hf_model.speecht5.encoder.prenet.feature_encoder + MAPPING = MAPPING_S2T + IGNORE_KEYS = IGNORE_KEYS_S2T + elif task == "t2s": + feature_encoder = None + MAPPING = MAPPING_T2S + IGNORE_KEYS = IGNORE_KEYS_T2S + elif task == "s2s": + feature_encoder = hf_model.speecht5.encoder.prenet.feature_encoder + MAPPING = MAPPING_S2S + IGNORE_KEYS = IGNORE_KEYS_S2S + else: + raise ValueError(f"Unsupported task: {task}") + + for name, value in fairseq_dict.items(): + if should_ignore(name, IGNORE_KEYS): + logger.info(f"{name} was ignored") + continue + + is_used = False + if "conv_layers" in name: + load_conv_layer( + name, + value, + feature_encoder, + unused_weights, + hf_model.config.feat_extract_norm == "group", + ) + is_used = True + else: + for key, mapped_key in MAPPING.items(): + # mapped_key = "speecht5." + mapped_key if mapped_key not in TOP_LEVEL_KEYS else mapped_key + + if "*" in key: + prefix, suffix = key.split(".*.") + if prefix in name and suffix in name: + key = suffix + + # if key in name or key.split("w2v_model.")[-1] == name.split(".")[0]: + if key in name: + is_used = True + if "*" in mapped_key: + layer_index = name.split(key)[0].split(".")[-2] + mapped_key = mapped_key.replace("*", layer_index) + if "weight_g" in name: + weight_type = "weight_g" + elif "weight_v" in name: + weight_type = "weight_v" + elif "bias" in name: + weight_type = "bias" + elif "weight" in name: + weight_type = "weight" + elif "running_mean" in name: + weight_type = "running_mean" + elif "running_var" in name: + weight_type = "running_var" + elif "num_batches_tracked" in name: + weight_type = "num_batches_tracked" + else: + weight_type = None + set_recursively(hf_model, mapped_key, value, name, weight_type) + continue + if not is_used: + unused_weights.append(name) + + logger.warning(f"Unused weights: {unused_weights}") + + +def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_group_norm): + name = full_name.split("conv_layers.")[-1] + items = name.split(".") + layer_id = int(items[0]) + type_id = int(items[1]) + + if type_id == 0: + if "bias" in name: + if value.shape != feature_extractor.conv_layers[layer_id].conv.bias.data.shape: + raise ValueError( + f"{full_name} has size {value.shape}, but" + f" {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found." + ) + feature_extractor.conv_layers[layer_id].conv.bias.data = value + logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.") + elif "weight" in name: + if value.shape != feature_extractor.conv_layers[layer_id].conv.weight.data.shape: + raise ValueError( + f"{full_name} has size {value.shape}, but" + f" {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found." + ) + feature_extractor.conv_layers[layer_id].conv.weight.data = value + logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.") + elif (type_id == 2 and not use_group_norm) or (type_id == 2 and layer_id == 0 and use_group_norm): + if "bias" in name: + if value.shape != feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape: + raise ValueError( + f"{full_name} has size {value.shape}, but" + f" {feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape} was found." + ) + feature_extractor.conv_layers[layer_id].layer_norm.bias.data = value + logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.") + elif "weight" in name: + if value.shape != feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape: + raise ValueError( + f"{full_name} has size {value.shape}, but" + f" {feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape} was found." + ) + feature_extractor.conv_layers[layer_id].layer_norm.weight.data = value + logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.") + else: + unused_weights.append(full_name) + + +@torch.no_grad() +def convert_speecht5_checkpoint( + task, + checkpoint_path, + pytorch_dump_folder_path, + config_path=None, + vocab_path=None, + repo_id=None, +): + """ + Copy/paste/tweak model's weights to transformers design. + """ + if config_path is not None: + config = SpeechT5Config.from_pretrained(config_path) + else: + config = SpeechT5Config() + + if task == "s2t": + config.max_length = config.max_text_positions + model = SpeechT5ForSpeechToText(config) + elif task == "t2s": + config.max_speech_positions = 1876 + config.max_text_positions = 600 + config.max_length = config.max_speech_positions + model = SpeechT5ForTextToSpeech(config) + elif task == "s2s": + config.max_speech_positions = 1876 + config.max_length = config.max_speech_positions + model = SpeechT5ForSpeechToSpeech(config) + else: + raise ValueError(f"Unknown task name: {task}") + + if vocab_path: + tokenizer = SpeechT5Tokenizer(vocab_path, model_max_length=config.max_text_positions) + + # Mask token behaves like a normal word, i.e. include the space before it + mask_token = AddedToken("", lstrip=True, rstrip=False) + tokenizer.mask_token = mask_token + tokenizer.add_special_tokens({"mask_token": mask_token}) + tokenizer.add_tokens([""]) + + feature_extractor = SpeechT5FeatureExtractor() + processor = SpeechT5Processor(tokenizer=tokenizer, feature_extractor=feature_extractor) + processor.save_pretrained(pytorch_dump_folder_path) + + fairseq_checkpoint = torch.load(checkpoint_path, weights_only=True) + recursively_load_weights(fairseq_checkpoint["model"], model, task) + + model.save_pretrained(pytorch_dump_folder_path) + + if repo_id: + print("Pushing to the hub...") + processor.push_to_hub(repo_id) + model.push_to_hub(repo_id) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--task", + default="s2t", + type=str, + help="Type of the SpeechT5 model you'd like to convert. Should be one of 's2t', 't2s', 's2s'.", + ) + parser.add_argument("--checkpoint_path", required=True, default=None, type=str, help="Path to fairseq checkpoint") + parser.add_argument("--vocab_path", default=None, type=str, help="Path to SentencePiece model") + parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert") + parser.add_argument( + "--pytorch_dump_folder_path", required=True, default=None, type=str, help="Path to the output PyTorch model." + ) + parser.add_argument( + "--push_to_hub", default=None, type=str, help="Where to upload the converted model on the 🤗 hub." + ) + + args = parser.parse_args() + convert_speecht5_checkpoint( + args.task, + args.checkpoint_path, + args.pytorch_dump_folder_path, + args.config_path, + args.vocab_path, + args.push_to_hub, + ) diff --git a/docs/transformers/src/transformers/models/speecht5/feature_extraction_speecht5.py b/docs/transformers/src/transformers/models/speecht5/feature_extraction_speecht5.py new file mode 100644 index 0000000000000000000000000000000000000000..e6b277644ce94d3a6b1dfaf2657bdf02b115f29f --- /dev/null +++ b/docs/transformers/src/transformers/models/speecht5/feature_extraction_speecht5.py @@ -0,0 +1,396 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Feature extractor class for SpeechT5.""" + +import warnings +from typing import Any, Dict, List, Optional, Union + +import numpy as np + +from ...audio_utils import mel_filter_bank, optimal_fft_length, spectrogram, window_function +from ...feature_extraction_sequence_utils import SequenceFeatureExtractor +from ...feature_extraction_utils import BatchFeature +from ...utils import PaddingStrategy, TensorType, logging + + +logger = logging.get_logger(__name__) + + +class SpeechT5FeatureExtractor(SequenceFeatureExtractor): + r""" + Constructs a SpeechT5 feature extractor. + + This class can pre-process a raw speech signal by (optionally) normalizing to zero-mean unit-variance, for use by + the SpeechT5 speech encoder prenet. + + This class can also extract log-mel filter bank features from raw speech, for use by the SpeechT5 speech decoder + prenet. + + This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains + most of the main methods. Users should refer to this superclass for more information regarding those methods. + + Args: + feature_size (`int`, *optional*, defaults to 1): + The feature dimension of the extracted features. + sampling_rate (`int`, *optional*, defaults to 16000): + The sampling rate at which the audio files should be digitalized expressed in hertz (Hz). + padding_value (`float`, *optional*, defaults to 0.0): + The value that is used to fill the padding values. + do_normalize (`bool`, *optional*, defaults to `False`): + Whether or not to zero-mean unit-variance normalize the input. Normalizing can help to significantly + improve the performance for some models. + num_mel_bins (`int`, *optional*, defaults to 80): + The number of mel-frequency bins in the extracted spectrogram features. + hop_length (`int`, *optional*, defaults to 16): + Number of ms between windows. Otherwise referred to as "shift" in many papers. + win_length (`int`, *optional*, defaults to 64): + Number of ms per window. + win_function (`str`, *optional*, defaults to `"hann_window"`): + Name for the window function used for windowing, must be accessible via `torch.{win_function}` + frame_signal_scale (`float`, *optional*, defaults to 1.0): + Constant multiplied in creating the frames before applying DFT. This argument is deprecated. + fmin (`float`, *optional*, defaults to 80): + Minimum mel frequency in Hz. + fmax (`float`, *optional*, defaults to 7600): + Maximum mel frequency in Hz. + mel_floor (`float`, *optional*, defaults to 1e-10): + Minimum value of mel frequency banks. + reduction_factor (`int`, *optional*, defaults to 2): + Spectrogram length reduction factor. This argument is deprecated. + return_attention_mask (`bool`, *optional*, defaults to `True`): + Whether or not [`~SpeechT5FeatureExtractor.__call__`] should return `attention_mask`. + """ + + model_input_names = ["input_values", "attention_mask"] + + def __init__( + self, + feature_size: int = 1, + sampling_rate: int = 16000, + padding_value: float = 0.0, + do_normalize: bool = False, + num_mel_bins: int = 80, + hop_length: int = 16, + win_length: int = 64, + win_function: str = "hann_window", + frame_signal_scale: float = 1.0, + fmin: float = 80, + fmax: float = 7600, + mel_floor: float = 1e-10, + reduction_factor: int = 2, + return_attention_mask: bool = True, + **kwargs, + ): + super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs) + self.do_normalize = do_normalize + self.return_attention_mask = return_attention_mask + + self.num_mel_bins = num_mel_bins + self.hop_length = hop_length + self.win_length = win_length + self.win_function = win_function + self.frame_signal_scale = frame_signal_scale + self.fmin = fmin + self.fmax = fmax + self.mel_floor = mel_floor + self.reduction_factor = reduction_factor + + self.sample_size = win_length * sampling_rate // 1000 + self.sample_stride = hop_length * sampling_rate // 1000 + self.n_fft = optimal_fft_length(self.sample_size) + self.n_freqs = (self.n_fft // 2) + 1 + + self.window = window_function(window_length=self.sample_size, name=self.win_function, periodic=True) + + self.mel_filters = mel_filter_bank( + num_frequency_bins=self.n_freqs, + num_mel_filters=self.num_mel_bins, + min_frequency=self.fmin, + max_frequency=self.fmax, + sampling_rate=self.sampling_rate, + norm="slaney", + mel_scale="slaney", + ) + + if frame_signal_scale != 1.0: + warnings.warn( + "The argument `frame_signal_scale` is deprecated and will be removed in version 4.30.0 of Transformers", + FutureWarning, + ) + if reduction_factor != 2.0: + warnings.warn( + "The argument `reduction_factor` is deprecated and will be removed in version 4.30.0 of Transformers", + FutureWarning, + ) + + @staticmethod + # Copied from transformers.models.wav2vec2.feature_extraction_wav2vec2.Wav2Vec2FeatureExtractor.zero_mean_unit_var_norm + def zero_mean_unit_var_norm( + input_values: List[np.ndarray], attention_mask: List[np.ndarray], padding_value: float = 0.0 + ) -> List[np.ndarray]: + """ + Every array in the list is normalized to have zero mean and unit variance + """ + if attention_mask is not None: + attention_mask = np.array(attention_mask, np.int32) + normed_input_values = [] + + for vector, length in zip(input_values, attention_mask.sum(-1)): + normed_slice = (vector - vector[:length].mean()) / np.sqrt(vector[:length].var() + 1e-7) + if length < normed_slice.shape[0]: + normed_slice[length:] = padding_value + + normed_input_values.append(normed_slice) + else: + normed_input_values = [(x - x.mean()) / np.sqrt(x.var() + 1e-7) for x in input_values] + + return normed_input_values + + def _extract_mel_features( + self, + one_waveform: np.ndarray, + ) -> np.ndarray: + """ + Extracts log-mel filterbank features for one waveform array (unbatched). + """ + log_mel_spec = spectrogram( + one_waveform, + window=self.window, + frame_length=self.sample_size, + hop_length=self.sample_stride, + fft_length=self.n_fft, + mel_filters=self.mel_filters, + mel_floor=self.mel_floor, + log_mel="log10", + ) + return log_mel_spec.T + + def __call__( + self, + audio: Optional[Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]]] = None, + audio_target: Optional[Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]]] = None, + padding: Union[bool, str, PaddingStrategy] = False, + max_length: Optional[int] = None, + truncation: bool = False, + pad_to_multiple_of: Optional[int] = None, + return_attention_mask: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + sampling_rate: Optional[int] = None, + **kwargs, + ) -> BatchFeature: + """ + Main method to featurize and prepare for the model one or several sequence(s). + + Pass in a value for `audio` to extract waveform features. Pass in a value for `audio_target` to extract log-mel + spectrogram features. + + Args: + audio (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`, *optional*): + The sequence or batch of sequences to be processed. Each sequence can be a numpy array, a list of float + values, a list of numpy arrays or a list of list of float values. This outputs waveform features. Must + be mono channel audio, not stereo, i.e. single float per timestep. + audio_target (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`, *optional*): + The sequence or batch of sequences to be processed as targets. Each sequence can be a numpy array, a + list of float values, a list of numpy arrays or a list of list of float values. This outputs log-mel + spectrogram features. + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding + index) among: + + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + max_length (`int`, *optional*): + Maximum length of the returned list and optionally padding length (see above). + truncation (`bool`): + Activates truncation to cut input sequences longer than *max_length* to *max_length*. + pad_to_multiple_of (`int`, *optional*): + If set will pad the sequence to a multiple of the provided value. + + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability + `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. + return_attention_mask (`bool`, *optional*): + Whether to return the attention mask. If left to the default, will return the attention mask according + to the specific feature_extractor's default. + + [What are attention masks?](../glossary#attention-mask) + + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors instead of list of python integers. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return Numpy `np.ndarray` objects. + sampling_rate (`int`, *optional*): + The sampling rate at which the `audio` or `audio_target` input was sampled. It is strongly recommended + to pass `sampling_rate` at the forward call to prevent silent errors. + """ + if audio is None and audio_target is None: + raise ValueError("You must provide either `audio` or `audio_target` values.") + + if sampling_rate is not None: + if sampling_rate != self.sampling_rate: + raise ValueError( + f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of" + f" {self.sampling_rate}. Please make sure that the provided audio input was sampled with" + f" {self.sampling_rate} and not {sampling_rate}." + ) + else: + logger.warning( + f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. " + "Failing to do so can result in silent errors that might be hard to debug." + ) + + if audio is not None: + inputs = self._process_audio( + audio, + False, + padding, + max_length, + truncation, + pad_to_multiple_of, + return_attention_mask, + return_tensors, + **kwargs, + ) + else: + inputs = None + + if audio_target is not None: + inputs_target = self._process_audio( + audio_target, + True, + padding, + max_length, + truncation, + pad_to_multiple_of, + return_attention_mask, + return_tensors, + **kwargs, + ) + + if inputs is None: + return inputs_target + else: + inputs["labels"] = inputs_target["input_values"] + decoder_attention_mask = inputs_target.get("attention_mask") + if decoder_attention_mask is not None: + inputs["decoder_attention_mask"] = decoder_attention_mask + + return inputs + + def _process_audio( + self, + speech: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]], + is_target: bool = False, + padding: Union[bool, str, PaddingStrategy] = False, + max_length: Optional[int] = None, + truncation: bool = False, + pad_to_multiple_of: Optional[int] = None, + return_attention_mask: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + **kwargs, + ) -> BatchFeature: + is_batched_numpy = isinstance(speech, np.ndarray) and len(speech.shape) > 1 + if is_batched_numpy and len(speech.shape) > 2: + raise ValueError(f"Only mono-channel audio is supported for input to {self}") + is_batched = is_batched_numpy or ( + isinstance(speech, (list, tuple)) and (isinstance(speech[0], (np.ndarray, tuple, list))) + ) + + if is_batched: + speech = [np.asarray(speech, dtype=np.float32) for speech in speech] + elif not is_batched and not isinstance(speech, np.ndarray): + speech = np.asarray(speech, dtype=np.float32) + elif isinstance(speech, np.ndarray) and speech.dtype is np.dtype(np.float64): + speech = speech.astype(np.float32) + + # always return batch + if not is_batched: + speech = [speech] + + # needed to make pad() work on spectrogram inputs + feature_size_hack = self.feature_size + + # convert into correct format for padding + if is_target: + features = [self._extract_mel_features(waveform) for waveform in speech] + encoded_inputs = BatchFeature({"input_values": features}) + self.feature_size = self.num_mel_bins + else: + encoded_inputs = BatchFeature({"input_values": speech}) + + padded_inputs = self.pad( + encoded_inputs, + padding=padding, + max_length=max_length, + truncation=truncation, + pad_to_multiple_of=pad_to_multiple_of, + return_attention_mask=return_attention_mask, + **kwargs, + ) + + self.feature_size = feature_size_hack + + # convert input values to correct format + input_values = padded_inputs["input_values"] + if not isinstance(input_values[0], np.ndarray): + padded_inputs["input_values"] = [np.asarray(array, dtype=np.float32) for array in input_values] + elif ( + not isinstance(input_values, np.ndarray) + and isinstance(input_values[0], np.ndarray) + and input_values[0].dtype is np.dtype(np.float64) + ): + padded_inputs["input_values"] = [array.astype(np.float32) for array in input_values] + elif isinstance(input_values, np.ndarray) and input_values.dtype is np.dtype(np.float64): + padded_inputs["input_values"] = input_values.astype(np.float32) + + # convert attention_mask to correct format + attention_mask = padded_inputs.get("attention_mask") + if attention_mask is not None: + padded_inputs["attention_mask"] = [np.asarray(array, dtype=np.int32) for array in attention_mask] + + # zero-mean and unit-variance normalization + if not is_target and self.do_normalize: + attention_mask = ( + attention_mask + if self._get_padding_strategies(padding, max_length=max_length) is not PaddingStrategy.DO_NOT_PAD + else None + ) + padded_inputs["input_values"] = self.zero_mean_unit_var_norm( + padded_inputs["input_values"], attention_mask=attention_mask, padding_value=self.padding_value + ) + + if return_tensors is not None: + padded_inputs = padded_inputs.convert_to_tensors(return_tensors) + + return padded_inputs + + def to_dict(self) -> Dict[str, Any]: + output = super().to_dict() + + # Don't serialize these as they are derived from the other properties. + names = ["window", "mel_filters", "sample_size", "sample_stride", "n_fft", "n_freqs"] + for name in names: + if name in output: + del output[name] + + return output + + +__all__ = ["SpeechT5FeatureExtractor"] diff --git a/docs/transformers/src/transformers/models/speecht5/modeling_speecht5.py b/docs/transformers/src/transformers/models/speecht5/modeling_speecht5.py new file mode 100644 index 0000000000000000000000000000000000000000..a46ab875d425a07fac7efb6992db1d29eec952b3 --- /dev/null +++ b/docs/transformers/src/transformers/models/speecht5/modeling_speecht5.py @@ -0,0 +1,3364 @@ +# coding=utf-8 +# Copyright 2023 The Fairseq Authors, Microsoft Research, and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch SpeechT5 model.""" + +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, L1Loss + +from ...activations import ACT2FN +from ...generation import GenerationMixin +from ...integrations.deepspeed import is_deepspeed_zero3_enabled +from ...integrations.fsdp import is_fsdp_managed_module +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, + Seq2SeqSpectrogramOutput, +) +from ...modeling_utils import PreTrainedModel +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from .configuration_speecht5 import SpeechT5Config, SpeechT5HifiGanConfig + + +logger = logging.get_logger(__name__) + + +_HIDDEN_STATES_START_POSITION = 1 + +# General docstring +_CONFIG_FOR_DOC = "SpeechT5Config" + + +# Copied from transformers.models.bart.modeling_bart.shift_tokens_right +def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): + """ + Shift input ids one token to the right. + """ + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() + shifted_input_ids[:, 0] = decoder_start_token_id + + if pad_token_id is None: + raise ValueError("self.model.config.pad_token_id has to be defined.") + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +def shift_spectrograms_right( + input_values: torch.Tensor, reduction_factor: int = 1, attention_mask: Optional[torch.Tensor] = None +): + """ + Shift input spectrograms one timestep to the right. Also applies the reduction factor to the sequence length. + """ + # thin out frames for reduction factor + if reduction_factor > 1: + input_values = input_values[:, reduction_factor - 1 :: reduction_factor] + if attention_mask is not None: + attention_mask = attention_mask[:, reduction_factor - 1 :: reduction_factor] + + shifted_input_values = input_values.new_zeros(input_values.shape) + shifted_input_values[:, 1:] = input_values[:, :-1].clone() + + # replace possible -100 values in labels by zeros + shifted_input_values.masked_fill_(shifted_input_values == -100.0, 0.0) + + return shifted_input_values, attention_mask + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices +def _compute_mask_indices( + shape: Tuple[int, int], + mask_prob: float, + mask_length: int, + attention_mask: Optional[torch.LongTensor] = None, + min_masks: int = 0, +) -> np.ndarray: + """ + Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for + ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on + CPU as part of the preprocessing during training. + + Args: + shape: The shape for which to compute masks. This should be of a tuple of size 2 where + the first element is the batch size and the second element is the length of the axis to span. + mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of + independently generated mask spans of length `mask_length` is computed by + `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the + actual percentage will be smaller. + mask_length: size of the mask + min_masks: minimum number of masked spans + attention_mask: A (right-padded) attention mask which independently shortens the feature axis of + each batch dimension. + """ + batch_size, sequence_length = shape + + if mask_length < 1: + raise ValueError("`mask_length` has to be bigger than 0.") + + if mask_length > sequence_length: + raise ValueError( + f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}" + f" and `sequence_length`: {sequence_length}`" + ) + + # epsilon is used for probabilistic rounding + epsilon = np.random.rand(1).item() + + def compute_num_masked_span(input_length): + """Given input length, compute how many spans should be masked""" + num_masked_span = int(mask_prob * input_length / mask_length + epsilon) + num_masked_span = max(num_masked_span, min_masks) + + # make sure num masked span <= sequence_length + if num_masked_span * mask_length > sequence_length: + num_masked_span = sequence_length // mask_length + + # make sure num_masked span is also <= input_length - (mask_length - 1) + if input_length - (mask_length - 1) < num_masked_span: + num_masked_span = max(input_length - (mask_length - 1), 0) + + return num_masked_span + + # compute number of masked spans in batch + input_lengths = ( + attention_mask.detach().sum(-1).tolist() + if attention_mask is not None + else [sequence_length for _ in range(batch_size)] + ) + + # SpecAugment mask to fill + spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool) + spec_aug_mask_idxs = [] + + max_num_masked_span = compute_num_masked_span(sequence_length) + + if max_num_masked_span == 0: + return spec_aug_mask + + for input_length in input_lengths: + # compute num of masked spans for this input + num_masked_span = compute_num_masked_span(input_length) + + # get random indices to mask + spec_aug_mask_idx = np.random.choice( + np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False + ) + + # pick first sampled index that will serve as a dummy index to pad vector + # to ensure same dimension for all batches due to probabilistic rounding + # Picking first sample just pads those vectors twice. + if len(spec_aug_mask_idx) == 0: + # this case can only happen if `input_length` is strictly smaller then + # `sequence_length` in which case the last token has to be a padding + # token which we can use as a dummy mask id + dummy_mask_idx = sequence_length - 1 + else: + dummy_mask_idx = spec_aug_mask_idx[0] + + spec_aug_mask_idx = np.concatenate( + [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx] + ) + spec_aug_mask_idxs.append(spec_aug_mask_idx) + + spec_aug_mask_idxs = np.array(spec_aug_mask_idxs) + + # expand masked indices to masked spans + spec_aug_mask_idxs = np.broadcast_to( + spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length) + ) + spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length) + + # add offset to the starting indexes so that indexes now create a span + offsets = np.arange(mask_length)[None, None, :] + offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape( + batch_size, max_num_masked_span * mask_length + ) + spec_aug_mask_idxs = spec_aug_mask_idxs + offsets + + # ensure that we cannot have indices larger than sequence_length + if spec_aug_mask_idxs.max() > sequence_length - 1: + spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1 + + # scatter indices to mask + np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1) + + return spec_aug_mask + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2NoLayerNormConvLayer with Wav2Vec2->SpeechT5 +class SpeechT5NoLayerNormConvLayer(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 + self.out_conv_dim = config.conv_dim[layer_id] + + self.conv = nn.Conv1d( + self.in_conv_dim, + self.out_conv_dim, + kernel_size=config.conv_kernel[layer_id], + stride=config.conv_stride[layer_id], + bias=config.conv_bias, + ) + self.activation = ACT2FN[config.feat_extract_activation] + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + hidden_states = self.activation(hidden_states) + return hidden_states + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2LayerNormConvLayer with Wav2Vec2->SpeechT5 +class SpeechT5LayerNormConvLayer(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 + self.out_conv_dim = config.conv_dim[layer_id] + + self.conv = nn.Conv1d( + self.in_conv_dim, + self.out_conv_dim, + kernel_size=config.conv_kernel[layer_id], + stride=config.conv_stride[layer_id], + bias=config.conv_bias, + ) + self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True) + self.activation = ACT2FN[config.feat_extract_activation] + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + + hidden_states = hidden_states.transpose(-2, -1) + hidden_states = self.layer_norm(hidden_states) + hidden_states = hidden_states.transpose(-2, -1) + + hidden_states = self.activation(hidden_states) + return hidden_states + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GroupNormConvLayer with Wav2Vec2->SpeechT5 +class SpeechT5GroupNormConvLayer(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 + self.out_conv_dim = config.conv_dim[layer_id] + + self.conv = nn.Conv1d( + self.in_conv_dim, + self.out_conv_dim, + kernel_size=config.conv_kernel[layer_id], + stride=config.conv_stride[layer_id], + bias=config.conv_bias, + ) + self.activation = ACT2FN[config.feat_extract_activation] + + self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True) + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + hidden_states = self.layer_norm(hidden_states) + hidden_states = self.activation(hidden_states) + return hidden_states + + +# Copied from transformers.models.speech_to_text.modeling_speech_to_text.Speech2TextSinusoidalPositionalEmbedding with Speech2Text->SpeechT5 +class SpeechT5SinusoidalPositionalEmbedding(nn.Module): + """This module produces sinusoidal positional embeddings of any length.""" + + def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None): + super().__init__() + self.offset = 2 + self.embedding_dim = embedding_dim + self.padding_idx = padding_idx + self.make_weights(num_positions + self.offset, embedding_dim, padding_idx) + + def make_weights(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None): + emb_weights = self.get_embedding(num_embeddings, embedding_dim, padding_idx) + if hasattr(self, "weights"): + # in forward put the weights on the correct dtype and device of the param + emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device) + + self.weights = nn.Parameter(emb_weights) + self.weights.requires_grad = False + self.weights.detach_() + + @staticmethod + def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None): + """ + Build sinusoidal embeddings. This matches the implementation in tensor2tensor, but differs slightly from the + description in Section 3.5 of "Attention Is All You Need". + """ + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb) + emb = torch.arange(num_embeddings, dtype=torch.int64).float().unsqueeze(1) * emb.unsqueeze(0) + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1) + if embedding_dim % 2 == 1: + # zero pad + emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) + if padding_idx is not None: + emb[padding_idx, :] = 0 + return emb.to(torch.get_default_dtype()) + + @torch.no_grad() + def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0): + bsz, seq_len = input_ids.size() + # Create the position ids from the input token ids. Any padded tokens remain padded. + position_ids = self.create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length).to( + input_ids.device + ) + + # expand embeddings if needed + max_pos = self.padding_idx + 1 + seq_len + if max_pos > self.weights.size(0): + self.make_weights(max_pos + self.offset, self.embedding_dim, self.padding_idx) + + return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, -1).detach() + + def create_position_ids_from_input_ids( + self, input_ids: torch.Tensor, padding_idx: int, past_key_values_length: Optional[int] = 0 + ): + """ + Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding + symbols are ignored. This is modified from fairseq's `utils.make_positions`. + + Args: + x: torch.Tensor x: + Returns: torch.Tensor + """ + # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. + mask = input_ids.ne(padding_idx).int() + incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask + return incremental_indices.long() + padding_idx + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PositionalConvEmbedding with Wav2Vec2->SpeechT5 +class SpeechT5PositionalConvEmbedding(nn.Module): + def __init__(self, config): + super().__init__() + self.conv = nn.Conv1d( + config.hidden_size, + config.hidden_size, + kernel_size=config.num_conv_pos_embeddings, + padding=config.num_conv_pos_embeddings // 2, + groups=config.num_conv_pos_embedding_groups, + ) + + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + + if is_deepspeed_zero3_enabled(): + import deepspeed + + with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0): + self.conv = weight_norm(self.conv, name="weight", dim=2) + if hasattr(self.conv, "parametrizations"): + weight_g = self.conv.parametrizations.weight.original0 + weight_v = self.conv.parametrizations.weight.original1 + else: + weight_g = self.conv.weight_g + weight_v = self.conv.weight_v + deepspeed.zero.register_external_parameter(self, weight_v) + deepspeed.zero.register_external_parameter(self, weight_g) + else: + self.conv = weight_norm(self.conv, name="weight", dim=2) + + self.padding = SpeechT5SamePadLayer(config.num_conv_pos_embeddings) + self.activation = ACT2FN[config.feat_extract_activation] + + def forward(self, hidden_states): + hidden_states = hidden_states.transpose(1, 2) + + hidden_states = self.conv(hidden_states) + hidden_states = self.padding(hidden_states) + hidden_states = self.activation(hidden_states) + + hidden_states = hidden_states.transpose(1, 2) + return hidden_states + + +class SpeechT5ScaledPositionalEncoding(nn.Module): + """ + Scaled positional encoding, see §3.2 in https://arxiv.org/abs/1809.08895 + """ + + def __init__(self, dropout, dim, max_len=5000): + pe = torch.zeros(max_len, dim) + position = torch.arange(0, max_len).unsqueeze(1) + div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.int64).float() * -(math.log(10000.0) / dim))) + pe[:, 0::2] = torch.sin(position.float() * div_term) + pe[:, 1::2] = torch.cos(position.float() * div_term) + pe = pe.unsqueeze(0) + super().__init__() + self.register_buffer("pe", pe, persistent=False) + self.dropout = nn.Dropout(p=dropout) + self.dim = dim + self.alpha = torch.nn.Parameter(torch.tensor(1.0)) + + def forward(self, emb): + emb = emb + self.alpha * self.pe[:, : emb.size(1)] + emb = self.dropout(emb) + return emb + + +class SpeechT5RelativePositionalEncoding(torch.nn.Module): + def __init__(self, dim, max_length=1000): + super().__init__() + self.dim = dim + self.max_length = max_length + self.pe_k = torch.nn.Embedding(2 * max_length, dim) + + def forward(self, hidden_states): + seq_len = hidden_states.shape[1] + pos_seq = torch.arange(0, seq_len).to(device=hidden_states.device, dtype=torch.long) + pos_seq = pos_seq[:, None] - pos_seq[None, :] + + pos_seq[pos_seq < -self.max_length] = -self.max_length + pos_seq[pos_seq >= self.max_length] = self.max_length - 1 + pos_seq = pos_seq + self.max_length + + return self.pe_k(pos_seq) + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2SamePadLayer with Wav2Vec2->SpeechT5 +class SpeechT5SamePadLayer(nn.Module): + def __init__(self, num_conv_pos_embeddings): + super().__init__() + self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0 + + def forward(self, hidden_states): + if self.num_pad_remove > 0: + hidden_states = hidden_states[:, :, : -self.num_pad_remove] + return hidden_states + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureEncoder with Wav2Vec2->SpeechT5 +class SpeechT5FeatureEncoder(nn.Module): + """Construct the features from raw audio waveform""" + + def __init__(self, config): + super().__init__() + + if config.feat_extract_norm == "group": + conv_layers = [SpeechT5GroupNormConvLayer(config, layer_id=0)] + [ + SpeechT5NoLayerNormConvLayer(config, layer_id=i + 1) for i in range(config.num_feat_extract_layers - 1) + ] + elif config.feat_extract_norm == "layer": + conv_layers = [ + SpeechT5LayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers) + ] + else: + raise ValueError( + f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']" + ) + self.conv_layers = nn.ModuleList(conv_layers) + self.gradient_checkpointing = False + self._requires_grad = True + + def _freeze_parameters(self): + for param in self.parameters(): + param.requires_grad = False + self._requires_grad = False + + def forward(self, input_values): + hidden_states = input_values[:, None] + + # make sure hidden_states require grad for gradient_checkpointing + if self._requires_grad and self.training: + hidden_states.requires_grad = True + + for conv_layer in self.conv_layers: + if self._requires_grad and self.gradient_checkpointing and self.training: + hidden_states = self._gradient_checkpointing_func( + conv_layer.__call__, + hidden_states, + ) + else: + hidden_states = conv_layer(hidden_states) + + return hidden_states + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureProjection with Wav2Vec2->SpeechT5 +class SpeechT5FeatureProjection(nn.Module): + def __init__(self, config): + super().__init__() + self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps) + self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size) + self.dropout = nn.Dropout(config.feat_proj_dropout) + + def forward(self, hidden_states): + # non-projected hidden states are needed for quantization + norm_hidden_states = self.layer_norm(hidden_states) + hidden_states = self.projection(norm_hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states, norm_hidden_states + + +class SpeechT5SpeechEncoderPrenet(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.feature_encoder = SpeechT5FeatureEncoder(config) + self.feature_projection = SpeechT5FeatureProjection(config) + + # model only needs masking vector if mask prob is > 0.0 + if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0: + self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_()) + + self.pos_conv_embed = SpeechT5PositionalConvEmbedding(config) + self.pos_sinusoidal_embed = SpeechT5SinusoidalPositionalEmbedding( + config.max_speech_positions + config.pad_token_id + 1, + config.hidden_size, + config.pad_token_id, + ) + + def freeze_feature_encoder(self): + self.feature_encoder._freeze_parameters() + + def forward( + self, + input_values: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + mask_time_indices: Optional[torch.FloatTensor] = None, + ): + extract_features = self.feature_encoder(input_values) + extract_features = extract_features.transpose(1, 2) + + if attention_mask is not None: + # compute reduced attention_mask corresponding to feature vectors + attention_mask = self._get_feature_vector_attention_mask( + extract_features.shape[1], + attention_mask, + ) + + hidden_states, extract_features = self.feature_projection(extract_features) + hidden_states = self._mask_hidden_states( + hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask + ) + + positional_conv_embedding = self.pos_conv_embed(hidden_states) + hidden_states = hidden_states + positional_conv_embedding + + if attention_mask is not None: + padding_mask = attention_mask.ne(1).long() + else: + padding_mask = torch.zeros(hidden_states.shape[:2], dtype=torch.long, device=hidden_states.device) + + positional_sinusoidal_embeddings = self.pos_sinusoidal_embed(padding_mask) + hidden_states = hidden_states + positional_sinusoidal_embeddings + + return hidden_states, attention_mask + + # Copied from transformers.models.unispeech.modeling_unispeech.UniSpeechPreTrainedModel._get_feature_vector_attention_mask + def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor): + # Effectively attention_mask.sum(-1), but not inplace to be able to run + # on inference mode. + non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1] + output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths).to(torch.long) + batch_size = attention_mask.shape[0] + + attention_mask = torch.zeros( + (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device + ) + # these two operations makes sure that all values before the output lengths idxs are attended to + attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1 + attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() + return attention_mask + + # Copied from transformers.models.unispeech.modeling_unispeech.UniSpeechPreTrainedModel._get_feat_extract_output_lengths + def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): + """ + Computes the output length of the convolutional layers + """ + + def _conv_out_length(input_length, kernel_size, stride): + # 1D convolutional layer output length formula taken + # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html + return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1 + + for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): + input_lengths = _conv_out_length(input_lengths, kernel_size, stride) + + return input_lengths + + # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states + def _mask_hidden_states( + self, + hidden_states: torch.FloatTensor, + mask_time_indices: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + ): + """ + Masks extracted features along time axis and/or along feature axis according to + [SpecAugment](https://arxiv.org/abs/1904.08779). + """ + + # `config.apply_spec_augment` can set masking to False + if not getattr(self.config, "apply_spec_augment", True): + return hidden_states + + # generate indices & apply SpecAugment along time axis + batch_size, sequence_length, hidden_size = hidden_states.size() + + if mask_time_indices is not None: + # apply SpecAugment along time axis with given mask_time_indices + hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) + elif self.config.mask_time_prob > 0 and self.training: + mask_time_indices = _compute_mask_indices( + (batch_size, sequence_length), + mask_prob=self.config.mask_time_prob, + mask_length=self.config.mask_time_length, + attention_mask=attention_mask, + min_masks=self.config.mask_time_min_masks, + ) + mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool) + hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) + + if self.config.mask_feature_prob > 0 and self.training: + # generate indices & apply SpecAugment along feature axis + mask_feature_indices = _compute_mask_indices( + (batch_size, hidden_size), + mask_prob=self.config.mask_feature_prob, + mask_length=self.config.mask_feature_length, + min_masks=self.config.mask_feature_min_masks, + ) + mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool) + mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1) + hidden_states[mask_feature_indices] = 0 + + return hidden_states + + +class SpeechT5SpeechDecoderPrenet(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + + self.layers = nn.ModuleList( + [ + nn.Linear( + config.num_mel_bins if i == 0 else config.speech_decoder_prenet_units, + config.speech_decoder_prenet_units, + ) + for i in range(config.speech_decoder_prenet_layers) + ] + ) + + self.final_layer = nn.Linear(config.speech_decoder_prenet_units, config.hidden_size) + self.encode_positions = SpeechT5ScaledPositionalEncoding( + config.positional_dropout, + config.hidden_size, + config.max_speech_positions, + ) + self.speaker_embeds_layer = nn.Linear(config.speaker_embedding_dim + config.hidden_size, config.hidden_size) + + def _consistent_dropout(self, inputs_embeds, p): + mask = torch.bernoulli(inputs_embeds[0], p=p) + all_masks = mask.unsqueeze(0).repeat(inputs_embeds.size(0), 1, 1) + return torch.where(all_masks == 1, inputs_embeds, 0) * 1 / (1 - p) + + def forward( + self, + input_values: torch.Tensor, + speaker_embeddings: Optional[torch.Tensor] = None, + ): + # Dropout is always applied, even when evaluating. See §2.2 in https://arxiv.org/abs/1712.05884. + + inputs_embeds = input_values + for layer in self.layers: + inputs_embeds = nn.functional.relu(layer(inputs_embeds)) + inputs_embeds = self._consistent_dropout(inputs_embeds, self.config.speech_decoder_prenet_dropout) + + inputs_embeds = self.final_layer(inputs_embeds) + inputs_embeds = self.encode_positions(inputs_embeds) + + if speaker_embeddings is not None: + speaker_embeddings = nn.functional.normalize(speaker_embeddings) + speaker_embeddings = speaker_embeddings.unsqueeze(1).expand(-1, inputs_embeds.size(1), -1) + inputs_embeds = torch.cat([inputs_embeds, speaker_embeddings], dim=-1) + inputs_embeds = nn.functional.relu(self.speaker_embeds_layer(inputs_embeds)) + + return inputs_embeds + + +class SpeechT5BatchNormConvLayer(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + + if layer_id == 0: + in_conv_dim = config.num_mel_bins + else: + in_conv_dim = config.speech_decoder_postnet_units + + if layer_id == config.speech_decoder_postnet_layers - 1: + out_conv_dim = config.num_mel_bins + else: + out_conv_dim = config.speech_decoder_postnet_units + + self.conv = nn.Conv1d( + in_conv_dim, + out_conv_dim, + kernel_size=config.speech_decoder_postnet_kernel, + stride=1, + padding=(config.speech_decoder_postnet_kernel - 1) // 2, + bias=False, + ) + self.batch_norm = nn.BatchNorm1d(out_conv_dim) + + if layer_id < config.speech_decoder_postnet_layers - 1: + self.activation = nn.Tanh() + else: + self.activation = None + + self.dropout = nn.Dropout(config.speech_decoder_postnet_dropout) + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + hidden_states = self.batch_norm(hidden_states) + if self.activation is not None: + hidden_states = self.activation(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +class SpeechT5SpeechDecoderPostnet(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + + self.feat_out = nn.Linear(config.hidden_size, config.num_mel_bins * config.reduction_factor) + self.prob_out = nn.Linear(config.hidden_size, config.reduction_factor) + + self.layers = nn.ModuleList( + [SpeechT5BatchNormConvLayer(config, i) for i in range(config.speech_decoder_postnet_layers)] + ) + + def forward(self, hidden_states: torch.Tensor): + outputs_before_postnet = self.feat_out(hidden_states).view(hidden_states.size(0), -1, self.config.num_mel_bins) + outputs_after_postnet = self.postnet(outputs_before_postnet) + logits = self.prob_out(hidden_states).view(hidden_states.size(0), -1) + return outputs_before_postnet, outputs_after_postnet, logits + + def postnet(self, hidden_states: torch.Tensor): + layer_output = hidden_states.transpose(1, 2) + for layer in self.layers: + layer_output = layer(layer_output) + return hidden_states + layer_output.transpose(1, 2) + + +class SpeechT5TextEncoderPrenet(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) + self.encode_positions = SpeechT5ScaledPositionalEncoding( + config.positional_dropout, + config.hidden_size, + config.max_text_positions, + ) + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward(self, input_ids: torch.Tensor): + inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = self.encode_positions(inputs_embeds) + return inputs_embeds + + +class SpeechT5TextDecoderPrenet(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.dropout = nn.Dropout(config.positional_dropout) + self.embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0 + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) + + self.embed_positions = SpeechT5SinusoidalPositionalEmbedding( + config.max_text_positions + config.pad_token_id + 1, + config.hidden_size, + config.pad_token_id, + ) + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + ): + if input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + else: + raise ValueError("You have to specify `decoder_input_ids`") + + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + positions = self.embed_positions(input_ids, past_key_values_length) + + inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + inputs_embeds += positions + inputs_embeds = self.dropout(inputs_embeds) + + return inputs_embeds, attention_mask + + +class SpeechT5TextDecoderPostnet(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + def forward(self, hidden_states: torch.Tensor): + return self.lm_head(hidden_states) + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + +class SpeechT5Attention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper with relative position bias (see + https://aclanthology.org/N18-2074.pdf) + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + position_bias: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + # relative attention bias + if position_bias is not None: + reshape_q = query_states.contiguous().view(bsz * self.num_heads, -1, self.head_dim).transpose(0, 1) + rel_pos_bias = torch.matmul(reshape_q, position_bias.transpose(-2, -1)) + rel_pos_bias = rel_pos_bias.transpose(0, 1).view( + bsz * self.num_heads, position_bias.size(0), position_bias.size(1) + ) + attn_weights += rel_pos_bias + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned aross GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value + + +class SpeechT5FeedForward(nn.Module): + def __init__(self, config, intermediate_size): + super().__init__() + self.intermediate_dropout = nn.Dropout(config.activation_dropout) + + self.intermediate_dense = nn.Linear(config.hidden_size, intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + self.output_dense = nn.Linear(intermediate_size, config.hidden_size) + self.output_dropout = nn.Dropout(config.hidden_dropout) + + def forward(self, hidden_states): + hidden_states = self.intermediate_dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + hidden_states = self.intermediate_dropout(hidden_states) + + hidden_states = self.output_dense(hidden_states) + hidden_states = self.output_dropout(hidden_states) + return hidden_states + + +class SpeechT5EncoderLayer(nn.Module): + def __init__(self, config: SpeechT5Config): + super().__init__() + self.attention = SpeechT5Attention( + embed_dim=config.hidden_size, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=False, + ) + self.dropout = nn.Dropout(config.hidden_dropout) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.feed_forward = SpeechT5FeedForward(config, config.encoder_ffn_dim) + self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + position_bias: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ): + """ + Args: + hidden_states (`torch.FloatTensor`): + input to the layer of shape `(batch, seq_len, hidden_size)` + attention_mask (`torch.FloatTensor`): + attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very + large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(config.encoder_attention_heads,)`. + position_bias (`torch.FloatTensor`): + relative position embeddings of size `(seq_len, seq_len, hidden_size // encoder_attention_heads)` + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states, attn_weights, _ = self.attention( + hidden_states=hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + position_bias=position_bias, + output_attentions=output_attentions, + ) + + hidden_states = self.dropout(hidden_states) + hidden_states = residual + hidden_states + + hidden_states = self.layer_norm(hidden_states) + hidden_states = hidden_states + self.feed_forward(hidden_states) + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class SpeechT5DecoderLayer(nn.Module): + def __init__(self, config: SpeechT5Config): + super().__init__() + self.self_attn = SpeechT5Attention( + embed_dim=config.hidden_size, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + ) + self.dropout = nn.Dropout(config.hidden_dropout) + self.self_attn_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.encoder_attn = SpeechT5Attention( + config.hidden_size, + config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + ) + self.encoder_attn_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.feed_forward = SpeechT5FeedForward(config, config.decoder_ffn_dim) + self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + cross_attn_layer_head_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = True, + ): + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, hidden_size)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + encoder_hidden_states (`torch.FloatTensor`): + cross attention input to the layer of shape `(batch, seq_len, hidden_size)` + encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of + size `(decoder_attention_heads,)`. + past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + # Self Attention + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + # add present self-attn cache to positions 1,2 of present_key_value tuple + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=self_attn_past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = self.dropout(hidden_states) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Cross-Attention Block + cross_attn_present_key_value = None + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + + # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=cross_attn_past_key_value, + output_attentions=output_attentions, + ) + hidden_states = self.dropout(hidden_states) + hidden_states = residual + hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + # add cross-attn to positions 3,4 of present_key_value tuple + present_key_value = present_key_value + cross_attn_present_key_value + + # Fully Connected + hidden_states = hidden_states + self.feed_forward(hidden_states) + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class SpeechT5PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = SpeechT5Config + base_model_prefix = "speecht5" + main_input_name = "input_values" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, SpeechT5PositionalConvEmbedding): + nn.init.normal_( + module.conv.weight, + mean=0, + std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)), + ) + nn.init.constant_(module.conv.bias, 0) + elif isinstance(module, SpeechT5FeatureProjection): + k = math.sqrt(1 / module.projection.in_features) + nn.init.uniform_(module.projection.weight, a=-k, b=k) + nn.init.uniform_(module.projection.bias, a=-k, b=k) + elif isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, nn.Conv1d): + nn.init.kaiming_normal_(module.weight) + if module.bias is not None: + k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) + nn.init.uniform_(module.bias, a=-k, b=k) + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +class SpeechT5Encoder(SpeechT5PreTrainedModel): + """ + Transformer encoder consisting of *config.encoder_layers* layers. Each layer is a [`SpeechT5EncoderLayer`]. + """ + + def __init__(self, config: SpeechT5Config): + super().__init__(config) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout) + self.layerdrop = config.encoder_layerdrop + + self.layers = nn.ModuleList([SpeechT5EncoderLayer(config) for _ in range(config.encoder_layers)]) + + self.embed_positions = SpeechT5RelativePositionalEncoding( + config.hidden_size // config.encoder_attention_heads, config.encoder_max_relative_position + ) + + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + """ + Args: + hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, feature_size)`): + Features extracted from the speech or text input by the encoder prenet. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing convolution and attention on padding token indices. Mask values selected in + `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # expand attention_mask + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) + + hidden_states = self.layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states) + + position_bias = self.embed_positions(hidden_states) + + synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) + + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + if head_mask.size()[0] != len(self.layers): + raise ValueError( + f"The head_mask should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + skip_the_layer = False + if self.training: + dropout_probability = torch.rand([]) + skip_the_layer = dropout_probability < self.layerdrop + + if not skip_the_layer or synced_gpus: + # under fsdp or deepspeed zero3 all gpus must run in sync + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + attention_mask, + (head_mask[idx] if head_mask is not None else None), + position_bias, + output_attentions, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) + hidden_states = layer_outputs[0] + + if skip_the_layer: + layer_outputs = (None, None) + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class SpeechT5EncoderWithSpeechPrenet(SpeechT5PreTrainedModel): + """ + Wrapper around SpeechT5Encoder that applies SpeechT5SpeechEncoderPrenet to convert the audio waveform data to + hidden features. + """ + + def __init__(self, config: SpeechT5Config): + super().__init__(config) + self.prenet = SpeechT5SpeechEncoderPrenet(config) + self.wrapped_encoder = SpeechT5Encoder(config) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_values: torch.FloatTensor, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + hidden_states, attention_mask = self.prenet(input_values, attention_mask) + + outputs = self.wrapped_encoder( + hidden_states=hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + return outputs + + +class SpeechT5EncoderWithTextPrenet(SpeechT5PreTrainedModel): + """ + Wrapper around SpeechT5Encoder that applies SpeechT5TextEncoderPrenet to convert the input_ids to hidden features. + """ + + def __init__(self, config: SpeechT5Config): + super().__init__(config) + self.prenet = SpeechT5TextEncoderPrenet(config) + self.wrapped_encoder = SpeechT5Encoder(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.prenet.get_input_embeddings() + + def set_input_embeddings(self, value): + self.prenet.set_input_embeddings(value) + + def forward( + self, + input_values: torch.FloatTensor, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + hidden_states = self.prenet(input_values) + + outputs = self.wrapped_encoder( + hidden_states=hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + return outputs + + +class SpeechT5EncoderWithoutPrenet(SpeechT5PreTrainedModel): + """ + This wrapper class is a helper class to correctly load pretrained checkpoints when used in combination with + [`SpeechT5Model`]. + """ + + def __init__(self, config: SpeechT5Config): + super().__init__(config) + self.wrapped_encoder = SpeechT5Encoder(config) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_values: torch.FloatTensor, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + return self.wrapped_encoder( + hidden_states=input_values, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +class SpeechT5Decoder(SpeechT5PreTrainedModel): + """ + Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`SpeechT5DecoderLayer`] + """ + + def __init__(self, config: SpeechT5Config): + super().__init__(config) + self.layerdrop = config.decoder_layerdrop + + self.layers = nn.ModuleList([SpeechT5DecoderLayer(config) for _ in range(config.decoder_layers)]) + + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + r""" + Args: + hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, feature_size)`): + Features extracted from the speech or text input by the decoder prenet. + attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): + Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values + selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing + cross-attention on hidden heads. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + input_shape = hidden_states.size()[:-1] + + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, hidden_states, past_key_values_length + ) + + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, hidden_states.dtype, tgt_len=input_shape[-1] + ) + + synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + next_decoder_cache = () if use_cache else None + + # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): + if attn_mask is not None: + if attn_mask.size()[0] != (len(self.layers)): + raise ValueError( + f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + skip_the_layer = False + if self.training: + dropout_probability = torch.rand([]) + skip_the_layer = dropout_probability < self.layerdrop + if skip_the_layer and not synced_gpus: + continue + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + head_mask[idx] if head_mask is not None else None, + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, + None, + output_attentions, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=( + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None + ), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attentions, all_cross_attentions] + if v is not None + ) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class SpeechT5DecoderWithSpeechPrenet(SpeechT5PreTrainedModel): + """ + Wrapper around SpeechT5Decoder that applies SpeechT5SpeechDecoderPrenet to convert log-mel filterbanks to hidden + features. + """ + + def __init__(self, config: SpeechT5Config): + super().__init__(config) + self.prenet = SpeechT5SpeechDecoderPrenet(config) + self.wrapped_decoder = SpeechT5Decoder(config) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.LongTensor] = None, + speaker_embeddings: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + decoder_hidden_states = self.prenet(input_values, speaker_embeddings) + + outputs = self.wrapped_decoder( + hidden_states=decoder_hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + return outputs + + +class SpeechT5DecoderWithTextPrenet(SpeechT5PreTrainedModel): + """ + Wrapper around SpeechT5Decoder that applies SpeechT5TextDecoderPrenet to convert input tokens to hidden features. + """ + + def __init__(self, config: SpeechT5Config): + super().__init__(config) + self.prenet = SpeechT5TextDecoderPrenet(config) + self.wrapped_decoder = SpeechT5Decoder(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.prenet.get_input_embeddings() + + def set_input_embeddings(self, value): + self.prenet.set_input_embeddings(value) + + def forward( + self, + input_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + decoder_hidden_states, attention_mask = self.prenet(input_values, attention_mask, past_key_values) + + outputs = self.wrapped_decoder( + hidden_states=decoder_hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + return outputs + + +class SpeechT5DecoderWithoutPrenet(SpeechT5PreTrainedModel): + """ + This wrapper class is a helper class to correctly load pretrained checkpoints when used in combination with + [`SpeechT5Model`]. + """ + + def __init__(self, config: SpeechT5Config): + super().__init__(config) + self.wrapped_decoder = SpeechT5Decoder(config) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + outputs = self.wrapped_decoder( + hidden_states=input_values, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + return outputs + + +class SpeechT5GuidedMultiheadAttentionLoss(nn.Module): + """ + Guided attention loss from the paper [Efficiently Trainable Text-to-Speech System Based on Deep Convolutional + Networks with Guided Attention](https://arxiv.org/abs/1710.08969), adapted for multi-head attention. + """ + + def __init__(self, config: SpeechT5Config): + super().__init__() + self.sigma = config.guided_attention_loss_sigma + self.scale = config.guided_attention_loss_scale + + def forward( + self, attentions: torch.FloatTensor, input_masks: torch.BoolTensor, output_masks: torch.BoolTensor + ) -> torch.Tensor: + """ + Compute the attention loss. + + Args: + attentions (`torch.FloatTensor` of shape `(batch_size, layers * heads, output_sequence_length, input_sequence_length)`): + Batch of multi-head attention weights + input_masks (`torch.BoolTensor` of shape `(batch_size, input_sequence_length)`): + Input attention mask as booleans. + output_masks (`torch.BoolTensor` of shape `(batch_size, output_sequence_length)`): + Target attention mask as booleans. + + Returns: + `torch.Tensor` with the loss value + """ + guided_attn_masks = self._make_guided_attention_masks(input_masks, output_masks, attentions.device) + masks = output_masks.unsqueeze(-1) & input_masks.unsqueeze(-2) + masks = masks.to(attentions.device).unsqueeze(1) + + losses = guided_attn_masks * attentions + loss = torch.mean(losses.masked_select(masks)) + return self.scale * loss + + def _make_guided_attention_masks(self, input_masks, output_masks, device): + input_lengths = input_masks.sum(-1) + output_lengths = output_masks.sum(-1) + + guided_attn_masks = torch.zeros((len(input_masks), output_masks.shape[1], input_masks.shape[1]), device=device) + + for idx, (ilen, olen) in enumerate(zip(input_lengths, output_lengths)): + guided_attn_masks[idx, :olen, :ilen] = self._make_guided_attention_mask(ilen, olen, self.sigma, device) + + return guided_attn_masks.unsqueeze(1) + + @staticmethod + def _make_guided_attention_mask(input_length, output_length, sigma, device): + grid_y, grid_x = torch.meshgrid( + torch.arange(input_length, device=device), + torch.arange(output_length, device=device), + indexing="xy", + ) + grid_x = grid_x.float() / output_length + grid_y = grid_y.float() / input_length + return 1.0 - torch.exp(-((grid_y - grid_x) ** 2) / (2 * (sigma**2))) + + +class SpeechT5SpectrogramLoss(nn.Module): + """ + Loss computation used by SpeechT5ForTextToSpeech. + """ + + def __init__(self, config: SpeechT5Config): + super().__init__() + self.use_guided_attention_loss = config.use_guided_attention_loss + self.guided_attention_loss_num_heads = config.guided_attention_loss_num_heads + self.reduction_factor = config.reduction_factor + + self.l1_criterion = L1Loss() + self.bce_criterion = BCEWithLogitsLoss(pos_weight=torch.tensor(5.0)) + + if self.use_guided_attention_loss: + self.attn_criterion = SpeechT5GuidedMultiheadAttentionLoss(config) + + def forward( + self, + attention_mask: torch.LongTensor, + outputs_before_postnet: torch.FloatTensor, + outputs_after_postnet: torch.FloatTensor, + logits: torch.FloatTensor, + labels: torch.FloatTensor, + cross_attentions: Optional[torch.FloatTensor] = None, + ) -> torch.Tensor: + padding_mask = labels != -100.0 + + # mask out the padded portions + labels = labels.masked_select(padding_mask) + outputs_before_postnet = outputs_before_postnet.masked_select(padding_mask) + outputs_after_postnet = outputs_after_postnet.masked_select(padding_mask) + + # spectrogram loss + l1_loss = self.l1_criterion(outputs_after_postnet, labels) + self.l1_criterion(outputs_before_postnet, labels) + + # construct stop labels from the padding mask + masks = padding_mask[:, :, 0] + stop_labels = torch.cat([~masks * 1.0, torch.ones(masks.size(0), 1).to(masks.device)], dim=1) + stop_labels = stop_labels[:, 1:].masked_select(masks) + logits = logits.masked_select(masks) + + # stop token loss + bce_loss = self.bce_criterion(logits, stop_labels) + + # combined loss + loss = l1_loss + bce_loss + + # guided attention loss + if self.use_guided_attention_loss: + attn = torch.cat([x[:, : self.guided_attention_loss_num_heads] for x in cross_attentions], dim=1) + input_masks = attention_mask == 1 + output_masks = padding_mask[:, :, 0] + if self.reduction_factor > 1: + output_masks = output_masks[:, self.reduction_factor - 1 :: self.reduction_factor] + attn_loss = self.attn_criterion(attn, input_masks, output_masks) + loss += attn_loss + + return loss + + +SPEECHT5_BASE_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`SpeechT5Config`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. + encoder ([`SpeechT5EncoderWithSpeechPrenet`] or [`SpeechT5EncoderWithTextPrenet`] or `None`): + The Transformer encoder module that applies the appropiate speech or text encoder prenet. If `None`, + [`SpeechT5EncoderWithoutPrenet`] will be used and the `input_values` are assumed to be hidden states. + decoder ([`SpeechT5DecoderWithSpeechPrenet`] or [`SpeechT5DecoderWithTextPrenet`] or `None`): + The Transformer decoder module that applies the appropiate speech or text decoder prenet. If `None`, + [`SpeechT5DecoderWithoutPrenet`] will be used and the `decoder_input_values` are assumed to be hidden + states. +""" + + +SPEECHT5_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`SpeechT5Config`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +SPEECHT5_INPUTS_DOCSTRING = r""" + Args: + attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0, + 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + + + `attention_mask` should only be passed if the corresponding processor has `config.return_attention_mask == + True`. For all models whose processor has `config.return_attention_mask == False`, `attention_mask` should + **not** be passed to avoid degraded performance when doing batched inference. For such models + `input_values` should simply be padded with 0 and passed without `attention_mask`. Be aware that these + models also yield slightly different results depending on whether `input_values` is padded or not. + + + + decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_values`. Causal mask will + also be used by default. + + If you want to change padding behavior, you should read [`SpeechT5Decoder._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + head_mask (`torch.FloatTensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + decoder_head_mask (`torch.FloatTensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): + Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of + hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_values` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_values` of shape `(batch_size, sequence_length)`. decoder_inputs_embeds (`torch.FloatTensor` + of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): Optionally, instead of passing + `decoder_input_values` you can choose to directly pass an embedded representation. If `past_key_values` is + used, optionally only the last `decoder_inputs_embeds` have to be input (see `past_key_values`). This is + useful if you want more control over how to convert `decoder_input_values` indices into associated vectors + than the model's internal embedding lookup matrix. + + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare SpeechT5 Encoder-Decoder Model outputting raw hidden-states without any specific pre- or post-nets.", + SPEECHT5_BASE_START_DOCSTRING, +) +class SpeechT5Model(SpeechT5PreTrainedModel): + def __init__( + self, + config: SpeechT5Config, + encoder: Optional[nn.Module] = None, + decoder: Optional[nn.Module] = None, + ): + super().__init__(config) + self.config = config + self.encoder = SpeechT5EncoderWithoutPrenet(config) if encoder is None else encoder + self.decoder = SpeechT5DecoderWithoutPrenet(config) if decoder is None else decoder + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + if isinstance(self.encoder, SpeechT5EncoderWithTextPrenet): + return self.encoder.get_input_embeddings() + if isinstance(self.decoder, SpeechT5DecoderWithTextPrenet): + return self.decoder.get_input_embeddings() + raise NotImplementedError + + def set_input_embeddings(self, value): + if isinstance(self.encoder, SpeechT5EncoderWithTextPrenet): + self.encoder.set_input_embeddings(value) + if isinstance(self.decoder, SpeechT5DecoderWithTextPrenet): + self.decoder.set_input_embeddings(value) + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + def freeze_feature_encoder(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameter will + not be updated during training. + """ + if isinstance(self.encoder, SpeechT5EncoderWithSpeechPrenet): + self.encoder.prenet.freeze_feature_encoder() + + @add_start_docstrings_to_model_forward(SPEECHT5_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_values: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + decoder_input_values: Optional[torch.Tensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + speaker_embeddings: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]: + r""" + input_values (`torch.Tensor` of shape `(batch_size, sequence_length)`): + Depending on which encoder is being used, the `input_values` are either: float values of the input raw + speech waveform, or indices of input sequence tokens in the vocabulary, or hidden states. + + decoder_input_values (`torch.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Depending on which decoder is being used, the `decoder_input_values` are either: float values of log-mel + filterbank features extracted from the raw speech waveform, or indices of decoder input sequence tokens in + the vocabulary, or hidden states. + + speaker_embeddings (`torch.FloatTensor` of shape `(batch_size, config.speaker_embedding_dim)`, *optional*): + Tensor containing the speaker embeddings. + + Returns: + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_values=input_values, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + # downsample encoder attention mask (only for encoders with speech input) + if attention_mask is not None and isinstance(self.encoder, SpeechT5EncoderWithSpeechPrenet): + encoder_attention_mask = self.encoder.prenet._get_feature_vector_attention_mask( + encoder_outputs[0].shape[1], attention_mask + ) + else: + encoder_attention_mask = attention_mask + + if isinstance(self.decoder, SpeechT5DecoderWithSpeechPrenet): + decoder_args = {"speaker_embeddings": speaker_embeddings} + else: + decoder_args = {} + + decoder_outputs = self.decoder( + input_values=decoder_input_values, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=encoder_attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **decoder_args, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + """SpeechT5 Model with a speech encoder and a text decoder.""", + SPEECHT5_START_DOCSTRING, +) +class SpeechT5ForSpeechToText(SpeechT5PreTrainedModel, GenerationMixin): + _tied_weights_keys = ["text_decoder_postnet.lm_head.weight"] + + def __init__(self, config: SpeechT5Config): + super().__init__(config) + + if config.vocab_size is None: + raise ValueError( + f"You are trying to instantiate {self.__class__} with a configuration that does not define the" + " vocabulary size of the language model head. Please instantiate the model as follows:" + " `SpeechT5ForSpeechToText.from_pretrained(..., vocab_size=vocab_size)`. or define `vocab_size` of" + " your model's configuration." + ) + + speech_encoder = SpeechT5EncoderWithSpeechPrenet(config) + text_decoder = SpeechT5DecoderWithTextPrenet(config) + self.speecht5 = SpeechT5Model(config, speech_encoder, text_decoder) + + self.text_decoder_postnet = SpeechT5TextDecoderPostnet(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_encoder(self): + return self.speecht5.get_encoder() + + def get_decoder(self): + return self.speecht5.get_decoder() + + def freeze_feature_encoder(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameter will + not be updated during training. + """ + self.get_encoder().prenet.freeze_feature_encoder() + + def get_output_embeddings(self): + return self.text_decoder_postnet.get_output_embeddings() + + def set_output_embeddings(self, new_embeddings): + self.text_decoder_postnet.set_output_embeddings(new_embeddings) + + @add_start_docstrings_to_model_forward(SPEECHT5_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, Seq2SeqLMOutput]: + r""" + input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Float values of input raw speech waveform. Values can be obtained by loading a *.flac* or *.wav* audio file + into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (*pip install + soundfile*). To prepare the array into `input_values`, the [`SpeechT5Processor`] should be used for padding + and conversion into a tensor of type `torch.FloatTensor`. See [`SpeechT5Processor.__call__`] for details. + + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`SpeechT5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + SpeechT5 uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If + `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` + or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is + only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Label indices can be obtained using [`SpeechT5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + Returns: + + Example: + + ```python + >>> from transformers import SpeechT5Processor, SpeechT5ForSpeechToText + >>> from datasets import load_dataset + + >>> dataset = load_dataset( + ... "hf-internal-testing/librispeech_asr_demo", "clean", split="validation", trust_remote_code=True + ... ) # doctest: +IGNORE_RESULT + >>> dataset = dataset.sort("id") + >>> sampling_rate = dataset.features["audio"].sampling_rate + + >>> processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_asr") + >>> model = SpeechT5ForSpeechToText.from_pretrained("microsoft/speecht5_asr") + + >>> # audio file is decoded on the fly + >>> inputs = processor(audio=dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt") + >>> predicted_ids = model.generate(**inputs, max_length=100) + + >>> # transcribe speech + >>> transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True) + >>> transcription[0] + 'mister quilter is the apostle of the middle classes and we are glad to welcome his gospel' + ``` + + ```python + >>> inputs["labels"] = processor(text_target=dataset[0]["text"], return_tensors="pt").input_ids + + >>> # compute loss + >>> loss = model(**inputs).loss + >>> round(loss.item(), 2) + 19.68 + ``` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None: + if decoder_input_ids is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + outputs = self.speecht5( + input_values=input_values, + attention_mask=attention_mask, + decoder_input_values=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + encoder_outputs=encoder_outputs, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + ) + + logits = self.text_decoder_postnet(outputs[0]) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return Seq2SeqLMOutput( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + + +def _generate_speech( + model: SpeechT5PreTrainedModel, + input_values: torch.FloatTensor, + speaker_embeddings: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + threshold: float = 0.5, + minlenratio: float = 0.0, + maxlenratio: float = 20.0, + vocoder: Optional[nn.Module] = None, + output_cross_attentions: bool = False, + return_output_lengths: bool = False, +) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, torch.FloatTensor]]: + if speaker_embeddings is None: + raise ValueError( + """`speaker_embeddings` must be specified. For example, you can use a speaker embeddings by following + the code snippet provided in this link: + https://huggingface.co/datasets/Matthijs/cmu-arctic-xvectors + """ + ) + + if attention_mask is None: + encoder_attention_mask = 1 - (input_values == model.config.pad_token_id).int() + else: + encoder_attention_mask = attention_mask + + bsz = input_values.size(0) + + encoder_out = model.speecht5.encoder( + input_values=input_values, + attention_mask=encoder_attention_mask, + return_dict=True, + ) + + encoder_last_hidden_state = encoder_out.last_hidden_state + + # downsample encoder attention mask + if isinstance(model.speecht5.encoder, SpeechT5EncoderWithSpeechPrenet): + encoder_attention_mask = model.speecht5.encoder.prenet._get_feature_vector_attention_mask( + encoder_out[0].shape[1], encoder_attention_mask + ) + + maxlen = int(encoder_last_hidden_state.size(1) * maxlenratio / model.config.reduction_factor) + minlen = int(encoder_last_hidden_state.size(1) * minlenratio / model.config.reduction_factor) + + # Start the output sequence with a mel spectrum that is all zeros. + output_sequence = encoder_last_hidden_state.new_zeros(bsz, 1, model.config.num_mel_bins) + + spectrogram = [] + cross_attentions = [] + past_key_values = None + idx = 0 + result_spectrogram = {} + + while True: + idx += 1 + + # Run the decoder prenet on the entire output sequence. + decoder_hidden_states = model.speecht5.decoder.prenet(output_sequence, speaker_embeddings) + # Run the decoder layers on the last element of the prenet output. + decoder_out = model.speecht5.decoder.wrapped_decoder( + hidden_states=decoder_hidden_states[:, -1:], + attention_mask=None, + encoder_hidden_states=encoder_last_hidden_state, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=True, + output_attentions=output_cross_attentions, + return_dict=True, + ) + + if output_cross_attentions: + cross_attentions.append(torch.cat(decoder_out.cross_attentions, dim=0)) + + last_decoder_output = decoder_out.last_hidden_state.squeeze(1) + past_key_values = decoder_out.past_key_values + + # Predict the new mel spectrum for this step in the sequence. + spectrum = model.speech_decoder_postnet.feat_out(last_decoder_output) + spectrum = spectrum.view(bsz, model.config.reduction_factor, model.config.num_mel_bins) + spectrogram.append(spectrum) + + # Extend the output sequence with the new mel spectrum. + new_spectrogram = spectrum[:, -1, :].view(bsz, 1, model.config.num_mel_bins) + output_sequence = torch.cat((output_sequence, new_spectrogram), dim=1) + # Predict the probability that this is the stop token. + prob = torch.sigmoid(model.speech_decoder_postnet.prob_out(last_decoder_output)) + + if idx < minlen: + continue + else: + # If the generation loop is less than maximum length time, check the ones in the batch that have met + # the prob threshold. Otherwise, assume all have met thresholds and fill other spectrograms for the batch. + if idx < maxlen: + meet_thresholds = torch.sum(prob, dim=-1) >= threshold + meet_indexes = torch.where(meet_thresholds)[0].tolist() + else: + meet_indexes = range(len(prob)) + meet_indexes = [i for i in meet_indexes if i not in result_spectrogram] + if len(meet_indexes) > 0: + spectrograms = torch.stack(spectrogram) + spectrograms = spectrograms.transpose(0, 1).flatten(1, 2) + spectrograms = model.speech_decoder_postnet.postnet(spectrograms) + for meet_index in meet_indexes: + result_spectrogram[meet_index] = spectrograms[meet_index] + if len(result_spectrogram) >= bsz: + break + spectrograms = [result_spectrogram[i] for i in range(len(result_spectrogram))] + if not return_output_lengths: + spectrogram = spectrograms[0] if bsz == 1 else torch.nn.utils.rnn.pad_sequence(spectrograms, batch_first=True) + if vocoder is not None: + outputs = vocoder(spectrogram) + else: + outputs = spectrogram + if output_cross_attentions: + cross_attentions = torch.cat(cross_attentions, dim=2) + if bsz > 1: + cross_attentions = cross_attentions.view( + bsz, int(cross_attentions.size(0) / bsz), *cross_attentions.size()[-3:] + ) + outputs = (outputs, cross_attentions) + else: + # batched return values should also include the spectrogram/waveform lengths + spectrogram_lengths = [] + for i in range(bsz): + spectrogram_lengths.append(spectrograms[i].size(0)) + if vocoder is None: + spectrograms = torch.nn.utils.rnn.pad_sequence(spectrograms, batch_first=True) + outputs = (spectrograms, spectrogram_lengths) + else: + waveforms = [] + spectrograms = torch.nn.utils.rnn.pad_sequence(spectrograms, batch_first=True) + waveforms = vocoder(spectrograms) + waveform_lengths = [int(waveforms.size(1) / max(spectrogram_lengths)) * i for i in spectrogram_lengths] + outputs = (waveforms, waveform_lengths) + if output_cross_attentions: + cross_attentions = torch.cat(cross_attentions, dim=2) + cross_attentions = cross_attentions.view( + bsz, int(cross_attentions.size(0) / bsz), *cross_attentions.size()[-3:] + ) + outputs = (*outputs, cross_attentions) + return outputs + + +@add_start_docstrings( + """SpeechT5 Model with a text encoder and a speech decoder.""", + SPEECHT5_START_DOCSTRING, +) +class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel): + main_input_name = "input_ids" + + def __init__(self, config: SpeechT5Config): + super().__init__(config) + + if config.vocab_size is None: + raise ValueError( + f"You are trying to instantiate {self.__class__} with a configuration that does not define the" + " vocabulary size of the language model head. Please instantiate the model as follows:" + " `SpeechT5ForTextToSpeech.from_pretrained(..., vocab_size=vocab_size)`. or define `vocab_size` of" + " your model's configuration." + ) + + text_encoder = SpeechT5EncoderWithTextPrenet(config) + speech_decoder = SpeechT5DecoderWithSpeechPrenet(config) + self.speecht5 = SpeechT5Model(config, text_encoder, speech_decoder) + + self.speech_decoder_postnet = SpeechT5SpeechDecoderPostnet(config) + + # Initialize weights and apply final processing + self.post_init() + + @classmethod + def can_generate(cls) -> bool: + # Speecht5 has a unique model structure, where the external class (`SpeechT5ForTextToSpeech`) doesn't need to inherit from + # `GenerationMixin` (it has a non-standard generation method). This means that the base `can_generate()` will return `False`, + # but we need to override it so as to do `GenerationConfig` handling in multiple parts of the codebase. + return True + + def get_encoder(self): + return self.speecht5.get_encoder() + + def get_decoder(self): + return self.speecht5.get_decoder() + + @add_start_docstrings_to_model_forward(SPEECHT5_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqSpectrogramOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + decoder_input_values: Optional[torch.FloatTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + speaker_embeddings: Optional[torch.FloatTensor] = None, + labels: Optional[torch.FloatTensor] = None, + stop_labels: Optional[torch.Tensor] = None, + ) -> Union[Tuple, Seq2SeqSpectrogramOutput]: + r""" + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`SpeechT5Tokenizer`]. See [`~PreTrainedTokenizer.encode`] and + [`~PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + decoder_input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_mel_bins)`): + Float values of input mel spectrogram. + + SpeechT5 uses an all-zero spectrum as the starting token for `decoder_input_values` generation. If + `past_key_values` is used, optionally only the last `decoder_input_values` have to be input (see + `past_key_values`). + speaker_embeddings (`torch.FloatTensor` of shape `(batch_size, config.speaker_embedding_dim)`, *optional*): + Tensor containing the speaker embeddings. + labels (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_mel_bins)`, *optional*): + Float values of target mel spectrogram. Timesteps set to `-100.0` are ignored (masked) for the loss + computation. Spectrograms can be obtained using [`SpeechT5Processor`]. See [`SpeechT5Processor.__call__`] + for details. + + Returns: + + Example: + + ```python + >>> from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan, set_seed + >>> import torch + + >>> processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts") + >>> model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts") + >>> vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan") + + >>> inputs = processor(text="Hello, my dog is cute", return_tensors="pt") + >>> speaker_embeddings = torch.zeros((1, 512)) # or load xvectors from a file + + >>> set_seed(555) # make deterministic + + >>> # generate speech + >>> speech = model.generate(inputs["input_ids"], speaker_embeddings=speaker_embeddings, vocoder=vocoder) + >>> speech.shape + torch.Size([15872]) + ``` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None: + if decoder_input_values is None: + decoder_input_values, decoder_attention_mask = shift_spectrograms_right( + labels, self.config.reduction_factor, decoder_attention_mask + ) + if self.config.use_guided_attention_loss: + output_attentions = True + + outputs = self.speecht5( + input_values=input_ids, + attention_mask=attention_mask, + decoder_input_values=decoder_input_values, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + encoder_outputs=encoder_outputs, + past_key_values=past_key_values, + use_cache=use_cache, + speaker_embeddings=speaker_embeddings, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + ) + + outputs_before_postnet, outputs_after_postnet, logits = self.speech_decoder_postnet(outputs[0]) + + loss = None + if labels is not None: + criterion = SpeechT5SpectrogramLoss(self.config) + loss = criterion( + attention_mask, + outputs_before_postnet, + outputs_after_postnet, + logits, + labels, + outputs.cross_attentions, + ) + + if not return_dict: + output = (outputs_after_postnet,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return Seq2SeqSpectrogramOutput( + loss=loss, + spectrogram=outputs_after_postnet, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + @torch.no_grad() + def generate( + self, + input_ids: torch.LongTensor, + attention_mask: Optional[torch.LongTensor] = None, + speaker_embeddings: Optional[torch.FloatTensor] = None, + threshold: float = 0.5, + minlenratio: float = 0.0, + maxlenratio: float = 20.0, + vocoder: Optional[nn.Module] = None, + output_cross_attentions: bool = False, + return_output_lengths: bool = False, + **kwargs, + ) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, torch.FloatTensor]]: + r""" + Converts a sequence of input tokens into a sequence of mel spectrograms, which are subsequently turned into a + speech waveform using a vocoder. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`SpeechT5Tokenizer`]. See [`~PreTrainedTokenizer.encode`] and + [`~PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Attention mask from the tokenizer, required for batched inference to signal to the model where to + ignore padded tokens from the input_ids. + speaker_embeddings (`torch.FloatTensor` of shape `(batch_size, config.speaker_embedding_dim)`, *optional*): + Tensor containing the speaker embeddings. + threshold (`float`, *optional*, defaults to 0.5): + The generated sequence ends when the predicted stop token probability exceeds this value. + minlenratio (`float`, *optional*, defaults to 0.0): + Used to calculate the minimum required length for the output sequence. + maxlenratio (`float`, *optional*, defaults to 20.0): + Used to calculate the maximum allowed length for the output sequence. + vocoder (`nn.Module`, *optional*): + The vocoder that converts the mel spectrogram into a speech waveform. If `None`, the output is the mel + spectrogram. + output_cross_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of the decoder's cross-attention layers. + return_output_lengths (`bool`, *optional*, defaults to `False`): + Whether or not to return the concrete spectrogram/waveform lengths. + + Returns: + `tuple(torch.FloatTensor)` comprising various elements depending on the inputs: + - when `return_output_lengths` is False + - **spectrogram** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape + `(output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrogram. + - **waveform** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape + `(num_frames,)` -- The predicted speech waveform. + - **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`) + `torch.FloatTensor` of shape `(config.decoder_layers, config.decoder_attention_heads, + output_sequence_length, input_sequence_length)` -- The outputs of the decoder's cross-attention layers. + - when `return_output_lengths` is True + - **spectrograms** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape + `(batch_size, output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrograms that + are padded to the maximum length. + - **spectrogram_lengths** (*optional*, returned when no `vocoder` is provided) `List[Int]` -- A list of + all the concrete lengths for each spectrogram. + - **waveforms** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape + `(batch_size, num_frames)` -- The predicted speech waveforms that are padded to the maximum length. + - **waveform_lengths** (*optional*, returned when a `vocoder` is provided) `List[Int]` -- A list of all + the concrete lengths for each waveform. + - **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`) + `torch.FloatTensor` of shape `(batch_size, config.decoder_layers, config.decoder_attention_heads, + output_sequence_length, input_sequence_length)` -- The outputs of the decoder's cross-attention layers. + """ + if speaker_embeddings is not None: + batch_size = input_ids.size(0) + if speaker_embeddings.size(0) != batch_size: + if speaker_embeddings.size(0) == 1: + speaker_embeddings = speaker_embeddings.repeat(batch_size, 1) + else: + raise ValueError( + "The first dimension of speaker_embeddings must be either 1 or the same as batch_size." + ) + + return _generate_speech( + self, + input_ids, + speaker_embeddings, + attention_mask, + threshold, + minlenratio, + maxlenratio, + vocoder, + output_cross_attentions, + return_output_lengths, + ) + + @torch.no_grad() + def generate_speech( + self, + input_ids: torch.LongTensor, + speaker_embeddings: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + threshold: float = 0.5, + minlenratio: float = 0.0, + maxlenratio: float = 20.0, + vocoder: Optional[nn.Module] = None, + output_cross_attentions: bool = False, + return_output_lengths: bool = False, + ) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, torch.FloatTensor]]: + r""" + Converts a sequence of input tokens into a sequence of mel spectrograms, which are subsequently turned into a + speech waveform using a vocoder. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`SpeechT5Tokenizer`]. See [`~PreTrainedTokenizer.encode`] and + [`~PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + speaker_embeddings (`torch.FloatTensor` of shape `(batch_size, config.speaker_embedding_dim)`, *optional*): + Tensor containing the speaker embeddings. + attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing convolution and attention on padding token indices. Mask values selected in + `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + threshold (`float`, *optional*, defaults to 0.5): + The generated sequence ends when the predicted stop token probability exceeds this value. + minlenratio (`float`, *optional*, defaults to 0.0): + Used to calculate the minimum required length for the output sequence. + maxlenratio (`float`, *optional*, defaults to 20.0): + Used to calculate the maximum allowed length for the output sequence. + vocoder (`nn.Module`, *optional*, defaults to `None`): + The vocoder that converts the mel spectrogram into a speech waveform. If `None`, the output is the mel + spectrogram. + output_cross_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of the decoder's cross-attention layers. + return_output_lengths (`bool`, *optional*, defaults to `False`): + Whether or not to return the concrete spectrogram/waveform lengths. + + Returns: + `tuple(torch.FloatTensor)` comprising various elements depending on the inputs: + - when `return_output_lengths` is False + - **spectrogram** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape + `(output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrogram. + - **waveform** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape + `(num_frames,)` -- The predicted speech waveform. + - **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`) + `torch.FloatTensor` of shape `(config.decoder_layers, config.decoder_attention_heads, + output_sequence_length, input_sequence_length)` -- The outputs of the decoder's cross-attention layers. + - when `return_output_lengths` is True + - **spectrograms** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape + `(batch_size, output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrograms that + are padded to the maximum length. + - **spectrogram_lengths** (*optional*, returned when no `vocoder` is provided) `List[Int]` -- A list of + all the concrete lengths for each spectrogram. + - **waveforms** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape + `(batch_size, num_frames)` -- The predicted speech waveforms that are padded to the maximum length. + - **waveform_lengths** (*optional*, returned when a `vocoder` is provided) `List[Int]` -- A list of all + the concrete lengths for each waveform. + - **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`) + `torch.FloatTensor` of shape `(batch_size, config.decoder_layers, config.decoder_attention_heads, + output_sequence_length, input_sequence_length)` -- The outputs of the decoder's cross-attention layers. + """ + if speaker_embeddings is not None: + batch_size = input_ids.size(0) + if speaker_embeddings.size(0) != batch_size: + if speaker_embeddings.size(0) == 1: + speaker_embeddings = speaker_embeddings.repeat(batch_size, 1) + else: + raise ValueError( + "The first dimension of speaker_embeddings must be either 1 or the same as batch size." + ) + + return _generate_speech( + self, + input_ids, + speaker_embeddings, + attention_mask, + threshold, + minlenratio, + maxlenratio, + vocoder, + output_cross_attentions, + return_output_lengths, + ) + + +@add_start_docstrings( + """SpeechT5 Model with a speech encoder and a speech decoder.""", + SPEECHT5_START_DOCSTRING, +) +class SpeechT5ForSpeechToSpeech(SpeechT5PreTrainedModel): + def __init__(self, config: SpeechT5Config): + super().__init__(config) + + speech_encoder = SpeechT5EncoderWithSpeechPrenet(config) + speech_decoder = SpeechT5DecoderWithSpeechPrenet(config) + self.speecht5 = SpeechT5Model(config, speech_encoder, speech_decoder) + + self.speech_decoder_postnet = SpeechT5SpeechDecoderPostnet(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_encoder(self): + return self.speecht5.get_encoder() + + def get_decoder(self): + return self.speecht5.get_decoder() + + def freeze_feature_encoder(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameter will + not be updated during training. + """ + self.get_encoder().prenet.freeze_feature_encoder() + + @add_start_docstrings_to_model_forward(SPEECHT5_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqSpectrogramOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + decoder_input_values: Optional[torch.FloatTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + speaker_embeddings: Optional[torch.FloatTensor] = None, + labels: Optional[torch.FloatTensor] = None, + stop_labels: Optional[torch.Tensor] = None, + ) -> Union[Tuple, Seq2SeqSpectrogramOutput]: + r""" + input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Float values of input raw speech waveform. Values can be obtained by loading a *.flac* or *.wav* audio file + into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (*pip install + soundfile*). To prepare the array into `input_values`, the [`SpeechT5Processor`] should be used for padding + and conversion into a tensor of type `torch.FloatTensor`. See [`SpeechT5Processor.__call__`] for details. + decoder_input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_mel_bins)`): + Float values of input mel spectrogram. + + SpeechT5 uses an all-zero spectrum as the starting token for `decoder_input_values` generation. If + `past_key_values` is used, optionally only the last `decoder_input_values` have to be input (see + `past_key_values`). + speaker_embeddings (`torch.FloatTensor` of shape `(batch_size, config.speaker_embedding_dim)`, *optional*): + Tensor containing the speaker embeddings. + labels (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_mel_bins)`, *optional*): + Float values of target mel spectrogram. Spectrograms can be obtained using [`SpeechT5Processor`]. See + [`SpeechT5Processor.__call__`] for details. + + Returns: + + Example: + + ```python + >>> from transformers import SpeechT5Processor, SpeechT5ForSpeechToSpeech, SpeechT5HifiGan, set_seed + >>> from datasets import load_dataset + >>> import torch + + >>> dataset = load_dataset( + ... "hf-internal-testing/librispeech_asr_demo", "clean", split="validation", trust_remote_code=True + ... ) # doctest: +IGNORE_RESULT + >>> dataset = dataset.sort("id") + >>> sampling_rate = dataset.features["audio"].sampling_rate + + >>> processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_vc") + >>> model = SpeechT5ForSpeechToSpeech.from_pretrained("microsoft/speecht5_vc") + >>> vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan") + + >>> # audio file is decoded on the fly + >>> inputs = processor(audio=dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt") + + >>> speaker_embeddings = torch.zeros((1, 512)) # or load xvectors from a file + + >>> set_seed(555) # make deterministic + + >>> # generate speech + >>> speech = model.generate_speech(inputs["input_values"], speaker_embeddings, vocoder=vocoder) + >>> speech.shape + torch.Size([77824]) + ``` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None: + if decoder_input_values is None: + decoder_input_values, decoder_attention_mask = shift_spectrograms_right( + labels, self.config.reduction_factor, decoder_attention_mask + ) + + outputs = self.speecht5( + input_values=input_values, + attention_mask=attention_mask, + decoder_input_values=decoder_input_values, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + encoder_outputs=encoder_outputs, + past_key_values=past_key_values, + use_cache=use_cache, + speaker_embeddings=speaker_embeddings, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + ) + + _, spectrogram, logits = self.speech_decoder_postnet(outputs[0]) + + loss = None + + if not return_dict: + output = (spectrogram,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return Seq2SeqSpectrogramOutput( + loss=loss, + spectrogram=spectrogram, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + @torch.no_grad() + def generate_speech( + self, + input_values: torch.FloatTensor, + speaker_embeddings: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + threshold: float = 0.5, + minlenratio: float = 0.0, + maxlenratio: float = 20.0, + vocoder: Optional[nn.Module] = None, + output_cross_attentions: bool = False, + return_output_lengths: bool = False, + ) -> torch.FloatTensor: + r""" + Converts a raw speech waveform into a sequence of mel spectrograms, which are subsequently turned back into a + speech waveform using a vocoder. + + Args: + input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Float values of input raw speech waveform. + + Values can be obtained by loading a *.flac* or *.wav* audio file into an array of type `List[float]` or + a `numpy.ndarray`, *e.g.* via the soundfile library (*pip install soundfile*). To prepare the array + into `input_values`, the [`SpeechT5Processor`] should be used for padding and conversion into a tensor + of type `torch.FloatTensor`. See [`SpeechT5Processor.__call__`] for details. + speaker_embeddings (`torch.FloatTensor` of shape `(batch_size, config.speaker_embedding_dim)`, *optional*): + Tensor containing the speaker embeddings. + attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing convolution and attention on padding token indices. Mask values selected in + `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + threshold (`float`, *optional*, defaults to 0.5): + The generated sequence ends when the predicted stop token probability exceeds this value. + minlenratio (`float`, *optional*, defaults to 0.0): + Used to calculate the minimum required length for the output sequence. + maxlenratio (`float`, *optional*, defaults to 20.0): + Used to calculate the maximum allowed length for the output sequence. + vocoder (`nn.Module`, *optional*, defaults to `None`): + The vocoder that converts the mel spectrogram into a speech waveform. If `None`, the output is the mel + spectrogram. + output_cross_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of the decoder's cross-attention layers. + return_output_lengths (`bool`, *optional*, defaults to `False`): + Whether or not to return the concrete spectrogram/waveform lengths. + + Returns: + `tuple(torch.FloatTensor)` comprising various elements depending on the inputs: + - when `return_output_lengths` is False + - **spectrogram** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape + `(output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrogram. + - **waveform** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape + `(num_frames,)` -- The predicted speech waveform. + - **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`) + `torch.FloatTensor` of shape `(config.decoder_layers, config.decoder_attention_heads, + output_sequence_length, input_sequence_length)` -- The outputs of the decoder's cross-attention layers. + - when `return_output_lengths` is True + - **spectrograms** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape + `(batch_size, output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrograms that + are padded to the maximum length. + - **spectrogram_lengths** (*optional*, returned when no `vocoder` is provided) `List[Int]` -- A list of + all the concrete lengths for each spectrogram. + - **waveforms** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape + `(batch_size, num_frames)` -- The predicted speech waveforms that are padded to the maximum length. + - **waveform_lengths** (*optional*, returned when a `vocoder` is provided) `List[Int]` -- A list of all + the concrete lengths for each waveform. + - **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`) + `torch.FloatTensor` of shape `(batch_size, config.decoder_layers, config.decoder_attention_heads, + output_sequence_length, input_sequence_length)` -- The outputs of the decoder's cross-attention layers. + """ + if speaker_embeddings is None: + speaker_embeddings = torch.zeros((1, 512), device=input_values.device) + + return _generate_speech( + self, + input_values, + speaker_embeddings, + attention_mask, + threshold, + minlenratio, + maxlenratio, + vocoder, + output_cross_attentions, + return_output_lengths, + ) + + +HIFIGAN_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`SpeechT5HifiGanConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +class HifiGanResidualBlock(nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), leaky_relu_slope=0.1): + super().__init__() + self.leaky_relu_slope = leaky_relu_slope + + self.convs1 = nn.ModuleList( + [ + nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + dilation=dilation[i], + padding=self.get_padding(kernel_size, dilation[i]), + ) + for i in range(len(dilation)) + ] + ) + self.convs2 = nn.ModuleList( + [ + nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + dilation=1, + padding=self.get_padding(kernel_size, 1), + ) + for _ in range(len(dilation)) + ] + ) + + def get_padding(self, kernel_size, dilation=1): + return (kernel_size * dilation - dilation) // 2 + + def apply_weight_norm(self): + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + + for layer in self.convs1: + weight_norm(layer) + for layer in self.convs2: + weight_norm(layer) + + def remove_weight_norm(self): + for layer in self.convs1: + nn.utils.remove_weight_norm(layer) + for layer in self.convs2: + nn.utils.remove_weight_norm(layer) + + def forward(self, hidden_states): + for conv1, conv2 in zip(self.convs1, self.convs2): + residual = hidden_states + hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope) + hidden_states = conv1(hidden_states) + hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope) + hidden_states = conv2(hidden_states) + hidden_states = hidden_states + residual + return hidden_states + + +@add_start_docstrings( + """HiFi-GAN vocoder.""", + HIFIGAN_START_DOCSTRING, +) +class SpeechT5HifiGan(PreTrainedModel): + config_class = SpeechT5HifiGanConfig + main_input_name = "spectrogram" + + def __init__(self, config: SpeechT5HifiGanConfig): + super().__init__(config) + self.num_kernels = len(config.resblock_kernel_sizes) + self.num_upsamples = len(config.upsample_rates) + self.conv_pre = nn.Conv1d( + config.model_in_dim, + config.upsample_initial_channel, + kernel_size=7, + stride=1, + padding=3, + ) + + self.upsampler = nn.ModuleList() + for i, (upsample_rate, kernel_size) in enumerate(zip(config.upsample_rates, config.upsample_kernel_sizes)): + self.upsampler.append( + nn.ConvTranspose1d( + config.upsample_initial_channel // (2**i), + config.upsample_initial_channel // (2 ** (i + 1)), + kernel_size=kernel_size, + stride=upsample_rate, + padding=(kernel_size - upsample_rate) // 2, + ) + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.upsampler)): + channels = config.upsample_initial_channel // (2 ** (i + 1)) + for kernel_size, dilation in zip(config.resblock_kernel_sizes, config.resblock_dilation_sizes): + self.resblocks.append(HifiGanResidualBlock(channels, kernel_size, dilation, config.leaky_relu_slope)) + + self.conv_post = nn.Conv1d(channels, 1, kernel_size=7, stride=1, padding=3) + + self.register_buffer("mean", torch.zeros(config.model_in_dim)) + self.register_buffer("scale", torch.ones(config.model_in_dim)) + + # Initialize weights and apply final processing + self.post_init() + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, (nn.Linear, nn.Conv1d)): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + + def apply_weight_norm(self): + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + + weight_norm(self.conv_pre) + for layer in self.upsampler: + weight_norm(layer) + for layer in self.resblocks: + layer.apply_weight_norm() + weight_norm(self.conv_post) + + def remove_weight_norm(self): + nn.utils.remove_weight_norm(self.conv_pre) + for layer in self.upsampler: + nn.utils.remove_weight_norm(layer) + for layer in self.resblocks: + layer.remove_weight_norm() + nn.utils.remove_weight_norm(self.conv_post) + + def forward(self, spectrogram: torch.FloatTensor) -> torch.FloatTensor: + r""" + Converts a log-mel spectrogram into a speech waveform. Passing a batch of log-mel spectrograms returns a batch + of speech waveforms. Passing a single, un-batched log-mel spectrogram returns a single, un-batched speech + waveform. + + Args: + spectrogram (`torch.FloatTensor`): + Tensor containing the log-mel spectrograms. Can be batched and of shape `(batch_size, sequence_length, + config.model_in_dim)`, or un-batched and of shape `(sequence_length, config.model_in_dim)`. + + Returns: + `torch.FloatTensor`: Tensor containing the speech waveform. If the input spectrogram is batched, will be of + shape `(batch_size, num_frames,)`. If un-batched, will be of shape `(num_frames,)`. + """ + if self.config.normalize_before: + spectrogram = (spectrogram - self.mean) / self.scale + + is_batched = spectrogram.dim() == 3 + if not is_batched: + spectrogram = spectrogram.unsqueeze(0) + + hidden_states = spectrogram.transpose(2, 1) + + hidden_states = self.conv_pre(hidden_states) + for i in range(self.num_upsamples): + hidden_states = nn.functional.leaky_relu(hidden_states, self.config.leaky_relu_slope) + hidden_states = self.upsampler[i](hidden_states) + + res_state = self.resblocks[i * self.num_kernels](hidden_states) + for j in range(1, self.num_kernels): + res_state += self.resblocks[i * self.num_kernels + j](hidden_states) + hidden_states = res_state / self.num_kernels + + hidden_states = nn.functional.leaky_relu(hidden_states) + hidden_states = self.conv_post(hidden_states) + hidden_states = torch.tanh(hidden_states) + + if not is_batched: + # remove batch dim and collapse tensor to 1-d audio waveform + waveform = hidden_states.squeeze(0).transpose(1, 0).view(-1) + else: + # remove seq-len dim since this collapses to 1 + waveform = hidden_states.squeeze(1) + + return waveform + + +__all__ = [ + "SpeechT5ForSpeechToText", + "SpeechT5ForSpeechToSpeech", + "SpeechT5ForTextToSpeech", + "SpeechT5Model", + "SpeechT5PreTrainedModel", + "SpeechT5HifiGan", +] diff --git a/docs/transformers/src/transformers/models/speecht5/number_normalizer.py b/docs/transformers/src/transformers/models/speecht5/number_normalizer.py new file mode 100644 index 0000000000000000000000000000000000000000..eb3314c24f24c1f8b9bc760c4ece69e0a2819888 --- /dev/null +++ b/docs/transformers/src/transformers/models/speecht5/number_normalizer.py @@ -0,0 +1,192 @@ +# coding=utf-8 +# Copyright 2023 The Fairseq Authors, Microsoft Research, and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Number Normalizer class for SpeechT5.""" + +import re + + +class EnglishNumberNormalizer: + def __init__(self): + self.ones = ["", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"] + self.teens = [ + "", + "eleven", + "twelve", + "thirteen", + "fourteen", + "fifteen", + "sixteen", + "seventeen", + "eighteen", + "nineteen", + ] + self.tens = ["", "ten", "twenty", "thirty", "forty", "fifty", "sixty", "seventy", "eighty", "ninety"] + self.thousands = [ + "", + "thousand", + "million", + "billion", + "trillion", + "quadrillion", + "quintillion", + "sextillion", + "septillion", + "octillion", + "nonillion", + "decillion", + ] + + # Define a dictionary to map currency symbols to their names + # Top most traded currencies according to + # https://en.wikipedia.org/wiki/Template:Most_traded_currencies + self.currency_symbols = { + "$": " dollars", + "€": " euros", + "£": " pounds", + "¢": " cents", + "¥": " japanese yen", + "﷼": " saudi riyal", + "₹": " indian rupees", + "₽": " russian rubles", + "฿": " thai baht", + "₺": " turkish liras", + "₴": " ukrainian hryvnia", + "₣": " swiss francs", + "₡": " costa rican colon", + "₱": " philippine peso", + "₪": " israeli shekels", + "₮": " mongolian tögrög", + "₩": " south korean won", + "₦": " nigerian naira", + "₫": " vietnamese Đồng", + } + + def spell_number(self, num): + if num == 0: + return "zero" + + parts = [] + for i in range(0, len(self.thousands)): + if num % 1000 != 0: + part = "" + hundreds = num % 1000 // 100 + tens_units = num % 100 + + if hundreds > 0: + part += self.ones[hundreds] + " hundred" + if tens_units > 0: + part += " and " + + if tens_units > 10 and tens_units < 20: + part += self.teens[tens_units - 10] + else: + tens_digit = self.tens[tens_units // 10] + ones_digit = self.ones[tens_units % 10] + if tens_digit: + part += tens_digit + if ones_digit: + if tens_digit: + part += " " + part += ones_digit + + parts.append(part) + + num //= 1000 + + return " ".join(reversed(parts)) + + def convert(self, number): + """ + Converts an individual number passed in string form to spelt-out form + """ + if "." in number: + integer_part, decimal_part = number.split(".") + else: + integer_part, decimal_part = number, "00" + + # Extract currency symbol if present + currency_symbol = "" + for symbol, name in self.currency_symbols.items(): + if integer_part.startswith(symbol): + currency_symbol = name + integer_part = integer_part[len(symbol) :] + break + + if integer_part.startswith("-"): + if integer_part[1:].startswith(symbol): + currency_symbol = name + integer_part = "-" + integer_part[len(symbol) + 1 :] + break + + # Extract 'minus' prefix for negative numbers + minus_prefix = "" + if integer_part.startswith("-"): + minus_prefix = "minus " + integer_part = integer_part[1:] + elif integer_part.startswith("minus"): + minus_prefix = "minus " + integer_part = integer_part[len("minus") :] + + percent_suffix = "" + if "%" in integer_part or "%" in decimal_part: + percent_suffix = " percent" + integer_part = integer_part.replace("%", "") + decimal_part = decimal_part.replace("%", "") + + integer_part = integer_part.zfill(3 * ((len(integer_part) - 1) // 3 + 1)) + + parts = [] + for i in range(0, len(integer_part), 3): + chunk = int(integer_part[i : i + 3]) + if chunk > 0: + part = self.spell_number(chunk) + unit = self.thousands[len(integer_part[i:]) // 3 - 1] + if unit: + part += " " + unit + parts.append(part) + + spelled_integer = " ".join(parts) + + # Format the spelt-out number based on conditions, such as: + # If it has decimal parts, currency symbol, minus prefix, etc + if decimal_part == "00": + return ( + f"{minus_prefix}{spelled_integer}{percent_suffix}{currency_symbol}" + if minus_prefix or currency_symbol + else f"{spelled_integer}{percent_suffix}" + ) + else: + spelled_decimal = " ".join([self.spell_number(int(digit)) for digit in decimal_part]) + return ( + f"{minus_prefix}{spelled_integer} point {spelled_decimal}{percent_suffix}{currency_symbol}" + if minus_prefix or currency_symbol + else f"{minus_prefix}{spelled_integer} point {spelled_decimal}{percent_suffix}" + ) + + def __call__(self, text): + """ + Convert numbers / number-like quantities in a string to their spelt-out counterparts + """ + # Form part of the pattern for all currency symbols + pattern = r"(? 15000, etc) + text = re.sub(r"(\d+,\d+)", lambda match: match.group(1).replace(",", ""), text) + + # Use regex to find and replace numbers in the text + converted_text = re.sub(pattern, lambda match: self.convert(match.group(1)), text) + converted_text = re.sub(" +", " ", converted_text) + + return converted_text diff --git a/docs/transformers/src/transformers/models/speecht5/processing_speecht5.py b/docs/transformers/src/transformers/models/speecht5/processing_speecht5.py new file mode 100644 index 0000000000000000000000000000000000000000..0c038d97ae8c05a91c5bde75bfcf91a544d5166a --- /dev/null +++ b/docs/transformers/src/transformers/models/speecht5/processing_speecht5.py @@ -0,0 +1,186 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Speech processor class for SpeechT5.""" + +from ...processing_utils import ProcessorMixin + + +class SpeechT5Processor(ProcessorMixin): + r""" + Constructs a SpeechT5 processor which wraps a feature extractor and a tokenizer into a single processor. + + [`SpeechT5Processor`] offers all the functionalities of [`SpeechT5FeatureExtractor`] and [`SpeechT5Tokenizer`]. See + the docstring of [`~SpeechT5Processor.__call__`] and [`~SpeechT5Processor.decode`] for more information. + + Args: + feature_extractor (`SpeechT5FeatureExtractor`): + An instance of [`SpeechT5FeatureExtractor`]. The feature extractor is a required input. + tokenizer (`SpeechT5Tokenizer`): + An instance of [`SpeechT5Tokenizer`]. The tokenizer is a required input. + """ + + feature_extractor_class = "SpeechT5FeatureExtractor" + tokenizer_class = "SpeechT5Tokenizer" + + def __init__(self, feature_extractor, tokenizer): + super().__init__(feature_extractor, tokenizer) + + def __call__(self, *args, **kwargs): + """ + Processes audio and text input, as well as audio and text targets. + + You can process audio by using the argument `audio`, or process audio targets by using the argument + `audio_target`. This forwards the arguments to SpeechT5FeatureExtractor's + [`~SpeechT5FeatureExtractor.__call__`]. + + You can process text by using the argument `text`, or process text labels by using the argument `text_target`. + This forwards the arguments to SpeechT5Tokenizer's [`~SpeechT5Tokenizer.__call__`]. + + Valid input combinations are: + + - `text` only + - `audio` only + - `text_target` only + - `audio_target` only + - `text` and `audio_target` + - `audio` and `audio_target` + - `text` and `text_target` + - `audio` and `text_target` + + Please refer to the docstring of the above two methods for more information. + """ + audio = kwargs.pop("audio", None) + text = kwargs.pop("text", None) + text_target = kwargs.pop("text_target", None) + audio_target = kwargs.pop("audio_target", None) + sampling_rate = kwargs.pop("sampling_rate", None) + + if audio is not None and text is not None: + raise ValueError( + "Cannot process both `audio` and `text` inputs. Did you mean `audio_target` or `text_target`?" + ) + if audio_target is not None and text_target is not None: + raise ValueError( + "Cannot process both `audio_target` and `text_target` inputs. Did you mean `audio` or `text`?" + ) + if audio is None and audio_target is None and text is None and text_target is None: + raise ValueError( + "You need to specify either an `audio`, `audio_target`, `text`, or `text_target` input to process." + ) + + if audio is not None: + inputs = self.feature_extractor(audio, *args, sampling_rate=sampling_rate, **kwargs) + elif text is not None: + inputs = self.tokenizer(text, **kwargs) + else: + inputs = None + + if audio_target is not None: + targets = self.feature_extractor(audio_target=audio_target, *args, sampling_rate=sampling_rate, **kwargs) + labels = targets["input_values"] + elif text_target is not None: + targets = self.tokenizer(text_target, **kwargs) + labels = targets["input_ids"] + else: + targets = None + + if inputs is None: + return targets + + if targets is not None: + inputs["labels"] = labels + + decoder_attention_mask = targets.get("attention_mask") + if decoder_attention_mask is not None: + inputs["decoder_attention_mask"] = decoder_attention_mask + + return inputs + + def pad(self, *args, **kwargs): + """ + Collates the audio and text inputs, as well as their targets, into a padded batch. + + Audio inputs are padded by SpeechT5FeatureExtractor's [`~SpeechT5FeatureExtractor.pad`]. Text inputs are padded + by SpeechT5Tokenizer's [`~SpeechT5Tokenizer.pad`]. + + Valid input combinations are: + + - `input_ids` only + - `input_values` only + - `labels` only, either log-mel spectrograms or text tokens + - `input_ids` and log-mel spectrogram `labels` + - `input_values` and text `labels` + + Please refer to the docstring of the above two methods for more information. + """ + input_values = kwargs.pop("input_values", None) + input_ids = kwargs.pop("input_ids", None) + labels = kwargs.pop("labels", None) + + if input_values is not None and input_ids is not None: + raise ValueError("Cannot process both `input_values` and `input_ids` inputs.") + if input_values is None and input_ids is None and labels is None: + raise ValueError( + "You need to specify either an `input_values`, `input_ids`, or `labels` input to be padded." + ) + + if input_values is not None: + inputs = self.feature_extractor.pad(input_values, *args, **kwargs) + elif input_ids is not None: + inputs = self.tokenizer.pad(input_ids, **kwargs) + else: + inputs = None + + if labels is not None: + if "input_ids" in labels or (isinstance(labels, list) and "input_ids" in labels[0]): + targets = self.tokenizer.pad(labels, **kwargs) + labels = targets["input_ids"] + else: + feature_size_hack = self.feature_extractor.feature_size + self.feature_extractor.feature_size = self.feature_extractor.num_mel_bins + targets = self.feature_extractor.pad(labels, *args, **kwargs) + self.feature_extractor.feature_size = feature_size_hack + labels = targets["input_values"] + else: + targets = None + + if inputs is None: + return targets + + if targets is not None: + inputs["labels"] = labels + + decoder_attention_mask = targets.get("attention_mask") + if decoder_attention_mask is not None: + inputs["decoder_attention_mask"] = decoder_attention_mask + + return inputs + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to SpeechT5Tokenizer's [`~SpeechT5Tokenizer.batch_decode`]. Please refer + to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to SpeechT5Tokenizer's [`~SpeechT5Tokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + +__all__ = ["SpeechT5Processor"] diff --git a/docs/transformers/src/transformers/models/speecht5/tokenization_speecht5.py b/docs/transformers/src/transformers/models/speecht5/tokenization_speecht5.py new file mode 100644 index 0000000000000000000000000000000000000000..5817fb20fe0c098a632ef08a266b3ff66a473e5d --- /dev/null +++ b/docs/transformers/src/transformers/models/speecht5/tokenization_speecht5.py @@ -0,0 +1,223 @@ +# coding=utf-8 +# Copyright 2023 The Facebook Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization class for SpeechT5.""" + +import os +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple + +import sentencepiece as spm + +from ...tokenization_utils import PreTrainedTokenizer +from ...utils import logging +from ...utils.import_utils import requires +from .number_normalizer import EnglishNumberNormalizer + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "spm_char.model"} + + +@requires(backends=("sentencepiece",)) +class SpeechT5Tokenizer(PreTrainedTokenizer): + """ + Construct a SpeechT5 tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece). + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that + contains the vocabulary necessary to instantiate a tokenizer. + bos_token (`str`, *optional*, defaults to `""`): + The begin of sequence token. + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + normalize (`bool`, *optional*, defaults to `False`): + Whether to convert numeric quantities in the text to their spelt-out english counterparts. + sp_model_kwargs (`dict`, *optional*): + Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for + SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things, + to set: + + - `enable_sampling`: Enable subword regularization. + - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout. + + - `nbest_size = {0,1}`: No sampling is performed. + - `nbest_size > 1`: samples from the nbest_size results. + - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice) + using forward-filtering-and-backward-sampling algorithm. + + - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for + BPE-dropout. + + Attributes: + sp_model (`SentencePieceProcessor`): + The *SentencePiece* processor that is used for every conversion (string, tokens and IDs). + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + bos_token="", + eos_token="", + unk_token="", + pad_token="", + normalize=False, + sp_model_kwargs: Optional[Dict[str, Any]] = None, + **kwargs, + ) -> None: + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + self.vocab_file = vocab_file + self.normalize = normalize + self._normalizer = None + + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(vocab_file) + + super().__init__( + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + pad_token=pad_token, + normalize=normalize, + sp_model_kwargs=self.sp_model_kwargs, + **kwargs, + ) + + def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs): + normalize = kwargs.pop("normalize", self.normalize) + if is_split_into_words: + text = " " + text + if normalize: + text = self.normalizer(text) + return (text, kwargs) + + @property + def vocab_size(self): + return self.sp_model.get_piece_size() + + @property + def normalizer(self): + if self._normalizer is None: + self._normalizer = EnglishNumberNormalizer() + return self._normalizer + + @normalizer.setter + def normalizer(self, value): + self._normalizer = value + + def get_vocab(self): + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + def __getstate__(self): + state = self.__dict__.copy() + state["sp_model"] = None + return state + + def __setstate__(self, d): + self.__dict__ = d + + # for backward compatibility + if not hasattr(self, "sp_model_kwargs"): + self.sp_model_kwargs = {} + + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(self.vocab_file) + + def _tokenize(self, text: str) -> List[str]: + """Take as input a string and return a list of strings (tokens) for words/sub-words""" + return self.sp_model.encode(text, out_type=str) + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.sp_model.piece_to_id(token) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + token = self.sp_model.IdToPiece(index) + return token + + # Copied from transformers.models.albert.tokenization_albert.AlbertTokenizer.convert_tokens_to_string + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + current_sub_tokens = [] + out_string = "" + prev_is_special = False + for token in tokens: + # make sure that special tokens are not decoded using sentencepiece model + if token in self.all_special_tokens: + if not prev_is_special: + out_string += " " + out_string += self.sp_model.decode(current_sub_tokens) + token + prev_is_special = True + current_sub_tokens = [] + else: + current_sub_tokens.append(token) + prev_is_special = False + out_string += self.sp_model.decode(current_sub_tokens) + return out_string.strip() + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]: + """Build model inputs from a sequence by appending eos_token_id.""" + if token_ids_1 is None: + return token_ids_0 + [self.eos_token_id] + # We don't expect to process pairs, but leave the pair logic for API consistency + return token_ids_0 + token_ids_1 + [self.eos_token_id] + + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + suffix_ones = [1] + if token_ids_1 is None: + return ([0] * len(token_ids_0)) + suffix_ones + return ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): + copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + + return (out_vocab_file,) + + +__all__ = ["SpeechT5Tokenizer"] diff --git a/docs/transformers/src/transformers/models/splinter/__init__.py b/docs/transformers/src/transformers/models/splinter/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..258d79a9c347991b6b10d66608c5cc2b3a9cecae --- /dev/null +++ b/docs/transformers/src/transformers/models/splinter/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_splinter import * + from .modeling_splinter import * + from .tokenization_splinter import * + from .tokenization_splinter_fast import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/splinter/configuration_splinter.py b/docs/transformers/src/transformers/models/splinter/configuration_splinter.py new file mode 100644 index 0000000000000000000000000000000000000000..533b067ed34f3118a9f85896120110103b7993f0 --- /dev/null +++ b/docs/transformers/src/transformers/models/splinter/configuration_splinter.py @@ -0,0 +1,123 @@ +# coding=utf-8 +# Copyright 2021 Tel AViv University, AllenAI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Splinter model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class SplinterConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`SplinterModel`]. It is used to instantiate an + Splinter model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the Splinter + [tau/splinter-base](https://huggingface.co/tau/splinter-base) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 30522): + Vocabulary size of the Splinter model. Defines the number of different tokens that can be represented by + the `inputs_ids` passed when calling [`SplinterModel`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimension of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + type_vocab_size (`int`, *optional*, defaults to 2): + The vocabulary size of the `token_type_ids` passed when calling [`SplinterModel`]. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + question_token_id (`int`, *optional*, defaults to 104): + The id of the `[QUESTION]` token. + + Example: + + ```python + >>> from transformers import SplinterModel, SplinterConfig + + >>> # Initializing a Splinter tau/splinter-base style configuration + >>> configuration = SplinterConfig() + + >>> # Initializing a model from the tau/splinter-base style configuration + >>> model = SplinterModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "splinter" + + def __init__( + self, + vocab_size=30522, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + layer_norm_eps=1e-12, + use_cache=True, + pad_token_id=0, + question_token_id=104, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, **kwargs) + + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.initializer_range = initializer_range + self.type_vocab_size = type_vocab_size + self.layer_norm_eps = layer_norm_eps + self.use_cache = use_cache + self.question_token_id = question_token_id + + +__all__ = ["SplinterConfig"] diff --git a/docs/transformers/src/transformers/models/splinter/modeling_splinter.py b/docs/transformers/src/transformers/models/splinter/modeling_splinter.py new file mode 100644 index 0000000000000000000000000000000000000000..174e766598a0cedcab53c2250db2ab9978649fd7 --- /dev/null +++ b/docs/transformers/src/transformers/models/splinter/modeling_splinter.py @@ -0,0 +1,1115 @@ +# coding=utf-8 +# Copyright 2021 Tel AViv University, AllenAI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Splinter model.""" + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss + +from ...activations import ACT2FN +from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, ModelOutput, QuestionAnsweringModelOutput +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging +from .configuration_splinter import SplinterConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "tau/splinter-base" +_CONFIG_FOR_DOC = "SplinterConfig" + + +class SplinterEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values_length: Optional[int] = 0, + ) -> Tuple: + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] + + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->Splinter +class SplinterSelfAttention(nn.Module): + def __init__(self, config, position_embedding_type=None): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = position_embedding_type or getattr( + config, "position_embedding_type", "absolute" + ) + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) + + self.is_decoder = config.is_decoder + + def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_layer = past_key_value[0] + value_layer = past_key_value[1] + attention_mask = encoder_attention_mask + elif is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + use_cache = past_key_value is not None + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + query_length, key_length = query_layer.shape[2], key_layer.shape[2] + if use_cache: + position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( + -1, 1 + ) + else: + position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in SplinterModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + +# Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->Splinter +class SplinterSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +SPLINTER_SELF_ATTENTION_CLASSES = { + "eager": SplinterSelfAttention, +} + + +# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Splinter,BERT->SPLINTER +class SplinterAttention(nn.Module): + def __init__(self, config, position_embedding_type=None): + super().__init__() + self.self = SPLINTER_SELF_ATTENTION_CLASSES[config._attn_implementation]( + config, position_embedding_type=position_embedding_type + ) + self.output = SplinterSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->Splinter +class SplinterIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->Splinter +class SplinterOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Splinter +class SplinterLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = SplinterAttention(config) + self.is_decoder = config.is_decoder + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + if not self.is_decoder: + raise ValueError(f"{self} should be used as a decoder model if cross attention is added") + self.crossattention = SplinterAttention(config, position_embedding_type="absolute") + self.intermediate = SplinterIntermediate(config) + self.output = SplinterOutput(config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + + # if decoder, the last output is tuple of self-attn cache + if self.is_decoder: + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + else: + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + cross_attn_present_key_value = None + if self.is_decoder and encoder_hidden_states is not None: + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" + " by setting `config.add_cross_attention=True`" + ) + + # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + cross_attn_past_key_value, + output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + + # add cross-attn cache to positions 3,4 of present_key_value tuple + cross_attn_present_key_value = cross_attention_outputs[-1] + present_key_value = present_key_value + cross_attn_present_key_value + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output + ) + outputs = (layer_output,) + outputs + + # if decoder, return the attn key/values as the last output + if self.is_decoder: + outputs = outputs + (present_key_value,) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +# Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->Splinter +class SplinterEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList([SplinterLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + next_decoder_cache = () if use_cache else None + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.__call__, + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class SplinterPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = SplinterConfig + base_model_prefix = "splinter" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +SPLINTER_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use + it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`SplinterConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +SPLINTER_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.FloatTensor` of shape `{0}`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.LongTensor` of shape `{0}`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `{0}`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert *input_ids* indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare Splinter Model transformer outputting raw hidden-states without any specific head on top.", + SPLINTER_START_DOCSTRING, +) +class SplinterModel(SplinterPreTrainedModel): + """ + The model is an encoder (with only self-attention) following the architecture described in [Attention is all you + need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, + Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + """ + + def __init__(self, config): + super().__init__(config) + self.config = config + + self.embeddings = SplinterEmbeddings(config) + self.encoder = SplinterEncoder(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(SPLINTER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPastAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.config.is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + + if not return_dict: + return (sequence_output,) + encoder_outputs[1:] + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=sequence_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +class SplinterFullyConnectedLayer(nn.Module): + def __init__(self, input_dim, output_dim, hidden_act="gelu"): + super().__init__() + + self.input_dim = input_dim + self.output_dim = output_dim + + self.dense = nn.Linear(self.input_dim, self.output_dim) + self.act_fn = ACT2FN[hidden_act] + self.LayerNorm = nn.LayerNorm(self.output_dim) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(inputs) + hidden_states = self.act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class QuestionAwareSpanSelectionHead(nn.Module): + """ + Implementation of Question-Aware Span Selection (QASS) head, described in Splinter's paper: + + """ + + def __init__(self, config): + super().__init__() + + self.query_start_transform = SplinterFullyConnectedLayer(config.hidden_size, config.hidden_size) + self.query_end_transform = SplinterFullyConnectedLayer(config.hidden_size, config.hidden_size) + self.start_transform = SplinterFullyConnectedLayer(config.hidden_size, config.hidden_size) + self.end_transform = SplinterFullyConnectedLayer(config.hidden_size, config.hidden_size) + + self.start_classifier = nn.Linear(config.hidden_size, config.hidden_size, bias=False) + self.end_classifier = nn.Linear(config.hidden_size, config.hidden_size, bias=False) + + def forward(self, inputs, positions): + _, _, dim = inputs.size() + index = positions.unsqueeze(-1).repeat(1, 1, dim) # [batch_size, num_positions, dim] + gathered_reps = torch.gather(inputs, dim=1, index=index) # [batch_size, num_positions, dim] + + query_start_reps = self.query_start_transform(gathered_reps) # [batch_size, num_positions, dim] + query_end_reps = self.query_end_transform(gathered_reps) # [batch_size, num_positions, dim] + start_reps = self.start_transform(inputs) # [batch_size, seq_length, dim] + end_reps = self.end_transform(inputs) # [batch_size, seq_length, dim] + + hidden_states = self.start_classifier(query_start_reps) # [batch_size, num_positions, dim] + start_reps = start_reps.permute(0, 2, 1) # [batch_size, dim, seq_length] + start_logits = torch.matmul(hidden_states, start_reps) + + hidden_states = self.end_classifier(query_end_reps) + end_reps = end_reps.permute(0, 2, 1) + end_logits = torch.matmul(hidden_states, end_reps) + + return start_logits, end_logits + + +@add_start_docstrings( + """ + Splinter Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear + layers on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + SPLINTER_START_DOCSTRING, +) +class SplinterForQuestionAnswering(SplinterPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.splinter = SplinterModel(config) + self.splinter_qass = QuestionAwareSpanSelectionHead(config) + self.question_token_id = config.question_token_id + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(SPLINTER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=QuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + question_positions: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + question_positions (`torch.LongTensor` of shape `(batch_size, num_questions)`, *optional*): + The positions of all question tokens. If given, start_logits and end_logits will be of shape `(batch_size, + num_questions, sequence_length)`. If None, the first question token in each sequence in the batch will be + the only one for which start_logits and end_logits are calculated and they will be of shape `(batch_size, + sequence_length)`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + question_positions_were_none = False + if question_positions is None: + if input_ids is not None: + question_position_for_each_example = torch.argmax( + (torch.eq(input_ids, self.question_token_id)).int(), dim=-1 + ) + else: + question_position_for_each_example = torch.zeros( + inputs_embeds.size(0), dtype=torch.long, layout=inputs_embeds.layout, device=inputs_embeds.device + ) + question_positions = question_position_for_each_example.unsqueeze(-1) + question_positions_were_none = True + + outputs = self.splinter( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + start_logits, end_logits = self.splinter_qass(sequence_output, question_positions) + + if question_positions_were_none: + start_logits, end_logits = start_logits.squeeze(1), end_logits.squeeze(1) + + if attention_mask is not None: + start_logits = start_logits + (1 - attention_mask) * torch.finfo(start_logits.dtype).min + end_logits = end_logits + (1 - attention_mask) * torch.finfo(end_logits.dtype).min + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions.clamp_(0, ignored_index) + end_positions.clamp_(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[1:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@dataclass +class SplinterForPreTrainingOutput(ModelOutput): + """ + Class for outputs of Splinter as a span selection model. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when start and end positions are provided): + Total span extraction loss is the sum of a Cross-Entropy for the start and end positions. + start_logits (`torch.FloatTensor` of shape `(batch_size, num_questions, sequence_length)`): + Span-start scores (before SoftMax). + end_logits (`torch.FloatTensor` of shape `(batch_size, num_questions, sequence_length)`): + Span-end scores (before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + start_logits: Optional[torch.FloatTensor] = None + end_logits: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@add_start_docstrings( + """ + Splinter Model for the recurring span selection task as done during the pretraining. The difference to the QA task + is that we do not have a question, but multiple question tokens that replace the occurrences of recurring spans + instead. + """, + SPLINTER_START_DOCSTRING, +) +class SplinterForPreTraining(SplinterPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.splinter = SplinterModel(config) + self.splinter_qass = QuestionAwareSpanSelectionHead(config) + self.question_token_id = config.question_token_id + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward( + SPLINTER_INPUTS_DOCSTRING.format("batch_size, num_questions, sequence_length") + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + question_positions: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, SplinterForPreTrainingOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size, num_questions)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size, num_questions)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + question_positions (`torch.LongTensor` of shape `(batch_size, num_questions)`, *optional*): + The positions of all question tokens. If given, start_logits and end_logits will be of shape `(batch_size, + num_questions, sequence_length)`. If None, the first question token in each sequence in the batch will be + the only one for which start_logits and end_logits are calculated and they will be of shape `(batch_size, + sequence_length)`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if question_positions is None and start_positions is not None and end_positions is not None: + raise TypeError("question_positions must be specified in order to calculate the loss") + + elif question_positions is None and input_ids is None: + raise TypeError("question_positions must be specified when input_embeds is used") + + elif question_positions is None: + question_positions = self._prepare_question_positions(input_ids) + + outputs = self.splinter( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + batch_size, sequence_length, dim = sequence_output.size() + # [batch_size, num_questions, sequence_length] + start_logits, end_logits = self.splinter_qass(sequence_output, question_positions) + + num_questions = question_positions.size(1) + if attention_mask is not None: + attention_mask_for_each_question = attention_mask.unsqueeze(1).expand( + batch_size, num_questions, sequence_length + ) + start_logits = start_logits + (1 - attention_mask_for_each_question) * torch.finfo(start_logits.dtype).min + end_logits = end_logits + (1 - attention_mask_for_each_question) * torch.finfo(end_logits.dtype).min + + total_loss = None + # [batch_size, num_questions, sequence_length] + if start_positions is not None and end_positions is not None: + # sometimes the start/end positions are outside our model inputs, we ignore these terms + start_positions.clamp_(0, max(0, sequence_length - 1)) + end_positions.clamp_(0, max(0, sequence_length - 1)) + + # Ignore zero positions in the loss. Splinter never predicts zero + # during pretraining and zero is used for padding question + # tokens as well as for start and end positions of padded + # question tokens. + loss_fct = CrossEntropyLoss(ignore_index=self.config.pad_token_id) + start_loss = loss_fct( + start_logits.view(batch_size * num_questions, sequence_length), + start_positions.view(batch_size * num_questions), + ) + end_loss = loss_fct( + end_logits.view(batch_size * num_questions, sequence_length), + end_positions.view(batch_size * num_questions), + ) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[1:] + return ((total_loss,) + output) if total_loss is not None else output + + return SplinterForPreTrainingOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def _prepare_question_positions(self, input_ids: torch.Tensor) -> torch.Tensor: + rows, flat_positions = torch.where(input_ids == self.config.question_token_id) + num_questions = torch.bincount(rows) + positions = torch.full( + (input_ids.size(0), num_questions.max()), + self.config.pad_token_id, + dtype=torch.long, + device=input_ids.device, + ) + cols = torch.cat([torch.arange(n) for n in num_questions]) + positions[rows, cols] = flat_positions + return positions + + +__all__ = [ + "SplinterForQuestionAnswering", + "SplinterForPreTraining", + "SplinterLayer", + "SplinterModel", + "SplinterPreTrainedModel", +] diff --git a/docs/transformers/src/transformers/models/splinter/tokenization_splinter.py b/docs/transformers/src/transformers/models/splinter/tokenization_splinter.py new file mode 100644 index 0000000000000000000000000000000000000000..76b7543eb2eec474b947450360ba1893c55cb8b1 --- /dev/null +++ b/docs/transformers/src/transformers/models/splinter/tokenization_splinter.py @@ -0,0 +1,507 @@ +# coding=utf-8 +# Copyright 2021 Tel AViv University, AllenAI and The HuggingFace Inc. team. All rights reserved. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for Splinter.""" + +import collections +import os +import unicodedata +from typing import List, Optional, Tuple + +from ...tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace +from ...utils import logging + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} + + +def load_vocab(vocab_file): + """Loads a vocabulary file into a dictionary.""" + vocab = collections.OrderedDict() + with open(vocab_file, "r", encoding="utf-8") as reader: + tokens = reader.readlines() + for index, token in enumerate(tokens): + token = token.rstrip("\n") + vocab[token] = index + return vocab + + +def whitespace_tokenize(text): + """Runs basic whitespace cleaning and splitting on a piece of text.""" + text = text.strip() + if not text: + return [] + tokens = text.split() + return tokens + + +class SplinterTokenizer(PreTrainedTokenizer): + r""" + Construct a Splinter tokenizer. Based on WordPiece. + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + File containing the vocabulary. + do_lower_case (`bool`, *optional*, defaults to `True`): + Whether or not to lowercase the input when tokenizing. + do_basic_tokenize (`bool`, *optional*, defaults to `True`): + Whether or not to do basic tokenization before WordPiece. + never_split (`Iterable`, *optional*): + Collection of tokens which will never be split during tokenization. Only has an effect when + `do_basic_tokenize=True` + unk_token (`str`, *optional*, defaults to `"[UNK]"`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + sep_token (`str`, *optional*, defaults to `"[SEP]"`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + pad_token (`str`, *optional*, defaults to `"[PAD]"`): + The token used for padding, for example when batching sequences of different lengths. + cls_token (`str`, *optional*, defaults to `"[CLS]"`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + mask_token (`str`, *optional*, defaults to `"[MASK]"`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + question_token (`str`, *optional*, defaults to `"[QUESTION]"`): + The token used for constructing question representations. + tokenize_chinese_chars (`bool`, *optional*, defaults to `True`): + Whether or not to tokenize Chinese characters. + + This should likely be deactivated for Japanese (see this + [issue](https://github.com/huggingface/transformers/issues/328)). + strip_accents (`bool`, *optional*): + Whether or not to strip all accents. If this option is not specified, then it will be determined by the + value for `lowercase` (as in the original BERT). + """ + + vocab_files_names = VOCAB_FILES_NAMES + + def __init__( + self, + vocab_file, + do_lower_case=True, + do_basic_tokenize=True, + never_split=None, + unk_token="[UNK]", + sep_token="[SEP]", + pad_token="[PAD]", + cls_token="[CLS]", + mask_token="[MASK]", + question_token="[QUESTION]", + tokenize_chinese_chars=True, + strip_accents=None, + **kwargs, + ): + if not os.path.isfile(vocab_file): + raise ValueError( + f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained" + " model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`" + ) + self.vocab = load_vocab(vocab_file) + self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()]) + self.do_basic_tokenize = do_basic_tokenize + if do_basic_tokenize: + self.basic_tokenizer = BasicTokenizer( + do_lower_case=do_lower_case, + never_split=never_split, + tokenize_chinese_chars=tokenize_chinese_chars, + strip_accents=strip_accents, + ) + self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=str(unk_token)) + self.question_token = question_token + super().__init__( + do_lower_case=do_lower_case, + do_basic_tokenize=do_basic_tokenize, + never_split=never_split, + unk_token=unk_token, + sep_token=sep_token, + pad_token=pad_token, + cls_token=cls_token, + mask_token=mask_token, + question_token=question_token, + tokenize_chinese_chars=tokenize_chinese_chars, + strip_accents=strip_accents, + **kwargs, + ) + + @property + def question_token_id(self): + """ + `Optional[int]`: Id of the question token in the vocabulary, used to condition the answer on a question + representation. + """ + return self.convert_tokens_to_ids(self.question_token) + + @property + def do_lower_case(self): + return self.basic_tokenizer.do_lower_case + + @property + def vocab_size(self): + return len(self.vocab) + + def get_vocab(self): + return dict(self.vocab, **self.added_tokens_encoder) + + def _tokenize(self, text): + split_tokens = [] + if self.do_basic_tokenize: + for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens): + # If the token is part of the never_split set + if token in self.basic_tokenizer.never_split: + split_tokens.append(token) + else: + split_tokens += self.wordpiece_tokenizer.tokenize(token) + else: + split_tokens = self.wordpiece_tokenizer.tokenize(text) + return split_tokens + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.vocab.get(token, self.vocab.get(self.unk_token)) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.ids_to_tokens.get(index, self.unk_token) + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + out_string = " ".join(tokens).replace(" ##", "").strip() + return out_string + + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a pair of sequence for question answering tasks by concatenating and adding special + tokens. A Splinter sequence has the following format: + + - single sequence: `[CLS] X [SEP]` + - pair of sequences for question answering: `[CLS] question_tokens [QUESTION] . [SEP] context_tokens [SEP]` + + Args: + token_ids_0 (`List[int]`): + The question token IDs if pad_on_right, else context tokens IDs + token_ids_1 (`List[int]`, *optional*): + The context token IDs if pad_on_right, else question token IDs + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + if token_ids_1 is None: + return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + + cls = [self.cls_token_id] + sep = [self.sep_token_id] + question_suffix = [self.question_token_id] + [self.convert_tokens_to_ids(".")] + if self.padding_side == "right": + # Input is question-then-context + return cls + token_ids_0 + question_suffix + sep + token_ids_1 + sep + else: + # Input is context-then-question + return cls + token_ids_0 + sep + token_ids_1 + question_suffix + sep + + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + if token_ids_1 is not None: + return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1] + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create the token type IDs corresponding to the sequences passed. [What are token type + IDs?](../glossary#token-type-ids) + + Should be overridden in a subclass if the model has a special way of building those. + + Args: + token_ids_0 (`List[int]`): The first tokenized sequence. + token_ids_1 (`List[int]`, *optional*): The second tokenized sequence. + + Returns: + `List[int]`: The token type ids. + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + question_suffix = [self.question_token_id] + [self.convert_tokens_to_ids(".")] + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + + if self.padding_side == "right": + # Input is question-then-context + return len(cls + token_ids_0 + question_suffix + sep) * [0] + len(token_ids_1 + sep) * [1] + else: + # Input is context-then-question + return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + question_suffix + sep) * [1] + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + index = 0 + if os.path.isdir(save_directory): + vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + else: + vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory + with open(vocab_file, "w", encoding="utf-8") as writer: + for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning( + f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive." + " Please check that the vocabulary is not corrupted!" + ) + index = token_index + writer.write(token + "\n") + index += 1 + return (vocab_file,) + + +class BasicTokenizer: + """ + Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.). + + Args: + do_lower_case (`bool`, *optional*, defaults to `True`): + Whether or not to lowercase the input when tokenizing. + never_split (`Iterable`, *optional*): + Collection of tokens which will never be split during tokenization. Only has an effect when + `do_basic_tokenize=True` + tokenize_chinese_chars (`bool`, *optional*, defaults to `True`): + Whether or not to tokenize Chinese characters. + + This should likely be deactivated for Japanese (see this + [issue](https://github.com/huggingface/transformers/issues/328)). + strip_accents (`bool`, *optional*): + Whether or not to strip all accents. If this option is not specified, then it will be determined by the + value for `lowercase` (as in the original BERT). + """ + + def __init__(self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True, strip_accents=None): + if never_split is None: + never_split = [] + self.do_lower_case = do_lower_case + self.never_split = set(never_split) + self.tokenize_chinese_chars = tokenize_chinese_chars + self.strip_accents = strip_accents + + def tokenize(self, text, never_split=None): + """ + Basic Tokenization of a piece of text. Split on "white spaces" only, for sub-word tokenization, see + WordPieceTokenizer. + + Args: + **never_split**: (*optional*) list of str + Kept for backward compatibility purposes. Now implemented directly at the base class level (see + [`PreTrainedTokenizer.tokenize`]) List of token not to split. + """ + # union() returns a new set by concatenating the two sets. + never_split = self.never_split.union(set(never_split)) if never_split else self.never_split + text = self._clean_text(text) + + # This was added on November 1st, 2018 for the multilingual and Chinese + # models. This is also applied to the English models now, but it doesn't + # matter since the English models were not trained on any Chinese data + # and generally don't have any Chinese data in them (there are Chinese + # characters in the vocabulary because Wikipedia does have some Chinese + # words in the English Wikipedia.). + if self.tokenize_chinese_chars: + text = self._tokenize_chinese_chars(text) + orig_tokens = whitespace_tokenize(text) + split_tokens = [] + for token in orig_tokens: + if token not in never_split: + if self.do_lower_case: + token = token.lower() + if self.strip_accents is not False: + token = self._run_strip_accents(token) + elif self.strip_accents: + token = self._run_strip_accents(token) + split_tokens.extend(self._run_split_on_punc(token, never_split)) + + output_tokens = whitespace_tokenize(" ".join(split_tokens)) + return output_tokens + + def _run_strip_accents(self, text): + """Strips accents from a piece of text.""" + text = unicodedata.normalize("NFD", text) + output = [] + for char in text: + cat = unicodedata.category(char) + if cat == "Mn": + continue + output.append(char) + return "".join(output) + + def _run_split_on_punc(self, text, never_split=None): + """Splits punctuation on a piece of text.""" + if never_split is not None and text in never_split: + return [text] + chars = list(text) + i = 0 + start_new_word = True + output = [] + while i < len(chars): + char = chars[i] + if _is_punctuation(char): + output.append([char]) + start_new_word = True + else: + if start_new_word: + output.append([]) + start_new_word = False + output[-1].append(char) + i += 1 + + return ["".join(x) for x in output] + + def _tokenize_chinese_chars(self, text): + """Adds whitespace around any CJK character.""" + output = [] + for char in text: + cp = ord(char) + if self._is_chinese_char(cp): + output.append(" ") + output.append(char) + output.append(" ") + else: + output.append(char) + return "".join(output) + + def _is_chinese_char(self, cp): + """Checks whether CP is the codepoint of a CJK character.""" + # This defines a "chinese character" as anything in the CJK Unicode block: + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + # + # Note that the CJK Unicode block is NOT all Japanese and Korean characters, + # despite its name. The modern Korean Hangul alphabet is a different block, + # as is Japanese Hiragana and Katakana. Those alphabets are used to write + # space-separated words, so they are not treated specially and handled + # like the all of the other languages. + if ( + (cp >= 0x4E00 and cp <= 0x9FFF) + or (cp >= 0x3400 and cp <= 0x4DBF) # + or (cp >= 0x20000 and cp <= 0x2A6DF) # + or (cp >= 0x2A700 and cp <= 0x2B73F) # + or (cp >= 0x2B740 and cp <= 0x2B81F) # + or (cp >= 0x2B820 and cp <= 0x2CEAF) # + or (cp >= 0xF900 and cp <= 0xFAFF) + or (cp >= 0x2F800 and cp <= 0x2FA1F) # + ): # + return True + + return False + + def _clean_text(self, text): + """Performs invalid character removal and whitespace cleanup on text.""" + output = [] + for char in text: + cp = ord(char) + if cp == 0 or cp == 0xFFFD or _is_control(char): + continue + if _is_whitespace(char): + output.append(" ") + else: + output.append(char) + return "".join(output) + + +class WordpieceTokenizer: + """Runs WordPiece tokenization.""" + + def __init__(self, vocab, unk_token, max_input_chars_per_word=100): + self.vocab = vocab + self.unk_token = unk_token + self.max_input_chars_per_word = max_input_chars_per_word + + def tokenize(self, text): + """ + Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform + tokenization using the given vocabulary. + + For example, `input = "unaffable"` wil return as output `["un", "##aff", "##able"]`. + + Args: + text: A single token or whitespace separated tokens. This should have + already been passed through *BasicTokenizer*. + + Returns: + A list of wordpiece tokens. + """ + + output_tokens = [] + for token in whitespace_tokenize(text): + chars = list(token) + if len(chars) > self.max_input_chars_per_word: + output_tokens.append(self.unk_token) + continue + + is_bad = False + start = 0 + sub_tokens = [] + while start < len(chars): + end = len(chars) + cur_substr = None + while start < end: + substr = "".join(chars[start:end]) + if start > 0: + substr = "##" + substr + if substr in self.vocab: + cur_substr = substr + break + end -= 1 + if cur_substr is None: + is_bad = True + break + sub_tokens.append(cur_substr) + start = end + + if is_bad: + output_tokens.append(self.unk_token) + else: + output_tokens.extend(sub_tokens) + return output_tokens + + +__all__ = ["SplinterTokenizer"] diff --git a/docs/transformers/src/transformers/models/splinter/tokenization_splinter_fast.py b/docs/transformers/src/transformers/models/splinter/tokenization_splinter_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..85dd01a2be036fb2bd9a9201fa31f904facfeaf5 --- /dev/null +++ b/docs/transformers/src/transformers/models/splinter/tokenization_splinter_fast.py @@ -0,0 +1,193 @@ +# coding=utf-8 +# Copyright 2021 Tel AViv University, AllenAI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Tokenization classes for Splinter.""" + +import json +from typing import List, Optional, Tuple + +from tokenizers import normalizers + +from ...tokenization_utils_fast import PreTrainedTokenizerFast +from ...utils import logging +from .tokenization_splinter import SplinterTokenizer + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} + + +class SplinterTokenizerFast(PreTrainedTokenizerFast): + r""" + Construct a "fast" Splinter tokenizer (backed by HuggingFace's *tokenizers* library). Based on WordPiece. + + This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should + refer to this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + File containing the vocabulary. + do_lower_case (`bool`, *optional*, defaults to `True`): + Whether or not to lowercase the input when tokenizing. + unk_token (`str`, *optional*, defaults to `"[UNK]"`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + sep_token (`str`, *optional*, defaults to `"[SEP]"`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + pad_token (`str`, *optional*, defaults to `"[PAD]"`): + The token used for padding, for example when batching sequences of different lengths. + cls_token (`str`, *optional*, defaults to `"[CLS]"`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + mask_token (`str`, *optional*, defaults to `"[MASK]"`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + question_token (`str`, *optional*, defaults to `"[QUESTION]"`): + The token used for constructing question representations. + clean_text (`bool`, *optional*, defaults to `True`): + Whether or not to clean the text before tokenization by removing any control characters and replacing all + whitespaces by the classic one. + tokenize_chinese_chars (`bool`, *optional*, defaults to `True`): + Whether or not to tokenize Chinese characters. This should likely be deactivated for Japanese (see [this + issue](https://github.com/huggingface/transformers/issues/328)). + strip_accents (`bool`, *optional*): + Whether or not to strip all accents. If this option is not specified, then it will be determined by the + value for `lowercase` (as in the original BERT). + wordpieces_prefix (`str`, *optional*, defaults to `"##"`): + The prefix for subwords. + """ + + vocab_files_names = VOCAB_FILES_NAMES + slow_tokenizer_class = SplinterTokenizer + + def __init__( + self, + vocab_file=None, + tokenizer_file=None, + do_lower_case=True, + unk_token="[UNK]", + sep_token="[SEP]", + pad_token="[PAD]", + cls_token="[CLS]", + mask_token="[MASK]", + question_token="[QUESTION]", + tokenize_chinese_chars=True, + strip_accents=None, + **kwargs, + ): + super().__init__( + vocab_file, + tokenizer_file=tokenizer_file, + do_lower_case=do_lower_case, + unk_token=unk_token, + sep_token=sep_token, + pad_token=pad_token, + cls_token=cls_token, + mask_token=mask_token, + tokenize_chinese_chars=tokenize_chinese_chars, + strip_accents=strip_accents, + additional_special_tokens=(question_token,), + **kwargs, + ) + + pre_tok_state = json.loads(self.backend_tokenizer.normalizer.__getstate__()) + if ( + pre_tok_state.get("lowercase", do_lower_case) != do_lower_case + or pre_tok_state.get("strip_accents", strip_accents) != strip_accents + ): + pre_tok_class = getattr(normalizers, pre_tok_state.pop("type")) + pre_tok_state["lowercase"] = do_lower_case + pre_tok_state["strip_accents"] = strip_accents + self.backend_tokenizer.normalizer = pre_tok_class(**pre_tok_state) + + self.do_lower_case = do_lower_case + + @property + def question_token_id(self): + """ + `Optional[int]`: Id of the question token in the vocabulary, used to condition the answer on a question + representation. + """ + return self.convert_tokens_to_ids(self.question_token) + + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a pair of sequence for question answering tasks by concatenating and adding special + tokens. A Splinter sequence has the following format: + + - single sequence: `[CLS] X [SEP]` + - pair of sequences for question answering: `[CLS] question_tokens [QUESTION] . [SEP] context_tokens [SEP]` + + Args: + token_ids_0 (`List[int]`): + The question token IDs if pad_on_right, else context tokens IDs + token_ids_1 (`List[int]`, *optional*): + The context token IDs if pad_on_right, else question token IDs + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + if token_ids_1 is None: + return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + + cls = [self.cls_token_id] + sep = [self.sep_token_id] + question_suffix = [self.question_token_id] + [self.convert_tokens_to_ids(".")] + if self.padding_side == "right": + # Input is question-then-context + return cls + token_ids_0 + question_suffix + sep + token_ids_1 + sep + else: + # Input is context-then-question + return cls + token_ids_0 + sep + token_ids_1 + question_suffix + sep + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create the token type IDs corresponding to the sequences passed. [What are token type + IDs?](../glossary#token-type-ids) + + Should be overridden in a subclass if the model has a special way of building those. + + Args: + token_ids_0 (`List[int]`): The first tokenized sequence. + token_ids_1 (`List[int]`, *optional*): The second tokenized sequence. + + Returns: + `List[int]`: The token type ids. + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + question_suffix = [self.question_token_id] + [self.convert_tokens_to_ids(".")] + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + + if self.padding_side == "right": + # Input is question-then-context + return len(cls + token_ids_0 + question_suffix + sep) * [0] + len(token_ids_1 + sep) * [1] + else: + # Input is context-then-question + return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + question_suffix + sep) * [1] + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + files = self._tokenizer.model.save(save_directory, name=filename_prefix) + return tuple(files) + + +__all__ = ["SplinterTokenizerFast"] diff --git a/docs/transformers/src/transformers/models/squeezebert/__init__.py b/docs/transformers/src/transformers/models/squeezebert/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e0a760d2b5cae36272850ddd2035976e476a6315 --- /dev/null +++ b/docs/transformers/src/transformers/models/squeezebert/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_squeezebert import * + from .modeling_squeezebert import * + from .tokenization_squeezebert import * + from .tokenization_squeezebert_fast import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/squeezebert/configuration_squeezebert.py b/docs/transformers/src/transformers/models/squeezebert/configuration_squeezebert.py new file mode 100644 index 0000000000000000000000000000000000000000..a659517545530eadac957d23e6cb2cb4e6f0d13e --- /dev/null +++ b/docs/transformers/src/transformers/models/squeezebert/configuration_squeezebert.py @@ -0,0 +1,167 @@ +# coding=utf-8 +# Copyright 2020 The SqueezeBert authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""SqueezeBERT model configuration""" + +from collections import OrderedDict +from typing import Mapping + +from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class SqueezeBertConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`SqueezeBertModel`]. It is used to instantiate a + SqueezeBERT model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the SqueezeBERT + [squeezebert/squeezebert-uncased](https://huggingface.co/squeezebert/squeezebert-uncased) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 30522): + Vocabulary size of the SqueezeBERT model. Defines the number of different tokens that can be represented by + the `inputs_ids` passed when calling [`SqueezeBertModel`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + type_vocab_size (`int`, *optional*, defaults to 2): + The vocabulary size of the `token_type_ids` passed when calling [`BertModel`] or [`TFBertModel`]. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + + pad_token_id (`int`, *optional*, defaults to 0): + The ID of the token in the word embedding to use as padding. + embedding_size (`int`, *optional*, defaults to 768): + The dimension of the word embedding vectors. + + q_groups (`int`, *optional*, defaults to 4): + The number of groups in Q layer. + k_groups (`int`, *optional*, defaults to 4): + The number of groups in K layer. + v_groups (`int`, *optional*, defaults to 4): + The number of groups in V layer. + post_attention_groups (`int`, *optional*, defaults to 1): + The number of groups in the first feed forward network layer. + intermediate_groups (`int`, *optional*, defaults to 4): + The number of groups in the second feed forward network layer. + output_groups (`int`, *optional*, defaults to 4): + The number of groups in the third feed forward network layer. + + Examples: + + ```python + >>> from transformers import SqueezeBertConfig, SqueezeBertModel + + >>> # Initializing a SqueezeBERT configuration + >>> configuration = SqueezeBertConfig() + + >>> # Initializing a model (with random weights) from the configuration above + >>> model = SqueezeBertModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "squeezebert" + + def __init__( + self, + vocab_size=30522, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + layer_norm_eps=1e-12, + pad_token_id=0, + embedding_size=768, + q_groups=4, + k_groups=4, + v_groups=4, + post_attention_groups=1, + intermediate_groups=4, + output_groups=4, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.embedding_size = embedding_size + self.q_groups = q_groups + self.k_groups = k_groups + self.v_groups = v_groups + self.post_attention_groups = post_attention_groups + self.intermediate_groups = intermediate_groups + self.output_groups = output_groups + + +# # Copied from transformers.models.bert.configuration_bert.BertOnxxConfig with Bert->SqueezeBert +class SqueezeBertOnnxConfig(OnnxConfig): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + if self.task == "multiple-choice": + dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"} + else: + dynamic_axis = {0: "batch", 1: "sequence"} + return OrderedDict( + [ + ("input_ids", dynamic_axis), + ("attention_mask", dynamic_axis), + ("token_type_ids", dynamic_axis), + ] + ) + + +__all__ = ["SqueezeBertConfig", "SqueezeBertOnnxConfig"] diff --git a/docs/transformers/src/transformers/models/squeezebert/modeling_squeezebert.py b/docs/transformers/src/transformers/models/squeezebert/modeling_squeezebert.py new file mode 100644 index 0000000000000000000000000000000000000000..efb29035ba94233b853041d12371087603a728a8 --- /dev/null +++ b/docs/transformers/src/transformers/models/squeezebert/modeling_squeezebert.py @@ -0,0 +1,1101 @@ +# coding=utf-8 +# Copyright 2020 The SqueezeBert authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch SqueezeBert model.""" + +import math +from typing import Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPooling, + MaskedLMOutput, + MultipleChoiceModelOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from ...modeling_utils import PreTrainedModel +from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging +from .configuration_squeezebert import SqueezeBertConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "squeezebert/squeezebert-uncased" +_CONFIG_FOR_DOC = "SqueezeBertConfig" + + +class SqueezeBertEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.embedding_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.embedding_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + + def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + position_embeddings + token_type_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class MatMulWrapper(nn.Module): + """ + Wrapper for torch.matmul(). This makes flop-counting easier to implement. Note that if you directly call + torch.matmul() in your code, the flop counter will typically ignore the flops of the matmul. + """ + + def __init__(self): + super().__init__() + + def forward(self, mat1, mat2): + """ + + :param inputs: two torch tensors :return: matmul of these tensors + + Here are the typical dimensions found in BERT (the B is optional) mat1.shape: [B, , M, K] + mat2.shape: [B, , K, N] output shape: [B, , M, N] + """ + return torch.matmul(mat1, mat2) + + +class SqueezeBertLayerNorm(nn.LayerNorm): + """ + This is a nn.LayerNorm subclass that accepts NCW data layout and performs normalization in the C dimension. + + N = batch C = channels W = sequence length + """ + + def __init__(self, hidden_size, eps=1e-12): + nn.LayerNorm.__init__(self, normalized_shape=hidden_size, eps=eps) # instantiates self.{weight, bias, eps} + + def forward(self, x): + x = x.permute(0, 2, 1) + x = nn.LayerNorm.forward(self, x) + return x.permute(0, 2, 1) + + +class ConvDropoutLayerNorm(nn.Module): + """ + ConvDropoutLayerNorm: Conv, Dropout, LayerNorm + """ + + def __init__(self, cin, cout, groups, dropout_prob): + super().__init__() + + self.conv1d = nn.Conv1d(in_channels=cin, out_channels=cout, kernel_size=1, groups=groups) + self.layernorm = SqueezeBertLayerNorm(cout) + self.dropout = nn.Dropout(dropout_prob) + + def forward(self, hidden_states, input_tensor): + x = self.conv1d(hidden_states) + x = self.dropout(x) + x = x + input_tensor + x = self.layernorm(x) + return x + + +class ConvActivation(nn.Module): + """ + ConvActivation: Conv, Activation + """ + + def __init__(self, cin, cout, groups, act): + super().__init__() + self.conv1d = nn.Conv1d(in_channels=cin, out_channels=cout, kernel_size=1, groups=groups) + self.act = ACT2FN[act] + + def forward(self, x): + output = self.conv1d(x) + return self.act(output) + + +class SqueezeBertSelfAttention(nn.Module): + def __init__(self, config, cin, q_groups=1, k_groups=1, v_groups=1): + """ + config = used for some things; ignored for others (work in progress...) cin = input channels = output channels + groups = number of groups to use in conv1d layers + """ + super().__init__() + if cin % config.num_attention_heads != 0: + raise ValueError( + f"cin ({cin}) is not a multiple of the number of attention heads ({config.num_attention_heads})" + ) + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(cin / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Conv1d(in_channels=cin, out_channels=cin, kernel_size=1, groups=q_groups) + self.key = nn.Conv1d(in_channels=cin, out_channels=cin, kernel_size=1, groups=k_groups) + self.value = nn.Conv1d(in_channels=cin, out_channels=cin, kernel_size=1, groups=v_groups) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.softmax = nn.Softmax(dim=-1) + + self.matmul_qk = MatMulWrapper() + self.matmul_qkv = MatMulWrapper() + + def transpose_for_scores(self, x): + """ + - input: [N, C, W] + - output: [N, C1, W, C2] where C1 is the head index, and C2 is one head's contents + """ + new_x_shape = (x.size()[0], self.num_attention_heads, self.attention_head_size, x.size()[-1]) # [N, C1, C2, W] + x = x.view(*new_x_shape) + return x.permute(0, 1, 3, 2) # [N, C1, C2, W] --> [N, C1, W, C2] + + def transpose_key_for_scores(self, x): + """ + - input: [N, C, W] + - output: [N, C1, C2, W] where C1 is the head index, and C2 is one head's contents + """ + new_x_shape = (x.size()[0], self.num_attention_heads, self.attention_head_size, x.size()[-1]) # [N, C1, C2, W] + x = x.view(*new_x_shape) + # no `permute` needed + return x + + def transpose_output(self, x): + """ + - input: [N, C1, W, C2] + - output: [N, C, W] + """ + x = x.permute(0, 1, 3, 2).contiguous() # [N, C1, C2, W] + new_x_shape = (x.size()[0], self.all_head_size, x.size()[3]) # [N, C, W] + x = x.view(*new_x_shape) + return x + + def forward(self, hidden_states, attention_mask, output_attentions): + """ + expects hidden_states in [N, C, W] data layout. + + The attention_mask data layout is [N, W], and it does not need to be transposed. + """ + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_key_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_score = self.matmul_qk(query_layer, key_layer) + attention_score = attention_score / math.sqrt(self.attention_head_size) + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_score = attention_score + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = self.softmax(attention_score) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + context_layer = self.matmul_qkv(attention_probs, value_layer) + context_layer = self.transpose_output(context_layer) + + result = {"context_layer": context_layer} + if output_attentions: + result["attention_score"] = attention_score + return result + + +class SqueezeBertModule(nn.Module): + def __init__(self, config): + """ + - hidden_size = input chans = output chans for Q, K, V (they are all the same ... for now) = output chans for + the module + - intermediate_size = output chans for intermediate layer + - groups = number of groups for all layers in the BertModule. (eventually we could change the interface to + allow different groups for different layers) + """ + super().__init__() + + c0 = config.hidden_size + c1 = config.hidden_size + c2 = config.intermediate_size + c3 = config.hidden_size + + self.attention = SqueezeBertSelfAttention( + config=config, cin=c0, q_groups=config.q_groups, k_groups=config.k_groups, v_groups=config.v_groups + ) + self.post_attention = ConvDropoutLayerNorm( + cin=c0, cout=c1, groups=config.post_attention_groups, dropout_prob=config.hidden_dropout_prob + ) + self.intermediate = ConvActivation(cin=c1, cout=c2, groups=config.intermediate_groups, act=config.hidden_act) + self.output = ConvDropoutLayerNorm( + cin=c2, cout=c3, groups=config.output_groups, dropout_prob=config.hidden_dropout_prob + ) + + def forward(self, hidden_states, attention_mask, output_attentions): + att = self.attention(hidden_states, attention_mask, output_attentions) + attention_output = att["context_layer"] + + post_attention_output = self.post_attention(attention_output, hidden_states) + intermediate_output = self.intermediate(post_attention_output) + layer_output = self.output(intermediate_output, post_attention_output) + + output_dict = {"feature_map": layer_output} + if output_attentions: + output_dict["attention_score"] = att["attention_score"] + + return output_dict + + +class SqueezeBertEncoder(nn.Module): + def __init__(self, config): + super().__init__() + + assert config.embedding_size == config.hidden_size, ( + "If you want embedding_size != intermediate hidden_size, " + "please insert a Conv1d layer to adjust the number of channels " + "before the first SqueezeBertModule." + ) + + self.layers = nn.ModuleList(SqueezeBertModule(config) for _ in range(config.num_hidden_layers)) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + ): + if head_mask is None: + head_mask_is_all_none = True + elif head_mask.count(None) == len(head_mask): + head_mask_is_all_none = True + else: + head_mask_is_all_none = False + assert head_mask_is_all_none is True, "head_mask is not yet supported in the SqueezeBert implementation." + + # [batch_size, sequence_length, hidden_size] --> [batch_size, hidden_size, sequence_length] + hidden_states = hidden_states.permute(0, 2, 1) + + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + for layer in self.layers: + if output_hidden_states: + hidden_states = hidden_states.permute(0, 2, 1) + all_hidden_states += (hidden_states,) + hidden_states = hidden_states.permute(0, 2, 1) + + layer_output = layer.forward(hidden_states, attention_mask, output_attentions) + + hidden_states = layer_output["feature_map"] + + if output_attentions: + all_attentions += (layer_output["attention_score"],) + + # [batch_size, hidden_size, sequence_length] --> [batch_size, sequence_length, hidden_size] + hidden_states = hidden_states.permute(0, 2, 1) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions + ) + + +class SqueezeBertPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class SqueezeBertPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class SqueezeBertLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = SqueezeBertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def _tie_weights(self) -> None: + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class SqueezeBertOnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = SqueezeBertLMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class SqueezeBertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = SqueezeBertConfig + base_model_prefix = "transformer" + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv1d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, SqueezeBertLMPredictionHead): + module.bias.data.zero_() + + +SQUEEZEBERT_START_DOCSTRING = r""" + + The SqueezeBERT model was proposed in [SqueezeBERT: What can computer vision teach NLP about efficient neural + networks?](https://arxiv.org/abs/2006.11316) by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. + Keutzer + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + For best results finetuning SqueezeBERT on text classification tasks, it is recommended to use the + *squeezebert/squeezebert-mnli-headless* checkpoint as a starting point. + + Parameters: + config ([`SqueezeBertConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. + + Hierarchy: + + ``` + Internal class hierarchy: + SqueezeBertModel + SqueezeBertEncoder + SqueezeBertModule + SqueezeBertSelfAttention + ConvActivation + ConvDropoutLayerNorm + ``` + + Data layouts: + + ``` + Input data is in [batch, sequence_length, hidden_size] format. + + Data inside the encoder is in [batch, hidden_size, sequence_length] format. But, if `output_hidden_states == True`, the data from inside the encoder is returned in [batch, sequence_length, hidden_size] format. + + The final output of the encoder is in [batch, sequence_length, hidden_size] format. + ``` +""" + +SQUEEZEBERT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare SqueezeBERT Model transformer outputting raw hidden-states without any specific head on top.", + SQUEEZEBERT_START_DOCSTRING, +) +class SqueezeBertModel(SqueezeBertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.embeddings = SqueezeBertEmbeddings(config) + self.encoder = SqueezeBertEncoder(config) + self.pooler = SqueezeBertPooler(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, new_embeddings): + self.embeddings.word_embeddings = new_embeddings + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(SQUEEZEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPooling, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if attention_mask is None: + attention_mask = torch.ones(input_shape, device=device) + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings( + input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds + ) + encoder_outputs = self.encoder( + hidden_states=embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings("""SqueezeBERT Model with a `language modeling` head on top.""", SQUEEZEBERT_START_DOCSTRING) +class SqueezeBertForMaskedLM(SqueezeBertPreTrainedModel): + _tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.transformer = SqueezeBertModel(config) + self.cls = SqueezeBertOnlyMLMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias + + @add_start_docstrings_to_model_forward(SQUEEZEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MaskedLMOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, MaskedLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + SqueezeBERT Model transformer with a sequence classification/regression head on top (a linear layer on top of the + pooled output) e.g. for GLUE tasks. + """, + SQUEEZEBERT_START_DOCSTRING, +) +class SqueezeBertForSequenceClassification(SqueezeBertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + + self.transformer = SqueezeBertModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, self.config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(SQUEEZEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=SequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + SqueezeBERT Model with a multiple choice classification head on top (a linear layer on top of the pooled output and + a softmax) e.g. for RocStories/SWAG tasks. + """, + SQUEEZEBERT_START_DOCSTRING, +) +class SqueezeBertForMultipleChoice(SqueezeBertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.transformer = SqueezeBertModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, 1) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward( + SQUEEZEBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length") + ) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MultipleChoiceModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, MultipleChoiceModelOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., + num_choices-1]` where *num_choices* is the size of the second dimension of the input tensors. (see + *input_ids* above) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] + + input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None + attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None + position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) + + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + reshaped_logits = logits.view(-1, num_choices) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + + if not return_dict: + output = (reshaped_logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return MultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + SqueezeBERT Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. + for Named-Entity-Recognition (NER) tasks. + """, + SQUEEZEBERT_START_DOCSTRING, +) +class SqueezeBertForTokenClassification(SqueezeBertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.transformer = SqueezeBertModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(SQUEEZEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + SqueezeBERT Model with a span classification head on top for extractive question-answering tasks like SQuAD (a + linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + SQUEEZEBERT_START_DOCSTRING, +) +class SqueezeBertForQuestionAnswering(SqueezeBertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.transformer = SqueezeBertModel(config) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(SQUEEZEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=QuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + start_positions: Optional[torch.Tensor] = None, + end_positions: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "SqueezeBertForMaskedLM", + "SqueezeBertForMultipleChoice", + "SqueezeBertForQuestionAnswering", + "SqueezeBertForSequenceClassification", + "SqueezeBertForTokenClassification", + "SqueezeBertModel", + "SqueezeBertModule", + "SqueezeBertPreTrainedModel", +] diff --git a/docs/transformers/src/transformers/models/squeezebert/tokenization_squeezebert.py b/docs/transformers/src/transformers/models/squeezebert/tokenization_squeezebert.py new file mode 100644 index 0000000000000000000000000000000000000000..71194f3d8df7cb2e2011bb82eadd657c7fd049d1 --- /dev/null +++ b/docs/transformers/src/transformers/models/squeezebert/tokenization_squeezebert.py @@ -0,0 +1,511 @@ +# coding=utf-8 +# Copyright 2020 The SqueezeBert authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for SqueezeBERT.""" + +import collections +import os +import unicodedata +from typing import List, Optional, Tuple + +from ...tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace +from ...utils import logging + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} + + +# Copied from transformers.models.bert.tokenization_bert.load_vocab +def load_vocab(vocab_file): + """Loads a vocabulary file into a dictionary.""" + vocab = collections.OrderedDict() + with open(vocab_file, "r", encoding="utf-8") as reader: + tokens = reader.readlines() + for index, token in enumerate(tokens): + token = token.rstrip("\n") + vocab[token] = index + return vocab + + +# Copied from transformers.models.bert.tokenization_bert.whitespace_tokenize +def whitespace_tokenize(text): + """Runs basic whitespace cleaning and splitting on a piece of text.""" + text = text.strip() + if not text: + return [] + tokens = text.split() + return tokens + + +# Copied from transformers.models.bert.tokenization_bert.BertTokenizer with Bert->SqueezeBert,BERT->SqueezeBERT +class SqueezeBertTokenizer(PreTrainedTokenizer): + r""" + Construct a SqueezeBERT tokenizer. Based on WordPiece. + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + File containing the vocabulary. + do_lower_case (`bool`, *optional*, defaults to `True`): + Whether or not to lowercase the input when tokenizing. + do_basic_tokenize (`bool`, *optional*, defaults to `True`): + Whether or not to do basic tokenization before WordPiece. + never_split (`Iterable`, *optional*): + Collection of tokens which will never be split during tokenization. Only has an effect when + `do_basic_tokenize=True` + unk_token (`str`, *optional*, defaults to `"[UNK]"`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + sep_token (`str`, *optional*, defaults to `"[SEP]"`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + pad_token (`str`, *optional*, defaults to `"[PAD]"`): + The token used for padding, for example when batching sequences of different lengths. + cls_token (`str`, *optional*, defaults to `"[CLS]"`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + mask_token (`str`, *optional*, defaults to `"[MASK]"`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + tokenize_chinese_chars (`bool`, *optional*, defaults to `True`): + Whether or not to tokenize Chinese characters. + + This should likely be deactivated for Japanese (see this + [issue](https://github.com/huggingface/transformers/issues/328)). + strip_accents (`bool`, *optional*): + Whether or not to strip all accents. If this option is not specified, then it will be determined by the + value for `lowercase` (as in the original SqueezeBERT). + clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`): + Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like + extra spaces. + """ + + vocab_files_names = VOCAB_FILES_NAMES + + def __init__( + self, + vocab_file, + do_lower_case=True, + do_basic_tokenize=True, + never_split=None, + unk_token="[UNK]", + sep_token="[SEP]", + pad_token="[PAD]", + cls_token="[CLS]", + mask_token="[MASK]", + tokenize_chinese_chars=True, + strip_accents=None, + clean_up_tokenization_spaces=True, + **kwargs, + ): + if not os.path.isfile(vocab_file): + raise ValueError( + f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained" + " model use `tokenizer = SqueezeBertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`" + ) + self.vocab = load_vocab(vocab_file) + self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()]) + self.do_basic_tokenize = do_basic_tokenize + if do_basic_tokenize: + self.basic_tokenizer = BasicTokenizer( + do_lower_case=do_lower_case, + never_split=never_split, + tokenize_chinese_chars=tokenize_chinese_chars, + strip_accents=strip_accents, + ) + + self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=str(unk_token)) + + super().__init__( + do_lower_case=do_lower_case, + do_basic_tokenize=do_basic_tokenize, + never_split=never_split, + unk_token=unk_token, + sep_token=sep_token, + pad_token=pad_token, + cls_token=cls_token, + mask_token=mask_token, + tokenize_chinese_chars=tokenize_chinese_chars, + strip_accents=strip_accents, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + **kwargs, + ) + + @property + def do_lower_case(self): + return self.basic_tokenizer.do_lower_case + + @property + def vocab_size(self): + return len(self.vocab) + + def get_vocab(self): + return dict(self.vocab, **self.added_tokens_encoder) + + def _tokenize(self, text, split_special_tokens=False): + split_tokens = [] + if self.do_basic_tokenize: + for token in self.basic_tokenizer.tokenize( + text, never_split=self.all_special_tokens if not split_special_tokens else None + ): + # If the token is part of the never_split set + if token in self.basic_tokenizer.never_split: + split_tokens.append(token) + else: + split_tokens += self.wordpiece_tokenizer.tokenize(token) + else: + split_tokens = self.wordpiece_tokenizer.tokenize(text) + return split_tokens + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.vocab.get(token, self.vocab.get(self.unk_token)) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.ids_to_tokens.get(index, self.unk_token) + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + out_string = " ".join(tokens).replace(" ##", "").strip() + return out_string + + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A SqueezeBERT sequence has the following format: + + - single sequence: `[CLS] X [SEP]` + - pair of sequences: `[CLS] A [SEP] B [SEP]` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + if token_ids_1 is None: + return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + cls = [self.cls_token_id] + sep = [self.sep_token_id] + return cls + token_ids_0 + sep + token_ids_1 + sep + + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + if token_ids_1 is not None: + return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1] + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. A SqueezeBERT sequence + pair mask has the following format: + + ``` + 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 + | first sequence | second sequence | + ``` + + If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s). + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s). + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + index = 0 + if os.path.isdir(save_directory): + vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + else: + vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory + with open(vocab_file, "w", encoding="utf-8") as writer: + for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning( + f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive." + " Please check that the vocabulary is not corrupted!" + ) + index = token_index + writer.write(token + "\n") + index += 1 + return (vocab_file,) + + +# Copied from transformers.models.bert.tokenization_bert.BasicTokenizer +class BasicTokenizer: + """ + Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.). + + Args: + do_lower_case (`bool`, *optional*, defaults to `True`): + Whether or not to lowercase the input when tokenizing. + never_split (`Iterable`, *optional*): + Collection of tokens which will never be split during tokenization. Only has an effect when + `do_basic_tokenize=True` + tokenize_chinese_chars (`bool`, *optional*, defaults to `True`): + Whether or not to tokenize Chinese characters. + + This should likely be deactivated for Japanese (see this + [issue](https://github.com/huggingface/transformers/issues/328)). + strip_accents (`bool`, *optional*): + Whether or not to strip all accents. If this option is not specified, then it will be determined by the + value for `lowercase` (as in the original BERT). + do_split_on_punc (`bool`, *optional*, defaults to `True`): + In some instances we want to skip the basic punctuation splitting so that later tokenization can capture + the full context of the words, such as contractions. + """ + + def __init__( + self, + do_lower_case=True, + never_split=None, + tokenize_chinese_chars=True, + strip_accents=None, + do_split_on_punc=True, + ): + if never_split is None: + never_split = [] + self.do_lower_case = do_lower_case + self.never_split = set(never_split) + self.tokenize_chinese_chars = tokenize_chinese_chars + self.strip_accents = strip_accents + self.do_split_on_punc = do_split_on_punc + + def tokenize(self, text, never_split=None): + """ + Basic Tokenization of a piece of text. For sub-word tokenization, see WordPieceTokenizer. + + Args: + never_split (`List[str]`, *optional*) + Kept for backward compatibility purposes. Now implemented directly at the base class level (see + [`PreTrainedTokenizer.tokenize`]) List of token not to split. + """ + # union() returns a new set by concatenating the two sets. + never_split = self.never_split.union(set(never_split)) if never_split else self.never_split + text = self._clean_text(text) + + # This was added on November 1st, 2018 for the multilingual and Chinese + # models. This is also applied to the English models now, but it doesn't + # matter since the English models were not trained on any Chinese data + # and generally don't have any Chinese data in them (there are Chinese + # characters in the vocabulary because Wikipedia does have some Chinese + # words in the English Wikipedia.). + if self.tokenize_chinese_chars: + text = self._tokenize_chinese_chars(text) + # prevents treating the same character with different unicode codepoints as different characters + unicode_normalized_text = unicodedata.normalize("NFC", text) + orig_tokens = whitespace_tokenize(unicode_normalized_text) + split_tokens = [] + for token in orig_tokens: + if token not in never_split: + if self.do_lower_case: + token = token.lower() + if self.strip_accents is not False: + token = self._run_strip_accents(token) + elif self.strip_accents: + token = self._run_strip_accents(token) + split_tokens.extend(self._run_split_on_punc(token, never_split)) + + output_tokens = whitespace_tokenize(" ".join(split_tokens)) + return output_tokens + + def _run_strip_accents(self, text): + """Strips accents from a piece of text.""" + text = unicodedata.normalize("NFD", text) + output = [] + for char in text: + cat = unicodedata.category(char) + if cat == "Mn": + continue + output.append(char) + return "".join(output) + + def _run_split_on_punc(self, text, never_split=None): + """Splits punctuation on a piece of text.""" + if not self.do_split_on_punc or (never_split is not None and text in never_split): + return [text] + chars = list(text) + i = 0 + start_new_word = True + output = [] + while i < len(chars): + char = chars[i] + if _is_punctuation(char): + output.append([char]) + start_new_word = True + else: + if start_new_word: + output.append([]) + start_new_word = False + output[-1].append(char) + i += 1 + + return ["".join(x) for x in output] + + def _tokenize_chinese_chars(self, text): + """Adds whitespace around any CJK character.""" + output = [] + for char in text: + cp = ord(char) + if self._is_chinese_char(cp): + output.append(" ") + output.append(char) + output.append(" ") + else: + output.append(char) + return "".join(output) + + def _is_chinese_char(self, cp): + """Checks whether CP is the codepoint of a CJK character.""" + # This defines a "chinese character" as anything in the CJK Unicode block: + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + # + # Note that the CJK Unicode block is NOT all Japanese and Korean characters, + # despite its name. The modern Korean Hangul alphabet is a different block, + # as is Japanese Hiragana and Katakana. Those alphabets are used to write + # space-separated words, so they are not treated specially and handled + # like the all of the other languages. + if ( + (cp >= 0x4E00 and cp <= 0x9FFF) + or (cp >= 0x3400 and cp <= 0x4DBF) # + or (cp >= 0x20000 and cp <= 0x2A6DF) # + or (cp >= 0x2A700 and cp <= 0x2B73F) # + or (cp >= 0x2B740 and cp <= 0x2B81F) # + or (cp >= 0x2B820 and cp <= 0x2CEAF) # + or (cp >= 0xF900 and cp <= 0xFAFF) + or (cp >= 0x2F800 and cp <= 0x2FA1F) # + ): # + return True + + return False + + def _clean_text(self, text): + """Performs invalid character removal and whitespace cleanup on text.""" + output = [] + for char in text: + cp = ord(char) + if cp == 0 or cp == 0xFFFD or _is_control(char): + continue + if _is_whitespace(char): + output.append(" ") + else: + output.append(char) + return "".join(output) + + +class WordpieceTokenizer: + """Runs WordPiece tokenization.""" + + def __init__(self, vocab, unk_token, max_input_chars_per_word=100): + self.vocab = vocab + self.unk_token = unk_token + self.max_input_chars_per_word = max_input_chars_per_word + + def tokenize(self, text): + """ + Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform + tokenization using the given vocabulary. + + For example, `input = "unaffable"` wil return as output `["un", "##aff", "##able"]`. + + Args: + text: A single token or whitespace separated tokens. This should have + already been passed through *BasicTokenizer*. + + Returns: + A list of wordpiece tokens. + """ + + output_tokens = [] + for token in whitespace_tokenize(text): + chars = list(token) + if len(chars) > self.max_input_chars_per_word: + output_tokens.append(self.unk_token) + continue + + is_bad = False + start = 0 + sub_tokens = [] + while start < len(chars): + end = len(chars) + cur_substr = None + while start < end: + substr = "".join(chars[start:end]) + if start > 0: + substr = "##" + substr + if substr in self.vocab: + cur_substr = substr + break + end -= 1 + if cur_substr is None: + is_bad = True + break + sub_tokens.append(cur_substr) + start = end + + if is_bad: + output_tokens.append(self.unk_token) + else: + output_tokens.extend(sub_tokens) + return output_tokens + + +__all__ = ["SqueezeBertTokenizer"] diff --git a/docs/transformers/src/transformers/models/squeezebert/tokenization_squeezebert_fast.py b/docs/transformers/src/transformers/models/squeezebert/tokenization_squeezebert_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..a908dcbf146bdd74cdb3e295f8832594043bf219 --- /dev/null +++ b/docs/transformers/src/transformers/models/squeezebert/tokenization_squeezebert_fast.py @@ -0,0 +1,176 @@ +# coding=utf-8 +# Copyright 2020 The SqueezeBert authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for SqueezeBERT.""" + +import json +from typing import List, Optional, Tuple + +from tokenizers import normalizers + +from ...tokenization_utils_fast import PreTrainedTokenizerFast +from ...utils import logging +from .tokenization_squeezebert import SqueezeBertTokenizer + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer.json"} + + +# Copied from transformers.models.bert.tokenization_bert_fast.BertTokenizerFast with Bert->SqueezeBert,BERT->SqueezeBERT +class SqueezeBertTokenizerFast(PreTrainedTokenizerFast): + r""" + Construct a "fast" SqueezeBERT tokenizer (backed by HuggingFace's *tokenizers* library). Based on WordPiece. + + This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should + refer to this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + File containing the vocabulary. + do_lower_case (`bool`, *optional*, defaults to `True`): + Whether or not to lowercase the input when tokenizing. + unk_token (`str`, *optional*, defaults to `"[UNK]"`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + sep_token (`str`, *optional*, defaults to `"[SEP]"`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + pad_token (`str`, *optional*, defaults to `"[PAD]"`): + The token used for padding, for example when batching sequences of different lengths. + cls_token (`str`, *optional*, defaults to `"[CLS]"`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + mask_token (`str`, *optional*, defaults to `"[MASK]"`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + clean_text (`bool`, *optional*, defaults to `True`): + Whether or not to clean the text before tokenization by removing any control characters and replacing all + whitespaces by the classic one. + tokenize_chinese_chars (`bool`, *optional*, defaults to `True`): + Whether or not to tokenize Chinese characters. This should likely be deactivated for Japanese (see [this + issue](https://github.com/huggingface/transformers/issues/328)). + strip_accents (`bool`, *optional*): + Whether or not to strip all accents. If this option is not specified, then it will be determined by the + value for `lowercase` (as in the original SqueezeBERT). + wordpieces_prefix (`str`, *optional*, defaults to `"##"`): + The prefix for subwords. + """ + + vocab_files_names = VOCAB_FILES_NAMES + slow_tokenizer_class = SqueezeBertTokenizer + + def __init__( + self, + vocab_file=None, + tokenizer_file=None, + do_lower_case=True, + unk_token="[UNK]", + sep_token="[SEP]", + pad_token="[PAD]", + cls_token="[CLS]", + mask_token="[MASK]", + tokenize_chinese_chars=True, + strip_accents=None, + **kwargs, + ): + super().__init__( + vocab_file, + tokenizer_file=tokenizer_file, + do_lower_case=do_lower_case, + unk_token=unk_token, + sep_token=sep_token, + pad_token=pad_token, + cls_token=cls_token, + mask_token=mask_token, + tokenize_chinese_chars=tokenize_chinese_chars, + strip_accents=strip_accents, + **kwargs, + ) + + normalizer_state = json.loads(self.backend_tokenizer.normalizer.__getstate__()) + if ( + normalizer_state.get("lowercase", do_lower_case) != do_lower_case + or normalizer_state.get("strip_accents", strip_accents) != strip_accents + or normalizer_state.get("handle_chinese_chars", tokenize_chinese_chars) != tokenize_chinese_chars + ): + normalizer_class = getattr(normalizers, normalizer_state.pop("type")) + normalizer_state["lowercase"] = do_lower_case + normalizer_state["strip_accents"] = strip_accents + normalizer_state["handle_chinese_chars"] = tokenize_chinese_chars + self.backend_tokenizer.normalizer = normalizer_class(**normalizer_state) + + self.do_lower_case = do_lower_case + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A SqueezeBERT sequence has the following format: + + - single sequence: `[CLS] X [SEP]` + - pair of sequences: `[CLS] A [SEP] B [SEP]` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + + if token_ids_1 is not None: + output += token_ids_1 + [self.sep_token_id] + + return output + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. A SqueezeBERT sequence + pair mask has the following format: + + ``` + 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 + | first sequence | second sequence | + ``` + + If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s). + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s). + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + files = self._tokenizer.model.save(save_directory, name=filename_prefix) + return tuple(files) + + +__all__ = ["SqueezeBertTokenizerFast"] diff --git a/docs/transformers/src/transformers/models/stablelm/__init__.py b/docs/transformers/src/transformers/models/stablelm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..004a82f49cd66857a6d10e614da2451046a1dcad --- /dev/null +++ b/docs/transformers/src/transformers/models/stablelm/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_stablelm import * + from .modeling_stablelm import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/stablelm/configuration_stablelm.py b/docs/transformers/src/transformers/models/stablelm/configuration_stablelm.py new file mode 100644 index 0000000000000000000000000000000000000000..06501792c42a31401363b9b013958b7dde154284 --- /dev/null +++ b/docs/transformers/src/transformers/models/stablelm/configuration_stablelm.py @@ -0,0 +1,202 @@ +# coding=utf-8 +# Copyright 2024 Stability AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""StableLM model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...modeling_rope_utils import rope_config_validation +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class StableLmConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`~StableLmModel`]. + It is used to instantiate an StableLM model according to the specified arguments, defining the model + architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of + the StableLM [stabilityai/stablelm-3b-4e1t](https://huggingface.co/stabilityai/stablelm-3b-4e1t) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used + to control the model outputs. Read the documentation from [`PretrainedConfig`] + for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50304): + Vocabulary size of the StableLM model. Defines the number of different tokens that + can be represented by the `inputs_ids` passed when calling [`StableLmModel`]. + intermediate_size (`int`, *optional*, defaults to 6912): + Dimension of the MLP representations. + hidden_size (`int`, *optional*, defaults to 2560): + Number of hidden layers in the Transformer decoder. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 32): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string). + max_position_embeddings (`int`, *optional*, defaults to 4096): + The maximum sequence length that this model might ever be used with. + Typically set this to something large just in case (e.g., 512 or 1024 or 2048). + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing + all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions + (not used by all models). Only relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to `10000.0`): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + use_qkv_bias (`bool`, *optional*, defaults to `False`): + Whether or not the model should use bias for qkv layers. + qk_layernorm (`bool`, *optional*, defaults to `False`): + Whether or not to normalize, per head, the Queries and Keys after projecting the hidden states. + use_parallel_residual (`bool`, *optional*, defaults to `False`): + Whether to use a "parallel" formulation in each Transformer layer, which can provide a slight training + speedup at large scales. + hidden_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio after applying the MLP to the hidden states. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + partial_rotary_factor (`float`, *optional*, defaults to 0.25): + Percentage of the query and keys which will have rotary embedding. + bos_token_id (int, *optional*, defaults to 0): + The id of the `BOS` token in the vocabulary. + eos_token_id (int, *optional*, defaults to 0): + The id of the `EOS` token in the vocabulary. + + Example: + + ```python + >>> from transformers import StableLmModel, StableLmConfig + + >>> # Initializing a StableLM stablelm-3b style configuration + >>> configuration = StableLmConfig() + ```""" + + model_type = "stablelm" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=50304, + intermediate_size=6912, + hidden_size=2560, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=32, + hidden_act="silu", + max_position_embeddings=4096, + initializer_range=0.02, + layer_norm_eps=1.0e-5, + use_cache=True, + tie_word_embeddings=False, + rope_theta=10_000, + rope_scaling=None, + use_qkv_bias=False, + qk_layernorm=False, + use_parallel_residual=False, + hidden_dropout=0.0, + attention_dropout=0.0, + partial_rotary_factor=0.25, + bos_token_id=0, + eos_token_id=0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.use_qkv_bias = use_qkv_bias + self.qk_layernorm = qk_layernorm + self.use_parallel_residual = use_parallel_residual + self.hidden_dropout = hidden_dropout + self.attention_dropout = attention_dropout + self.partial_rotary_factor = partial_rotary_factor + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, move it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + + super().__init__( + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +__all__ = ["StableLmConfig"] diff --git a/docs/transformers/src/transformers/models/stablelm/modeling_stablelm.py b/docs/transformers/src/transformers/models/stablelm/modeling_stablelm.py new file mode 100644 index 0000000000000000000000000000000000000000..85e20b44932b898867f3eaffdbd4124e9093aa0f --- /dev/null +++ b/docs/transformers/src/transformers/models/stablelm/modeling_stablelm.py @@ -0,0 +1,1355 @@ +# coding=utf-8 +# Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch StableLM model.""" + +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import PreTrainedModel +from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + can_return_tuple, + is_torch_flex_attn_available, + logging, + replace_return_docstrings, +) +from .configuration_stablelm import StableLmConfig + + +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + +if is_flash_attn_available(): + from ...modeling_flash_attention_utils import _flash_attention_forward + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "stabilityai/stablelm-3b-4e1t" +_CONFIG_FOR_DOC = "StableLmConfig" + + +# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->StableLm +class StableLmRotaryEmbedding(nn.Module): + def __init__(self, config: StableLmConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +# Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->StableLm +class StableLmMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class StableLmLayerNormPerHead(nn.Module): + def __init__(self, dim, num_heads, eps=1e-5, bias=False): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.norms = nn.ModuleList([nn.LayerNorm(dim, eps=eps, bias=bias) for _ in range(self.num_heads)]) + + def forward(self, hidden_states: torch.Tensor): + # Split along the num_heads axis to get per-head inputs + # [batch_size, num_heads, seq_len, head_dim] -> [batch_size, 1, seq_len, head_dim] * num_heads + states_per_heads = torch.split(hidden_states, 1, dim=1) + # Normalize and merge the heads back together + return torch.cat([norm(hidden_states) for norm, hidden_states in zip(self.norms, states_per_heads)], dim=1) + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class StableLmAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: StableLmConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.rope_theta = config.rope_theta + self.rotary_ndims = int(self.head_dim * config.partial_rotary_factor) + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.use_qkv_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.use_qkv_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.use_qkv_bias) + self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + + self.qk_layernorm = config.qk_layernorm + if self.qk_layernorm: + self.q_layernorm = StableLmLayerNormPerHead(self.head_dim, self.num_heads, eps=config.layer_norm_eps) + self.k_layernorm = StableLmLayerNormPerHead( + self.head_dim, self.num_key_value_heads, eps=config.layer_norm_eps + ) + + self.attention_dropout = nn.Dropout(config.attention_dropout) + self.rotary_emb = StableLmRotaryEmbedding(config=self.config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if self.qk_layernorm: + query_states = self.q_layernorm(query_states) + key_states = self.k_layernorm(key_states) + + cos, sin = position_embeddings + + # Partial rotary embedding + query_rot, query_pass = ( + query_states[..., : self.rotary_ndims], + query_states[..., self.rotary_ndims :], + ) + key_rot, key_pass = ( + key_states[..., : self.rotary_ndims], + key_states[..., self.rotary_ndims :], + ) + # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor] + query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin) + + # [batch_size, seq_length, num_heads, head_dim] + query_states = torch.cat((query_rot, query_pass), dim=-1) + key_states = torch.cat((key_rot, key_pass), dim=-1) + + if past_key_value is not None: + # Specific to RoPE models with partial rotation + cache_kwargs = { + "sin": sin, + "cos": cos, + "partial_rotation_size": self.rotary_ndims, + "cache_position": cache_position, + } + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # Repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights += causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query_states.dtype) + attn_weights = self.attention_dropout(attn_weights) + + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class StableLmSdpaAttention(StableLmAttention): + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "StableLmModel is using StableLmSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if self.qk_layernorm: + query_states = self.q_layernorm(query_states) + key_states = self.k_layernorm(key_states) + + cos, sin = position_embeddings + + # Partial rotary embedding + query_rot, query_pass = ( + query_states[..., : self.rotary_ndims], + query_states[..., self.rotary_ndims :], + ) + key_rot, key_pass = ( + key_states[..., : self.rotary_ndims], + key_states[..., self.rotary_ndims :], + ) + # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor] + query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin) + + # [batch_size, seq_length, num_heads, head_dim] + query_states = torch.cat((query_rot, query_pass), dim=-1) + key_states = torch.cat((key_rot, key_pass), dim=-1) + + if past_key_value is not None: + # Specific to RoPE models with partial rotation + cache_kwargs = { + "sin": sin, + "cos": cos, + "partial_rotation_size": self.rotary_ndims, + "cache_position": cache_position, + } + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # Repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout.p if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +class StableLmFlashAttention2(StableLmAttention): + """ + StableLM flash attention module. This module inherits from `StableLmAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # StableLmFlashAttention2 attention does not support output_attentions + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if self.qk_layernorm: + query_states = self.q_layernorm(query_states) + key_states = self.k_layernorm(key_states) + + cos, sin = position_embeddings + + # Partial rotary embedding + query_rot, query_pass = ( + query_states[..., : self.rotary_ndims], + query_states[..., self.rotary_ndims :], + ) + key_rot, key_pass = ( + key_states[..., : self.rotary_ndims], + key_states[..., self.rotary_ndims :], + ) + query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin) + + # [batch_size, seq_length, num_heads, head_dim] + query_states = torch.cat((query_rot, query_pass), dim=-1) + key_states = torch.cat((key_rot, key_pass), dim=-1) + + if past_key_value is not None: + cache_kwargs = { + "sin": sin, + "cos": cos, + "partial_rotation_size": self.rotary_ndims, + "cache_position": cache_position, + } + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout.p if self.training else 0.0 + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + dropout=dropout_rate, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +ATTENTION_CLASSES = { + "eager": StableLmAttention, + "sdpa": StableLmSdpaAttention, + "flash_attention_2": StableLmFlashAttention2, +} + + +class StableLmDecoderLayer(nn.Module): + def __init__(self, config: StableLmConfig, layer_idx: int): + super().__init__() + self.use_parallel_residual = config.use_parallel_residual + self.hidden_size = config.hidden_size + self.self_attn = ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx) + self.mlp = StableLmMLP(config) + self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.post_attention_layernorm = None + if not self.use_parallel_residual: + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + position_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range + `[0, config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): + cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + self_attn_output, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + # copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXLayer.forward + if self.use_parallel_residual: + # x = x + attn(ln1(x)) + mlp(ln1(x)) + # Fully Connected + mlp_output = self.mlp(hidden_states) + mlp_output = self.dropout(mlp_output) + hidden_states = residual + self_attn_output + mlp_output + else: + # x = x + attn(ln1(x)) + # x = x + mlp(ln2(x)) + residual = residual + self_attn_output + # Fully Connected + mlp_output = self.mlp(self.post_attention_layernorm(residual)) + mlp_output = self.dropout(mlp_output) + hidden_states = residual + mlp_output + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +STABLELM_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`StableLmConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare StableLm Model outputting raw hidden-states without any specific head on top.", + STABLELM_START_DOCSTRING, +) +class StableLmPreTrainedModel(PreTrainedModel): + config_class = StableLmConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["StableLmDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_cache_class = True + _supports_sdpa = True + _supports_quantized_cache = True + _supports_static_cache = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + + +STABLELM_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance, see our + [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare StableLm Model outputting raw hidden-states without any specific head on top.", + STABLELM_START_DOCSTRING, +) +class StableLmModel(StableLmPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`StableLmDecoderLayer`] + + Args: + config: StableLmConfig + """ + + def __init__(self, config: StableLmConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [StableLmDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.rotary_emb = StableLmRotaryEmbedding(config=config) + + self._attn_implementation = config._attn_implementation + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @can_return_tuple + @add_start_docstrings_to_model_forward(STABLELM_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> BaseModelOutputWithPast: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + if past_key_values is None: + past_key_values = DynamicCache() + else: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " + "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " + "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" + ) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, "BlockMask"], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to place the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +# Copied from transformers.models.persimmon.modeling_persimmon.PersimmonForCausalLM with PERSIMMON->STABLELM,Persimmon->StableLm +class StableLmForCausalLM(StableLmPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with LLAMA->STABLELM,Llama->StableLm + def __init__(self, config): + super().__init__(config) + self.model = StableLmModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_input_embeddings + def get_input_embeddings(self): + return self.model.embed_tokens + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_input_embeddings + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_output_embeddings + def get_output_embeddings(self): + return self.lm_head + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_output_embeddings + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_decoder + def set_decoder(self, decoder): + self.model = decoder + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_decoder + def get_decoder(self): + return self.model + + @can_return_tuple + @add_start_docstrings_to_model_forward(STABLELM_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + # Ignore copy + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs, + ) -> CausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, StableLmForCausalLM + + >>> model = StableLmForCausalLM.from_pretrained("stabilityai/stablelm-3b-4e1t") + >>> tokenizer = AutoTokenizer.from_pretrained("stabilityai/stablelm-3b-4e1t") + + >>> prompt = "The weather is always wonderful in" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + 'The weather is always wonderful in the summer in the city of San Diego. The city is located on the coast of the Pacific Ocean and is surrounded by' + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + ) + + hidden_states = outputs.last_hidden_state + # No upscaling to float was ever done for StableLm + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function( + logits, + labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + The StableLm transformer with a sequence classification head on top (linear layer). + + [`StableLmForSequenceClassification`] uses the last token in order to do the classification, as other causal + models (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + STABLELM_START_DOCSTRING, +) +# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with LLAMA->STABLELM,Llama->StableLm +class StableLmForSequenceClassification(StableLmPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = StableLmModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @can_return_tuple + @add_start_docstrings_to_model_forward(STABLELM_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> SequenceClassifierOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + transformer_outputs: BaseModelOutputWithPast = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + hidden_states = transformer_outputs.last_hidden_state + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) + last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) + else: + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ + The StableLm Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + STABLELM_START_DOCSTRING, +) +# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->StableLm, LLAMA->STABLELM +class StableLmForTokenClassification(StableLmPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = StableLmModel(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @can_return_tuple + @add_start_docstrings_to_model_forward(STABLELM_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> TokenClassifierOutput: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + outputs: BaseModelOutputWithPast = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + sequence_output = outputs.last_hidden_state + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.config) + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "StableLmForCausalLM", + "StableLmModel", + "StableLmPreTrainedModel", + "StableLmForSequenceClassification", + "StableLmForTokenClassification", +] diff --git a/docs/transformers/src/transformers/models/starcoder2/__init__.py b/docs/transformers/src/transformers/models/starcoder2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6349255ed3a4756aa7d449ff6af35847886dd136 --- /dev/null +++ b/docs/transformers/src/transformers/models/starcoder2/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_starcoder2 import * + from .modeling_starcoder2 import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/starcoder2/configuration_starcoder2.py b/docs/transformers/src/transformers/models/starcoder2/configuration_starcoder2.py new file mode 100644 index 0000000000000000000000000000000000000000..b617a1cad8429c65b8de7bbdfb317bff418c9891 --- /dev/null +++ b/docs/transformers/src/transformers/models/starcoder2/configuration_starcoder2.py @@ -0,0 +1,207 @@ +# coding=utf-8 +# Copyright 2024 BigCode and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Starcoder2 model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...modeling_rope_utils import rope_config_validation +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class Starcoder2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Starcoder2Model`]. It is used to instantiate a + Starcoder2 model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the [bigcode/starcoder2-7b](https://huggingface.co/bigcode/starcoder2-7b) model. + + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 49152): + Vocabulary size of the Starcoder2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Starcoder2Model`] + hidden_size (`int`, *optional*, defaults to 3072): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 12288): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 30): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 24): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 2): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 4096): + The maximum sequence length that this model might ever be used with. Starcoder2's sliding window attention + allows sequence of up to 4096*32 tokens. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + norm_epsilon (`float`, *optional*, defaults to 1e-05): + Epsilon value for the layer norm + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + bos_token_id (`int`, *optional*, defaults to 50256): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 50256): + The id of the "end-of-sequence" token. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + sliding_window (`int`, *optional*): + Sliding window attention window size. If not specified, will default to `None` (no sliding window). + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + residual_dropout (`float`, *optional*, defaults to 0.0): + Residual connection dropout value. + embedding_dropout (`float`, *optional*, defaults to 0.0): + Embedding dropout. + use_bias (`bool`, *optional*, defaults to `True`): + Whether to use bias term on linear layers of the model. + + + ```python + >>> from transformers import Starcoder2Model, Starcoder2Config + + >>> # Initializing a Starcoder2 7B style configuration + >>> configuration = Starcoder2Config() + + >>> # Initializing a model from the Starcoder2 7B style configuration + >>> model = Starcoder2Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "starcoder2" + keys_to_ignore_at_inference = ["past_key_values"] + # Default tensor parallel plan for base model `Starcoder2` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.c_fc": "colwise", + "layers.*.mlp.c_proj": "colwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=49152, + hidden_size=3072, + intermediate_size=12288, + num_hidden_layers=30, + num_attention_heads=24, + num_key_value_heads=2, + hidden_act="gelu_pytorch_tanh", + max_position_embeddings=4096, + initializer_range=0.018042, + norm_epsilon=1e-5, + use_cache=True, + bos_token_id=50256, + eos_token_id=50256, + rope_theta=10000.0, + rope_scaling=None, + sliding_window=None, + attention_dropout=0.0, + residual_dropout=0.0, + embedding_dropout=0.0, + use_bias=True, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.sliding_window = sliding_window + self.use_bias = use_bias + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.norm_epsilon = norm_epsilon + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_dropout = attention_dropout + self.residual_dropout = residual_dropout + self.embedding_dropout = embedding_dropout + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, move it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + + super().__init__( + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + **kwargs, + ) + + +__all__ = ["Starcoder2Config"] diff --git a/docs/transformers/src/transformers/models/starcoder2/modeling_starcoder2.py b/docs/transformers/src/transformers/models/starcoder2/modeling_starcoder2.py new file mode 100644 index 0000000000000000000000000000000000000000..84e420607b99ae377fc2d4e92aee257ab608f614 --- /dev/null +++ b/docs/transformers/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -0,0 +1,1036 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/starcoder2/modular_starcoder2.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_starcoder2.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2024 BigCode and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, List, Optional, Tuple, Union + +import torch +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ( + LossKwargs, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + can_return_tuple, + is_torch_flex_attn_available, + logging, + replace_return_docstrings, +) +from .configuration_starcoder2 import Starcoder2Config + + +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + +logger = logging.get_logger(__name__) +_CHECKPOINT_FOR_DOC = "bigcode/starcoder2-7b" + +_CONFIG_FOR_DOC = "Starcoder2Config" + + +class Starcoder2MLP(nn.Module): + def __init__(self, config: Starcoder2Config): + super().__init__() + embed_dim = config.hidden_size + self.c_fc = nn.Linear(embed_dim, config.intermediate_size, bias=config.use_bias) + self.c_proj = nn.Linear(config.intermediate_size, embed_dim, bias=config.use_bias) + self.act = ACT2FN[config.hidden_act] + self.residual_dropout = config.residual_dropout + + def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: + hidden_states = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.c_proj(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.residual_dropout, training=self.training) + return hidden_states + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class Starcoder2Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Starcoder2Config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.use_bias) + self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) + self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.use_bias) + self.residual_dropout = config.residual_dropout + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=getattr(self.config, "sliding_window", None), # diff with Llama + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + attn_output = nn.functional.dropout( + attn_output, p=self.residual_dropout, training=self.training + ) # diff with Llama + + return attn_output, attn_weights + + +class Starcoder2DecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Starcoder2Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = Starcoder2Attention(config=config, layer_idx=layer_idx) + self.mlp = Starcoder2MLP(config) + self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +class Starcoder2RotaryEmbedding(nn.Module): + def __init__(self, config: Starcoder2Config, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +STARCODER2_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Starcoder2Config`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Starcoder2 Model outputting raw hidden-states without any specific head on top.", + STARCODER2_START_DOCSTRING, +) +class Starcoder2PreTrainedModel(PreTrainedModel): + config_class = Starcoder2Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Starcoder2DecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + _supports_attention_backend = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + + +STARCODER2_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length) or `BlockMask`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + If the model is configured to use flex_attention, it will attempt to convert the mask Tensor into a BlockMask, + but you can also pass a `BlockMask` object directly here. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare Starcoder2 Model outputting raw hidden-states without any specific head on top.", + STARCODER2_START_DOCSTRING, +) +class Starcoder2Model(Starcoder2PreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Starcoder2DecoderLayer`] + + Args: + config: Starcoder2Config + """ + + def __init__(self, config: Starcoder2Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Starcoder2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) + self.rotary_emb = Starcoder2RotaryEmbedding(config=config) + self.gradient_checkpointing = False + self.embedding_dropout = config.embedding_dropout + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @can_return_tuple + @add_start_docstrings_to_model_forward(STARCODER2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> BaseModelOutputWithPast: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + hidden_states = nn.functional.dropout( + hidden_states, p=self.embedding_dropout, training=self.training + ) # main diff with Llama + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, "BlockMask"], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and past_key_values is not None: + is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Starcoder2. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if ( + self.config._attn_implementation == "sdpa" + and not (using_static_cache or using_sliding_window_cache) + and not output_attentions + ): + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + # SlidingWindowCache or StaticCache + if using_sliding_window_cache or using_static_cache: + target_length = past_key_values.get_max_cache_shape() + # DynamicCache or no cache + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + config=self.config, + past_key_values=past_key_values, + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + config: Starcoder2Config, + past_key_values: Cache, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to place the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + config (`Starcoder2Config`): + The model's configuration class + past_key_values (`Cache`): + The cache class that is being used currently to generate + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + if config.get_text_config().sliding_window is not None: + # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also + # the check is needed to verify is current checkpoint was trained with sliding window or not + if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: + sliding_attend_mask = torch.arange(target_length, device=device) <= ( + cache_position.reshape(-1, 1) - config.get_text_config().sliding_window + ) + diagonal_attend_mask.bitwise_or_(sliding_attend_mask) + causal_mask *= diagonal_attend_mask + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.shape[-1] > target_length: + attention_mask = attention_mask[:, :target_length] + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + return causal_mask + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + +class Starcoder2ForCausalLM(Starcoder2PreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = Starcoder2Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @can_return_tuple + @add_start_docstrings_to_model_forward(STARCODER2_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[KwargsForCausalLM], + ) -> CausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, Starcoder2ForCausalLM + + >>> model = Starcoder2ForCausalLM.from_pretrained("meta-starcoder2/Starcoder2-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-starcoder2/Starcoder2-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + The Starcoder2 Model transformer with a sequence classification head on top (linear layer). + + [`Starcoder2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + STARCODER2_START_DOCSTRING, +) +class Starcoder2ForSequenceClassification(Starcoder2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = Starcoder2Model(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @can_return_tuple + @add_start_docstrings_to_model_forward(STARCODER2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> SequenceClassifierOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + transformer_outputs: BaseModelOutputWithPast = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + hidden_states = transformer_outputs.last_hidden_state + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) + last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) + else: + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ + The Starcoder2 Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + STARCODER2_START_DOCSTRING, +) +class Starcoder2ForTokenClassification(Starcoder2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = Starcoder2Model(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @can_return_tuple + @add_start_docstrings_to_model_forward(STARCODER2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> TokenClassifierOutput: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + outputs: BaseModelOutputWithPast = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + sequence_output = outputs.last_hidden_state + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.config) + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "Starcoder2ForCausalLM", + "Starcoder2Model", + "Starcoder2PreTrainedModel", + "Starcoder2ForSequenceClassification", + "Starcoder2ForTokenClassification", +] diff --git a/docs/transformers/src/transformers/models/starcoder2/modular_starcoder2.py b/docs/transformers/src/transformers/models/starcoder2/modular_starcoder2.py new file mode 100644 index 0000000000000000000000000000000000000000..77b05d1dc1aa48e386104cb80ca6d4b4845d6cdf --- /dev/null +++ b/docs/transformers/src/transformers/models/starcoder2/modular_starcoder2.py @@ -0,0 +1,294 @@ +# coding=utf-8 +# Copyright 2024 BigCode and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Starcoder2 model.""" + +from typing import Callable, List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import ( + BaseModelOutputWithPast, +) +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...processing_utils import Unpack +from ...utils import add_start_docstrings_to_model_forward, can_return_tuple, logging +from ..mistral.modeling_mistral import ( + MistralAttention, + MistralDecoderLayer, + MistralForCausalLM, + MistralForSequenceClassification, + MistralForTokenClassification, + MistralModel, + MistralPreTrainedModel, + MistralRotaryEmbedding, + apply_rotary_pos_emb, + eager_attention_forward, +) +from .configuration_starcoder2 import Starcoder2Config + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "Starcoder2Config" +_CHECKPOINT_FOR_DOC = "bigcode/starcoder2-7b" + + +class Starcoder2MLP(nn.Module): + def __init__(self, config: Starcoder2Config): + super().__init__() + embed_dim = config.hidden_size + self.c_fc = nn.Linear(embed_dim, config.intermediate_size, bias=config.use_bias) + self.c_proj = nn.Linear(config.intermediate_size, embed_dim, bias=config.use_bias) + self.act = ACT2FN[config.hidden_act] + self.residual_dropout = config.residual_dropout + + def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: + hidden_states = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.c_proj(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.residual_dropout, training=self.training) + return hidden_states + + +class Starcoder2Attention(MistralAttention): + def __init__(self, config: Starcoder2Config, layer_idx: Optional[int] = None): + super().__init__() + self.residual_dropout = config.residual_dropout + self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.use_bias) + self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) + self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.use_bias) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=getattr(self.config, "sliding_window", None), # diff with Llama + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + attn_output = nn.functional.dropout( + attn_output, p=self.residual_dropout, training=self.training + ) # diff with Llama + + return attn_output, attn_weights + + +class Starcoder2DecoderLayer(MistralDecoderLayer): + def __init__(self, config: Starcoder2Config, layer_idx: int): + super().__init__(self) + self.self_attn = Starcoder2Attention(config=config, layer_idx=layer_idx) + self.mlp = Starcoder2MLP(config) + self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) + + +class Starcoder2RotaryEmbedding(MistralRotaryEmbedding): + pass + + +class Starcoder2PreTrainedModel(MistralPreTrainedModel): + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + + +STARCODER2_INPUTS_DOCSTRING = None # will be automatically redefined + + +class Starcoder2Model(MistralModel): + def __init__(self, config: Starcoder2Config): + super().__init__(config) + self.layers = nn.ModuleList( + [Starcoder2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) + self.embedding_dropout = config.embedding_dropout + + @can_return_tuple + @add_start_docstrings_to_model_forward(STARCODER2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> BaseModelOutputWithPast: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + hidden_states = nn.functional.dropout( + hidden_states, p=self.embedding_dropout, training=self.training + ) # main diff with Llama + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class Starcoder2ForCausalLM(MistralForCausalLM): + pass + + +class Starcoder2ForSequenceClassification(MistralForSequenceClassification): + pass + + +class Starcoder2ForTokenClassification(MistralForTokenClassification): + pass + + +__all__ = [ + "Starcoder2ForCausalLM", + "Starcoder2Model", + "Starcoder2PreTrainedModel", # noqa: F822 + "Starcoder2ForSequenceClassification", + "Starcoder2ForTokenClassification", +] diff --git a/docs/transformers/src/transformers/models/superglue/__init__.py b/docs/transformers/src/transformers/models/superglue/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..666f44026eff6e9c5d908370f7d13f1adfd1d1cf --- /dev/null +++ b/docs/transformers/src/transformers/models/superglue/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_superglue import * + from .image_processing_superglue import * + from .modeling_superglue import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/superglue/configuration_superglue.py b/docs/transformers/src/transformers/models/superglue/configuration_superglue.py new file mode 100644 index 0000000000000000000000000000000000000000..fe301442d632bc935017d85cbed95c574b15b62c --- /dev/null +++ b/docs/transformers/src/transformers/models/superglue/configuration_superglue.py @@ -0,0 +1,120 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING, List + +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ..auto import CONFIG_MAPPING + + +if TYPE_CHECKING: + from ..superpoint import SuperPointConfig + +logger = logging.get_logger(__name__) + + +class SuperGlueConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`SuperGlueModel`]. It is used to instantiate a + SuperGlue model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the SuperGlue + [magic-leap-community/superglue_indoor](https://huggingface.co/magic-leap-community/superglue_indoor) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + keypoint_detector_config (`Union[AutoConfig, dict]`, *optional*, defaults to `SuperPointConfig`): + The config object or dictionary of the keypoint detector. + hidden_size (`int`, *optional*, defaults to 256): + The dimension of the descriptors. + keypoint_encoder_sizes (`List[int]`, *optional*, defaults to `[32, 64, 128, 256]`): + The sizes of the keypoint encoder layers. + gnn_layers_types (`List[str]`, *optional*, defaults to `['self', 'cross', 'self', 'cross', 'self', 'cross', 'self', 'cross', 'self', 'cross', 'self', 'cross', 'self', 'cross', 'self', 'cross', 'self', 'cross']`): + The types of the GNN layers. Must be either 'self' or 'cross'. + num_attention_heads (`int`, *optional*, defaults to 4): + The number of heads in the GNN layers. + sinkhorn_iterations (`int`, *optional*, defaults to 100): + The number of Sinkhorn iterations. + matching_threshold (`float`, *optional*, defaults to 0.0): + The matching threshold. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + + Examples: + ```python + >>> from transformers import SuperGlueConfig, SuperGlueModel + + >>> # Initializing a SuperGlue superglue style configuration + >>> configuration = SuperGlueConfig() + + >>> # Initializing a model from the superglue style configuration + >>> model = SuperGlueModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "superglue" + + def __init__( + self, + keypoint_detector_config: "SuperPointConfig" = None, + hidden_size: int = 256, + keypoint_encoder_sizes: List[int] = None, + gnn_layers_types: List[str] = None, + num_attention_heads: int = 4, + sinkhorn_iterations: int = 100, + matching_threshold: float = 0.0, + initializer_range: float = 0.02, + **kwargs, + ): + self.gnn_layers_types = gnn_layers_types if gnn_layers_types is not None else ["self", "cross"] * 9 + # Check whether all gnn_layers_types are either 'self' or 'cross' + if not all(layer_type in ["self", "cross"] for layer_type in self.gnn_layers_types): + raise ValueError("All gnn_layers_types must be either 'self' or 'cross'") + + if hidden_size % num_attention_heads != 0: + raise ValueError("hidden_size % num_attention_heads is different from zero") + + self.keypoint_encoder_sizes = ( + keypoint_encoder_sizes if keypoint_encoder_sizes is not None else [32, 64, 128, 256] + ) + self.hidden_size = hidden_size + self.keypoint_encoder_sizes = keypoint_encoder_sizes + self.gnn_layers_types = gnn_layers_types + self.num_attention_heads = num_attention_heads + self.sinkhorn_iterations = sinkhorn_iterations + self.matching_threshold = matching_threshold + + if isinstance(keypoint_detector_config, dict): + keypoint_detector_config["model_type"] = ( + keypoint_detector_config["model_type"] if "model_type" in keypoint_detector_config else "superpoint" + ) + keypoint_detector_config = CONFIG_MAPPING[keypoint_detector_config["model_type"]]( + **keypoint_detector_config + ) + if keypoint_detector_config is None: + keypoint_detector_config = CONFIG_MAPPING["superpoint"]() + + self.keypoint_detector_config = keypoint_detector_config + self.initializer_range = initializer_range + self.attention_probs_dropout_prob = 0 + self.is_decoder = False + + super().__init__(**kwargs) + + +__all__ = ["SuperGlueConfig"] diff --git a/docs/transformers/src/transformers/models/superglue/convert_superglue_to_hf.py b/docs/transformers/src/transformers/models/superglue/convert_superglue_to_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..cfff39acdfd8ee457e72e35bed98c6e7bcebafea --- /dev/null +++ b/docs/transformers/src/transformers/models/superglue/convert_superglue_to_hf.py @@ -0,0 +1,342 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import gc +import os +import re +from typing import List + +import torch +from datasets import load_dataset + +from transformers import ( + AutoModelForKeypointDetection, + SuperGlueConfig, + SuperGlueForKeypointMatching, + SuperGlueImageProcessor, +) + + +def prepare_imgs(): + dataset = load_dataset("hf-internal-testing/image-matching-test-dataset", split="train") + image1 = dataset[0]["image"] + image2 = dataset[1]["image"] + image3 = dataset[2]["image"] + return [[image1, image2], [image3, image2]] + + +def verify_model_outputs(model, model_name, device): + images = prepare_imgs() + preprocessor = SuperGlueImageProcessor() + inputs = preprocessor(images=images, return_tensors="pt").to(device) + model.to(device) + with torch.no_grad(): + outputs = model(**inputs, output_hidden_states=True, output_attentions=True) + + predicted_matches_values = outputs.matches[0, 0, :10] + predicted_matching_scores_values = outputs.matching_scores[0, 0, :10] + + predicted_number_of_matches = torch.sum(outputs.matches[0][0] != -1).item() + + if "outdoor" in model_name: + expected_max_number_keypoints = 865 + expected_matches_shape = torch.Size((len(images), 2, expected_max_number_keypoints)) + expected_matching_scores_shape = torch.Size((len(images), 2, expected_max_number_keypoints)) + + expected_matches_values = torch.tensor( + [125, 630, 137, 138, 136, 143, 135, -1, -1, 153], dtype=torch.int64, device=device + ) + expected_matching_scores_values = torch.tensor( + [0.9899, 0.0033, 0.9897, 0.9889, 0.9879, 0.7464, 0.7109, 0, 0, 0.9841], device=device + ) + + expected_number_of_matches = 281 + elif "indoor" in model_name: + expected_max_number_keypoints = 865 + expected_matches_shape = torch.Size((len(images), 2, expected_max_number_keypoints)) + expected_matching_scores_shape = torch.Size((len(images), 2, expected_max_number_keypoints)) + + expected_matches_values = torch.tensor( + [125, 144, 137, 138, 136, 155, 135, -1, -1, 153], dtype=torch.int64, device=device + ) + expected_matching_scores_values = torch.tensor( + [0.9694, 0.0010, 0.9006, 0.8753, 0.8521, 0.5688, 0.6321, 0.0, 0.0, 0.7235], device=device + ) + + expected_number_of_matches = 282 + + assert outputs.matches.shape == expected_matches_shape + assert outputs.matching_scores.shape == expected_matching_scores_shape + + assert torch.allclose(predicted_matches_values, expected_matches_values, atol=1e-4) + assert torch.allclose(predicted_matching_scores_values, expected_matching_scores_values, atol=1e-4) + + assert predicted_number_of_matches == expected_number_of_matches + + +ORIGINAL_TO_CONVERTED_KEY_MAPPING = { + r"kenc.encoder.(\d+)": r"keypoint_encoder.encoder.\1.old", + r"gnn.layers.(\d+).attn.proj.0": r"gnn.layers.\1.attention.self.query", + r"gnn.layers.(\d+).attn.proj.1": r"gnn.layers.\1.attention.self.key", + r"gnn.layers.(\d+).attn.proj.2": r"gnn.layers.\1.attention.self.value", + r"gnn.layers.(\d+).attn.merge": r"gnn.layers.\1.attention.output.dense", + r"gnn.layers.(\d+).mlp.0": r"gnn.layers.\1.mlp.0.linear", + r"gnn.layers.(\d+).mlp.1": r"gnn.layers.\1.mlp.0.batch_norm", + r"gnn.layers.(\d+).mlp.3": r"gnn.layers.\1.mlp.1", + r"final_proj": r"final_projection.final_proj", +} + + +def convert_old_keys_to_new_keys(state_dict_keys: List[str], conversion_mapping=ORIGINAL_TO_CONVERTED_KEY_MAPPING): + """ + This function should be applied only once, on the concatenated keys to efficiently rename using + the key mappings. + """ + output_dict = {} + if state_dict_keys is not None: + old_text = "\n".join(state_dict_keys) + new_text = old_text + for pattern, replacement in conversion_mapping.items(): + if replacement is None: + new_text = re.sub(pattern, "", new_text) # an empty line + continue + new_text = re.sub(pattern, replacement, new_text) + output_dict = dict(zip(old_text.split("\n"), new_text.split("\n"))) + return output_dict + + +def replace_state_dict_keys(all_keys, new_keys, original_state_dict): + state_dict = {} + for key in all_keys: + new_key = new_keys[key] + state_dict[new_key] = original_state_dict.pop(key).contiguous().clone() + return state_dict + + +def convert_state_dict(state_dict, config): + converted_to_final_key_mapping = {} + + def convert_conv_to_linear(keys): + for key in keys: + state_dict[key] = state_dict[key].squeeze(-1) + + def qkv_permute_weights_and_biases(keys, num_heads=4): + for key in keys: + tensor = state_dict[key] + shape = tensor.shape + dim_out = shape[0] + if len(shape) == 2: + dim_in = shape[1] + tensor = ( + tensor.reshape(dim_out // num_heads, num_heads, dim_in).permute(1, 0, 2).reshape(dim_out, dim_in) + ) + if len(shape) == 1: + tensor = tensor.reshape(dim_out // num_heads, num_heads).permute(1, 0).reshape(dim_out) + state_dict[key] = tensor + + def output_permute_weights(keys, num_heads=4): + for key in keys: + tensor = state_dict[key] + dim_in = tensor.shape[1] + dim_out = tensor.shape[0] + tensor = tensor.reshape(dim_out, dim_in // num_heads, num_heads).permute(0, 2, 1).reshape(dim_out, dim_in) + state_dict[key] = tensor + + conv_keys = [] + qkv_permute_keys = [] + output_permute_keys = [] + # Keypoint Encoder + keypoint_encoder_key = "keypoint_encoder.encoder" + for i in range(1, len(config.keypoint_encoder_sizes) + 2): + old_conv_key = f"{keypoint_encoder_key}.{(i - 1) * 3}.old" + new_index = i - 1 + new_conv_key = f"{keypoint_encoder_key}.{new_index}." + if i < len(config.keypoint_encoder_sizes) + 1: + new_conv_key = f"{new_conv_key}linear." + converted_to_final_key_mapping[rf"{old_conv_key}\."] = new_conv_key + if i < len(config.keypoint_encoder_sizes) + 1: + old_batch_norm_key = f"{keypoint_encoder_key}.{(i - 1) * 3 + 1}.old" + new_batch_norm_key = f"{keypoint_encoder_key}.{new_index}.batch_norm." + converted_to_final_key_mapping[rf"{old_batch_norm_key}\."] = new_batch_norm_key + + conv_keys.append(f"{old_conv_key}.weight") + + # Attentional GNN + for i in range(len(config.gnn_layers_types)): + gnn_layer_key = f"gnn.layers.{i}" + ## Attention + attention_key = f"{gnn_layer_key}.attention" + conv_keys.extend( + [ + f"{attention_key}.self.query.weight", + f"{attention_key}.self.key.weight", + f"{attention_key}.self.value.weight", + f"{attention_key}.output.dense.weight", + ] + ) + qkv_permute_keys.extend( + [ + f"{attention_key}.self.query.weight", + f"{attention_key}.self.key.weight", + f"{attention_key}.self.value.weight", + f"{attention_key}.self.query.bias", + f"{attention_key}.self.key.bias", + f"{attention_key}.self.value.bias", + ] + ) + output_permute_keys.append(f"{attention_key}.output.dense.weight") + + ## MLP + conv_keys.extend([f"{gnn_layer_key}.mlp.0.linear.weight", f"{gnn_layer_key}.mlp.1.weight"]) + + # Final Projection + conv_keys.append("final_projection.final_proj.weight") + + convert_conv_to_linear(conv_keys) + qkv_permute_weights_and_biases(qkv_permute_keys) + output_permute_weights(output_permute_keys) + all_keys = list(state_dict.keys()) + new_keys = convert_old_keys_to_new_keys(all_keys, converted_to_final_key_mapping) + state_dict = replace_state_dict_keys(all_keys, new_keys, state_dict) + return state_dict + + +def add_keypoint_detector_state_dict(superglue_state_dict): + keypoint_detector = AutoModelForKeypointDetection.from_pretrained("magic-leap-community/superpoint") + keypoint_detector_state_dict = keypoint_detector.state_dict() + for k, v in keypoint_detector_state_dict.items(): + superglue_state_dict[f"keypoint_detector.{k}"] = v + return superglue_state_dict + + +@torch.no_grad() +def write_model( + model_path, + checkpoint_url, + safe_serialization=True, + push_to_hub=False, +): + os.makedirs(model_path, exist_ok=True) + + # ------------------------------------------------------------ + # SuperGlue config + # ------------------------------------------------------------ + + config = SuperGlueConfig( + hidden_size=256, + keypoint_encoder_sizes=[32, 64, 128, 256], + gnn_layers_types=["self", "cross"] * 9, + sinkhorn_iterations=100, + matching_threshold=0.0, + ) + config.architectures = ["SuperGlueForKeypointMatching"] + config.save_pretrained(model_path, push_to_hub=push_to_hub) + print("Model config saved successfully...") + + # ------------------------------------------------------------ + # Convert weights + # ------------------------------------------------------------ + + print(f"Fetching all parameters from the checkpoint at {checkpoint_url}...") + original_state_dict = torch.hub.load_state_dict_from_url(checkpoint_url) + + print("Converting model...") + all_keys = list(original_state_dict.keys()) + new_keys = convert_old_keys_to_new_keys(all_keys) + + state_dict = replace_state_dict_keys(all_keys, new_keys, original_state_dict) + state_dict = convert_state_dict(state_dict, config) + + del original_state_dict + gc.collect() + state_dict = add_keypoint_detector_state_dict(state_dict) + + print("Loading the checkpoint in a SuperGlue model...") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + with torch.device(device): + model = SuperGlueForKeypointMatching(config) + model.load_state_dict(state_dict, strict=True) + print("Checkpoint loaded successfully...") + del model.config._name_or_path + + print("Saving the model...") + model.save_pretrained(model_path, safe_serialization=safe_serialization) + del state_dict, model + + # Safety check: reload the converted model + gc.collect() + print("Reloading the model to check if it's saved correctly.") + model = SuperGlueForKeypointMatching.from_pretrained(model_path) + print("Model reloaded successfully.") + + model_name = "superglue" + if "superglue_outdoor.pth" in checkpoint_url: + model_name += "_outdoor" + elif "superglue_indoor.pth" in checkpoint_url: + model_name += "_indoor" + + print("Checking the model outputs...") + verify_model_outputs(model, model_name, device) + print("Model outputs verified successfully.") + + organization = "magic-leap-community" + if push_to_hub: + print("Pushing model to the hub...") + model.push_to_hub( + repo_id=f"{organization}/{model_name}", + commit_message="Add model", + ) + + write_image_processor(model_path, model_name, organization, push_to_hub=push_to_hub) + + +def write_image_processor(save_dir, model_name, organization, push_to_hub=False): + image_processor = SuperGlueImageProcessor() + image_processor.save_pretrained(save_dir) + + if push_to_hub: + print("Pushing image processor to the hub...") + image_processor.push_to_hub( + repo_id=f"{organization}/{model_name}", + commit_message="Add image processor", + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--checkpoint_url", + default="https://raw.githubusercontent.com/magicleap/SuperGluePretrainedNetwork/master/models/weights/superglue_indoor.pth", + type=str, + help="URL of the original SuperGlue checkpoint you'd like to convert.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", + default=None, + type=str, + required=True, + help="Path to the output PyTorch model directory.", + ) + parser.add_argument("--save_model", action="store_true", help="Save model to local") + parser.add_argument( + "--push_to_hub", + action="store_true", + help="Push model and image preprocessor to the hub", + ) + + args = parser.parse_args() + write_model( + args.pytorch_dump_folder_path, args.checkpoint_url, safe_serialization=True, push_to_hub=args.push_to_hub + ) diff --git a/docs/transformers/src/transformers/models/superglue/image_processing_superglue.py b/docs/transformers/src/transformers/models/superglue/image_processing_superglue.py new file mode 100644 index 0000000000000000000000000000000000000000..a74e1649961f485f3f5491c073678c383b193a12 --- /dev/null +++ b/docs/transformers/src/transformers/models/superglue/image_processing_superglue.py @@ -0,0 +1,409 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for SuperPoint.""" + +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union + +import numpy as np + +from ... import is_torch_available, is_vision_available +from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import resize, to_channel_dimension_format +from ...image_utils import ( + ChannelDimension, + ImageInput, + ImageType, + PILImageResampling, + get_image_type, + infer_channel_dimension_format, + is_pil_image, + is_scaled_image, + is_valid_image, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...utils import TensorType, logging, requires_backends +from ...utils.import_utils import requires + + +if is_torch_available(): + import torch + +if TYPE_CHECKING: + from .modeling_superglue import KeypointMatchingOutput + +if is_vision_available(): + import PIL + +logger = logging.get_logger(__name__) + + +# Copied from transformers.models.superpoint.image_processing_superpoint.is_grayscale +def is_grayscale( + image: ImageInput, + input_data_format: Optional[Union[str, ChannelDimension]] = None, +): + if input_data_format == ChannelDimension.FIRST: + if image.shape[0] == 1: + return True + return np.all(image[0, ...] == image[1, ...]) and np.all(image[1, ...] == image[2, ...]) + elif input_data_format == ChannelDimension.LAST: + if image.shape[-1] == 1: + return True + return np.all(image[..., 0] == image[..., 1]) and np.all(image[..., 1] == image[..., 2]) + + +# Copied from transformers.models.superpoint.image_processing_superpoint.convert_to_grayscale +def convert_to_grayscale( + image: ImageInput, + input_data_format: Optional[Union[str, ChannelDimension]] = None, +) -> ImageInput: + """ + Converts an image to grayscale format using the NTSC formula. Only support numpy and PIL Image. TODO support torch + and tensorflow grayscale conversion + + This function is supposed to return a 1-channel image, but it returns a 3-channel image with the same value in each + channel, because of an issue that is discussed in : + https://github.com/huggingface/transformers/pull/25786#issuecomment-1730176446 + + Args: + image (Image): + The image to convert. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. + """ + requires_backends(convert_to_grayscale, ["vision"]) + + if isinstance(image, np.ndarray): + if is_grayscale(image, input_data_format=input_data_format): + return image + if input_data_format == ChannelDimension.FIRST: + gray_image = image[0, ...] * 0.2989 + image[1, ...] * 0.5870 + image[2, ...] * 0.1140 + gray_image = np.stack([gray_image] * 3, axis=0) + elif input_data_format == ChannelDimension.LAST: + gray_image = image[..., 0] * 0.2989 + image[..., 1] * 0.5870 + image[..., 2] * 0.1140 + gray_image = np.stack([gray_image] * 3, axis=-1) + return gray_image + + if not isinstance(image, PIL.Image.Image): + return image + + image = image.convert("L") + return image + + +def validate_and_format_image_pairs(images: ImageInput): + error_message = ( + "Input images must be a one of the following :", + " - A pair of PIL images.", + " - A pair of 3D arrays.", + " - A list of pairs of PIL images.", + " - A list of pairs of 3D arrays.", + ) + + def _is_valid_image(image): + """images is a PIL Image or a 3D array.""" + return is_pil_image(image) or ( + is_valid_image(image) and get_image_type(image) != ImageType.PIL and len(image.shape) == 3 + ) + + if isinstance(images, list): + if len(images) == 2 and all((_is_valid_image(image)) for image in images): + return images + if all( + isinstance(image_pair, list) + and len(image_pair) == 2 + and all(_is_valid_image(image) for image in image_pair) + for image_pair in images + ): + return [image for image_pair in images for image in image_pair] + raise ValueError(error_message) + + +@requires(backends=("torch",)) +class SuperGlueImageProcessor(BaseImageProcessor): + r""" + Constructs a SuperGlue image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Controls whether to resize the image's (height, width) dimensions to the specified `size`. Can be overriden + by `do_resize` in the `preprocess` method. + size (`Dict[str, int]` *optional*, defaults to `{"height": 480, "width": 640}`): + Resolution of the output image after `resize` is applied. Only has an effect if `do_resize` is set to + `True`. Can be overriden by `size` in the `preprocess` method. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`): + Resampling filter to use if resizing the image. Can be overriden by `resample` in the `preprocess` method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overriden by `do_rescale` in + the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overriden by `rescale_factor` in the `preprocess` + method. + do_grayscale (`bool`, *optional*, defaults to `True`): + Whether to convert the image to grayscale. Can be overriden by `do_grayscale` in the `preprocess` method. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize: bool = True, + size: Dict[str, int] = None, + resample: PILImageResampling = PILImageResampling.BILINEAR, + do_rescale: bool = True, + rescale_factor: float = 1 / 255, + do_grayscale: bool = True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"height": 480, "width": 640} + size = get_size_dict(size, default_to_square=False) + + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_grayscale = do_grayscale + + # Copied from transformers.models.superpoint.image_processing_superpoint.SuperPointImageProcessor.resize + def resize( + self, + image: np.ndarray, + size: Dict[str, int], + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ): + """ + Resize an image. + + Args: + image (`np.ndarray`): + Image to resize. + size (`Dict[str, int]`): + Dictionary of the form `{"height": int, "width": int}`, specifying the size of the output image. + data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the output image. If not provided, it will be inferred from the input + image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + size = get_size_dict(size, default_to_square=False) + + return resize( + image, + size=(size["height"], size["width"]), + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + def preprocess( + self, + images, + do_resize: Optional[bool] = None, + size: Dict[str, int] = None, + resample: PILImageResampling = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_grayscale: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: ChannelDimension = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> BatchFeature: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image pairs to preprocess. Expects either a list of 2 images or a list of list of 2 images list with + pixel values ranging from 0 to 255. If passing in images with pixel values between 0 and 1, set + `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Size of the output image after `resize` has been applied. If `size["shortest_edge"]` >= 384, the image + is resized to `(size["shortest_edge"], size["shortest_edge"])`. Otherwise, the smaller edge of the + image will be matched to `int(size["shortest_edge"]/ crop_pct)`, after which the image is cropped to + `(size["shortest_edge"], size["shortest_edge"])`. Only has an effect if `do_resize` is set to `True`. + resample (`PILImageResampling`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of `PILImageResampling`, filters. Only + has an effect if `do_resize` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image values between [0 - 1]. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_grayscale (`bool`, *optional*, defaults to `self.do_grayscale`): + Whether to convert the image to grayscale. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + + do_resize = do_resize if do_resize is not None else self.do_resize + resample = resample if resample is not None else self.resample + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_grayscale = do_grayscale if do_grayscale is not None else self.do_grayscale + + size = size if size is not None else self.size + size = get_size_dict(size, default_to_square=False) + + # Validate and convert the input images into a flattened list of images for all subsequent processing steps. + images = validate_and_format_image_pairs(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + validate_preprocess_arguments( + do_resize=do_resize, + size=size, + resample=resample, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + ) + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if is_scaled_image(images[0]) and do_rescale: + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + all_images = [] + for image in images: + if do_resize: + image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format) + + if do_rescale: + image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + + if do_grayscale: + image = convert_to_grayscale(image, input_data_format=input_data_format) + + image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + all_images.append(image) + + # Convert back the flattened list of images into a list of pairs of images. + image_pairs = [all_images[i : i + 2] for i in range(0, len(all_images), 2)] + + data = {"pixel_values": image_pairs} + + return BatchFeature(data=data, tensor_type=return_tensors) + + def post_process_keypoint_matching( + self, + outputs: "KeypointMatchingOutput", + target_sizes: Union[TensorType, List[Tuple]], + threshold: float = 0.0, + ) -> List[Dict[str, torch.Tensor]]: + """ + Converts the raw output of [`KeypointMatchingOutput`] into lists of keypoints, scores and descriptors + with coordinates absolute to the original image sizes. + Args: + outputs ([`KeypointMatchingOutput`]): + Raw outputs of the model. + target_sizes (`torch.Tensor` or `List[Tuple[Tuple[int, int]]]`, *optional*): + Tensor of shape `(batch_size, 2, 2)` or list of tuples of tuples (`Tuple[int, int]`) containing the + target size `(height, width)` of each image in the batch. This must be the original image size (before + any processing). + threshold (`float`, *optional*, defaults to 0.0): + Threshold to filter out the matches with low scores. + Returns: + `List[Dict]`: A list of dictionaries, each dictionary containing the keypoints in the first and second image + of the pair, the matching scores and the matching indices. + """ + if outputs.mask.shape[0] != len(target_sizes): + raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the mask") + if not all(len(target_size) == 2 for target_size in target_sizes): + raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch") + + if isinstance(target_sizes, List): + image_pair_sizes = torch.tensor(target_sizes, device=outputs.mask.device) + else: + if target_sizes.shape[1] != 2 or target_sizes.shape[2] != 2: + raise ValueError( + "Each element of target_sizes must contain the size (h, w) of each image of the batch" + ) + image_pair_sizes = target_sizes + + keypoints = outputs.keypoints.clone() + keypoints = keypoints * image_pair_sizes.flip(-1).reshape(-1, 2, 1, 2) + keypoints = keypoints.to(torch.int32) + + results = [] + for mask_pair, keypoints_pair, matches, scores in zip( + outputs.mask, keypoints, outputs.matches[:, 0], outputs.matching_scores[:, 0] + ): + mask0 = mask_pair[0] > 0 + mask1 = mask_pair[1] > 0 + keypoints0 = keypoints_pair[0][mask0] + keypoints1 = keypoints_pair[1][mask1] + matches0 = matches[mask0] + scores0 = scores[mask0] + + # Filter out matches with low scores + valid_matches = torch.logical_and(scores0 > threshold, matches0 > -1) + + matched_keypoints0 = keypoints0[valid_matches] + matched_keypoints1 = keypoints1[matches0[valid_matches]] + matching_scores = scores0[valid_matches] + + results.append( + { + "keypoints0": matched_keypoints0, + "keypoints1": matched_keypoints1, + "matching_scores": matching_scores, + } + ) + + return results + + +__all__ = ["SuperGlueImageProcessor"] diff --git a/docs/transformers/src/transformers/models/superglue/modeling_superglue.py b/docs/transformers/src/transformers/models/superglue/modeling_superglue.py new file mode 100644 index 0000000000000000000000000000000000000000..049eb91b8451f3d6e59974cd8daef76aa78c486a --- /dev/null +++ b/docs/transformers/src/transformers/models/superglue/modeling_superglue.py @@ -0,0 +1,866 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch SuperGlue model.""" + +import math +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +from torch import nn + +from transformers import PreTrainedModel, add_start_docstrings +from transformers.models.superglue.configuration_superglue import SuperGlueConfig + +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import ModelOutput, add_start_docstrings_to_model_forward, logging +from ..auto import AutoModelForKeypointDetection + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC_ = "SuperGlueConfig" +_CHECKPOINT_FOR_DOC_ = "magic-leap-community/superglue_indoor" + + +def concat_pairs(tensor_tuple0: Tuple[torch.Tensor], tensor_tuple1: Tuple[torch.Tensor]) -> Tuple[torch.Tensor]: + """ + Concatenate two tuples of tensors pairwise + + Args: + tensor_tuple0 (`Tuple[torch.Tensor]`): + Tuple of tensors. + tensor_tuple1 (`Tuple[torch.Tensor]`): + Tuple of tensors. + + Returns: + (`Tuple[torch.Tensor]`): Tuple of concatenated tensors. + """ + return tuple([torch.cat([tensor0, tensor1]) for tensor0, tensor1 in zip(tensor_tuple0, tensor_tuple1)]) + + +def normalize_keypoints(keypoints: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + Normalize keypoints locations based on image image_shape + + Args: + keypoints (`torch.Tensor` of shape `(batch_size, num_keypoints, 2)`): + Keypoints locations in (x, y) format. + height (`int`): + Image height. + width (`int`): + Image width. + + Returns: + Normalized keypoints locations of shape (`torch.Tensor` of shape `(batch_size, num_keypoints, 2)`). + """ + size = torch.tensor([width, height], device=keypoints.device, dtype=keypoints.dtype)[None] + center = size / 2 + scaling = size.max(1, keepdim=True).values * 0.7 + return (keypoints - center[:, None, :]) / scaling[:, None, :] + + +def log_sinkhorn_iterations( + log_cost_matrix: torch.Tensor, + log_source_distribution: torch.Tensor, + log_target_distribution: torch.Tensor, + num_iterations: int, +) -> torch.Tensor: + """ + Perform Sinkhorn Normalization in Log-space for stability + + Args: + log_cost_matrix (`torch.Tensor` of shape `(batch_size, num_rows, num_columns)`): + Logarithm of the cost matrix. + log_source_distribution (`torch.Tensor` of shape `(batch_size, num_rows)`): + Logarithm of the source distribution. + log_target_distribution (`torch.Tensor` of shape `(batch_size, num_columns)`): + Logarithm of the target distribution. + + Returns: + log_cost_matrix (`torch.Tensor` of shape `(batch_size, num_rows, num_columns)`): Logarithm of the optimal + transport matrix. + """ + log_u_scaling = torch.zeros_like(log_source_distribution) + log_v_scaling = torch.zeros_like(log_target_distribution) + for _ in range(num_iterations): + log_u_scaling = log_source_distribution - torch.logsumexp(log_cost_matrix + log_v_scaling.unsqueeze(1), dim=2) + log_v_scaling = log_target_distribution - torch.logsumexp(log_cost_matrix + log_u_scaling.unsqueeze(2), dim=1) + return log_cost_matrix + log_u_scaling.unsqueeze(2) + log_v_scaling.unsqueeze(1) + + +def log_optimal_transport(scores: torch.Tensor, reg_param: torch.Tensor, iterations: int) -> torch.Tensor: + """ + Perform Differentiable Optimal Transport in Log-space for stability + + Args: + scores: (`torch.Tensor` of shape `(batch_size, num_rows, num_columns)`): + Cost matrix. + reg_param: (`torch.Tensor` of shape `(batch_size, 1, 1)`): + Regularization parameter. + iterations: (`int`): + Number of Sinkhorn iterations. + + Returns: + log_optimal_transport_matrix: (`torch.Tensor` of shape `(batch_size, num_rows, num_columns)`): Logarithm of the + optimal transport matrix. + """ + batch_size, num_rows, num_columns = scores.shape + one_tensor = scores.new_tensor(1) + num_rows_tensor, num_columns_tensor = (num_rows * one_tensor).to(scores), (num_columns * one_tensor).to(scores) + + source_reg_param = reg_param.expand(batch_size, num_rows, 1) + target_reg_param = reg_param.expand(batch_size, 1, num_columns) + reg_param = reg_param.expand(batch_size, 1, 1) + + couplings = torch.cat([torch.cat([scores, source_reg_param], -1), torch.cat([target_reg_param, reg_param], -1)], 1) + + log_normalization = -(num_rows_tensor + num_columns_tensor).log() + log_source_distribution = torch.cat( + [log_normalization.expand(num_rows), num_columns_tensor.log()[None] + log_normalization] + ) + log_target_distribution = torch.cat( + [log_normalization.expand(num_columns), num_rows_tensor.log()[None] + log_normalization] + ) + log_source_distribution, log_target_distribution = ( + log_source_distribution[None].expand(batch_size, -1), + log_target_distribution[None].expand(batch_size, -1), + ) + + log_optimal_transport_matrix = log_sinkhorn_iterations( + couplings, log_source_distribution, log_target_distribution, num_iterations=iterations + ) + log_optimal_transport_matrix = log_optimal_transport_matrix - log_normalization # multiply probabilities by M+N + return log_optimal_transport_matrix + + +def arange_like(x, dim: int) -> torch.Tensor: + return x.new_ones(x.shape[dim]).cumsum(0) - 1 + + +@dataclass +class KeypointMatchingOutput(ModelOutput): + """ + Base class for outputs of keypoint matching models. Due to the nature of keypoint detection and matching, the number + of keypoints is not fixed and can vary from image to image, which makes batching non-trivial. In the batch of + images, the maximum number of matches is set as the dimension of the matches and matching scores. The mask tensor is + used to indicate which values in the keypoints, matches and matching_scores tensors are keypoint matching + information. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*): + Loss computed during training. + mask (`torch.IntTensor` of shape `(batch_size, num_keypoints)`): + Mask indicating which values in matches and matching_scores are keypoint matching information. + matches (`torch.FloatTensor` of shape `(batch_size, 2, num_matches)`): + Index of keypoint matched in the other image. + matching_scores (`torch.FloatTensor` of shape `(batch_size, 2, num_matches)`): + Scores of predicted matches. + keypoints (`torch.FloatTensor` of shape `(batch_size, num_keypoints, 2)`): + Absolute (x, y) coordinates of predicted keypoints in a given image. + hidden_states (`Tuple[torch.FloatTensor, ...]`, *optional*): + Tuple of `torch.FloatTensor` (one for the output of each stage) of shape `(batch_size, 2, num_channels, + num_keypoints)`, returned when `output_hidden_states=True` is passed or when + `config.output_hidden_states=True`) + attentions (`Tuple[torch.FloatTensor, ...]`, *optional*): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, 2, num_heads, num_keypoints, + num_keypoints)`, returned when `output_attentions=True` is passed or when `config.output_attentions=True`) + """ + + loss: Optional[torch.FloatTensor] = None + matches: Optional[torch.FloatTensor] = None + matching_scores: Optional[torch.FloatTensor] = None + keypoints: Optional[torch.FloatTensor] = None + mask: Optional[torch.IntTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +class SuperGlueMultiLayerPerceptron(nn.Module): + def __init__(self, config: SuperGlueConfig, in_channels: int, out_channels: int) -> None: + super().__init__() + self.linear = nn.Linear(in_channels, out_channels) + self.batch_norm = nn.BatchNorm1d(out_channels) + self.activation = nn.ReLU() + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + hidden_state = self.linear(hidden_state) + hidden_state = hidden_state.transpose(-1, -2) + hidden_state = self.batch_norm(hidden_state) + hidden_state = hidden_state.transpose(-1, -2) + hidden_state = self.activation(hidden_state) + return hidden_state + + +class SuperGlueKeypointEncoder(nn.Module): + def __init__(self, config: SuperGlueConfig) -> None: + super().__init__() + layer_sizes = config.keypoint_encoder_sizes + hidden_size = config.hidden_size + # 3 here consists of 2 for the (x, y) coordinates and 1 for the score of the keypoint + encoder_channels = [3] + layer_sizes + [hidden_size] + + layers = [ + SuperGlueMultiLayerPerceptron(config, encoder_channels[i - 1], encoder_channels[i]) + for i in range(1, len(encoder_channels) - 1) + ] + layers.append(nn.Linear(encoder_channels[-2], encoder_channels[-1])) + self.encoder = nn.ModuleList(layers) + + def forward( + self, + keypoints: torch.Tensor, + scores: torch.Tensor, + output_hidden_states: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]: + scores = scores.unsqueeze(2) + hidden_state = torch.cat([keypoints, scores], dim=2) + all_hidden_states = () if output_hidden_states else None + for layer in self.encoder: + hidden_state = layer(hidden_state) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_state,) + return hidden_state, all_hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->SuperGlue +class SuperGlueSelfAttention(nn.Module): + def __init__(self, config, position_embedding_type=None): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = position_embedding_type or getattr( + config, "position_embedding_type", "absolute" + ) + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) + + self.is_decoder = config.is_decoder + + def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_layer = past_key_value[0] + value_layer = past_key_value[1] + attention_mask = encoder_attention_mask + elif is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + use_cache = past_key_value is not None + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + query_length, key_length = query_layer.shape[2], key_layer.shape[2] + if use_cache: + position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( + -1, 1 + ) + else: + position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in SuperGlueModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + +class SuperGlueSelfOutput(nn.Module): + def __init__(self, config: SuperGlueConfig): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor, *args) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + return hidden_states + + +SUPERGLUE_SELF_ATTENTION_CLASSES = { + "eager": SuperGlueSelfAttention, +} + + +# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->SuperGlue,BERT->SUPERGLUE +class SuperGlueAttention(nn.Module): + def __init__(self, config, position_embedding_type=None): + super().__init__() + self.self = SUPERGLUE_SELF_ATTENTION_CLASSES[config._attn_implementation]( + config, position_embedding_type=position_embedding_type + ) + self.output = SuperGlueSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class SuperGlueAttentionalPropagation(nn.Module): + def __init__(self, config: SuperGlueConfig) -> None: + super().__init__() + hidden_size = config.hidden_size + self.attention = SuperGlueAttention(config) + mlp_channels = [hidden_size * 2, hidden_size * 2, hidden_size] + layers = [ + SuperGlueMultiLayerPerceptron(config, mlp_channels[i - 1], mlp_channels[i]) + for i in range(1, len(mlp_channels) - 1) + ] + layers.append(nn.Linear(mlp_channels[-2], mlp_channels[-1])) + self.mlp = nn.ModuleList(layers) + + def forward( + self, + descriptors: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]], Optional[Tuple[torch.Tensor]]]: + attention_outputs = self.attention( + descriptors, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + ) + output = attention_outputs[0] + attention = attention_outputs[1:] + + hidden_state = torch.cat([descriptors, output], dim=2) + + all_hidden_states = () if output_hidden_states else None + for layer in self.mlp: + hidden_state = layer(hidden_state) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_state,) + + return hidden_state, all_hidden_states, attention + + +class SuperGlueAttentionalGNN(nn.Module): + def __init__(self, config: SuperGlueConfig) -> None: + super().__init__() + self.hidden_size = config.hidden_size + self.layers_types = config.gnn_layers_types + self.layers = nn.ModuleList([SuperGlueAttentionalPropagation(config) for _ in range(len(self.layers_types))]) + + def forward( + self, + descriptors: torch.Tensor, + mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + output_hidden_states: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[Tuple], Optional[Tuple]]: + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + batch_size, num_keypoints, _ = descriptors.shape + if output_hidden_states: + all_hidden_states = all_hidden_states + (descriptors,) + + for gnn_layer, layer_type in zip(self.layers, self.layers_types): + encoder_hidden_states = None + encoder_attention_mask = None + if layer_type == "cross": + encoder_hidden_states = ( + descriptors.reshape(-1, 2, num_keypoints, self.hidden_size) + .flip(1) + .reshape(batch_size, num_keypoints, self.hidden_size) + ) + encoder_attention_mask = ( + mask.reshape(-1, 2, 1, 1, num_keypoints).flip(1).reshape(batch_size, 1, 1, num_keypoints) + if mask is not None + else None + ) + + gnn_outputs = gnn_layer( + descriptors, + attention_mask=mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + ) + delta = gnn_outputs[0] + + if output_hidden_states: + all_hidden_states = all_hidden_states + gnn_outputs[1] + if output_attentions: + all_attentions = all_attentions + gnn_outputs[2] + + descriptors = descriptors + delta + return descriptors, all_hidden_states, all_attentions + + +class SuperGlueFinalProjection(nn.Module): + def __init__(self, config: SuperGlueConfig) -> None: + super().__init__() + hidden_size = config.hidden_size + self.final_proj = nn.Linear(hidden_size, hidden_size, bias=True) + + def forward(self, descriptors: torch.Tensor) -> torch.Tensor: + return self.final_proj(descriptors) + + +class SuperGluePreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = SuperGlueConfig + base_model_prefix = "superglue" + main_input_name = "pixel_values" + + def _init_weights(self, module: nn.Module) -> None: + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d, nn.Conv1d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, SuperGlueMultiLayerPerceptron): + nn.init.constant_(module.linear.bias, 0.0) + + +SUPERGLUE_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it + as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`SuperGlueConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. + """ + +SUPERGLUE_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`SuperGlueImageProcessor`]. See + [`SuperGlueImageProcessor.__call__`] for details. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors. See `attentions` under returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "SuperGlue model taking images as inputs and outputting the matching of them.", + SUPERGLUE_START_DOCSTRING, +) +class SuperGlueForKeypointMatching(SuperGluePreTrainedModel): + """SuperGlue feature matching middle-end + + Given two sets of keypoints and locations, we determine the + correspondences by: + 1. Keypoint Encoding (normalization + visual feature and location fusion) + 2. Graph Neural Network with multiple self and cross-attention layers + 3. Final projection layer + 4. Optimal Transport Layer (a differentiable Hungarian matching algorithm) + 5. Thresholding matrix based on mutual exclusivity and a match_threshold + + The correspondence ids use -1 to indicate non-matching points. + + Paul-Edouard Sarlin, Daniel DeTone, Tomasz Malisiewicz, and Andrew + Rabinovich. SuperGlue: Learning Feature Matching with Graph Neural + Networks. In CVPR, 2020. https://arxiv.org/abs/1911.11763 + """ + + def __init__(self, config: SuperGlueConfig) -> None: + super().__init__(config) + + self.keypoint_detector = AutoModelForKeypointDetection.from_config(config.keypoint_detector_config) + + self.keypoint_encoder = SuperGlueKeypointEncoder(config) + self.gnn = SuperGlueAttentionalGNN(config) + self.final_projection = SuperGlueFinalProjection(config) + + bin_score = torch.nn.Parameter(torch.tensor(1.0)) + self.register_parameter("bin_score", bin_score) + + self.post_init() + + def _match_image_pair( + self, + keypoints: torch.Tensor, + descriptors: torch.Tensor, + scores: torch.Tensor, + height: int, + width: int, + mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, Tuple, Tuple]: + """ + Perform keypoint matching between two images. + + Args: + keypoints (`torch.Tensor` of shape `(batch_size, 2, num_keypoints, 2)`): + Keypoints detected in the pair of image. + descriptors (`torch.Tensor` of shape `(batch_size, 2, descriptor_dim, num_keypoints)`): + Descriptors of the keypoints detected in the image pair. + scores (`torch.Tensor` of shape `(batch_size, 2, num_keypoints)`): + Confidence scores of the keypoints detected in the image pair. + height (`int`): Image height. + width (`int`): Image width. + mask (`torch.Tensor` of shape `(batch_size, 2, num_keypoints)`, *optional*): + Mask indicating which values in the keypoints, matches and matching_scores tensors are keypoint matching + information. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors. Default to `config.output_attentions`. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. Default to `config.output_hidden_states`. + + Returns: + matches (`torch.Tensor` of shape `(batch_size, 2, num_keypoints)`): + For each image pair, for each keypoint in image0, the index of the keypoint in image1 that was matched + with. And for each keypoint in image1, the index of the keypoint in image0 that was matched with. + matching_scores (`torch.Tensor` of shape `(batch_size, 2, num_keypoints)`): + Scores of predicted matches for each image pair + all_hidden_states (`tuple(torch.FloatTensor)`, *optional*): + Tuple of `torch.FloatTensor` (one for the output of each stage) of shape `(1, 2, num_keypoints, + num_channels)`. + all_attentions (`tuple(torch.FloatTensor)`, *optional*): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(1, 2, num_heads, num_keypoints, + num_keypoints)`. + """ + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + if keypoints.shape[2] == 0: # no keypoints + shape = keypoints.shape[:-1] + return ( + keypoints.new_full(shape, -1, dtype=torch.int), + keypoints.new_zeros(shape), + all_hidden_states, + all_attentions, + ) + + batch_size, _, num_keypoints, _ = keypoints.shape + # (batch_size, 2, num_keypoints, 2) -> (batch_size * 2, num_keypoints, 2) + keypoints = keypoints.reshape(batch_size * 2, num_keypoints, 2) + descriptors = descriptors.reshape(batch_size * 2, num_keypoints, self.config.hidden_size) + scores = scores.reshape(batch_size * 2, num_keypoints) + mask = mask.reshape(batch_size * 2, num_keypoints) if mask is not None else None + + # Keypoint normalization + keypoints = normalize_keypoints(keypoints, height, width) + + encoded_keypoints = self.keypoint_encoder(keypoints, scores, output_hidden_states=output_hidden_states) + + last_hidden_state = encoded_keypoints[0] + + # Keypoint MLP encoder. + descriptors = descriptors + last_hidden_state + + if mask is not None: + input_shape = descriptors.size() + extended_attention_mask = self.get_extended_attention_mask(mask, input_shape) + else: + extended_attention_mask = torch.ones((batch_size, num_keypoints), device=keypoints.device) + + # Multi-layer Transformer network. + gnn_outputs = self.gnn( + descriptors, + mask=extended_attention_mask, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + ) + descriptors = gnn_outputs[0] + + # Final MLP projection. + projected_descriptors = self.final_projection(descriptors) + + # (batch_size * 2, num_keypoints, descriptor_dim) -> (batch_size, 2, num_keypoints, descriptor_dim) + final_descriptors = projected_descriptors.reshape(batch_size, 2, num_keypoints, self.config.hidden_size) + final_descriptors0 = final_descriptors[:, 0] + final_descriptors1 = final_descriptors[:, 1] + + # Compute matching descriptor distance. + scores = final_descriptors0 @ final_descriptors1.transpose(1, 2) + scores = scores / self.config.hidden_size**0.5 + + if mask is not None: + mask = mask.reshape(batch_size, 2, num_keypoints) + mask0 = mask[:, 0].unsqueeze(-1).expand(-1, -1, num_keypoints) + scores = scores.masked_fill(mask0 == 0, -1e9) + + # Run the optimal transport. + scores = log_optimal_transport(scores, self.bin_score, iterations=self.config.sinkhorn_iterations) + + # Get the matches with score above "match_threshold". + max0 = scores[:, :-1, :-1].max(2) + max1 = scores[:, :-1, :-1].max(1) + indices0 = max0.indices + indices1 = max1.indices + mutual0 = arange_like(indices0, 1)[None] == indices1.gather(1, indices0) + mutual1 = arange_like(indices1, 1)[None] == indices0.gather(1, indices1) + zero = scores.new_tensor(0) + matching_scores0 = torch.where(mutual0, max0.values.exp(), zero) + matching_scores0 = torch.where(matching_scores0 > self.config.matching_threshold, matching_scores0, zero) + matching_scores1 = torch.where(mutual1, matching_scores0.gather(1, indices1), zero) + valid0 = mutual0 & (matching_scores0 > zero) + valid1 = mutual1 & valid0.gather(1, indices1) + matches0 = torch.where(valid0, indices0, indices0.new_tensor(-1)) + matches1 = torch.where(valid1, indices1, indices1.new_tensor(-1)) + + matches = torch.cat([matches0, matches1]).reshape(batch_size, 2, -1) + matching_scores = torch.cat([matching_scores0, matching_scores1]).reshape(batch_size, 2, -1) + + if output_hidden_states: + all_hidden_states = all_hidden_states + encoded_keypoints[1] + all_hidden_states = all_hidden_states + gnn_outputs[1] + all_hidden_states = all_hidden_states + (projected_descriptors,) + all_hidden_states = tuple( + x.reshape(batch_size, 2, num_keypoints, -1).transpose(-1, -2) for x in all_hidden_states + ) + if output_attentions: + all_attentions = all_attentions + gnn_outputs[2] + all_attentions = tuple(x.reshape(batch_size, 2, -1, num_keypoints, num_keypoints) for x in all_attentions) + + return ( + matches, + matching_scores, + all_hidden_states, + all_attentions, + ) + + @add_start_docstrings_to_model_forward(SUPERGLUE_INPUTS_DOCSTRING) + def forward( + self, + pixel_values: torch.FloatTensor, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, KeypointMatchingOutput]: + """ + Examples: + + ```python + >>> from transformers import AutoImageProcessor, AutoModel + >>> import torch + >>> from PIL import Image + >>> import requests + + >>> url = "https://github.com/magicleap/SuperGluePretrainedNetwork/blob/master/assets/phototourism_sample_images/london_bridge_78916675_4568141288.jpg?raw=true" + >>> image1 = Image.open(requests.get(url, stream=True).raw) + >>> url = "https://github.com/magicleap/SuperGluePretrainedNetwork/blob/master/assets/phototourism_sample_images/london_bridge_19481797_2295892421.jpg?raw=true" + >>> image2 = Image.open(requests.get(url, stream=True).raw) + >>> images = [image1, image2] + + >>> processor = AutoImageProcessor.from_pretrained("magic-leap-community/superglue_outdoor") + >>> model = AutoModel.from_pretrained("magic-leap-community/superglue_outdoor") + + >>> with torch.no_grad(): + >>> inputs = processor(images, return_tensors="pt") + >>> outputs = model(**inputs) + ```""" + loss = None + if labels is not None: + raise ValueError("SuperGlue is not trainable, no labels should be provided.") + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values.ndim != 5 or pixel_values.size(1) != 2: + raise ValueError("Input must be a 5D tensor of shape (batch_size, 2, num_channels, height, width)") + + batch_size, _, channels, height, width = pixel_values.shape + pixel_values = pixel_values.reshape(batch_size * 2, channels, height, width) + keypoint_detections = self.keypoint_detector(pixel_values) + + keypoints, scores, descriptors, mask = keypoint_detections[:4] + keypoints = keypoints.reshape(batch_size, 2, -1, 2).to(pixel_values) + scores = scores.reshape(batch_size, 2, -1).to(pixel_values) + descriptors = descriptors.reshape(batch_size, 2, -1, self.config.hidden_size).to(pixel_values) + mask = mask.reshape(batch_size, 2, -1) + + absolute_keypoints = keypoints.clone() + absolute_keypoints[:, :, :, 0] = absolute_keypoints[:, :, :, 0] * width + absolute_keypoints[:, :, :, 1] = absolute_keypoints[:, :, :, 1] * height + + matches, matching_scores, hidden_states, attentions = self._match_image_pair( + absolute_keypoints, + descriptors, + scores, + height, + width, + mask=mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + if not return_dict: + return tuple( + v + for v in [loss, matches, matching_scores, keypoints, mask, hidden_states, attentions] + if v is not None + ) + + return KeypointMatchingOutput( + loss=loss, + matches=matches, + matching_scores=matching_scores, + keypoints=keypoints, + mask=mask, + hidden_states=hidden_states, + attentions=attentions, + ) + + +__all__ = ["SuperGluePreTrainedModel", "SuperGlueForKeypointMatching"] diff --git a/docs/transformers/src/transformers/models/superpoint/__init__.py b/docs/transformers/src/transformers/models/superpoint/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..aab40abaa86d5a0dffb3252c659b15e3b650485b --- /dev/null +++ b/docs/transformers/src/transformers/models/superpoint/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_superpoint import * + from .image_processing_superpoint import * + from .modeling_superpoint import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/superpoint/configuration_superpoint.py b/docs/transformers/src/transformers/models/superpoint/configuration_superpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..82104e682bcbd2e4a4cf547e037c17d0b2992ad6 --- /dev/null +++ b/docs/transformers/src/transformers/models/superpoint/configuration_superpoint.py @@ -0,0 +1,90 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class SuperPointConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`SuperPointForKeypointDetection`]. It is used to instantiate a + SuperPoint model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the SuperPoint + [magic-leap-community/superpoint](https://huggingface.co/magic-leap-community/superpoint) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + encoder_hidden_sizes (`List`, *optional*, defaults to `[64, 64, 128, 128]`): + The number of channels in each convolutional layer in the encoder. + decoder_hidden_size (`int`, *optional*, defaults to 256): The hidden size of the decoder. + keypoint_decoder_dim (`int`, *optional*, defaults to 65): The output dimension of the keypoint decoder. + descriptor_decoder_dim (`int`, *optional*, defaults to 256): The output dimension of the descriptor decoder. + keypoint_threshold (`float`, *optional*, defaults to 0.005): + The threshold to use for extracting keypoints. + max_keypoints (`int`, *optional*, defaults to -1): + The maximum number of keypoints to extract. If `-1`, will extract all keypoints. + nms_radius (`int`, *optional*, defaults to 4): + The radius for non-maximum suppression. + border_removal_distance (`int`, *optional*, defaults to 4): + The distance from the border to remove keypoints. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + + Example: + ```python + >>> from transformers import SuperPointConfig, SuperPointForKeypointDetection + + >>> # Initializing a SuperPoint superpoint style configuration + >>> configuration = SuperPointConfig() + >>> # Initializing a model from the superpoint style configuration + >>> model = SuperPointForKeypointDetection(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "superpoint" + + def __init__( + self, + encoder_hidden_sizes: List[int] = [64, 64, 128, 128], + decoder_hidden_size: int = 256, + keypoint_decoder_dim: int = 65, + descriptor_decoder_dim: int = 256, + keypoint_threshold: float = 0.005, + max_keypoints: int = -1, + nms_radius: int = 4, + border_removal_distance: int = 4, + initializer_range=0.02, + **kwargs, + ): + self.encoder_hidden_sizes = encoder_hidden_sizes + self.decoder_hidden_size = decoder_hidden_size + self.keypoint_decoder_dim = keypoint_decoder_dim + self.descriptor_decoder_dim = descriptor_decoder_dim + self.keypoint_threshold = keypoint_threshold + self.max_keypoints = max_keypoints + self.nms_radius = nms_radius + self.border_removal_distance = border_removal_distance + self.initializer_range = initializer_range + + super().__init__(**kwargs) + + +__all__ = ["SuperPointConfig"] diff --git a/docs/transformers/src/transformers/models/superpoint/convert_superpoint_to_pytorch.py b/docs/transformers/src/transformers/models/superpoint/convert_superpoint_to_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..007966a0557a3f49048efb29c87c8897c50ea350 --- /dev/null +++ b/docs/transformers/src/transformers/models/superpoint/convert_superpoint_to_pytorch.py @@ -0,0 +1,175 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import os + +import requests +import torch +from PIL import Image + +from transformers import SuperPointConfig, SuperPointForKeypointDetection, SuperPointImageProcessor + + +def get_superpoint_config(): + config = SuperPointConfig( + encoder_hidden_sizes=[64, 64, 128, 128], + decoder_hidden_size=256, + keypoint_decoder_dim=65, + descriptor_decoder_dim=256, + keypoint_threshold=0.005, + max_keypoints=-1, + nms_radius=4, + border_removal_distance=4, + initializer_range=0.02, + ) + + return config + + +def create_rename_keys(config, state_dict): + rename_keys = [] + + # Encoder weights + rename_keys.append(("conv1a.weight", "encoder.conv_blocks.0.conv_a.weight")) + rename_keys.append(("conv1b.weight", "encoder.conv_blocks.0.conv_b.weight")) + rename_keys.append(("conv2a.weight", "encoder.conv_blocks.1.conv_a.weight")) + rename_keys.append(("conv2b.weight", "encoder.conv_blocks.1.conv_b.weight")) + rename_keys.append(("conv3a.weight", "encoder.conv_blocks.2.conv_a.weight")) + rename_keys.append(("conv3b.weight", "encoder.conv_blocks.2.conv_b.weight")) + rename_keys.append(("conv4a.weight", "encoder.conv_blocks.3.conv_a.weight")) + rename_keys.append(("conv4b.weight", "encoder.conv_blocks.3.conv_b.weight")) + rename_keys.append(("conv1a.bias", "encoder.conv_blocks.0.conv_a.bias")) + rename_keys.append(("conv1b.bias", "encoder.conv_blocks.0.conv_b.bias")) + rename_keys.append(("conv2a.bias", "encoder.conv_blocks.1.conv_a.bias")) + rename_keys.append(("conv2b.bias", "encoder.conv_blocks.1.conv_b.bias")) + rename_keys.append(("conv3a.bias", "encoder.conv_blocks.2.conv_a.bias")) + rename_keys.append(("conv3b.bias", "encoder.conv_blocks.2.conv_b.bias")) + rename_keys.append(("conv4a.bias", "encoder.conv_blocks.3.conv_a.bias")) + rename_keys.append(("conv4b.bias", "encoder.conv_blocks.3.conv_b.bias")) + + # Keypoint Decoder weights + rename_keys.append(("convPa.weight", "keypoint_decoder.conv_score_a.weight")) + rename_keys.append(("convPb.weight", "keypoint_decoder.conv_score_b.weight")) + rename_keys.append(("convPa.bias", "keypoint_decoder.conv_score_a.bias")) + rename_keys.append(("convPb.bias", "keypoint_decoder.conv_score_b.bias")) + + # Descriptor Decoder weights + rename_keys.append(("convDa.weight", "descriptor_decoder.conv_descriptor_a.weight")) + rename_keys.append(("convDb.weight", "descriptor_decoder.conv_descriptor_b.weight")) + rename_keys.append(("convDa.bias", "descriptor_decoder.conv_descriptor_a.bias")) + rename_keys.append(("convDb.bias", "descriptor_decoder.conv_descriptor_b.bias")) + + return rename_keys + + +def rename_key(dct, old, new): + val = dct.pop(old) + dct[new] = val + + +def prepare_imgs(): + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + im1 = Image.open(requests.get(url, stream=True).raw) + url = "http://images.cocodataset.org/test-stuff2017/000000004016.jpg" + im2 = Image.open(requests.get(url, stream=True).raw) + return [im1, im2] + + +@torch.no_grad() +def convert_superpoint_checkpoint(checkpoint_url, pytorch_dump_folder_path, save_model, push_to_hub, test_mode=False): + """ + Copy/paste/tweak model's weights to our SuperPoint structure. + """ + + print("Downloading original model from checkpoint...") + config = get_superpoint_config() + + # load original state_dict from URL + original_state_dict = torch.hub.load_state_dict_from_url(checkpoint_url) + + print("Converting model parameters...") + # rename keys + rename_keys = create_rename_keys(config, original_state_dict) + new_state_dict = original_state_dict.copy() + for src, dest in rename_keys: + rename_key(new_state_dict, src, dest) + + # Load HuggingFace model + model = SuperPointForKeypointDetection(config) + model.load_state_dict(new_state_dict) + model.eval() + print("Successfully loaded weights in the model") + + # Check model outputs + preprocessor = SuperPointImageProcessor() + inputs = preprocessor(images=prepare_imgs(), return_tensors="pt") + outputs = model(**inputs) + + # If test_mode is True, we check that the model outputs match the original results + if test_mode: + torch.count_nonzero(outputs.mask[0]) + expected_keypoints_shape = (2, 830, 2) + expected_scores_shape = (2, 830) + expected_descriptors_shape = (2, 830, 256) + + expected_keypoints_values = torch.tensor([[480.0, 9.0], [494.0, 9.0], [489.0, 16.0]]) + expected_scores_values = torch.tensor([0.0064, 0.0140, 0.0595, 0.0728, 0.5170, 0.0175, 0.1523, 0.2055, 0.0336]) + expected_descriptors_value = torch.tensor(-0.1096) + assert outputs.keypoints.shape == expected_keypoints_shape + assert outputs.scores.shape == expected_scores_shape + assert outputs.descriptors.shape == expected_descriptors_shape + + assert torch.allclose(outputs.keypoints[0, :3], expected_keypoints_values, atol=1e-3) + assert torch.allclose(outputs.scores[0, :9], expected_scores_values, atol=1e-3) + assert torch.allclose(outputs.descriptors[0, 0, 0], expected_descriptors_value, atol=1e-3) + print("Model outputs match the original results!") + + if save_model: + print("Saving model to local...") + # Create folder to save model + if not os.path.isdir(pytorch_dump_folder_path): + os.mkdir(pytorch_dump_folder_path) + + model.save_pretrained(pytorch_dump_folder_path) + preprocessor.save_pretrained(pytorch_dump_folder_path) + + model_name = "magic-leap-community/superpoint" + if push_to_hub: + print(f"Pushing {model_name} to the hub...") + model.push_to_hub(model_name) + preprocessor.push_to_hub(model_name) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--checkpoint_url", + default="https://github.com/magicleap/SuperPointPretrainedNetwork/raw/master/superpoint_v1.pth", + type=str, + help="URL of the original SuperPoint checkpoint you'd like to convert.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", + default="model", + type=str, + help="Path to the output PyTorch model directory.", + ) + parser.add_argument("--save_model", action="store_true", help="Save model to local") + parser.add_argument("--push_to_hub", action="store_true", help="Push model and image preprocessor to the hub") + + args = parser.parse_args() + convert_superpoint_checkpoint( + args.checkpoint_url, args.pytorch_dump_folder_path, args.save_model, args.push_to_hub + ) diff --git a/docs/transformers/src/transformers/models/superpoint/image_processing_superpoint.py b/docs/transformers/src/transformers/models/superpoint/image_processing_superpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..e1ad94613f090cceb50415597e39123f21569180 --- /dev/null +++ b/docs/transformers/src/transformers/models/superpoint/image_processing_superpoint.py @@ -0,0 +1,342 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for SuperPoint.""" + +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union + +import numpy as np + +from ... import is_torch_available, is_vision_available +from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import resize, to_channel_dimension_format +from ...image_utils import ( + ChannelDimension, + ImageInput, + infer_channel_dimension_format, + is_scaled_image, + make_list_of_images, + to_numpy_array, + valid_images, +) +from ...utils import TensorType, logging, requires_backends + + +if is_torch_available(): + import torch + +if TYPE_CHECKING: + from .modeling_superpoint import SuperPointKeypointDescriptionOutput + +if is_vision_available(): + import PIL + +logger = logging.get_logger(__name__) + + +def is_grayscale( + image: ImageInput, + input_data_format: Optional[Union[str, ChannelDimension]] = None, +): + if input_data_format == ChannelDimension.FIRST: + if image.shape[0] == 1: + return True + return np.all(image[0, ...] == image[1, ...]) and np.all(image[1, ...] == image[2, ...]) + elif input_data_format == ChannelDimension.LAST: + if image.shape[-1] == 1: + return True + return np.all(image[..., 0] == image[..., 1]) and np.all(image[..., 1] == image[..., 2]) + + +def convert_to_grayscale( + image: ImageInput, + input_data_format: Optional[Union[str, ChannelDimension]] = None, +) -> ImageInput: + """ + Converts an image to grayscale format using the NTSC formula. Only support numpy and PIL Image. TODO support torch + and tensorflow grayscale conversion + + This function is supposed to return a 1-channel image, but it returns a 3-channel image with the same value in each + channel, because of an issue that is discussed in : + https://github.com/huggingface/transformers/pull/25786#issuecomment-1730176446 + + Args: + image (Image): + The image to convert. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. + """ + requires_backends(convert_to_grayscale, ["vision"]) + + if isinstance(image, np.ndarray): + if is_grayscale(image, input_data_format=input_data_format): + return image + if input_data_format == ChannelDimension.FIRST: + gray_image = image[0, ...] * 0.2989 + image[1, ...] * 0.5870 + image[2, ...] * 0.1140 + gray_image = np.stack([gray_image] * 3, axis=0) + elif input_data_format == ChannelDimension.LAST: + gray_image = image[..., 0] * 0.2989 + image[..., 1] * 0.5870 + image[..., 2] * 0.1140 + gray_image = np.stack([gray_image] * 3, axis=-1) + return gray_image + + if not isinstance(image, PIL.Image.Image): + return image + + image = image.convert("L") + return image + + +class SuperPointImageProcessor(BaseImageProcessor): + r""" + Constructs a SuperPoint image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Controls whether to resize the image's (height, width) dimensions to the specified `size`. Can be overriden + by `do_resize` in the `preprocess` method. + size (`Dict[str, int]` *optional*, defaults to `{"height": 480, "width": 640}`): + Resolution of the output image after `resize` is applied. Only has an effect if `do_resize` is set to + `True`. Can be overriden by `size` in the `preprocess` method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overriden by `do_rescale` in + the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overriden by `rescale_factor` in the `preprocess` + method. + do_grayscale (`bool`, *optional*, defaults to `False`): + Whether to convert the image to grayscale. Can be overriden by `do_grayscale` in the `preprocess` method. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize: bool = True, + size: Dict[str, int] = None, + do_rescale: bool = True, + rescale_factor: float = 1 / 255, + do_grayscale: bool = False, + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"height": 480, "width": 640} + size = get_size_dict(size, default_to_square=False) + + self.do_resize = do_resize + self.size = size + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_grayscale = do_grayscale + + def resize( + self, + image: np.ndarray, + size: Dict[str, int], + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ): + """ + Resize an image. + + Args: + image (`np.ndarray`): + Image to resize. + size (`Dict[str, int]`): + Dictionary of the form `{"height": int, "width": int}`, specifying the size of the output image. + data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the output image. If not provided, it will be inferred from the input + image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + size = get_size_dict(size, default_to_square=False) + + return resize( + image, + size=(size["height"], size["width"]), + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + def preprocess( + self, + images, + do_resize: Optional[bool] = None, + size: Dict[str, int] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_grayscale: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: ChannelDimension = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> BatchFeature: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Size of the output image after `resize` has been applied. If `size["shortest_edge"]` >= 384, the image + is resized to `(size["shortest_edge"], size["shortest_edge"])`. Otherwise, the smaller edge of the + image will be matched to `int(size["shortest_edge"]/ crop_pct)`, after which the image is cropped to + `(size["shortest_edge"], size["shortest_edge"])`. Only has an effect if `do_resize` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image values between [0 - 1]. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_grayscale (`bool`, *optional*, defaults to `self.do_grayscale`): + Whether to convert the image to grayscale. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + + do_resize = do_resize if do_resize is not None else self.do_resize + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_grayscale = do_grayscale if do_grayscale is not None else self.do_grayscale + + size = size if size is not None else self.size + size = get_size_dict(size, default_to_square=False) + + images = make_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + if do_resize and size is None: + raise ValueError("Size must be specified if do_resize is True.") + + if do_rescale and rescale_factor is None: + raise ValueError("Rescale factor must be specified if do_rescale is True.") + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if do_rescale and is_scaled_image(images[0]): + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + if do_resize: + images = [self.resize(image=image, size=size, input_data_format=input_data_format) for image in images] + + if do_rescale: + images = [ + self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + for image in images + ] + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + if do_grayscale: + images = [convert_to_grayscale(image, input_data_format=input_data_format) for image in images] + + images = [ + to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images + ] + + data = {"pixel_values": images} + + return BatchFeature(data=data, tensor_type=return_tensors) + + def post_process_keypoint_detection( + self, outputs: "SuperPointKeypointDescriptionOutput", target_sizes: Union[TensorType, List[Tuple]] + ) -> List[Dict[str, "torch.Tensor"]]: + """ + Converts the raw output of [`SuperPointForKeypointDetection`] into lists of keypoints, scores and descriptors + with coordinates absolute to the original image sizes. + + Args: + outputs ([`SuperPointKeypointDescriptionOutput`]): + Raw outputs of the model containing keypoints in a relative (x, y) format, with scores and descriptors. + target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`): + Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size + `(height, width)` of each image in the batch. This must be the original + image size (before any processing). + Returns: + `List[Dict]`: A list of dictionaries, each dictionary containing the keypoints in absolute format according + to target_sizes, scores and descriptors for an image in the batch as predicted by the model. + """ + if len(outputs.mask) != len(target_sizes): + raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the mask") + + if isinstance(target_sizes, List): + image_sizes = torch.tensor(target_sizes, device=outputs.mask.device) + else: + if target_sizes.shape[1] != 2: + raise ValueError( + "Each element of target_sizes must contain the size (h, w) of each image of the batch" + ) + image_sizes = target_sizes + + # Flip the image sizes to (width, height) and convert keypoints to absolute coordinates + image_sizes = torch.flip(image_sizes, [1]) + masked_keypoints = outputs.keypoints * image_sizes[:, None] + + # Convert masked_keypoints to int + masked_keypoints = masked_keypoints.to(torch.int32) + + results = [] + for image_mask, keypoints, scores, descriptors in zip( + outputs.mask, masked_keypoints, outputs.scores, outputs.descriptors + ): + indices = torch.nonzero(image_mask).squeeze(1) + keypoints = keypoints[indices] + scores = scores[indices] + descriptors = descriptors[indices] + results.append({"keypoints": keypoints, "scores": scores, "descriptors": descriptors}) + + return results + + +__all__ = ["SuperPointImageProcessor"] diff --git a/docs/transformers/src/transformers/models/superpoint/modeling_superpoint.py b/docs/transformers/src/transformers/models/superpoint/modeling_superpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..219a0e52adfbd74522c1f20105f5adfd34ff75d9 --- /dev/null +++ b/docs/transformers/src/transformers/models/superpoint/modeling_superpoint.py @@ -0,0 +1,507 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch SuperPoint model.""" + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +from torch import nn + +from transformers import PreTrainedModel +from transformers.modeling_outputs import ( + BaseModelOutputWithNoAttention, +) +from transformers.models.superpoint.configuration_superpoint import SuperPointConfig + +from ...utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, +) + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "SuperPointConfig" + +_CHECKPOINT_FOR_DOC = "magic-leap-community/superpoint" + + +def remove_keypoints_from_borders( + keypoints: torch.Tensor, scores: torch.Tensor, border: int, height: int, width: int +) -> Tuple[torch.Tensor, torch.Tensor]: + """Removes keypoints (and their associated scores) that are too close to the border""" + mask_h = (keypoints[:, 0] >= border) & (keypoints[:, 0] < (height - border)) + mask_w = (keypoints[:, 1] >= border) & (keypoints[:, 1] < (width - border)) + mask = mask_h & mask_w + return keypoints[mask], scores[mask] + + +def top_k_keypoints(keypoints: torch.Tensor, scores: torch.Tensor, k: int) -> Tuple[torch.Tensor, torch.Tensor]: + """Keeps the k keypoints with highest score""" + if k >= len(keypoints): + return keypoints, scores + scores, indices = torch.topk(scores, k, dim=0) + return keypoints[indices], scores + + +def simple_nms(scores: torch.Tensor, nms_radius: int) -> torch.Tensor: + """Applies non-maximum suppression on scores""" + if nms_radius < 0: + raise ValueError("Expected positive values for nms_radius") + + def max_pool(x): + return nn.functional.max_pool2d(x, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius) + + zeros = torch.zeros_like(scores) + max_mask = scores == max_pool(scores) + for _ in range(2): + supp_mask = max_pool(max_mask.float()) > 0 + supp_scores = torch.where(supp_mask, zeros, scores) + new_max_mask = supp_scores == max_pool(supp_scores) + max_mask = max_mask | (new_max_mask & (~supp_mask)) + return torch.where(max_mask, scores, zeros) + + +@dataclass +class SuperPointKeypointDescriptionOutput(ModelOutput): + """ + Base class for outputs of image point description models. Due to the nature of keypoint detection, the number of + keypoints is not fixed and can vary from image to image, which makes batching non-trivial. In the batch of images, + the maximum number of keypoints is set as the dimension of the keypoints, scores and descriptors tensors. The mask + tensor is used to indicate which values in the keypoints, scores and descriptors tensors are keypoint information + and which are padding. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*): + Loss computed during training. + keypoints (`torch.FloatTensor` of shape `(batch_size, num_keypoints, 2)`): + Relative (x, y) coordinates of predicted keypoints in a given image. + scores (`torch.FloatTensor` of shape `(batch_size, num_keypoints)`): + Scores of predicted keypoints. + descriptors (`torch.FloatTensor` of shape `(batch_size, num_keypoints, descriptor_size)`): + Descriptors of predicted keypoints. + mask (`torch.BoolTensor` of shape `(batch_size, num_keypoints)`): + Mask indicating which values in keypoints, scores and descriptors are keypoint information. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or + when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states + (also called feature maps) of the model at the output of each stage. + """ + + loss: Optional[torch.FloatTensor] = None + keypoints: Optional[torch.IntTensor] = None + scores: Optional[torch.FloatTensor] = None + descriptors: Optional[torch.FloatTensor] = None + mask: Optional[torch.BoolTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + + +class SuperPointConvBlock(nn.Module): + def __init__( + self, config: SuperPointConfig, in_channels: int, out_channels: int, add_pooling: bool = False + ) -> None: + super().__init__() + self.conv_a = nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + ) + self.conv_b = nn.Conv2d( + out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + ) + self.relu = nn.ReLU(inplace=True) + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) if add_pooling else None + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.relu(self.conv_a(hidden_states)) + hidden_states = self.relu(self.conv_b(hidden_states)) + if self.pool is not None: + hidden_states = self.pool(hidden_states) + return hidden_states + + +class SuperPointEncoder(nn.Module): + """ + SuperPoint encoder module. It is made of 4 convolutional layers with ReLU activation and max pooling, reducing the + dimensionality of the image. + """ + + def __init__(self, config: SuperPointConfig) -> None: + super().__init__() + # SuperPoint uses 1 channel images + self.input_dim = 1 + + conv_blocks = [] + conv_blocks.append( + SuperPointConvBlock(config, self.input_dim, config.encoder_hidden_sizes[0], add_pooling=True) + ) + for i in range(1, len(config.encoder_hidden_sizes) - 1): + conv_blocks.append( + SuperPointConvBlock( + config, config.encoder_hidden_sizes[i - 1], config.encoder_hidden_sizes[i], add_pooling=True + ) + ) + conv_blocks.append( + SuperPointConvBlock( + config, config.encoder_hidden_sizes[-2], config.encoder_hidden_sizes[-1], add_pooling=False + ) + ) + self.conv_blocks = nn.ModuleList(conv_blocks) + + def forward( + self, + input, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + ) -> Union[Tuple, BaseModelOutputWithNoAttention]: + all_hidden_states = () if output_hidden_states else None + + for conv_block in self.conv_blocks: + input = conv_block(input) + if output_hidden_states: + all_hidden_states = all_hidden_states + (input,) + output = input + if not return_dict: + return tuple(v for v in [output, all_hidden_states] if v is not None) + + return BaseModelOutputWithNoAttention( + last_hidden_state=output, + hidden_states=all_hidden_states, + ) + + +class SuperPointInterestPointDecoder(nn.Module): + """ + The SuperPointInterestPointDecoder uses the output of the SuperPointEncoder to compute the keypoint with scores. + The scores are first computed by a convolutional layer, then a softmax is applied to get a probability distribution + over the 65 possible keypoint classes. The keypoints are then extracted from the scores by thresholding and + non-maximum suppression. Post-processing is then applied to remove keypoints too close to the image borders as well + as to keep only the k keypoints with highest score. + """ + + def __init__(self, config: SuperPointConfig) -> None: + super().__init__() + self.keypoint_threshold = config.keypoint_threshold + self.max_keypoints = config.max_keypoints + self.nms_radius = config.nms_radius + self.border_removal_distance = config.border_removal_distance + + self.relu = nn.ReLU(inplace=True) + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + self.conv_score_a = nn.Conv2d( + config.encoder_hidden_sizes[-1], + config.decoder_hidden_size, + kernel_size=3, + stride=1, + padding=1, + ) + self.conv_score_b = nn.Conv2d( + config.decoder_hidden_size, config.keypoint_decoder_dim, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, encoded: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + scores = self._get_pixel_scores(encoded) + keypoints, scores = self._extract_keypoints(scores) + + return keypoints, scores + + def _get_pixel_scores(self, encoded: torch.Tensor) -> torch.Tensor: + """Based on the encoder output, compute the scores for each pixel of the image""" + scores = self.relu(self.conv_score_a(encoded)) + scores = self.conv_score_b(scores) + scores = nn.functional.softmax(scores, 1)[:, :-1] + batch_size, _, height, width = scores.shape + scores = scores.permute(0, 2, 3, 1).reshape(batch_size, height, width, 8, 8) + scores = scores.permute(0, 1, 3, 2, 4).reshape(batch_size, height * 8, width * 8) + scores = simple_nms(scores, self.nms_radius) + return scores + + def _extract_keypoints(self, scores: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Based on their scores, extract the pixels that represent the keypoints that will be used for descriptors computation. + The keypoints are in the form of relative (x, y) coordinates. + """ + _, height, width = scores.shape + + # Threshold keypoints by score value + keypoints = torch.nonzero(scores[0] > self.keypoint_threshold) + scores = scores[0][tuple(keypoints.t())] + + # Discard keypoints near the image borders + keypoints, scores = remove_keypoints_from_borders( + keypoints, scores, self.border_removal_distance, height * 8, width * 8 + ) + + # Keep the k keypoints with highest score + if self.max_keypoints >= 0: + keypoints, scores = top_k_keypoints(keypoints, scores, self.max_keypoints) + + # Convert (y, x) to (x, y) + keypoints = torch.flip(keypoints, [1]).float() + + return keypoints, scores + + +class SuperPointDescriptorDecoder(nn.Module): + """ + The SuperPointDescriptorDecoder uses the outputs of both the SuperPointEncoder and the + SuperPointInterestPointDecoder to compute the descriptors at the keypoints locations. + + The descriptors are first computed by a convolutional layer, then normalized to have a norm of 1. The descriptors + are then interpolated at the keypoints locations. + """ + + def __init__(self, config: SuperPointConfig) -> None: + super().__init__() + + self.relu = nn.ReLU(inplace=True) + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + self.conv_descriptor_a = nn.Conv2d( + config.encoder_hidden_sizes[-1], + config.decoder_hidden_size, + kernel_size=3, + stride=1, + padding=1, + ) + self.conv_descriptor_b = nn.Conv2d( + config.decoder_hidden_size, + config.descriptor_decoder_dim, + kernel_size=1, + stride=1, + padding=0, + ) + + def forward(self, encoded: torch.Tensor, keypoints: torch.Tensor) -> torch.Tensor: + """Based on the encoder output and the keypoints, compute the descriptors for each keypoint""" + descriptors = self.conv_descriptor_b(self.relu(self.conv_descriptor_a(encoded))) + descriptors = nn.functional.normalize(descriptors, p=2, dim=1) + + descriptors = self._sample_descriptors(keypoints[None], descriptors[0][None], 8)[0] + + # [descriptor_dim, num_keypoints] -> [num_keypoints, descriptor_dim] + descriptors = torch.transpose(descriptors, 0, 1) + + return descriptors + + @staticmethod + def _sample_descriptors(keypoints, descriptors, scale: int = 8) -> torch.Tensor: + """Interpolate descriptors at keypoint locations""" + batch_size, num_channels, height, width = descriptors.shape + keypoints = keypoints - scale / 2 + 0.5 + divisor = torch.tensor([[(width * scale - scale / 2 - 0.5), (height * scale - scale / 2 - 0.5)]]) + divisor = divisor.to(keypoints) + keypoints /= divisor + keypoints = keypoints * 2 - 1 # normalize to (-1, 1) + kwargs = {"align_corners": True} + # [batch_size, num_channels, num_keypoints, 2] -> [batch_size, num_channels, num_keypoints, 2] + keypoints = keypoints.view(batch_size, 1, -1, 2) + descriptors = nn.functional.grid_sample(descriptors, keypoints, mode="bilinear", **kwargs) + # [batch_size, descriptor_decoder_dim, num_channels, num_keypoints] -> [batch_size, descriptor_decoder_dim, num_keypoints] + descriptors = descriptors.reshape(batch_size, num_channels, -1) + descriptors = nn.functional.normalize(descriptors, p=2, dim=1) + return descriptors + + +class SuperPointPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = SuperPointConfig + base_model_prefix = "superpoint" + main_input_name = "pixel_values" + supports_gradient_checkpointing = False + + def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def extract_one_channel_pixel_values(self, pixel_values: torch.FloatTensor) -> torch.FloatTensor: + """ + Assuming pixel_values has shape (batch_size, 3, height, width), and that all channels values are the same, + extract the first channel value to get a tensor of shape (batch_size, 1, height, width) for SuperPoint. This is + a workaround for the issue discussed in : + https://github.com/huggingface/transformers/pull/25786#issuecomment-1730176446 + + Args: + pixel_values: torch.FloatTensor of shape (batch_size, 3, height, width) + + Returns: + pixel_values: torch.FloatTensor of shape (batch_size, 1, height, width) + + """ + return pixel_values[:, 0, :, :][:, None, :, :] + + +SUPERPOINT_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it + as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`SuperPointConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. + """ + +SUPERPOINT_INPUTS_DOCSTRING = r""" +Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`SuperPointImageProcessor`]. See + [`SuperPointImageProcessor.__call__`] for details. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more + detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + + +@add_start_docstrings( + "SuperPoint model outputting keypoints and descriptors.", + SUPERPOINT_START_DOCSTRING, +) +class SuperPointForKeypointDetection(SuperPointPreTrainedModel): + """ + SuperPoint model. It consists of a SuperPointEncoder, a SuperPointInterestPointDecoder and a + SuperPointDescriptorDecoder. SuperPoint was proposed in `SuperPoint: Self-Supervised Interest Point Detection and + Description `__ by Daniel DeTone, Tomasz Malisiewicz, and Andrew Rabinovich. It + is a fully convolutional neural network that extracts keypoints and descriptors from an image. It is trained in a + self-supervised manner, using a combination of a photometric loss and a loss based on the homographic adaptation of + keypoints. It is made of a convolutional encoder and two decoders: one for keypoints and one for descriptors. + """ + + def __init__(self, config: SuperPointConfig) -> None: + super().__init__(config) + + self.config = config + + self.encoder = SuperPointEncoder(config) + self.keypoint_decoder = SuperPointInterestPointDecoder(config) + self.descriptor_decoder = SuperPointDescriptorDecoder(config) + + self.post_init() + + @add_start_docstrings_to_model_forward(SUPERPOINT_INPUTS_DOCSTRING) + def forward( + self, + pixel_values: torch.FloatTensor, + labels: Optional[torch.LongTensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SuperPointKeypointDescriptionOutput]: + """ + Examples: + + ```python + >>> from transformers import AutoImageProcessor, SuperPointForKeypointDetection + >>> import torch + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> processor = AutoImageProcessor.from_pretrained("magic-leap-community/superpoint") + >>> model = SuperPointForKeypointDetection.from_pretrained("magic-leap-community/superpoint") + + >>> inputs = processor(image, return_tensors="pt") + >>> outputs = model(**inputs) + ```""" + loss = None + if labels is not None: + raise ValueError("SuperPoint does not support training for now.") + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + pixel_values = self.extract_one_channel_pixel_values(pixel_values) + + batch_size, _, height, width = pixel_values.shape + + encoder_outputs = self.encoder( + pixel_values, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + + list_keypoints_scores = [ + self.keypoint_decoder(last_hidden_state[None, ...]) for last_hidden_state in last_hidden_state + ] + + list_keypoints = [keypoints_scores[0] for keypoints_scores in list_keypoints_scores] + list_scores = [keypoints_scores[1] for keypoints_scores in list_keypoints_scores] + + list_descriptors = [ + self.descriptor_decoder(last_hidden_state[None, ...], keypoints[None, ...]) + for last_hidden_state, keypoints in zip(last_hidden_state, list_keypoints) + ] + + maximum_num_keypoints = max(keypoints.shape[0] for keypoints in list_keypoints) + + keypoints = torch.zeros((batch_size, maximum_num_keypoints, 2), device=pixel_values.device) + scores = torch.zeros((batch_size, maximum_num_keypoints), device=pixel_values.device) + descriptors = torch.zeros( + (batch_size, maximum_num_keypoints, self.config.descriptor_decoder_dim), + device=pixel_values.device, + ) + mask = torch.zeros((batch_size, maximum_num_keypoints), device=pixel_values.device, dtype=torch.int) + + for i, (_keypoints, _scores, _descriptors) in enumerate(zip(list_keypoints, list_scores, list_descriptors)): + keypoints[i, : _keypoints.shape[0]] = _keypoints + scores[i, : _scores.shape[0]] = _scores + descriptors[i, : _descriptors.shape[0]] = _descriptors + mask[i, : _scores.shape[0]] = 1 + + # Convert to relative coordinates + keypoints = keypoints / torch.tensor([width, height], device=keypoints.device) + + hidden_states = encoder_outputs[1] if output_hidden_states else None + if not return_dict: + return tuple(v for v in [loss, keypoints, scores, descriptors, mask, hidden_states] if v is not None) + + return SuperPointKeypointDescriptionOutput( + loss=loss, + keypoints=keypoints, + scores=scores, + descriptors=descriptors, + mask=mask, + hidden_states=hidden_states, + ) + + +__all__ = ["SuperPointForKeypointDetection", "SuperPointPreTrainedModel"] diff --git a/docs/transformers/src/transformers/models/swiftformer/__init__.py b/docs/transformers/src/transformers/models/swiftformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..370f6c71fadb5dad5ccae5609ea384700dbd97b9 --- /dev/null +++ b/docs/transformers/src/transformers/models/swiftformer/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_swiftformer import * + from .modeling_swiftformer import * + from .modeling_tf_swiftformer import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/swiftformer/configuration_swiftformer.py b/docs/transformers/src/transformers/models/swiftformer/configuration_swiftformer.py new file mode 100644 index 0000000000000000000000000000000000000000..00a0aaddfa240b9141c9e245f845c87538bed542 --- /dev/null +++ b/docs/transformers/src/transformers/models/swiftformer/configuration_swiftformer.py @@ -0,0 +1,148 @@ +# coding=utf-8 +# Copyright 2023 MBZUAI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""SwiftFormer model configuration""" + +from collections import OrderedDict +from typing import Mapping + +from packaging import version + +from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class SwiftFormerConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`SwiftFormerModel`]. It is used to instantiate an + SwiftFormer model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the SwiftFormer + [MBZUAI/swiftformer-xs](https://huggingface.co/MBZUAI/swiftformer-xs) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image + num_channels (`int`, *optional*, defaults to 3): + The number of input channels + depths (`List[int]`, *optional*, defaults to `[3, 3, 6, 4]`): + Depth of each stage + embed_dims (`List[int]`, *optional*, defaults to `[48, 56, 112, 220]`): + The embedding dimension at each stage + mlp_ratio (`int`, *optional*, defaults to 4): + Ratio of size of the hidden dimensionality of an MLP to the dimensionality of its input. + downsamples (`List[bool]`, *optional*, defaults to `[True, True, True, True]`): + Whether or not to downsample inputs between two stages. + hidden_act (`str`, *optional*, defaults to `"gelu"`): + The non-linear activation function (string). `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` are supported. + down_patch_size (`int`, *optional*, defaults to 3): + The size of patches in downsampling layers. + down_stride (`int`, *optional*, defaults to 2): + The stride of convolution kernels in downsampling layers. + down_pad (`int`, *optional*, defaults to 1): + Padding in downsampling layers. + drop_path_rate (`float`, *optional*, defaults to 0.0): + Rate at which to increase dropout probability in DropPath. + drop_mlp_rate (`float`, *optional*, defaults to 0.0): + Dropout rate for the MLP component of SwiftFormer. + drop_conv_encoder_rate (`float`, *optional*, defaults to 0.0): + Dropout rate for the ConvEncoder component of SwiftFormer. + use_layer_scale (`bool`, *optional*, defaults to `True`): + Whether to scale outputs from token mixers. + layer_scale_init_value (`float`, *optional*, defaults to 1e-05): + Factor by which outputs from token mixers are scaled. + batch_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the batch normalization layers. + + + Example: + + ```python + >>> from transformers import SwiftFormerConfig, SwiftFormerModel + + >>> # Initializing a SwiftFormer swiftformer-base-patch16-224 style configuration + >>> configuration = SwiftFormerConfig() + + >>> # Initializing a model (with random weights) from the swiftformer-base-patch16-224 style configuration + >>> model = SwiftFormerModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "swiftformer" + + def __init__( + self, + image_size=224, + num_channels=3, + depths=[3, 3, 6, 4], + embed_dims=[48, 56, 112, 220], + mlp_ratio=4, + downsamples=[True, True, True, True], + hidden_act="gelu", + down_patch_size=3, + down_stride=2, + down_pad=1, + drop_path_rate=0.0, + drop_mlp_rate=0.0, + drop_conv_encoder_rate=0.0, + use_layer_scale=True, + layer_scale_init_value=1e-5, + batch_norm_eps=1e-5, + **kwargs, + ): + super().__init__(**kwargs) + self.image_size = image_size + self.num_channels = num_channels + self.depths = depths + self.embed_dims = embed_dims + self.mlp_ratio = mlp_ratio + self.downsamples = downsamples + self.hidden_act = hidden_act + self.down_patch_size = down_patch_size + self.down_stride = down_stride + self.down_pad = down_pad + self.drop_path_rate = drop_path_rate + self.drop_mlp_rate = drop_mlp_rate + self.drop_conv_encoder_rate = drop_conv_encoder_rate + self.use_layer_scale = use_layer_scale + self.layer_scale_init_value = layer_scale_init_value + self.batch_norm_eps = batch_norm_eps + + +class SwiftFormerOnnxConfig(OnnxConfig): + torch_onnx_minimum_version = version.parse("1.11") + + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict( + [ + ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}), + ] + ) + + @property + def atol_for_validation(self) -> float: + return 1e-4 + + +__all__ = ["SwiftFormerConfig", "SwiftFormerOnnxConfig"] diff --git a/docs/transformers/src/transformers/models/swiftformer/convert_swiftformer_original_to_hf.py b/docs/transformers/src/transformers/models/swiftformer/convert_swiftformer_original_to_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..3567bb674e950bc7015043647672f6dc3dc0fee2 --- /dev/null +++ b/docs/transformers/src/transformers/models/swiftformer/convert_swiftformer_original_to_hf.py @@ -0,0 +1,175 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert SwiftFormer checkpoints from the original implementation.""" + +import argparse +import json +from pathlib import Path + +import requests +import torch +from huggingface_hub import hf_hub_download +from PIL import Image + +from transformers import ( + SwiftFormerConfig, + SwiftFormerForImageClassification, + ViTImageProcessor, +) +from transformers.utils import logging + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + +device = torch.device("cpu") + + +# We will verify our results on an image of cute cats +def prepare_img(): + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + im = Image.open(requests.get(url, stream=True).raw) + return im + + +def get_expected_output(swiftformer_name): + if swiftformer_name == "swiftformer_xs": + return torch.tensor([-2.1703e00, 2.1107e00, -2.0811e00, 8.8685e-01, 2.4360e-01]) + + elif swiftformer_name == "swiftformer_s": + return torch.tensor([3.9636e-01, 2.3478e-01, -1.6963e00, -1.7381e00, -8.6337e-01]) + + elif swiftformer_name == "swiftformer_l1": + return torch.tensor([-4.2768e-01, -4.7429e-01, -1.0897e00, -1.0248e00, 3.5523e-02]) + + elif swiftformer_name == "swiftformer_l3": + return torch.tensor([-2.5330e-01, 2.4211e-01, -6.0185e-01, -8.2789e-01, -6.0446e-02]) + + +def rename_key(dct, old, new): + val = dct.pop(old) + dct[new] = val + + +def create_rename_keys(state_dict): + rename_keys = [] + for k in state_dict.keys(): + k_new = k + if ".pwconv" in k: + k_new = k_new.replace(".pwconv", ".point_wise_conv") + if ".dwconv" in k: + k_new = k_new.replace(".dwconv", ".depth_wise_conv") + if ".Proj." in k: + k_new = k_new.replace(".Proj.", ".proj.") + if "patch_embed" in k_new: + k_new = k_new.replace("patch_embed", "swiftformer.patch_embed.patch_embedding") + if "network" in k_new: + ls = k_new.split(".") + if ls[2].isdigit(): + k_new = "swiftformer.encoder.network." + ls[1] + ".blocks." + ls[2] + "." + ".".join(ls[3:]) + else: + k_new = k_new.replace("network", "swiftformer.encoder.network") + rename_keys.append((k, k_new)) + return rename_keys + + +@torch.no_grad() +def convert_swiftformer_checkpoint(swiftformer_name, pytorch_dump_folder_path, original_ckpt): + """ + Copy/paste/tweak model's weights to our SwiftFormer structure. + """ + + # define default SwiftFormer configuration + config = SwiftFormerConfig() + + # dataset (ImageNet-21k only or also fine-tuned on ImageNet 2012), patch_size and image_size + config.num_labels = 1000 + repo_id = "huggingface/label-files" + filename = "imagenet-1k-id2label.json" + id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r")) + id2label = {int(k): v for k, v in id2label.items()} + config.id2label = id2label + config.label2id = {v: k for k, v in id2label.items()} + + # size of the architecture + if swiftformer_name == "swiftformer_xs": + config.depths = [3, 3, 6, 4] + config.embed_dims = [48, 56, 112, 220] + + elif swiftformer_name == "swiftformer_s": + config.depths = [3, 3, 9, 6] + config.embed_dims = [48, 64, 168, 224] + + elif swiftformer_name == "swiftformer_l1": + config.depths = [4, 3, 10, 5] + config.embed_dims = [48, 96, 192, 384] + + elif swiftformer_name == "swiftformer_l3": + config.depths = [4, 4, 12, 6] + config.embed_dims = [64, 128, 320, 512] + + # load state_dict of original model, remove and rename some keys + if original_ckpt: + if original_ckpt.startswith("https"): + checkpoint = torch.hub.load_state_dict_from_url(original_ckpt, map_location="cpu", check_hash=True) + else: + checkpoint = torch.load(original_ckpt, map_location="cpu", weights_only=True) + state_dict = checkpoint + + rename_keys = create_rename_keys(state_dict) + for rename_key_src, rename_key_dest in rename_keys: + rename_key(state_dict, rename_key_src, rename_key_dest) + + # load HuggingFace model + hf_model = SwiftFormerForImageClassification(config).eval() + hf_model.load_state_dict(state_dict) + + # prepare test inputs + image = prepare_img() + processor = ViTImageProcessor.from_pretrained("preprocessor_config") + inputs = processor(images=image, return_tensors="pt") + + # compare outputs from both models + timm_logits = get_expected_output(swiftformer_name) + hf_logits = hf_model(inputs["pixel_values"]).logits + + assert hf_logits.shape == torch.Size([1, 1000]) + assert torch.allclose(hf_logits[0, 0:5], timm_logits, atol=1e-3) + + Path(pytorch_dump_folder_path).mkdir(exist_ok=True) + print(f"Saving model {swiftformer_name} to {pytorch_dump_folder_path}") + hf_model.save_pretrained(pytorch_dump_folder_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--swiftformer_name", + default="swiftformer_xs", + choices=["swiftformer_xs", "swiftformer_s", "swiftformer_l1", "swiftformer_l3"], + type=str, + help="Name of the SwiftFormer model you'd like to convert.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", + default="./converted_outputs/", + type=str, + help="Path to the output PyTorch model directory.", + ) + parser.add_argument("--original_ckpt", default=None, type=str, help="Path to the original model checkpoint.") + + args = parser.parse_args() + convert_swiftformer_checkpoint(args.swiftformer_name, args.pytorch_dump_folder_path, args.original_ckpt) diff --git a/docs/transformers/src/transformers/models/swiftformer/modeling_swiftformer.py b/docs/transformers/src/transformers/models/swiftformer/modeling_swiftformer.py new file mode 100644 index 0000000000000000000000000000000000000000..5870e91338d0955a3e3611904afcba43ab7e1559 --- /dev/null +++ b/docs/transformers/src/transformers/models/swiftformer/modeling_swiftformer.py @@ -0,0 +1,607 @@ +# coding=utf-8 +# Copyright 2023 MBZUAI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch SwiftFormer model.""" + +import collections.abc +from typing import Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2CLS +from ...modeling_outputs import ( + BaseModelOutputWithNoAttention, + ImageClassifierOutputWithNoAttention, +) +from ...modeling_utils import PreTrainedModel +from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, +) +from .configuration_swiftformer import SwiftFormerConfig + + +logger = logging.get_logger(__name__) + +# General docstring +_CONFIG_FOR_DOC = "SwiftFormerConfig" + +# Base docstring +_CHECKPOINT_FOR_DOC = "MBZUAI/swiftformer-xs" +_EXPECTED_OUTPUT_SHAPE = [1, 220, 7, 7] + +# Image classification docstring +_IMAGE_CLASS_CHECKPOINT = "MBZUAI/swiftformer-xs" +_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat" + + +class SwiftFormerPatchEmbedding(nn.Module): + """ + Patch Embedding Layer constructed of two 2D convolutional layers. + + Input: tensor of shape `[batch_size, in_channels, height, width]` + + Output: tensor of shape `[batch_size, out_channels, height/4, width/4]` + """ + + def __init__(self, config: SwiftFormerConfig): + super().__init__() + + in_chs = config.num_channels + out_chs = config.embed_dims[0] + self.patch_embedding = nn.Sequential( + nn.Conv2d(in_chs, out_chs // 2, kernel_size=3, stride=2, padding=1), + nn.BatchNorm2d(out_chs // 2, eps=config.batch_norm_eps), + nn.ReLU(), + nn.Conv2d(out_chs // 2, out_chs, kernel_size=3, stride=2, padding=1), + nn.BatchNorm2d(out_chs, eps=config.batch_norm_eps), + nn.ReLU(), + ) + + def forward(self, x): + return self.patch_embedding(x) + + +# Copied from transformers.models.beit.modeling_beit.drop_path +def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, + however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the + layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the + argument. + """ + if drop_prob == 0.0 or not training: + return input + keep_prob = 1 - drop_prob + shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) + random_tensor.floor_() # binarize + output = input.div(keep_prob) * random_tensor + return output + + +class SwiftFormerDropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, config: SwiftFormerConfig) -> None: + super().__init__() + self.drop_prob = config.drop_path_rate + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return drop_path(hidden_states, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return "p={}".format(self.drop_prob) + + +class SwiftFormerEmbeddings(nn.Module): + """ + Embeddings layer consisting of a single 2D convolutional and batch normalization layer. + + Input: tensor of shape `[batch_size, channels, height, width]` + + Output: tensor of shape `[batch_size, channels, height/stride, width/stride]` + """ + + def __init__(self, config: SwiftFormerConfig, index: int): + super().__init__() + + patch_size = config.down_patch_size + stride = config.down_stride + padding = config.down_pad + embed_dims = config.embed_dims + + in_chans = embed_dims[index] + embed_dim = embed_dims[index + 1] + + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + stride = stride if isinstance(stride, collections.abc.Iterable) else (stride, stride) + padding = padding if isinstance(padding, collections.abc.Iterable) else (padding, padding) + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, padding=padding) + self.norm = nn.BatchNorm2d(embed_dim, eps=config.batch_norm_eps) + + def forward(self, x): + x = self.proj(x) + x = self.norm(x) + return x + + +class SwiftFormerConvEncoder(nn.Module): + """ + `SwiftFormerConvEncoder` with 3*3 and 1*1 convolutions. + + Input: tensor of shape `[batch_size, channels, height, width]` + + Output: tensor of shape `[batch_size, channels, height, width]` + """ + + def __init__(self, config: SwiftFormerConfig, dim: int): + super().__init__() + hidden_dim = int(config.mlp_ratio * dim) + + self.depth_wise_conv = nn.Conv2d(dim, dim, kernel_size=3, padding=1, groups=dim) + self.norm = nn.BatchNorm2d(dim, eps=config.batch_norm_eps) + self.point_wise_conv1 = nn.Conv2d(dim, hidden_dim, kernel_size=1) + self.act = nn.GELU() + self.point_wise_conv2 = nn.Conv2d(hidden_dim, dim, kernel_size=1) + self.drop_path = nn.Dropout(p=config.drop_conv_encoder_rate) + self.layer_scale = nn.Parameter(torch.ones(dim).unsqueeze(-1).unsqueeze(-1), requires_grad=True) + + def forward(self, x): + input = x + x = self.depth_wise_conv(x) + x = self.norm(x) + x = self.point_wise_conv1(x) + x = self.act(x) + x = self.point_wise_conv2(x) + x = input + self.drop_path(self.layer_scale * x) + return x + + +class SwiftFormerMlp(nn.Module): + """ + MLP layer with 1*1 convolutions. + + Input: tensor of shape `[batch_size, channels, height, width]` + + Output: tensor of shape `[batch_size, channels, height, width]` + """ + + def __init__(self, config: SwiftFormerConfig, in_features: int): + super().__init__() + hidden_features = int(in_features * config.mlp_ratio) + self.norm1 = nn.BatchNorm2d(in_features, eps=config.batch_norm_eps) + self.fc1 = nn.Conv2d(in_features, hidden_features, 1) + act_layer = ACT2CLS[config.hidden_act] + self.act = act_layer() + self.fc2 = nn.Conv2d(hidden_features, in_features, 1) + self.drop = nn.Dropout(p=config.drop_mlp_rate) + + def forward(self, x): + x = self.norm1(x) + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class SwiftFormerEfficientAdditiveAttention(nn.Module): + """ + Efficient Additive Attention module for SwiftFormer. + + Input: tensor of shape `[batch_size, channels, height, width]` + + Output: tensor of shape `[batch_size, channels, height, width]` + """ + + def __init__(self, config: SwiftFormerConfig, dim: int = 512): + super().__init__() + + self.to_query = nn.Linear(dim, dim) + self.to_key = nn.Linear(dim, dim) + + self.w_g = nn.Parameter(torch.randn(dim, 1)) + self.scale_factor = dim**-0.5 + self.proj = nn.Linear(dim, dim) + self.final = nn.Linear(dim, dim) + + def forward(self, x): + query = self.to_query(x) + key = self.to_key(x) + + query = torch.nn.functional.normalize(query, dim=-1) + key = torch.nn.functional.normalize(key, dim=-1) + + query_weight = query @ self.w_g + scaled_query_weight = query_weight * self.scale_factor + scaled_query_weight = scaled_query_weight.softmax(dim=-1) + + global_queries = torch.sum(scaled_query_weight * query, dim=1) + global_queries = global_queries.unsqueeze(1).repeat(1, key.shape[1], 1) + + out = self.proj(global_queries * key) + query + out = self.final(out) + + return out + + +class SwiftFormerLocalRepresentation(nn.Module): + """ + Local Representation module for SwiftFormer that is implemented by 3*3 depth-wise and point-wise convolutions. + + Input: tensor of shape `[batch_size, channels, height, width]` + + Output: tensor of shape `[batch_size, channels, height, width]` + """ + + def __init__(self, config: SwiftFormerConfig, dim: int): + super().__init__() + + self.depth_wise_conv = nn.Conv2d(dim, dim, kernel_size=3, padding=1, groups=dim) + self.norm = nn.BatchNorm2d(dim, eps=config.batch_norm_eps) + self.point_wise_conv1 = nn.Conv2d(dim, dim, kernel_size=1) + self.act = nn.GELU() + self.point_wise_conv2 = nn.Conv2d(dim, dim, kernel_size=1) + self.drop_path = nn.Identity() + self.layer_scale = nn.Parameter(torch.ones(dim).unsqueeze(-1).unsqueeze(-1), requires_grad=True) + + def forward(self, x): + input = x + x = self.depth_wise_conv(x) + x = self.norm(x) + x = self.point_wise_conv1(x) + x = self.act(x) + x = self.point_wise_conv2(x) + x = input + self.drop_path(self.layer_scale * x) + return x + + +class SwiftFormerEncoderBlock(nn.Module): + """ + SwiftFormer Encoder Block for SwiftFormer. It consists of (1) Local representation module, (2) + SwiftFormerEfficientAdditiveAttention, and (3) MLP block. + + Input: tensor of shape `[batch_size, channels, height, width]` + + Output: tensor of shape `[batch_size, channels,height, width]` + """ + + def __init__(self, config: SwiftFormerConfig, dim: int, drop_path: float = 0.0) -> None: + super().__init__() + + layer_scale_init_value = config.layer_scale_init_value + use_layer_scale = config.use_layer_scale + + self.local_representation = SwiftFormerLocalRepresentation(config, dim=dim) + self.attn = SwiftFormerEfficientAdditiveAttention(config, dim=dim) + self.linear = SwiftFormerMlp(config, in_features=dim) + self.drop_path = SwiftFormerDropPath(config) if drop_path > 0.0 else nn.Identity() + self.use_layer_scale = use_layer_scale + if use_layer_scale: + self.layer_scale_1 = nn.Parameter( + layer_scale_init_value * torch.ones(dim).unsqueeze(-1).unsqueeze(-1), requires_grad=True + ) + self.layer_scale_2 = nn.Parameter( + layer_scale_init_value * torch.ones(dim).unsqueeze(-1).unsqueeze(-1), requires_grad=True + ) + + def forward(self, x): + x = self.local_representation(x) + batch_size, channels, height, width = x.shape + res = self.attn(x.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels)) + res = res.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2) + if self.use_layer_scale: + x = x + self.drop_path(self.layer_scale_1 * res) + x = x + self.drop_path(self.layer_scale_2 * self.linear(x)) + else: + x = x + self.drop_path(res) + x = x + self.drop_path(self.linear(x)) + return x + + +class SwiftFormerStage(nn.Module): + """ + A Swiftformer stage consisting of a series of `SwiftFormerConvEncoder` blocks and a final + `SwiftFormerEncoderBlock`. + + Input: tensor in shape `[batch_size, channels, height, width]` + + Output: tensor in shape `[batch_size, channels, height, width]` + """ + + def __init__(self, config: SwiftFormerConfig, index: int) -> None: + super().__init__() + + layer_depths = config.depths + dim = config.embed_dims[index] + depth = layer_depths[index] + + blocks = [] + for block_idx in range(depth): + block_dpr = config.drop_path_rate * (block_idx + sum(layer_depths[:index])) / (sum(layer_depths) - 1) + + if depth - block_idx <= 1: + blocks.append(SwiftFormerEncoderBlock(config, dim=dim, drop_path=block_dpr)) + else: + blocks.append(SwiftFormerConvEncoder(config, dim=dim)) + + self.blocks = nn.ModuleList(blocks) + + def forward(self, input): + for block in self.blocks: + input = block(input) + return input + + +class SwiftFormerEncoder(nn.Module): + def __init__(self, config: SwiftFormerConfig) -> None: + super().__init__() + self.config = config + + embed_dims = config.embed_dims + downsamples = config.downsamples + layer_depths = config.depths + + # Transformer model + network = [] + for i in range(len(layer_depths)): + stage = SwiftFormerStage(config=config, index=i) + network.append(stage) + if i >= len(layer_depths) - 1: + break + if downsamples[i] or embed_dims[i] != embed_dims[i + 1]: + # downsampling between two stages + network.append(SwiftFormerEmbeddings(config, index=i)) + self.network = nn.ModuleList(network) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, BaseModelOutputWithNoAttention]: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + all_hidden_states = (hidden_states,) if output_hidden_states else None + + for block in self.network: + hidden_states = block(hidden_states) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states] if v is not None) + + return BaseModelOutputWithNoAttention( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + ) + + +class SwiftFormerPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = SwiftFormerConfig + base_model_prefix = "swiftformer" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + _no_split_modules = ["SwiftFormerEncoderBlock"] + + def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: + """Initialize the weights""" + if isinstance(module, (nn.Conv2d, nn.Linear)): + nn.init.trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + elif isinstance(module, (nn.LayerNorm)): + nn.init.constant_(module.bias, 0) + nn.init.constant_(module.weight, 1.0) + + +SWIFTFORMER_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it + as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`SwiftFormerConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +SWIFTFORMER_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`] + for details. + + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare SwiftFormer Model transformer outputting raw hidden-states without any specific head on top.", + SWIFTFORMER_START_DOCSTRING, +) +class SwiftFormerModel(SwiftFormerPreTrainedModel): + def __init__(self, config: SwiftFormerConfig): + super().__init__(config) + self.config = config + + self.patch_embed = SwiftFormerPatchEmbedding(config) + self.encoder = SwiftFormerEncoder(config) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(SWIFTFORMER_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithNoAttention, + config_class=_CONFIG_FOR_DOC, + modality="vision", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithNoAttention]: + r""" """ + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + embedding_output = self.patch_embed(pixel_values) + encoder_outputs = self.encoder( + embedding_output, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + return tuple(v for v in encoder_outputs if v is not None) + + return BaseModelOutputWithNoAttention( + last_hidden_state=encoder_outputs.last_hidden_state, + hidden_states=encoder_outputs.hidden_states, + ) + + +@add_start_docstrings( + """ + SwiftFormer Model transformer with an image classification head on top (e.g. for ImageNet). + """, + SWIFTFORMER_START_DOCSTRING, +) +class SwiftFormerForImageClassification(SwiftFormerPreTrainedModel): + def __init__(self, config: SwiftFormerConfig) -> None: + super().__init__(config) + + embed_dims = config.embed_dims + + self.num_labels = config.num_labels + self.swiftformer = SwiftFormerModel(config) + + # Classifier head + self.norm = nn.BatchNorm2d(embed_dims[-1], eps=config.batch_norm_eps) + self.head = nn.Linear(embed_dims[-1], self.num_labels) if self.num_labels > 0 else nn.Identity() + self.dist_head = nn.Linear(embed_dims[-1], self.num_labels) if self.num_labels > 0 else nn.Identity() + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(SWIFTFORMER_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_IMAGE_CLASS_CHECKPOINT, + output_type=ImageClassifierOutputWithNoAttention, + config_class=_CONFIG_FOR_DOC, + expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, + ) + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, ImageClassifierOutputWithNoAttention]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # run base model + outputs = self.swiftformer( + pixel_values, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs.last_hidden_state if return_dict else outputs[0] + + # run classification head + sequence_output = self.norm(sequence_output) + sequence_output = sequence_output.flatten(2).mean(-1) + cls_out = self.head(sequence_output) + distillation_out = self.dist_head(sequence_output) + logits = (cls_out + distillation_out) / 2 + + # calculate loss + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutputWithNoAttention( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + ) + + +__all__ = ["SwiftFormerForImageClassification", "SwiftFormerModel", "SwiftFormerPreTrainedModel"] diff --git a/docs/transformers/src/transformers/models/swiftformer/modeling_tf_swiftformer.py b/docs/transformers/src/transformers/models/swiftformer/modeling_tf_swiftformer.py new file mode 100644 index 0000000000000000000000000000000000000000..b7dc0e4f94d0315ae7124830b5fb7660e0056390 --- /dev/null +++ b/docs/transformers/src/transformers/models/swiftformer/modeling_tf_swiftformer.py @@ -0,0 +1,866 @@ +# coding=utf-8 +# Copyright 2024 MBZUAI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TensorFlow SwiftFormer model.""" + +import collections.abc +from typing import Optional, Tuple, Union + +import tensorflow as tf + +from ...activations_tf import get_tf_activation +from ...modeling_tf_outputs import ( + TFBaseModelOutputWithNoAttention, + TFImageClassifierOutputWithNoAttention, +) +from ...modeling_tf_utils import TFPreTrainedModel, keras, keras_serializable, unpack_inputs +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, +) +from .configuration_swiftformer import SwiftFormerConfig + + +logger = logging.get_logger(__name__) + +# General docstring +_CONFIG_FOR_DOC = "SwiftFormerConfig" + +# Base docstring +_CHECKPOINT_FOR_DOC = "MBZUAI/swiftformer-xs" +_EXPECTED_OUTPUT_SHAPE = [1, 220, 7, 7] + +# Image classification docstring +_IMAGE_CLASS_CHECKPOINT = "MBZUAI/swiftformer-xs" +_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat" + + +class TFSwiftFormerPatchEmbeddingSequential(keras.layers.Layer): + """ + The sequential component of the patch embedding layer. + + Input: tensor of shape `[batch_size, in_channels, height, width]` + + Output: tensor of shape `[batch_size, out_channels, height/4, width/4]` + """ + + def __init__(self, config: SwiftFormerConfig, **kwargs): + super().__init__(**kwargs) + self.out_chs = config.embed_dims[0] + + self.zero_padding = keras.layers.ZeroPadding2D(padding=(1, 1)) + self.conv1 = keras.layers.Conv2D(self.out_chs // 2, kernel_size=3, strides=2, name="0") + self.batch_norm1 = keras.layers.BatchNormalization(epsilon=config.batch_norm_eps, momentum=0.9, name="1") + self.conv2 = keras.layers.Conv2D(self.out_chs, kernel_size=3, strides=2, name="3") + self.batch_norm2 = keras.layers.BatchNormalization(epsilon=config.batch_norm_eps, momentum=0.9, name="4") + self.config = config + + def call(self, x: tf.Tensor, training: bool = False) -> tf.Tensor: + x = self.zero_padding(x) + x = self.conv1(x) + x = self.batch_norm1(x, training=training) + x = get_tf_activation("relu")(x) + x = self.zero_padding(x) + x = self.conv2(x) + x = self.batch_norm2(x, training=training) + x = get_tf_activation("relu")(x) + return x + + def build(self, input_shape=None): + if self.built: + return + if getattr(self, "conv1", None) is not None: + with tf.name_scope(self.conv1.name): + self.conv1.build(self.config.num_channels) + if getattr(self, "batch_norm1", None) is not None: + with tf.name_scope(self.batch_norm1.name): + self.batch_norm1.build((None, None, None, self.out_chs // 2)) + if getattr(self, "conv2", None) is not None: + with tf.name_scope(self.conv2.name): + self.conv2.build((None, None, None, self.out_chs // 2)) + if getattr(self, "batch_norm2", None) is not None: + with tf.name_scope(self.batch_norm2.name): + self.batch_norm2.build((None, None, None, self.out_chs)) + self.built = True + + +class TFSwiftFormerPatchEmbedding(keras.layers.Layer): + """ + Patch Embedding Layer constructed of two 2D convolutional layers. + + Input: tensor of shape `[batch_size, in_channels, height, width]` + + Output: tensor of shape `[batch_size, out_channels, height/4, width/4]` + """ + + def __init__(self, config: SwiftFormerConfig, **kwargs): + super().__init__(**kwargs) + self.patch_embedding = TFSwiftFormerPatchEmbeddingSequential(config, name="patch_embedding") + + def call(self, x: tf.Tensor, training: bool = False) -> tf.Tensor: + return self.patch_embedding(x, training=training) + + def build(self, input_shape=None): + if self.built: + return + if getattr(self, "patch_embedding", None) is not None: + with tf.name_scope(self.patch_embedding.name): + self.patch_embedding.build(None) + self.built = True + + +class TFSwiftFormerDropPath(keras.layers.Layer): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, config: SwiftFormerConfig, **kwargs) -> None: + super().__init__(**kwargs) + raise NotImplementedError("Drop path is not implemented in TF port") + + def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor: + raise NotImplementedError("Drop path is not implemented in TF port") + + +class TFSwiftFormerEmbeddings(keras.layers.Layer): + """ + Embeddings layer consisting of a single 2D convolutional and batch normalization layer. + + Input: tensor of shape `[batch_size, channels, height, width]` + + Output: tensor of shape `[batch_size, channels, height/stride, width/stride]` + """ + + def __init__(self, config: SwiftFormerConfig, index: int, **kwargs): + super().__init__(**kwargs) + + patch_size = config.down_patch_size + stride = config.down_stride + padding = config.down_pad + embed_dims = config.embed_dims + + self.in_chans = embed_dims[index] + self.embed_dim = embed_dims[index + 1] + + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + stride = stride if isinstance(stride, collections.abc.Iterable) else (stride, stride) + padding = padding if isinstance(padding, collections.abc.Iterable) else (padding, padding) + + self.pad = keras.layers.ZeroPadding2D(padding=padding) + self.proj = keras.layers.Conv2D(self.embed_dim, kernel_size=patch_size, strides=stride, name="proj") + self.norm = keras.layers.BatchNormalization(epsilon=config.batch_norm_eps, momentum=0.9, name="norm") + + def call(self, x: tf.Tensor, training: bool = False) -> tf.Tensor: + x = self.pad(x) + x = self.proj(x) + x = self.norm(x, training=training) + return x + + def build(self, input_shape=None): + if self.built: + return + if getattr(self, "proj", None) is not None: + with tf.name_scope(self.proj.name): + self.proj.build(self.in_chans) + if getattr(self, "norm", None) is not None: + with tf.name_scope(self.norm.name): + self.norm.build((None, None, None, self.embed_dim)) + self.built = True + + +class TFSwiftFormerConvEncoder(keras.layers.Layer): + """ + `SwiftFormerConvEncoder` with 3*3 and 1*1 convolutions. + + Input: tensor of shape `[batch_size, channels, height, width]` + + Output: tensor of shape `[batch_size, channels, height, width]` + """ + + def __init__(self, config: SwiftFormerConfig, dim: int, **kwargs): + super().__init__(**kwargs) + hidden_dim = int(config.mlp_ratio * dim) + + self.dim = dim + self.pad = keras.layers.ZeroPadding2D(padding=(1, 1)) + self.depth_wise_conv = keras.layers.Conv2D(dim, kernel_size=3, groups=dim, name="depth_wise_conv") + self.norm = keras.layers.BatchNormalization(epsilon=config.batch_norm_eps, momentum=0.9, name="norm") + self.point_wise_conv1 = keras.layers.Conv2D(hidden_dim, kernel_size=1, name="point_wise_conv1") + self.act = get_tf_activation("gelu") + self.point_wise_conv2 = keras.layers.Conv2D(dim, kernel_size=1, name="point_wise_conv2") + self.drop_path = keras.layers.Dropout(name="drop_path", rate=config.drop_conv_encoder_rate) + self.hidden_dim = int(config.mlp_ratio * self.dim) + + def build(self, input_shape=None): + if self.built: + return + self.layer_scale = self.add_weight( + name="layer_scale", + shape=self.dim, + initializer="ones", + trainable=True, + ) + + if getattr(self, "depth_wise_conv", None) is not None: + with tf.name_scope(self.depth_wise_conv.name): + self.depth_wise_conv.build(self.dim) + if getattr(self, "norm", None) is not None: + with tf.name_scope(self.norm.name): + self.norm.build((None, None, None, self.dim)) + if getattr(self, "point_wise_conv1", None) is not None: + with tf.name_scope(self.point_wise_conv1.name): + self.point_wise_conv1.build(self.dim) + if getattr(self, "point_wise_conv2", None) is not None: + with tf.name_scope(self.point_wise_conv2.name): + self.point_wise_conv2.build(self.hidden_dim) + if getattr(self, "drop_path", None) is not None: + with tf.name_scope(self.drop_path.name): + self.drop_path.build(None) + self.built = True + + def call(self, x: tf.Tensor, training: bool = False) -> tf.Tensor: + input = x + x = self.pad(x) + x = self.depth_wise_conv(x) + x = self.norm(x, training=training) + x = self.point_wise_conv1(x) + x = self.act(x) + x = self.point_wise_conv2(x) + x = input + self.drop_path(self.layer_scale * x) + return x + + +class TFSwiftFormerMlp(keras.layers.Layer): + """ + MLP layer with 1*1 convolutions. + + Input: tensor of shape `[batch_size, channels, height, width]` + + Output: tensor of shape `[batch_size, channels, height, width]` + """ + + def __init__(self, config: SwiftFormerConfig, in_features: int, **kwargs): + super().__init__(**kwargs) + + hidden_features = int(in_features * config.mlp_ratio) + self.norm1 = keras.layers.BatchNormalization(epsilon=config.batch_norm_eps, momentum=0.9, name="norm1") + self.fc1 = keras.layers.Conv2D(hidden_features, 1, name="fc1") + act_layer = get_tf_activation(config.hidden_act) + self.act = act_layer + self.fc2 = keras.layers.Conv2D(in_features, 1, name="fc2") + self.drop = keras.layers.Dropout(rate=config.drop_mlp_rate) + self.hidden_features = hidden_features + self.in_features = in_features + + def call(self, x: tf.Tensor, training: bool = False) -> tf.Tensor: + x = self.norm1(x, training=training) + x = self.fc1(x) + x = self.act(x) + x = self.drop(x, training=training) + x = self.fc2(x) + x = self.drop(x, training=training) + return x + + def build(self, input_shape=None): + if self.built: + return + if getattr(self, "norm1", None) is not None: + with tf.name_scope(self.norm1.name): + self.norm1.build((None, None, None, self.in_features)) + if getattr(self, "fc1", None) is not None: + with tf.name_scope(self.fc1.name): + self.fc1.build((None, None, None, self.in_features)) + if getattr(self, "fc2", None) is not None: + with tf.name_scope(self.fc2.name): + self.fc2.build((None, None, None, self.hidden_features)) + self.built = True + + +class TFSwiftFormerEfficientAdditiveAttention(keras.layers.Layer): + """ + Efficient Additive Attention module for SwiftFormer. + + Input: tensor of shape `[batch_size, channels, height, width]` + + Output: tensor of shape `[batch_size, channels, height, width]` + """ + + def __init__(self, config: SwiftFormerConfig, dim: int = 512, **kwargs): + super().__init__(**kwargs) + + self.dim = dim + + self.to_query = keras.layers.Dense(dim, name="to_query") + self.to_key = keras.layers.Dense(dim, name="to_key") + + self.scale_factor = dim**-0.5 + self.proj = keras.layers.Dense(dim, name="proj") + self.final = keras.layers.Dense(dim, name="final") + + def build(self, input_shape=None): + if self.built: + return + self.w_g = self.add_weight( + name="w_g", + shape=(self.dim, 1), + initializer=keras.initializers.RandomNormal(mean=0, stddev=1), + trainable=True, + ) + + if getattr(self, "to_query", None) is not None: + with tf.name_scope(self.to_query.name): + self.to_query.build(self.dim) + if getattr(self, "to_key", None) is not None: + with tf.name_scope(self.to_key.name): + self.to_key.build(self.dim) + if getattr(self, "proj", None) is not None: + with tf.name_scope(self.proj.name): + self.proj.build(self.dim) + if getattr(self, "final", None) is not None: + with tf.name_scope(self.final.name): + self.final.build(self.dim) + self.built = True + + def call(self, x: tf.Tensor) -> tf.Tensor: + query = self.to_query(x) + key = self.to_key(x) + + query = tf.math.l2_normalize(query, dim=-1) + key = tf.math.l2_normalize(key, dim=-1) + + query_weight = query @ self.w_g + scaled_query_weight = query_weight * self.scale_factor + scaled_query_weight = tf.nn.softmax(scaled_query_weight, axis=-1) + + global_queries = tf.math.reduce_sum(scaled_query_weight * query, axis=1) + global_queries = tf.tile(tf.expand_dims(global_queries, 1), (1, key.shape[1], 1)) + + out = self.proj(global_queries * key) + query + out = self.final(out) + + return out + + +class TFSwiftFormerLocalRepresentation(keras.layers.Layer): + """ + Local Representation module for SwiftFormer that is implemented by 3*3 depth-wise and point-wise convolutions. + + Input: tensor of shape `[batch_size, channels, height, width]` + + Output: tensor of shape `[batch_size, channels, height, width]` + """ + + def __init__(self, config: SwiftFormerConfig, dim: int, **kwargs): + super().__init__(**kwargs) + + self.dim = dim + + self.pad = keras.layers.ZeroPadding2D(padding=(1, 1)) + self.depth_wise_conv = keras.layers.Conv2D(dim, kernel_size=3, groups=dim, name="depth_wise_conv") + self.norm = keras.layers.BatchNormalization(epsilon=config.batch_norm_eps, momentum=0.9, name="norm") + self.point_wise_conv1 = keras.layers.Conv2D(dim, kernel_size=1, name="point_wise_conv1") + self.act = get_tf_activation("gelu") + self.point_wise_conv2 = keras.layers.Conv2D(dim, kernel_size=1, name="point_wise_conv2") + self.drop_path = keras.layers.Identity(name="drop_path") + + def build(self, input_shape=None): + if self.built: + return + self.layer_scale = self.add_weight( + name="layer_scale", + shape=(self.dim), + initializer="ones", + trainable=True, + ) + if getattr(self, "depth_wise_conv", None) is not None: + with tf.name_scope(self.depth_wise_conv.name): + self.depth_wise_conv.build((None, None, None, self.dim)) + if getattr(self, "norm", None) is not None: + with tf.name_scope(self.norm.name): + self.norm.build((None, None, None, self.dim)) + if getattr(self, "point_wise_conv1", None) is not None: + with tf.name_scope(self.point_wise_conv1.name): + self.point_wise_conv1.build(self.dim) + if getattr(self, "point_wise_conv2", None) is not None: + with tf.name_scope(self.point_wise_conv2.name): + self.point_wise_conv2.build(self.dim) + if getattr(self, "drop_path", None) is not None: + with tf.name_scope(self.drop_path.name): + self.drop_path.build(None) + self.built = True + + def call(self, x: tf.Tensor, training: bool = False) -> tf.Tensor: + input = x + x = self.pad(x) + x = self.depth_wise_conv(x) + x = self.norm(x, training=training) + x = self.point_wise_conv1(x) + x = self.act(x) + x = self.point_wise_conv2(x) + x = input + self.drop_path(self.layer_scale * x, training=training) + return x + + +class TFSwiftFormerEncoderBlock(keras.layers.Layer): + """ + SwiftFormer Encoder Block for SwiftFormer. It consists of (1) Local representation module, (2) + SwiftFormerEfficientAdditiveAttention, and (3) MLP block. + + Input: tensor of shape `[batch_size, channels, height, width]` + + Output: tensor of shape `[batch_size, channels,height, width]` + """ + + def __init__(self, config: SwiftFormerConfig, dim: int, drop_path: float = 0.0, **kwargs): + super().__init__(**kwargs) + + layer_scale_init_value = config.layer_scale_init_value + use_layer_scale = config.use_layer_scale + + self.local_representation = TFSwiftFormerLocalRepresentation(config, dim=dim, name="local_representation") + self.attn = TFSwiftFormerEfficientAdditiveAttention(config, dim=dim, name="attn") + self.linear = TFSwiftFormerMlp(config, in_features=dim, name="linear") + self.drop_path = TFSwiftFormerDropPath(config) if drop_path > 0.0 else keras.layers.Identity() + self.use_layer_scale = use_layer_scale + if use_layer_scale: + self.dim = dim + self.layer_scale_init_value = layer_scale_init_value + + def build(self, input_shape=None): + if self.built: + return + self.layer_scale_1 = self.add_weight( + name="layer_scale_1", + shape=self.dim, + initializer=keras.initializers.constant(self.layer_scale_init_value), + trainable=True, + ) + self.layer_scale_2 = self.add_weight( + name="layer_scale_2", + shape=self.dim, + initializer=keras.initializers.constant(self.layer_scale_init_value), + trainable=True, + ) + + if getattr(self, "local_representation", None) is not None: + with tf.name_scope(self.local_representation.name): + self.local_representation.build(None) + if getattr(self, "attn", None) is not None: + with tf.name_scope(self.attn.name): + self.attn.build(None) + if getattr(self, "linear", None) is not None: + with tf.name_scope(self.linear.name): + self.linear.build(None) + self.built = True + + def call(self, x: tf.Tensor, training: bool = False): + x = self.local_representation(x, training=training) + batch_size, height, width, channels = x.shape + + res = tf.reshape(x, [-1, height * width, channels]) + res = self.attn(res) + res = tf.reshape(res, [-1, height, width, channels]) + if self.use_layer_scale: + x = x + self.drop_path(self.layer_scale_1 * res, training=training) + x = x + self.drop_path(self.layer_scale_2 * self.linear(x), training=training) + else: + x = x + self.drop_path(res, training=training) + x = x + self.drop_path(self.linear(x), training=training) + return x + + +class TFSwiftFormerStage(keras.layers.Layer): + """ + A Swiftformer stage consisting of a series of `SwiftFormerConvEncoder` blocks and a final + `SwiftFormerEncoderBlock`. + + Input: tensor in shape `[batch_size, channels, height, width]` + + Output: tensor in shape `[batch_size, channels, height, width]` + """ + + def __init__(self, config: SwiftFormerConfig, index: int, **kwargs) -> None: + super().__init__(**kwargs) + + layer_depths = config.depths + dim = config.embed_dims[index] + depth = layer_depths[index] + + self.blocks = [] + for block_idx in range(depth): + block_dpr = config.drop_path_rate * (block_idx + sum(layer_depths[:index])) / (sum(layer_depths) - 1) + + if depth - block_idx <= 1: + self.blocks.append( + TFSwiftFormerEncoderBlock(config, dim=dim, drop_path=block_dpr, name=f"blocks_._{block_idx}") + ) + else: + self.blocks.append(TFSwiftFormerConvEncoder(config, dim=dim, name=f"blocks_._{block_idx}")) + + def call(self, input: tf.Tensor, training: bool = False) -> tf.Tensor: + for i, block in enumerate(self.blocks): + input = block(input, training=training) + return input + + def build(self, input_shape=None): + for layer in self.blocks: + with tf.name_scope(layer.name): + layer.build(None) + + +class TFSwiftFormerEncoder(keras.layers.Layer): + def __init__(self, config: SwiftFormerConfig, **kwargs) -> None: + super().__init__(**kwargs) + self.config = config + + embed_dims = config.embed_dims + downsamples = config.downsamples + layer_depths = config.depths + + # Transformer model + self.network = [] + name_i = 0 + for i in range(len(layer_depths)): + stage = TFSwiftFormerStage(config, index=i, name=f"network_._{name_i}") + self.network.append(stage) + name_i += 1 + if i >= len(layer_depths) - 1: + break + if downsamples[i] or embed_dims[i] != embed_dims[i + 1]: + # downsampling between two stages + self.network.append(TFSwiftFormerEmbeddings(config, index=i, name=f"network_._{name_i}")) + name_i += 1 + + self.gradient_checkpointing = False + + def call( + self, + hidden_states: tf.Tensor, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: bool = False, + ) -> Union[tuple, TFBaseModelOutputWithNoAttention]: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + all_hidden_states = (hidden_states,) if output_hidden_states else None + + for i, block in enumerate(self.network): + hidden_states = block(hidden_states, training=training) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + hidden_states = tf.transpose(hidden_states, perm=[0, 3, 1, 2]) + if all_hidden_states: + all_hidden_states = tuple(tf.transpose(s, perm=[0, 3, 1, 2]) for s in all_hidden_states) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states] if v is not None) + + return TFBaseModelOutputWithNoAttention( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + ) + + def build(self, input_shape=None): + for layer in self.network: + with tf.name_scope(layer.name): + layer.build(None) + + +class TFSwiftFormerPreTrainedModel(TFPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = SwiftFormerConfig + base_model_prefix = "swiftformer" + main_input_name = "pixel_values" + + +TFSWIFTFORMER_START_DOCSTRING = r""" + This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it + as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and + behavior. + + + + TF 2.0 models accepts two formats as inputs: + - having all inputs as keyword arguments (like PyTorch models), or + - having all inputs as a list, tuple or dict in the first positional arguments. + This second option is useful when using [`keras.Model.fit`] method which currently requires having all the + tensors in the first argument of the model call function: `model(inputs)`. + If you choose this second option, there are three possibilities you can use to gather all the input Tensors in the + first positional argument : + - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` + - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: + `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` + - a dictionary with one or several input Tensors associated to the input names given in the docstring: + `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` + + + + Parameters: + config ([`SwiftFormerConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +TFSWIFTFORMER_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`] + for details. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + training (`bool`, *optional*, defaults to `False`): + Whether or not to run the model in training mode. +""" + + +@keras_serializable +class TFSwiftFormerMainLayer(keras.layers.Layer): + config_class = SwiftFormerConfig + + def __init__(self, config: SwiftFormerConfig, **kwargs): + super().__init__(**kwargs) + self.config = config + + self.patch_embed = TFSwiftFormerPatchEmbedding(config, name="patch_embed") + self.encoder = TFSwiftFormerEncoder(config, name="encoder") + + @unpack_inputs + def call( + self, + pixel_values: Optional[tf.Tensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: bool = False, + ) -> Union[Tuple, TFBaseModelOutputWithNoAttention]: + r""" """ + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # TF 2.0 image layers can't use NCHW format when running on CPU. + # We transpose to NHWC format and then transpose back after the full forward pass. + # (batch_size, num_channels, height, width) -> (batch_size, height, width, num_channels) + pixel_values = tf.transpose(pixel_values, perm=[0, 2, 3, 1]) + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + embedding_output = self.patch_embed(pixel_values, training=training) + encoder_outputs = self.encoder( + embedding_output, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + if not return_dict: + return tuple(v for v in encoder_outputs if v is not None) + + return TFBaseModelOutputWithNoAttention( + last_hidden_state=encoder_outputs.last_hidden_state, + hidden_states=encoder_outputs.hidden_states, + ) + + def build(self, input_shape=None): + if self.built: + return + if getattr(self, "patch_embed", None) is not None: + with tf.name_scope(self.patch_embed.name): + self.patch_embed.build(None) + if getattr(self, "encoder", None) is not None: + with tf.name_scope(self.encoder.name): + self.encoder.build(None) + self.built = True + + +@add_start_docstrings( + "The bare TFSwiftFormer Model transformer outputting raw hidden-states without any specific head on top.", + TFSWIFTFORMER_START_DOCSTRING, +) +class TFSwiftFormerModel(TFSwiftFormerPreTrainedModel): + def __init__(self, config: SwiftFormerConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.swiftformer = TFSwiftFormerMainLayer(config, name="swiftformer") + + @unpack_inputs + @add_start_docstrings_to_model_forward(TFSWIFTFORMER_INPUTS_DOCSTRING) + def call( + self, + pixel_values: Optional[tf.Tensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: bool = False, + ) -> Union[TFBaseModelOutputWithNoAttention, Tuple[tf.Tensor]]: + outputs = self.swiftformer( + pixel_values=pixel_values, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + return outputs + + def build(self, input_shape=None): + if self.built: + return + if getattr(self, "swiftformer", None) is not None: + with tf.name_scope(self.swiftformer.name): + self.swiftformer.build(None) + self.built = True + + +@add_start_docstrings( + """ + TFSwiftFormer Model transformer with an image classification head on top (e.g. for ImageNet). + """, + TFSWIFTFORMER_START_DOCSTRING, +) +class TFSwiftFormerForImageClassification(TFSwiftFormerPreTrainedModel): + def __init__(self, config: SwiftFormerConfig, **kwargs) -> None: + super().__init__(config, **kwargs) + + self.num_labels = config.num_labels + self.swiftformer = TFSwiftFormerMainLayer(config, name="swiftformer") + + # Classifier head + self.norm = keras.layers.BatchNormalization(epsilon=config.batch_norm_eps, momentum=0.9, name="norm") + self.head = ( + keras.layers.Dense(self.num_labels, name="head") + if self.num_labels > 0 + else keras.layers.Identity(name="head") + ) + self.dist_head = ( + keras.layers.Dense(self.num_labels, name="dist_head") + if self.num_labels > 0 + else keras.layers.Identity(name="dist_head") + ) + + def hf_compute_loss(self, labels, logits): + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == tf.int64 or labels.dtype == tf.int32): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = keras.losses.MSE + if self.num_labels == 1: + loss = loss_fct(labels.squeeze(), logits.squeeze()) + else: + loss = loss_fct(labels, logits) + elif self.config.problem_type == "single_label_classification": + loss_fct = keras.losses.SparseCategoricalCrossentropy( + from_logits=True, reduction=keras.losses.Reduction.NONE + ) + loss = loss_fct(labels, logits) + elif self.config.problem_type == "multi_label_classification": + loss_fct = keras.losses.SparseCategoricalCrossentropy( + from_logits=True, + reduction=keras.losses.Reduction.NONE, + ) + loss = loss_fct(labels, logits) + else: + loss = None + + return loss + + @unpack_inputs + @add_start_docstrings_to_model_forward(TFSWIFTFORMER_INPUTS_DOCSTRING) + def call( + self, + pixel_values: Optional[tf.Tensor] = None, + labels: Optional[tf.Tensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: bool = False, + ) -> Union[tuple, TFImageClassifierOutputWithNoAttention]: + r""" + labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # run base model + outputs = self.swiftformer( + pixel_values, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + sequence_output = outputs.last_hidden_state if return_dict else outputs[0] + sequence_output = tf.transpose(sequence_output, perm=[0, 2, 3, 1]) + + # run classification head + sequence_output = self.norm(sequence_output, training=training) + sequence_output = tf.transpose(sequence_output, perm=[0, 3, 1, 2]) + _, num_channels, height, width = sequence_output.shape + sequence_output = tf.reshape(sequence_output, [-1, num_channels, height * width]) + sequence_output = tf.reduce_mean(sequence_output, axis=-1) + cls_out = self.head(sequence_output) + distillation_out = self.dist_head(sequence_output) + logits = (cls_out + distillation_out) / 2 + + # calculate loss + loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return TFImageClassifierOutputWithNoAttention( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + ) + + def build(self, input_shape=None): + if self.built: + return + if getattr(self, "swiftformer", None) is not None: + with tf.name_scope(self.swiftformer.name): + self.swiftformer.build(None) + if getattr(self, "norm", None) is not None: + with tf.name_scope(self.norm.name): + self.norm.build((None, None, None, self.config.embed_dims[-1])) + if getattr(self, "head", None) is not None: + with tf.name_scope(self.head.name): + self.head.build(self.config.embed_dims[-1]) + if getattr(self, "dist_head", None) is not None: + with tf.name_scope(self.dist_head.name): + self.dist_head.build(self.config.embed_dims[-1]) + self.built = True + + +__all__ = ["TFSwiftFormerForImageClassification", "TFSwiftFormerModel", "TFSwiftFormerPreTrainedModel"] diff --git a/docs/transformers/src/transformers/models/swin/__init__.py b/docs/transformers/src/transformers/models/swin/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3dc5871b0375e8a5553e30dce1e5dde807a1b493 --- /dev/null +++ b/docs/transformers/src/transformers/models/swin/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_swin import * + from .modeling_swin import * + from .modeling_tf_swin import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/swin/configuration_swin.py b/docs/transformers/src/transformers/models/swin/configuration_swin.py new file mode 100644 index 0000000000000000000000000000000000000000..da6ba9871407a644b99435e69c6ec46e26fd8735 --- /dev/null +++ b/docs/transformers/src/transformers/models/swin/configuration_swin.py @@ -0,0 +1,179 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Swin Transformer model configuration""" + +from collections import OrderedDict +from typing import Mapping + +from packaging import version + +from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig +from ...utils import logging +from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices + + +logger = logging.get_logger(__name__) + + +class SwinConfig(BackboneConfigMixin, PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`SwinModel`]. It is used to instantiate a Swin + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the Swin + [microsoft/swin-tiny-patch4-window7-224](https://huggingface.co/microsoft/swin-tiny-patch4-window7-224) + architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 4): + The size (resolution) of each patch. + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + embed_dim (`int`, *optional*, defaults to 96): + Dimensionality of patch embedding. + depths (`list(int)`, *optional*, defaults to `[2, 2, 6, 2]`): + Depth of each layer in the Transformer encoder. + num_heads (`list(int)`, *optional*, defaults to `[3, 6, 12, 24]`): + Number of attention heads in each layer of the Transformer encoder. + window_size (`int`, *optional*, defaults to 7): + Size of windows. + mlp_ratio (`float`, *optional*, defaults to 4.0): + Ratio of MLP hidden dimensionality to embedding dimensionality. + qkv_bias (`bool`, *optional*, defaults to `True`): + Whether or not a learnable bias should be added to the queries, keys and values. + hidden_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings and encoder. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + drop_path_rate (`float`, *optional*, defaults to 0.1): + Stochastic depth rate. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder. If string, `"gelu"`, `"relu"`, + `"selu"` and `"gelu_new"` are supported. + use_absolute_embeddings (`bool`, *optional*, defaults to `False`): + Whether or not to add absolute position embeddings to the patch embeddings. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the layer normalization layers. + encoder_stride (`int`, *optional*, defaults to 32): + Factor to increase the spatial resolution by in the decoder head for masked image modeling. + out_features (`List[str]`, *optional*): + If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc. + (depending on how many stages the model has). If unset and `out_indices` is set, will default to the + corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the + same order as defined in the `stage_names` attribute. + out_indices (`List[int]`, *optional*): + If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how + many stages the model has). If unset and `out_features` is set, will default to the corresponding stages. + If unset and `out_features` is unset, will default to the last stage. Must be in the + same order as defined in the `stage_names` attribute. + + Example: + + ```python + >>> from transformers import SwinConfig, SwinModel + + >>> # Initializing a Swin microsoft/swin-tiny-patch4-window7-224 style configuration + >>> configuration = SwinConfig() + + >>> # Initializing a model (with random weights) from the microsoft/swin-tiny-patch4-window7-224 style configuration + >>> model = SwinModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "swin" + + attribute_map = { + "num_attention_heads": "num_heads", + "num_hidden_layers": "num_layers", + } + + def __init__( + self, + image_size=224, + patch_size=4, + num_channels=3, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + mlp_ratio=4.0, + qkv_bias=True, + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + drop_path_rate=0.1, + hidden_act="gelu", + use_absolute_embeddings=False, + initializer_range=0.02, + layer_norm_eps=1e-5, + encoder_stride=32, + out_features=None, + out_indices=None, + **kwargs, + ): + super().__init__(**kwargs) + + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.embed_dim = embed_dim + self.depths = depths + self.num_layers = len(depths) + self.num_heads = num_heads + self.window_size = window_size + self.mlp_ratio = mlp_ratio + self.qkv_bias = qkv_bias + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.drop_path_rate = drop_path_rate + self.hidden_act = hidden_act + self.use_absolute_embeddings = use_absolute_embeddings + self.layer_norm_eps = layer_norm_eps + self.initializer_range = initializer_range + self.encoder_stride = encoder_stride + # we set the hidden_size attribute in order to make Swin work with VisionEncoderDecoderModel + # this indicates the channel dimension after the last stage of the model + self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1)) + self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)] + self._out_features, self._out_indices = get_aligned_output_features_output_indices( + out_features=out_features, out_indices=out_indices, stage_names=self.stage_names + ) + + +class SwinOnnxConfig(OnnxConfig): + torch_onnx_minimum_version = version.parse("1.11") + + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict( + [ + ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}), + ] + ) + + @property + def atol_for_validation(self) -> float: + return 1e-4 + + +__all__ = ["SwinConfig", "SwinOnnxConfig"] diff --git a/docs/transformers/src/transformers/models/swin/convert_swin_simmim_to_pytorch.py b/docs/transformers/src/transformers/models/swin/convert_swin_simmim_to_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..9a87ff693af209656ccf0e53a0ea4c2e26e5ba51 --- /dev/null +++ b/docs/transformers/src/transformers/models/swin/convert_swin_simmim_to_pytorch.py @@ -0,0 +1,182 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert Swin SimMIM checkpoints from the original repository. + +URL: https://github.com/microsoft/Swin-Transformer/blob/main/MODELHUB.md#simmim-pretrained-swin-v1-models""" + +import argparse + +import requests +import torch +from PIL import Image + +from transformers import SwinConfig, SwinForMaskedImageModeling, ViTImageProcessor + + +def get_swin_config(model_name): + config = SwinConfig(image_size=192) + + if "base" in model_name: + window_size = 6 + embed_dim = 128 + depths = (2, 2, 18, 2) + num_heads = (4, 8, 16, 32) + elif "large" in model_name: + window_size = 12 + embed_dim = 192 + depths = (2, 2, 18, 2) + num_heads = (6, 12, 24, 48) + else: + raise ValueError("Model not supported, only supports base and large variants") + + config.window_size = window_size + config.embed_dim = embed_dim + config.depths = depths + config.num_heads = num_heads + + return config + + +def rename_key(name): + if "encoder.mask_token" in name: + name = name.replace("encoder.mask_token", "embeddings.mask_token") + if "encoder.patch_embed.proj" in name: + name = name.replace("encoder.patch_embed.proj", "embeddings.patch_embeddings.projection") + if "encoder.patch_embed.norm" in name: + name = name.replace("encoder.patch_embed.norm", "embeddings.norm") + if "attn.proj" in name: + name = name.replace("attn.proj", "attention.output.dense") + if "attn" in name: + name = name.replace("attn", "attention.self") + if "norm1" in name: + name = name.replace("norm1", "layernorm_before") + if "norm2" in name: + name = name.replace("norm2", "layernorm_after") + if "mlp.fc1" in name: + name = name.replace("mlp.fc1", "intermediate.dense") + if "mlp.fc2" in name: + name = name.replace("mlp.fc2", "output.dense") + + if name == "encoder.norm.weight": + name = "layernorm.weight" + if name == "encoder.norm.bias": + name = "layernorm.bias" + + if "decoder" in name: + pass + else: + name = "swin." + name + + return name + + +def convert_state_dict(orig_state_dict, model): + for key in orig_state_dict.copy().keys(): + val = orig_state_dict.pop(key) + + if "attn_mask" in key: + pass + elif "qkv" in key: + key_split = key.split(".") + layer_num = int(key_split[2]) + block_num = int(key_split[4]) + dim = model.swin.encoder.layers[layer_num].blocks[block_num].attention.self.all_head_size + + if "weight" in key: + orig_state_dict[f"swin.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.query.weight"] = ( + val[:dim, :] + ) + orig_state_dict[f"swin.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.key.weight"] = val[ + dim : dim * 2, : + ] + orig_state_dict[f"swin.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.value.weight"] = ( + val[-dim:, :] + ) + else: + orig_state_dict[f"swin.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.query.bias"] = val[ + :dim + ] + orig_state_dict[f"swin.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.key.bias"] = val[ + dim : dim * 2 + ] + orig_state_dict[f"swin.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.value.bias"] = val[ + -dim: + ] + else: + orig_state_dict[rename_key(key)] = val + + return orig_state_dict + + +def convert_swin_checkpoint(model_name, checkpoint_path, pytorch_dump_folder_path, push_to_hub): + state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True)["model"] + + config = get_swin_config(model_name) + model = SwinForMaskedImageModeling(config) + model.eval() + + new_state_dict = convert_state_dict(state_dict, model) + model.load_state_dict(new_state_dict) + + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + + image_processor = ViTImageProcessor(size={"height": 192, "width": 192}) + image = Image.open(requests.get(url, stream=True).raw) + inputs = image_processor(images=image, return_tensors="pt") + + with torch.no_grad(): + outputs = model(**inputs).logits + + print(outputs.keys()) + print("Looks ok!") + + if pytorch_dump_folder_path is not None: + print(f"Saving model {model_name} to {pytorch_dump_folder_path}") + model.save_pretrained(pytorch_dump_folder_path) + + print(f"Saving image processor to {pytorch_dump_folder_path}") + image_processor.save_pretrained(pytorch_dump_folder_path) + + if push_to_hub: + print(f"Pushing model and image processor for {model_name} to hub") + model.push_to_hub(f"microsoft/{model_name}") + image_processor.push_to_hub(f"microsoft/{model_name}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--model_name", + default="swin-base-simmim-window6-192", + type=str, + choices=["swin-base-simmim-window6-192", "swin-large-simmim-window12-192"], + help="Name of the Swin SimMIM model you'd like to convert.", + ) + parser.add_argument( + "--checkpoint_path", + default="/Users/nielsrogge/Documents/SwinSimMIM/simmim_pretrain__swin_base__img192_window6__100ep.pth", + type=str, + help="Path to the original PyTorch checkpoint (.pth file).", + ) + parser.add_argument( + "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory." + ) + parser.add_argument( + "--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub." + ) + + args = parser.parse_args() + convert_swin_checkpoint(args.model_name, args.checkpoint_path, args.pytorch_dump_folder_path, args.push_to_hub) diff --git a/docs/transformers/src/transformers/models/swin/convert_swin_timm_to_pytorch.py b/docs/transformers/src/transformers/models/swin/convert_swin_timm_to_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..c91249b272baeba47f5798bd559c39789b6b0224 --- /dev/null +++ b/docs/transformers/src/transformers/models/swin/convert_swin_timm_to_pytorch.py @@ -0,0 +1,173 @@ +import argparse +import json + +import requests +import timm +import torch +from huggingface_hub import hf_hub_download +from PIL import Image + +from transformers import AutoImageProcessor, SwinConfig, SwinForImageClassification + + +def get_swin_config(swin_name): + config = SwinConfig() + name_split = swin_name.split("_") + + model_size = name_split[1] + img_size = int(name_split[4]) + window_size = int(name_split[3][-1]) + + if model_size == "tiny": + embed_dim = 96 + depths = (2, 2, 6, 2) + num_heads = (3, 6, 12, 24) + elif model_size == "small": + embed_dim = 96 + depths = (2, 2, 18, 2) + num_heads = (3, 6, 12, 24) + elif model_size == "base": + embed_dim = 128 + depths = (2, 2, 18, 2) + num_heads = (4, 8, 16, 32) + else: + embed_dim = 192 + depths = (2, 2, 18, 2) + num_heads = (6, 12, 24, 48) + + if "in22k" in swin_name: + num_classes = 21841 + else: + num_classes = 1000 + repo_id = "huggingface/label-files" + filename = "imagenet-1k-id2label.json" + id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r")) + id2label = {int(k): v for k, v in id2label.items()} + config.id2label = id2label + config.label2id = {v: k for k, v in id2label.items()} + + config.image_size = img_size + config.num_labels = num_classes + config.embed_dim = embed_dim + config.depths = depths + config.num_heads = num_heads + config.window_size = window_size + + return config + + +def rename_key(name): + if "patch_embed.proj" in name: + name = name.replace("patch_embed.proj", "embeddings.patch_embeddings.projection") + if "patch_embed.norm" in name: + name = name.replace("patch_embed.norm", "embeddings.norm") + if "layers" in name: + name = "encoder." + name + if "attn.proj" in name: + name = name.replace("attn.proj", "attention.output.dense") + if "attn" in name: + name = name.replace("attn", "attention.self") + if "norm1" in name: + name = name.replace("norm1", "layernorm_before") + if "norm2" in name: + name = name.replace("norm2", "layernorm_after") + if "mlp.fc1" in name: + name = name.replace("mlp.fc1", "intermediate.dense") + if "mlp.fc2" in name: + name = name.replace("mlp.fc2", "output.dense") + + if name == "norm.weight": + name = "layernorm.weight" + if name == "norm.bias": + name = "layernorm.bias" + + if "head" in name: + name = name.replace("head", "classifier") + else: + name = "swin." + name + + return name + + +def convert_state_dict(orig_state_dict, model): + for key in orig_state_dict.copy().keys(): + val = orig_state_dict.pop(key) + + if "mask" in key: + continue + elif "qkv" in key: + key_split = key.split(".") + layer_num = int(key_split[1]) + block_num = int(key_split[3]) + dim = model.swin.encoder.layers[layer_num].blocks[block_num].attention.self.all_head_size + + if "weight" in key: + orig_state_dict[f"swin.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.query.weight"] = ( + val[:dim, :] + ) + orig_state_dict[f"swin.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.key.weight"] = val[ + dim : dim * 2, : + ] + orig_state_dict[f"swin.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.value.weight"] = ( + val[-dim:, :] + ) + else: + orig_state_dict[f"swin.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.query.bias"] = val[ + :dim + ] + orig_state_dict[f"swin.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.key.bias"] = val[ + dim : dim * 2 + ] + orig_state_dict[f"swin.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.value.bias"] = val[ + -dim: + ] + else: + orig_state_dict[rename_key(key)] = val + + return orig_state_dict + + +def convert_swin_checkpoint(swin_name, pytorch_dump_folder_path): + timm_model = timm.create_model(swin_name, pretrained=True) + timm_model.eval() + + config = get_swin_config(swin_name) + model = SwinForImageClassification(config) + model.eval() + + new_state_dict = convert_state_dict(timm_model.state_dict(), model) + model.load_state_dict(new_state_dict) + + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + + image_processor = AutoImageProcessor.from_pretrained("microsoft/{}".format(swin_name.replace("_", "-"))) + image = Image.open(requests.get(url, stream=True).raw) + inputs = image_processor(images=image, return_tensors="pt") + + timm_outs = timm_model(inputs["pixel_values"]) + hf_outs = model(**inputs).logits + + assert torch.allclose(timm_outs, hf_outs, atol=1e-3) + + print(f"Saving model {swin_name} to {pytorch_dump_folder_path}") + model.save_pretrained(pytorch_dump_folder_path) + + print(f"Saving image processor to {pytorch_dump_folder_path}") + image_processor.save_pretrained(pytorch_dump_folder_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--swin_name", + default="swin_tiny_patch4_window7_224", + type=str, + help="Name of the Swin timm model you'd like to convert.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory." + ) + + args = parser.parse_args() + convert_swin_checkpoint(args.swin_name, args.pytorch_dump_folder_path) diff --git a/docs/transformers/src/transformers/models/swin/modeling_swin.py b/docs/transformers/src/transformers/models/swin/modeling_swin.py new file mode 100644 index 0000000000000000000000000000000000000000..ad998f6ff666eaa6a1a0e5eb75e8d735d48e6292 --- /dev/null +++ b/docs/transformers/src/transformers/models/swin/modeling_swin.py @@ -0,0 +1,1412 @@ +# coding=utf-8 +# Copyright 2022 Microsoft Research and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Swin Transformer model.""" + +import collections.abc +import math +import warnings +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn + +from ...activations import ACT2FN +from ...modeling_outputs import BackboneOutput +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer +from ...utils import ( + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, + torch_int, +) +from ...utils.backbone_utils import BackboneMixin +from .configuration_swin import SwinConfig + + +logger = logging.get_logger(__name__) + +# General docstring +_CONFIG_FOR_DOC = "SwinConfig" + +# Base docstring +_CHECKPOINT_FOR_DOC = "microsoft/swin-tiny-patch4-window7-224" +_EXPECTED_OUTPUT_SHAPE = [1, 49, 768] + +# Image classification docstring +_IMAGE_CLASS_CHECKPOINT = "microsoft/swin-tiny-patch4-window7-224" +_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat" + + +# drop_path, SwinPatchEmbeddings, SwinPatchMerging and SwinDropPath are from the timm library. + + +@dataclass +class SwinEncoderOutput(ModelOutput): + """ + Swin encoder's outputs, with potential hidden states and attentions. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, hidden_size, height, width)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to + include the spatial dimensions. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class SwinModelOutput(ModelOutput): + """ + Swin model's outputs that also contains a pooling of the last hidden states. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*, returned when `add_pooling_layer=True` is passed): + Average pooling of the last layer hidden-state. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, hidden_size, height, width)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to + include the spatial dimensions. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + pooler_output: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class SwinMaskedImageModelingOutput(ModelOutput): + """ + Swin masked image model outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided): + Masked image modeling (MLM) loss. + reconstruction (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Reconstructed pixel values. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, hidden_size, height, width)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to + include the spatial dimensions. + """ + + loss: Optional[torch.FloatTensor] = None + reconstruction: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + + @property + def logits(self): + warnings.warn( + "logits attribute is deprecated and will be removed in version 5 of Transformers." + " Please use the reconstruction attribute to retrieve the final output instead.", + FutureWarning, + ) + return self.reconstruction + + +@dataclass +class SwinImageClassifierOutput(ModelOutput): + """ + Swin outputs for image classification. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, hidden_size, height, width)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to + include the spatial dimensions. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + + +def window_partition(input_feature, window_size): + """ + Partitions the given input into windows. + """ + batch_size, height, width, num_channels = input_feature.shape + input_feature = input_feature.view( + batch_size, height // window_size, window_size, width // window_size, window_size, num_channels + ) + windows = input_feature.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels) + return windows + + +def window_reverse(windows, window_size, height, width): + """ + Merges windows to produce higher resolution features. + """ + num_channels = windows.shape[-1] + windows = windows.view(-1, height // window_size, width // window_size, window_size, window_size, num_channels) + windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, height, width, num_channels) + return windows + + +class SwinEmbeddings(nn.Module): + """ + Construct the patch and position embeddings. Optionally, also the mask token. + """ + + def __init__(self, config, use_mask_token=False): + super().__init__() + + self.patch_embeddings = SwinPatchEmbeddings(config) + num_patches = self.patch_embeddings.num_patches + self.patch_grid = self.patch_embeddings.grid_size + self.mask_token = nn.Parameter(torch.zeros(1, 1, config.embed_dim)) if use_mask_token else None + + if config.use_absolute_embeddings: + self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.embed_dim)) + else: + self.position_embeddings = None + + self.norm = nn.LayerNorm(config.embed_dim) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.patch_size = config.patch_size + self.config = config + + # Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing. + + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 + """ + + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embeddings.shape[1] - 1 + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: + return self.position_embeddings + + class_pos_embed = self.position_embeddings[:, :1] + patch_pos_embed = self.position_embeddings[:, 1:] + + dim = embeddings.shape[-1] + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(new_height, new_width), + mode="bicubic", + align_corners=False, + ) + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) + + def forward( + self, + pixel_values: Optional[torch.FloatTensor], + bool_masked_pos: Optional[torch.BoolTensor] = None, + interpolate_pos_encoding: bool = False, + ) -> Tuple[torch.Tensor]: + _, num_channels, height, width = pixel_values.shape + embeddings, output_dimensions = self.patch_embeddings(pixel_values) + embeddings = self.norm(embeddings) + batch_size, seq_len, _ = embeddings.size() + + if bool_masked_pos is not None: + mask_tokens = self.mask_token.expand(batch_size, seq_len, -1) + # replace the masked visual tokens by mask_tokens + mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens) + embeddings = embeddings * (1.0 - mask) + mask_tokens * mask + + if self.position_embeddings is not None: + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embeddings + + embeddings = self.dropout(embeddings) + + return embeddings, output_dimensions + + +class SwinPatchEmbeddings(nn.Module): + """ + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__(self, config): + super().__init__() + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.embed_dim + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + self.grid_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1]) + + self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) + + def maybe_pad(self, pixel_values, height, width): + if width % self.patch_size[1] != 0: + pad_values = (0, self.patch_size[1] - width % self.patch_size[1]) + pixel_values = nn.functional.pad(pixel_values, pad_values) + if height % self.patch_size[0] != 0: + pad_values = (0, 0, 0, self.patch_size[0] - height % self.patch_size[0]) + pixel_values = nn.functional.pad(pixel_values, pad_values) + return pixel_values + + def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]: + _, num_channels, height, width = pixel_values.shape + # pad the input to be divisible by self.patch_size, if needed + pixel_values = self.maybe_pad(pixel_values, height, width) + embeddings = self.projection(pixel_values) + _, _, height, width = embeddings.shape + output_dimensions = (height, width) + embeddings = embeddings.flatten(2).transpose(1, 2) + + return embeddings, output_dimensions + + +class SwinPatchMerging(nn.Module): + """ + Patch Merging Layer. + + Args: + input_resolution (`Tuple[int]`): + Resolution of input feature. + dim (`int`): + Number of input channels. + norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`): + Normalization layer class. + """ + + def __init__(self, input_resolution: Tuple[int], dim: int, norm_layer: nn.Module = nn.LayerNorm) -> None: + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def maybe_pad(self, input_feature, height, width): + should_pad = (height % 2 == 1) or (width % 2 == 1) + if should_pad: + pad_values = (0, 0, 0, width % 2, 0, height % 2) + input_feature = nn.functional.pad(input_feature, pad_values) + + return input_feature + + def forward(self, input_feature: torch.Tensor, input_dimensions: Tuple[int, int]) -> torch.Tensor: + height, width = input_dimensions + # `dim` is height * width + batch_size, dim, num_channels = input_feature.shape + + input_feature = input_feature.view(batch_size, height, width, num_channels) + # pad input to be disible by width and height, if needed + input_feature = self.maybe_pad(input_feature, height, width) + # [batch_size, height/2, width/2, num_channels] + input_feature_0 = input_feature[:, 0::2, 0::2, :] + # [batch_size, height/2, width/2, num_channels] + input_feature_1 = input_feature[:, 1::2, 0::2, :] + # [batch_size, height/2, width/2, num_channels] + input_feature_2 = input_feature[:, 0::2, 1::2, :] + # [batch_size, height/2, width/2, num_channels] + input_feature_3 = input_feature[:, 1::2, 1::2, :] + # batch_size height/2 width/2 4*num_channels + input_feature = torch.cat([input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1) + input_feature = input_feature.view(batch_size, -1, 4 * num_channels) # batch_size height/2*width/2 4*C + + input_feature = self.norm(input_feature) + input_feature = self.reduction(input_feature) + + return input_feature + + +# Copied from transformers.models.beit.modeling_beit.drop_path +def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, + however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the + layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the + argument. + """ + if drop_prob == 0.0 or not training: + return input + keep_prob = 1 - drop_prob + shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) + random_tensor.floor_() # binarize + output = input.div(keep_prob) * random_tensor + return output + + +# Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->Swin +class SwinDropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: Optional[float] = None) -> None: + super().__init__() + self.drop_prob = drop_prob + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return drop_path(hidden_states, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return "p={}".format(self.drop_prob) + + +class SwinSelfAttention(nn.Module): + def __init__(self, config, dim, num_heads, window_size): + super().__init__() + if dim % num_heads != 0: + raise ValueError( + f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})" + ) + + self.num_attention_heads = num_heads + self.attention_head_size = int(dim / num_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.window_size = ( + window_size if isinstance(window_size, collections.abc.Iterable) else (window_size, window_size) + ) + + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads) + ) + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij")) + coords_flatten = torch.flatten(coords, 1) + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + relative_coords[:, :, 0] += self.window_size[0] - 1 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) + self.register_buffer("relative_position_index", relative_position_index) + + self.query = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias) + self.key = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias) + self.value = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + batch_size, dim, num_channels = hidden_states.shape + mixed_query_layer = self.query(hidden_states) + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)] + relative_position_bias = relative_position_bias.view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1 + ) + + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() + attention_scores = attention_scores + relative_position_bias.unsqueeze(0) + + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in SwinModel forward() function) + mask_shape = attention_mask.shape[0] + attention_scores = attention_scores.view( + batch_size // mask_shape, mask_shape, self.num_attention_heads, dim, dim + ) + attention_scores = attention_scores + attention_mask.unsqueeze(1).unsqueeze(0) + attention_scores = attention_scores.view(-1, self.num_attention_heads, dim, dim) + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +class SwinSelfOutput(nn.Module): + def __init__(self, config, dim): + super().__init__() + self.dense = nn.Linear(dim, dim) + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + return hidden_states + + +class SwinAttention(nn.Module): + def __init__(self, config, dim, num_heads, window_size): + super().__init__() + self.self = SwinSelfAttention(config, dim, num_heads, window_size) + self.output = SwinSelfOutput(config, dim) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + self_outputs = self.self(hidden_states, attention_mask, head_mask, output_attentions) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class SwinIntermediate(nn.Module): + def __init__(self, config, dim): + super().__init__() + self.dense = nn.Linear(dim, int(config.mlp_ratio * dim)) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class SwinOutput(nn.Module): + def __init__(self, config, dim): + super().__init__() + self.dense = nn.Linear(int(config.mlp_ratio * dim), dim) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +class SwinLayer(nn.Module): + def __init__(self, config, dim, input_resolution, num_heads, drop_path_rate=0.0, shift_size=0): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.shift_size = shift_size + self.window_size = config.window_size + self.input_resolution = input_resolution + self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps) + self.attention = SwinAttention(config, dim, num_heads, window_size=self.window_size) + self.drop_path = SwinDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() + self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps) + self.intermediate = SwinIntermediate(config, dim) + self.output = SwinOutput(config, dim) + + def set_shift_and_window_size(self, input_resolution): + if min(input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = torch_int(0) + self.window_size = ( + torch.min(torch.tensor(input_resolution)) if torch.jit.is_tracing() else min(input_resolution) + ) + + def get_attn_mask(self, height, width, dtype, device): + if self.shift_size > 0: + # calculate attention mask for SW-MSA + img_mask = torch.zeros((1, height, width, 1), dtype=dtype, device=device) + height_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + width_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + count = 0 + for height_slice in height_slices: + for width_slice in width_slices: + img_mask[:, height_slice, width_slice, :] = count + count += 1 + + mask_windows = window_partition(img_mask, self.window_size) + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + else: + attn_mask = None + return attn_mask + + def maybe_pad(self, hidden_states, height, width): + pad_right = (self.window_size - width % self.window_size) % self.window_size + pad_bottom = (self.window_size - height % self.window_size) % self.window_size + pad_values = (0, 0, 0, pad_right, 0, pad_bottom) + hidden_states = nn.functional.pad(hidden_states, pad_values) + return hidden_states, pad_values + + def forward( + self, + hidden_states: torch.Tensor, + input_dimensions: Tuple[int, int], + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + always_partition: Optional[bool] = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if not always_partition: + self.set_shift_and_window_size(input_dimensions) + else: + pass + height, width = input_dimensions + batch_size, _, channels = hidden_states.size() + shortcut = hidden_states + + hidden_states = self.layernorm_before(hidden_states) + + hidden_states = hidden_states.view(batch_size, height, width, channels) + + # pad hidden_states to multiples of window size + hidden_states, pad_values = self.maybe_pad(hidden_states, height, width) + + _, height_pad, width_pad, _ = hidden_states.shape + # cyclic shift + if self.shift_size > 0: + shifted_hidden_states = torch.roll(hidden_states, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_hidden_states = hidden_states + + # partition windows + hidden_states_windows = window_partition(shifted_hidden_states, self.window_size) + hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels) + attn_mask = self.get_attn_mask( + height_pad, width_pad, dtype=hidden_states.dtype, device=hidden_states_windows.device + ) + + attention_outputs = self.attention( + hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions + ) + + attention_output = attention_outputs[0] + + attention_windows = attention_output.view(-1, self.window_size, self.window_size, channels) + shifted_windows = window_reverse(attention_windows, self.window_size, height_pad, width_pad) + + # reverse cyclic shift + if self.shift_size > 0: + attention_windows = torch.roll(shifted_windows, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + attention_windows = shifted_windows + + was_padded = pad_values[3] > 0 or pad_values[5] > 0 + if was_padded: + attention_windows = attention_windows[:, :height, :width, :].contiguous() + + attention_windows = attention_windows.view(batch_size, height * width, channels) + + hidden_states = shortcut + self.drop_path(attention_windows) + + layer_output = self.layernorm_after(hidden_states) + layer_output = self.intermediate(layer_output) + layer_output = hidden_states + self.output(layer_output) + + layer_outputs = (layer_output, attention_outputs[1]) if output_attentions else (layer_output,) + return layer_outputs + + +class SwinStage(nn.Module): + def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, downsample): + super().__init__() + self.config = config + self.dim = dim + self.blocks = nn.ModuleList( + [ + SwinLayer( + config=config, + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + drop_path_rate=drop_path[i], + shift_size=0 if (i % 2 == 0) else config.window_size // 2, + ) + for i in range(depth) + ] + ) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=nn.LayerNorm) + else: + self.downsample = None + + self.pointing = False + + def forward( + self, + hidden_states: torch.Tensor, + input_dimensions: Tuple[int, int], + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + always_partition: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + height, width = input_dimensions + for i, layer_module in enumerate(self.blocks): + layer_head_mask = head_mask[i] if head_mask is not None else None + + layer_outputs = layer_module( + hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition + ) + + hidden_states = layer_outputs[0] + + hidden_states_before_downsampling = hidden_states + if self.downsample is not None: + height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2 + output_dimensions = (height, width, height_downsampled, width_downsampled) + hidden_states = self.downsample(hidden_states_before_downsampling, input_dimensions) + else: + output_dimensions = (height, width, height, width) + + stage_outputs = (hidden_states, hidden_states_before_downsampling, output_dimensions) + + if output_attentions: + stage_outputs += layer_outputs[1:] + return stage_outputs + + +class SwinEncoder(nn.Module): + def __init__(self, config, grid_size): + super().__init__() + self.num_layers = len(config.depths) + self.config = config + dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")] + self.layers = nn.ModuleList( + [ + SwinStage( + config=config, + dim=int(config.embed_dim * 2**i_layer), + input_resolution=(grid_size[0] // (2**i_layer), grid_size[1] // (2**i_layer)), + depth=config.depths[i_layer], + num_heads=config.num_heads[i_layer], + drop_path=dpr[sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])], + downsample=SwinPatchMerging if (i_layer < self.num_layers - 1) else None, + ) + for i_layer in range(self.num_layers) + ] + ) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + input_dimensions: Tuple[int, int], + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + output_hidden_states_before_downsampling: Optional[bool] = False, + always_partition: Optional[bool] = False, + return_dict: Optional[bool] = True, + ) -> Union[Tuple, SwinEncoderOutput]: + all_hidden_states = () if output_hidden_states else None + all_reshaped_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + if output_hidden_states: + batch_size, _, hidden_size = hidden_states.shape + # rearrange b (h w) c -> b c h w + reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size) + reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2) + all_hidden_states += (hidden_states,) + all_reshaped_hidden_states += (reshaped_hidden_state,) + + for i, layer_module in enumerate(self.layers): + layer_head_mask = head_mask[i] if head_mask is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.__call__, + hidden_states, + input_dimensions, + layer_head_mask, + output_attentions, + always_partition, + ) + else: + layer_outputs = layer_module( + hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition + ) + + hidden_states = layer_outputs[0] + hidden_states_before_downsampling = layer_outputs[1] + output_dimensions = layer_outputs[2] + + input_dimensions = (output_dimensions[-2], output_dimensions[-1]) + + if output_hidden_states and output_hidden_states_before_downsampling: + batch_size, _, hidden_size = hidden_states_before_downsampling.shape + # rearrange b (h w) c -> b c h w + # here we use the original (not downsampled) height and width + reshaped_hidden_state = hidden_states_before_downsampling.view( + batch_size, *(output_dimensions[0], output_dimensions[1]), hidden_size + ) + reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2) + all_hidden_states += (hidden_states_before_downsampling,) + all_reshaped_hidden_states += (reshaped_hidden_state,) + elif output_hidden_states and not output_hidden_states_before_downsampling: + batch_size, _, hidden_size = hidden_states.shape + # rearrange b (h w) c -> b c h w + reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size) + reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2) + all_hidden_states += (hidden_states,) + all_reshaped_hidden_states += (reshaped_hidden_state,) + + if output_attentions: + all_self_attentions += layer_outputs[3:] + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + + return SwinEncoderOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + reshaped_hidden_states=all_reshaped_hidden_states, + ) + + +class SwinPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = SwinConfig + base_model_prefix = "swin" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + _no_split_modules = ["SwinStage"] + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, SwinEmbeddings): + if module.mask_token is not None: + module.mask_token.data.zero_() + if module.position_embeddings is not None: + module.position_embeddings.data.zero_() + elif isinstance(module, SwinSelfAttention): + module.relative_position_bias_table.data.zero_() + + +SWIN_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use + it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`SwinConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +SWIN_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`] + for details. + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): + Whether to interpolate the pre-trained position encodings. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare Swin Model transformer outputting raw hidden-states without any specific head on top.", + SWIN_START_DOCSTRING, + """ + add_pooling_layer (`bool`, *optional*, defaults to `True`): + Whether or not to apply pooling layer. + use_mask_token (`bool`, *optional*, defaults to `False`): + Whether or not to create and apply mask tokens in the embedding layer. + """, +) +class SwinModel(SwinPreTrainedModel): + def __init__(self, config, add_pooling_layer=True, use_mask_token=False): + super().__init__(config) + self.config = config + self.num_layers = len(config.depths) + self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1)) + + self.embeddings = SwinEmbeddings(config, use_mask_token=use_mask_token) + self.encoder = SwinEncoder(config, self.embeddings.patch_grid) + + self.layernorm = nn.LayerNorm(self.num_features, eps=config.layer_norm_eps) + self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.patch_embeddings + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=SwinModelOutput, + config_class=_CONFIG_FOR_DOC, + modality="vision", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SwinModelOutput]: + r""" + bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*): + Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, len(self.config.depths)) + + embedding_output, input_dimensions = self.embeddings( + pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding + ) + + encoder_outputs = self.encoder( + embedding_output, + input_dimensions, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + + pooled_output = None + if self.pooler is not None: + pooled_output = self.pooler(sequence_output.transpose(1, 2)) + pooled_output = torch.flatten(pooled_output, 1) + + if not return_dict: + output = (sequence_output, pooled_output) + encoder_outputs[1:] + + return output + + return SwinModelOutput( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + reshaped_hidden_states=encoder_outputs.reshaped_hidden_states, + ) + + +@add_start_docstrings( + """Swin Model with a decoder on top for masked image modeling, as proposed in [SimMIM](https://arxiv.org/abs/2111.09886). + + + + Note that we provide a script to pre-train this model on custom data in our [examples + directory](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-pretraining). + + + """, + SWIN_START_DOCSTRING, +) +class SwinForMaskedImageModeling(SwinPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.swin = SwinModel(config, add_pooling_layer=False, use_mask_token=True) + + num_features = int(config.embed_dim * 2 ** (config.num_layers - 1)) + self.decoder = nn.Sequential( + nn.Conv2d( + in_channels=num_features, out_channels=config.encoder_stride**2 * config.num_channels, kernel_size=1 + ), + nn.PixelShuffle(config.encoder_stride), + ) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=SwinMaskedImageModelingOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SwinMaskedImageModelingOutput]: + r""" + bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`): + Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). + + Returns: + + Examples: + ```python + >>> from transformers import AutoImageProcessor, SwinForMaskedImageModeling + >>> import torch + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/swin-base-simmim-window6-192") + >>> model = SwinForMaskedImageModeling.from_pretrained("microsoft/swin-base-simmim-window6-192") + + >>> num_patches = (model.config.image_size // model.config.patch_size) ** 2 + >>> pixel_values = image_processor(images=image, return_tensors="pt").pixel_values + >>> # create random boolean mask of shape (batch_size, num_patches) + >>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool() + + >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos) + >>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction + >>> list(reconstructed_pixel_values.shape) + [1, 3, 192, 192] + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.swin( + pixel_values, + bool_masked_pos=bool_masked_pos, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + # Reshape to (batch_size, num_channels, height, width) + sequence_output = sequence_output.transpose(1, 2) + batch_size, num_channels, sequence_length = sequence_output.shape + height = width = math.floor(sequence_length**0.5) + sequence_output = sequence_output.reshape(batch_size, num_channels, height, width) + + # Reconstruct pixel values + reconstructed_pixel_values = self.decoder(sequence_output) + + masked_im_loss = None + if bool_masked_pos is not None: + size = self.config.image_size // self.config.patch_size + bool_masked_pos = bool_masked_pos.reshape(-1, size, size) + mask = ( + bool_masked_pos.repeat_interleave(self.config.patch_size, 1) + .repeat_interleave(self.config.patch_size, 2) + .unsqueeze(1) + .contiguous() + ) + reconstruction_loss = nn.functional.l1_loss(pixel_values, reconstructed_pixel_values, reduction="none") + masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels + + if not return_dict: + output = (reconstructed_pixel_values,) + outputs[2:] + return ((masked_im_loss,) + output) if masked_im_loss is not None else output + + return SwinMaskedImageModelingOutput( + loss=masked_im_loss, + reconstruction=reconstructed_pixel_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + reshaped_hidden_states=outputs.reshaped_hidden_states, + ) + + +@add_start_docstrings( + """ + Swin Model transformer with an image classification head on top (a linear layer on top of the final hidden state of + the [CLS] token) e.g. for ImageNet. + + + + Note that it's possible to fine-tune Swin on higher resolution images than the ones it has been trained on, by + setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained + position embeddings to the higher resolution. + + + """, + SWIN_START_DOCSTRING, +) +class SwinForImageClassification(SwinPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.num_labels = config.num_labels + self.swin = SwinModel(config) + + # Classifier head + self.classifier = ( + nn.Linear(self.swin.num_features, config.num_labels) if config.num_labels > 0 else nn.Identity() + ) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_IMAGE_CLASS_CHECKPOINT, + output_type=SwinImageClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, + ) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SwinImageClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.swin( + pixel_values, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, pooled_logits=logits, config=self.config) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SwinImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + reshaped_hidden_states=outputs.reshaped_hidden_states, + ) + + +@add_start_docstrings( + """ + Swin backbone, to be used with frameworks like DETR and MaskFormer. + """, + SWIN_START_DOCSTRING, +) +class SwinBackbone(SwinPreTrainedModel, BackboneMixin): + def __init__(self, config: SwinConfig): + super().__init__(config) + super()._init_backbone(config) + + self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))] + self.embeddings = SwinEmbeddings(config) + self.encoder = SwinEncoder(config, self.embeddings.patch_grid) + + # Add layer norms to hidden states of out_features + hidden_states_norms = {} + for stage, num_channels in zip(self._out_features, self.channels): + hidden_states_norms[stage] = nn.LayerNorm(num_channels) + self.hidden_states_norms = nn.ModuleDict(hidden_states_norms) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.patch_embeddings + + def forward( + self, + pixel_values: torch.Tensor, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> BackboneOutput: + """ + Returns: + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, AutoBackbone + >>> import torch + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> processor = AutoImageProcessor.from_pretrained("shi-labs/nat-mini-in1k-224") + >>> model = AutoBackbone.from_pretrained( + ... "microsoft/swin-tiny-patch4-window7-224", out_features=["stage1", "stage2", "stage3", "stage4"] + ... ) + + >>> inputs = processor(image, return_tensors="pt") + >>> outputs = model(**inputs) + >>> feature_maps = outputs.feature_maps + >>> list(feature_maps[-1].shape) + [1, 768, 7, 7] + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + embedding_output, input_dimensions = self.embeddings(pixel_values) + + outputs = self.encoder( + embedding_output, + input_dimensions, + head_mask=None, + output_attentions=output_attentions, + output_hidden_states=True, + output_hidden_states_before_downsampling=True, + always_partition=True, + return_dict=True, + ) + + hidden_states = outputs.reshaped_hidden_states + + feature_maps = () + for stage, hidden_state in zip(self.stage_names, hidden_states): + if stage in self.out_features: + batch_size, num_channels, height, width = hidden_state.shape + hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous() + hidden_state = hidden_state.view(batch_size, height * width, num_channels) + hidden_state = self.hidden_states_norms[stage](hidden_state) + hidden_state = hidden_state.view(batch_size, height, width, num_channels) + hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous() + feature_maps += (hidden_state,) + + if not return_dict: + output = (feature_maps,) + if output_hidden_states: + output += (outputs.hidden_states,) + return output + + return BackboneOutput( + feature_maps=feature_maps, + hidden_states=outputs.hidden_states if output_hidden_states else None, + attentions=outputs.attentions, + ) + + +__all__ = [ + "SwinForImageClassification", + "SwinForMaskedImageModeling", + "SwinModel", + "SwinPreTrainedModel", + "SwinBackbone", +] diff --git a/docs/transformers/src/transformers/models/swin/modeling_tf_swin.py b/docs/transformers/src/transformers/models/swin/modeling_tf_swin.py new file mode 100644 index 0000000000000000000000000000000000000000..c1c4fb1620fd2a0feb4a352b1de12cb4e7fc6054 --- /dev/null +++ b/docs/transformers/src/transformers/models/swin/modeling_tf_swin.py @@ -0,0 +1,1638 @@ +# coding=utf-8 +# Copyright 2022 Microsoft Research and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TF 2.0 Swin Transformer model.""" + +from __future__ import annotations + +import collections.abc +import math +import warnings +from dataclasses import dataclass +from functools import partial +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union + +import tensorflow as tf + +from ...activations_tf import ACT2FN +from ...modeling_tf_utils import ( + TFPreTrainedModel, + TFSequenceClassificationLoss, + get_initializer, + keras, + keras_serializable, + unpack_inputs, +) +from ...tf_utils import shape_list +from ...utils import ( + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from .configuration_swin import SwinConfig + + +logger = logging.get_logger(__name__) + +# General docstring +_CONFIG_FOR_DOC = "SwinConfig" + +# Base docstring +_CHECKPOINT_FOR_DOC = "microsoft/swin-tiny-patch4-window7-224" +_EXPECTED_OUTPUT_SHAPE = [1, 49, 768] + +# Image classification docstring +_IMAGE_CLASS_CHECKPOINT = "microsoft/swin-tiny-patch4-window7-224" +_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat" + + +# drop_path, TFSwinPatchEmbeddings, TFSwinPatchMerging and TFSwinDropPath are tensorflow +# implementations of PyTorch functionalities in the timm library. + + +@dataclass +class TFSwinEncoderOutput(ModelOutput): + """ + Swin encoder's outputs, with potential hidden states and attentions. + + Args: + last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each stage) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + reshaped_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each stage) of shape + `(batch_size, hidden_size, height, width)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to + include the spatial dimensions. + """ + + last_hidden_state: Optional[tf.Tensor] = None + hidden_states: Tuple[tf.Tensor, ...] | None = None + attentions: Tuple[tf.Tensor, ...] | None = None + reshaped_hidden_states: Tuple[tf.Tensor, ...] | None = None + + +@dataclass +class TFSwinModelOutput(ModelOutput): + """ + Swin model's outputs that also contains a pooling of the last hidden states. + + Args: + last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (`tf.Tensor` of shape `(batch_size, hidden_size)`, *optional*, returned when `add_pooling_layer=True` is passed): + Average pooling of the last layer hidden-state. + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each stage) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + reshaped_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each stage) of shape + `(batch_size, hidden_size, height, width)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to + include the spatial dimensions. + """ + + last_hidden_state: Optional[tf.Tensor] = None + pooler_output: tf.Tensor | None = None + hidden_states: Tuple[tf.Tensor, ...] | None = None + attentions: Tuple[tf.Tensor, ...] | None = None + reshaped_hidden_states: Tuple[tf.Tensor, ...] | None = None + + +@dataclass +class TFSwinMaskedImageModelingOutput(ModelOutput): + """ + Swin masked image model outputs. + + Args: + loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided): + Masked image modeling (MLM) loss. + reconstruction (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`): + Reconstructed pixel values. + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each stage) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + reshaped_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each stage) of shape + `(batch_size, hidden_size, height, width)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to + include the spatial dimensions. + """ + + loss: tf.Tensor | None = None + reconstruction: Optional[tf.Tensor] = None + hidden_states: Tuple[tf.Tensor, ...] | None = None + attentions: Tuple[tf.Tensor, ...] | None = None + reshaped_hidden_states: Tuple[tf.Tensor, ...] | None = None + + @property + def logits(self): + warnings.warn( + "logits attribute is deprecated and will be removed in version 5 of Transformers." + " Please use the reconstruction attribute to retrieve the final output instead.", + FutureWarning, + ) + return self.reconstruction + + +@dataclass +class TFSwinImageClassifierOutput(ModelOutput): + """ + Swin outputs for image classification. + + Args: + loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each stage) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + reshaped_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each stage) of shape + `(batch_size, hidden_size, height, width)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to + include the spatial dimensions. + """ + + loss: tf.Tensor | None = None + logits: Optional[tf.Tensor] = None + hidden_states: Tuple[tf.Tensor, ...] | None = None + attentions: Tuple[tf.Tensor, ...] | None = None + reshaped_hidden_states: Tuple[tf.Tensor, ...] | None = None + + +def window_partition(input_feature: tf.Tensor, window_size: int) -> tf.Tensor: + """ + Partitions the given input into windows. + """ + batch_size, height, width, num_channels = shape_list(input_feature) + input_feature = tf.reshape( + input_feature, + (batch_size, height // window_size, window_size, width // window_size, window_size, num_channels), + ) + windows = tf.transpose(input_feature, (0, 1, 3, 2, 4, 5)) + windows = tf.reshape(windows, (-1, window_size, window_size, num_channels)) + return windows + + +def window_reverse(windows: tf.Tensor, window_size: int, height: int, width: int) -> tf.Tensor: + """ + Merges windows to produce higher resolution features. + """ + x = tf.shape(windows)[0] + y = tf.cast(height * width / (window_size * window_size), tf.int32) + batch_size = tf.math.floordiv(x, y) + windows = tf.reshape( + windows, (batch_size, height // window_size, width // window_size, window_size, window_size, -1) + ) + windows = tf.transpose(windows, (0, 1, 3, 2, 4, 5)) + windows = tf.reshape(windows, (batch_size, height, width, -1)) + return windows + + +def drop_path( + input: tf.Tensor, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True +) -> tf.Tensor: + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + if drop_prob == 0.0 or not training: + return input + keep_prob = 1 - drop_prob + input_shape = shape_list(input) + ndim = len(input_shape) + shape = [input_shape[0]] + [1] * (ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = tf.random.uniform(shape) + random_tensor = tf.where(random_tensor <= keep_prob, 1.0, 0.0) + if keep_prob > 0.0 and scale_by_keep: + random_tensor /= keep_prob + return input * random_tensor + + +class TFSwinEmbeddings(keras.layers.Layer): + """ + Construct the patch and position embeddings. Optionally, also the mask token. + """ + + def __init__(self, config: SwinConfig, use_mask_token: bool = False, **kwargs) -> None: + super().__init__(**kwargs) + self.patch_embeddings = TFSwinPatchEmbeddings(config, name="patch_embeddings") + self.num_patches = self.patch_embeddings.num_patches + self.patch_grid = self.patch_embeddings.grid_size + self.embed_dim = config.embed_dim + self.use_mask_token = use_mask_token + self.use_absolute_embeddings = config.use_absolute_embeddings + + self.norm = keras.layers.LayerNormalization(name="norm", epsilon=1e-5) + self.dropout = keras.layers.Dropout(config.hidden_dropout_prob, name="dropout") + self.config = config + + def build(self, input_shape: tf.TensorShape) -> None: + if self.use_mask_token: + self.mask_token = self.add_weight(shape=(1, 1, self.embed_dim), initializer="zeros", name="mask_token") + else: + self.mask_token = None + + if self.use_absolute_embeddings: + self.position_embeddings = self.add_weight( + (1, self.num_patches + 1, self.embed_dim), initializer="zeros", name="positional_embeddings" + ) + else: + self.position_embeddings = None + + if self.built: + return + self.built = True + if getattr(self, "patch_embeddings", None) is not None: + with tf.name_scope(self.patch_embeddings.name): + self.patch_embeddings.build(None) + if getattr(self, "norm", None) is not None: + with tf.name_scope(self.norm.name): + self.norm.build([None, None, self.config.embed_dim]) + if getattr(self, "dropout", None) is not None: + with tf.name_scope(self.dropout.name): + self.dropout.build(None) + + def call( + self, pixel_values: tf.Tensor, bool_masked_pos: Optional[bool] = None, training: bool = False + ) -> Tuple[tf.Tensor, Tuple[int, int]]: + embeddings, output_dimensions = self.patch_embeddings(pixel_values, training=training) + embeddings = self.norm(embeddings, training=training) + batch_size, seq_len, _ = shape_list(embeddings) + + if bool_masked_pos is not None: + mask_tokens = tf.repeat(self.mask_token, batch_size, 0) + mask_tokens = tf.repeat(mask_tokens, seq_len, 1) + # replace the masked visual tokens by mask_tokens + mask = tf.expand_dims(bool_masked_pos, -1) + mask = tf.cast(mask, mask_tokens.dtype) + + embeddings = embeddings * (1.0 - mask) + mask_tokens * mask + + if self.position_embeddings is not None: + embeddings = embeddings + self.position_embeddings + + embeddings = self.dropout(embeddings, training=training) + + return embeddings, output_dimensions + + +class TFSwinPatchEmbeddings(keras.layers.Layer): + """ + Image to Patch Embedding. + """ + + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.embed_dim + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + self.grid_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1]) + + self.projection = keras.layers.Conv2D( + filters=hidden_size, + kernel_size=self.patch_size, + strides=self.patch_size, + padding="valid", + name="projection", + ) + + def maybe_pad(self, pixel_values: tf.Tensor, height: int, width: int) -> tf.Tensor: + if width % self.patch_size[1] != 0: + pad_values = ((0, 0), (0, 0), (0, 0), (0, self.patch_size[1] - width % self.patch_size[1])) + pixel_values = tf.pad(pixel_values, pad_values) + if height % self.patch_size[0] != 0: + pad_values = ((0, 0), (0, 0), (0, self.patch_size[0] - height % self.patch_size[0]), (0, 0)) + pixel_values = tf.pad(pixel_values, pad_values) + return pixel_values + + def call(self, pixel_values: tf.Tensor, training: bool = False) -> Tuple[tf.Tensor, Tuple[int, int]]: + _, num_channels, height, width = shape_list(pixel_values) + if tf.executing_eagerly() and num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + # pad the input to be divisible by self.patch_size, if needed + pixel_values = self.maybe_pad(pixel_values, height, width) + + # B,C,H,W -> B,H,W,C + pixel_values = tf.transpose(pixel_values, (0, 2, 3, 1)) + + embeddings = self.projection(pixel_values, training=training) + + # B,H,W,C -> B,C,H,W + embeddings = tf.transpose(embeddings, (0, 3, 1, 2)) + + batch_size, channels, height, width = shape_list(embeddings) + output_dimensions = (height, width) + + embeddings = tf.reshape(embeddings, (batch_size, channels, -1)) + embeddings = tf.transpose(embeddings, (0, 2, 1)) + return embeddings, output_dimensions + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "projection", None) is not None: + with tf.name_scope(self.projection.name): + self.projection.build([None, None, None, self.num_channels]) + + +class TFSwinPatchMerging(keras.layers.Layer): + """ + Patch Merging Layer. + + Args: + input_resolution (`Tuple[int]`): + Resolution of input feature. + dim (`int`): + Number of input channels. + norm_layer (`keras.layer.Layer`, *optional*, defaults to `keras.layers.LayerNormalization`): + Normalization layer class. + """ + + def __init__( + self, input_resolution: Tuple[int, int], dim: int, norm_layer: Optional[Callable] = None, **kwargs + ) -> None: + super().__init__(**kwargs) + self.input_resolution = input_resolution + self.dim = dim + self.reduction = keras.layers.Dense(2 * dim, use_bias=False, name="reduction") + if norm_layer is None: + # Use same default epsilon as PyTorch + self.norm = keras.layers.LayerNormalization(epsilon=1e-5, name="norm") + else: + self.norm = norm_layer(name="norm") + + def maybe_pad(self, input_feature: tf.Tensor, height: int, width: int) -> tf.Tensor: + should_pad = (height % 2 == 1) or (width % 2 == 1) + if should_pad: + pad_values = ((0, 0), (0, height % 2), (0, width % 2), (0, 0)) + input_feature = tf.pad(input_feature, pad_values) + + return input_feature + + def call(self, input_feature: tf.Tensor, input_dimensions: Tuple[int, int], training: bool = False) -> tf.Tensor: + height, width = input_dimensions + # `dim` is height * width + batch_size, _, num_channels = shape_list(input_feature) + + input_feature = tf.reshape(input_feature, (batch_size, height, width, num_channels)) + # pad input to be disible by width and height, if needed + input_feature = self.maybe_pad(input_feature, height, width) + # [batch_size, height/2, width/2, num_channels] + input_feature_0 = input_feature[:, 0::2, 0::2, :] + # [batch_size, height/2, width/2, num_channels] + input_feature_1 = input_feature[:, 1::2, 0::2, :] + # [batch_size, height/2, width/2, num_channels] + input_feature_2 = input_feature[:, 0::2, 1::2, :] + # [batch_size, height/2, width/2, num_channels] + input_feature_3 = input_feature[:, 1::2, 1::2, :] + # batch_size height/2 width/2 4*num_channels + input_feature = tf.concat([input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1) + input_feature = tf.reshape( + input_feature, (batch_size, -1, 4 * num_channels) + ) # batch_size height/2*width/2 4*C + + input_feature = self.norm(input_feature, training=training) + input_feature = self.reduction(input_feature, training=training) + + return input_feature + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "reduction", None) is not None: + with tf.name_scope(self.reduction.name): + self.reduction.build([None, None, 4 * self.dim]) + if getattr(self, "norm", None) is not None: + with tf.name_scope(self.norm.name): + self.norm.build([None, None, 4 * self.dim]) + + +class TFSwinDropPath(keras.layers.Layer): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: Optional[float] = None, scale_by_keep: bool = True, **kwargs) -> None: + super(TFSwinDropPath, self).__init__(**kwargs) + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def call(self, input: tf.Tensor, training: bool = False) -> tf.Tensor: + return drop_path(input, self.drop_prob, training, self.scale_by_keep) + + +class TFSwinSelfAttention(keras.layers.Layer): + def __init__(self, config: SwinConfig, dim: int, num_heads: int, **kwargs) -> None: + super().__init__(**kwargs) + if dim % num_heads != 0: + raise ValueError( + f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})" + ) + + self.num_attention_heads = num_heads + self.attention_head_size = int(dim / num_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + window_size = config.window_size + self.window_size = ( + window_size if isinstance(window_size, collections.abc.Iterable) else (window_size, window_size) + ) + + self.query = keras.layers.Dense( + self.all_head_size, + kernel_initializer=get_initializer(config.initializer_range), + use_bias=config.qkv_bias, + name="query", + ) + self.key = keras.layers.Dense( + self.all_head_size, + kernel_initializer=get_initializer(config.initializer_range), + use_bias=config.qkv_bias, + name="key", + ) + self.value = keras.layers.Dense( + self.all_head_size, + kernel_initializer=get_initializer(config.initializer_range), + use_bias=config.qkv_bias, + name="value", + ) + + self.dropout = keras.layers.Dropout(config.attention_probs_dropout_prob) + + def build(self, input_shape: tf.TensorShape) -> None: + self.relative_position_bias_table = self.add_weight( + shape=(((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1)), self.num_attention_heads), + initializer="zeros", + name="relative_position_bias_table", + ) + self.relative_position_index = self.add_weight( + shape=(self.window_size[0] ** 2, self.window_size[1] ** 2), + trainable=False, + dtype=tf.int32, + name="relative_position_index", + ) + + # get pair-wise relative position index for each token inside the window + coords_h = tf.range(self.window_size[0]) + coords_w = tf.range(self.window_size[1]) + coords = tf.stack(tf.meshgrid(coords_h, coords_w, indexing="ij")) + coords_flatten = tf.reshape(coords, (shape_list(coords)[0], -1)) + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] + relative_coords = tf.transpose(relative_coords, (1, 2, 0)) + + stack_0, stack_1 = tf.unstack(relative_coords, axis=2) + stack_0 += self.window_size[0] - 1 + stack_0 *= 2 * self.window_size[1] - 1 + stack_1 += self.window_size[1] - 1 + relative_coords = tf.stack([stack_0, stack_1], axis=2) + + self.relative_position_index.assign(tf.cast(tf.reduce_sum(relative_coords, axis=-1), tf.int32)) + + if self.built: + return + self.built = True + if getattr(self, "query", None) is not None: + with tf.name_scope(self.query.name): + self.query.build([None, None, self.all_head_size]) + if getattr(self, "key", None) is not None: + with tf.name_scope(self.key.name): + self.key.build([None, None, self.all_head_size]) + if getattr(self, "value", None) is not None: + with tf.name_scope(self.value.name): + self.value.build([None, None, self.all_head_size]) + + def transpose_for_scores(self, x: tf.Tensor) -> tf.Tensor: + new_x_shape = shape_list(x)[:-1] + [self.num_attention_heads, self.attention_head_size] + x = tf.reshape(x, new_x_shape) + return tf.transpose(x, (0, 2, 1, 3)) + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: tf.Tensor | None = None, + head_mask: tf.Tensor | None = None, + output_attentions: bool = False, + training: bool = False, + ) -> Tuple[tf.Tensor, ...]: + batch_size, dim, _ = shape_list(hidden_states) + mixed_query_layer = self.query(hidden_states) + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = tf.matmul(query_layer, tf.transpose(key_layer, (0, 1, 3, 2))) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + relative_position_bias = tf.gather( + self.relative_position_bias_table, tf.reshape(self.relative_position_index, (-1,)) + ) + relative_position_bias = tf.reshape( + relative_position_bias, + (self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1), + ) + + relative_position_bias = tf.transpose(relative_position_bias, (2, 0, 1)) + attention_scores = attention_scores + tf.expand_dims(relative_position_bias, 0) + + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in SwinModel call() function) + mask_shape = shape_list(attention_mask)[0] + attention_scores = tf.reshape( + attention_scores, (batch_size // mask_shape, mask_shape, self.num_attention_heads, dim, dim) + ) + attention_mask = tf.expand_dims(attention_mask, 1) + attention_mask = tf.expand_dims(attention_mask, 0) + attention_scores = attention_scores + attention_mask + attention_scores = tf.reshape(attention_scores, (-1, self.num_attention_heads, dim, dim)) + + # Normalize the attention scores to probabilities. + attention_probs = tf.nn.softmax(attention_scores, axis=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs, training=training) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = tf.matmul(attention_probs, value_layer) + context_layer = tf.transpose(context_layer, (0, 2, 1, 3)) + new_context_layer_shape = shape_list(context_layer)[:-2] + [ + self.all_head_size, + ] + context_layer = tf.reshape(context_layer, new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +class TFSwinSelfOutput(keras.layers.Layer): + def __init__(self, config: SwinConfig, dim: int, **kwargs) -> None: + super().__init__(**kwargs) + self.dense = keras.layers.Dense(dim, name="dense") + self.dropout = keras.layers.Dropout(config.attention_probs_dropout_prob, name="dropout") + self.dim = dim + + def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states, training=training) + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.dim]) + if getattr(self, "dropout", None) is not None: + with tf.name_scope(self.dropout.name): + self.dropout.build(None) + + +class TFSwinAttention(keras.layers.Layer): + def __init__(self, config: SwinConfig, dim: int, num_heads: int, **kwargs) -> None: + super().__init__(**kwargs) + self.self = TFSwinSelfAttention(config, dim, num_heads, name="self") + self.self_output = TFSwinSelfOutput(config, dim, name="output") + self.pruned_heads = set() + + def prune_heads(self, heads): + """ + Prunes heads of the model. See base class PreTrainedModel heads: dict of {layer_num: list of heads to prune in + this layer} + """ + raise NotImplementedError + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: tf.Tensor | None = None, + head_mask: tf.Tensor | None = None, + output_attentions: bool = False, + training: bool = False, + ) -> tf.Tensor: + self_outputs = self.self(hidden_states, attention_mask, head_mask, output_attentions, training=training) + attention_output = self.self_output(self_outputs[0], hidden_states, training=training) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "self", None) is not None: + with tf.name_scope(self.self.name): + self.self.build(None) + if getattr(self, "self_output", None) is not None: + with tf.name_scope(self.self_output.name): + self.self_output.build(None) + + +class TFSwinIntermediate(keras.layers.Layer): + def __init__(self, config: SwinConfig, dim: int, **kwargs) -> None: + super().__init__(**kwargs) + self.dense = keras.layers.Dense(int(config.mlp_ratio * dim), name="dense") + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + self.dim = dim + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.dim]) + + +class TFSwinOutput(keras.layers.Layer): + def __init__(self, config: SwinConfig, dim: int, **kwargs) -> None: + super().__init__(**kwargs) + self.dense = keras.layers.Dense(dim, name="dense") + self.dropout = keras.layers.Dropout(config.hidden_dropout_prob, "dropout") + self.config = config + self.dim = dim + + def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states, training=training) + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, int(self.config.mlp_ratio * self.dim)]) + + +class TFSwinLayer(keras.layers.Layer): + def __init__( + self, + config, + dim, + input_resolution: Tuple[int, int], + num_heads: int, + drop_path_rate: float = 0.0, + shift_size: int = 0, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.chunk_size_feed_forward = config.chunk_size_feed_forward + min_res = tf.reduce_min(input_resolution) + self.window_size = min_res if min_res <= config.window_size else config.window_size + self.shift_size = 0 if min_res <= self.window_size else shift_size + self.input_resolution = input_resolution + + self.layernorm_before = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_before") + self.attention = TFSwinAttention(config, dim, num_heads, name="attention") + self.drop_path = ( + TFSwinDropPath(drop_path_rate, name="drop_path") + if drop_path_rate > 0.0 + else keras.layers.Activation("linear", name="drop_path") + ) + self.layernorm_after = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_after") + self.intermediate = TFSwinIntermediate(config, dim, name="intermediate") + self.swin_output = TFSwinOutput(config, dim, name="output") + self.dim = dim + + def get_attn_mask(self, height: int, width: int, window_size: int, shift_size: int) -> tf.Tensor | None: + img_mask = tf.zeros((height, width)) + height_slices = ((0, -window_size), (-window_size, -shift_size), (-shift_size, -1)) + width_slices = ((0, -window_size), (-window_size, -shift_size), (-shift_size, -1)) + + # calculate attention mask for SW-MSA + if shift_size > 0: + count = 0 + for height_slice in height_slices: + for width_slice in width_slices: + height_inds = tf.range(height_slice[0] % height, height_slice[1] % height + 1) + width_inds = tf.range(width_slice[0] % width, width_slice[1] % width + 1) + indices = tf.reshape(tf.stack(tf.meshgrid(height_inds, width_inds), axis=-1), (-1, 2)) + if len(indices) >= 1: + updates = tf.ones((len(indices),), dtype=img_mask.dtype) * count + img_mask = tf.tensor_scatter_nd_update(img_mask, indices, updates) + count += 1 + + img_mask = tf.expand_dims(img_mask, -1) + img_mask = tf.expand_dims(img_mask, 0) + + mask_windows = window_partition(img_mask, window_size) + mask_windows = tf.reshape(mask_windows, (-1, window_size * window_size)) + attn_mask = tf.expand_dims(mask_windows, 1) - tf.expand_dims(mask_windows, 2) + attn_mask = tf.where(attn_mask != 0, float(-100.0), attn_mask) + attn_mask = tf.where(attn_mask == 0, float(0.0), attn_mask) + return attn_mask + + def maybe_pad( + self, hidden_states: tf.Tensor, window_size: int, height: int, width: int + ) -> Tuple[tf.Tensor, tf.Tensor]: + pad_right = (window_size - width % window_size) % window_size + pad_bottom = (window_size - height % window_size) % window_size + pad_values = [[0, 0], [0, pad_bottom], [0, pad_right], [0, 0]] + hidden_states = tf.pad(hidden_states, pad_values) + pad_values = tf.reshape(pad_values, (-1,)) + return hidden_states, pad_values + + def call( + self, + hidden_states: tf.Tensor, + input_dimensions: Tuple[int, int], + head_mask: tf.Tensor | None = None, + output_attentions: bool = False, + training: bool = False, + ) -> tf.Tensor: + # if window size is larger than input resolution, we don't partition windows + min_res = tf.reduce_min(input_dimensions) + shift_size = 0 if min_res <= self.window_size else self.shift_size + window_size = min_res if min_res <= self.window_size else self.window_size + + height, width = input_dimensions + batch_size, _, channels = shape_list(hidden_states) + shortcut = hidden_states + + hidden_states = self.layernorm_before(hidden_states, training=training) + hidden_states = tf.reshape(hidden_states, (batch_size, height, width, channels)) + # pad hidden_states to multiples of window size + hidden_states, pad_values = self.maybe_pad(hidden_states, window_size, height, width) + + _, height_pad, width_pad, _ = shape_list(hidden_states) + # cyclic shift + if shift_size > 0: + shifted_hidden_states = tf.roll(hidden_states, shift=(-shift_size, -shift_size), axis=(1, 2)) + else: + shifted_hidden_states = hidden_states + + # partition windows + hidden_states_windows = window_partition(shifted_hidden_states, window_size) + hidden_states_windows = tf.reshape(hidden_states_windows, (-1, window_size * window_size, channels)) + attn_mask = self.get_attn_mask( + height=height_pad, width=width_pad, window_size=window_size, shift_size=shift_size + ) + + attention_outputs = self.attention( + hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions, training=training + ) + + attention_output = attention_outputs[0] + + attention_windows = tf.reshape(attention_output, (-1, window_size, window_size, channels)) + shifted_windows = window_reverse(attention_windows, window_size, height_pad, width_pad) + + # reverse cyclic shift + if shift_size > 0: + attention_windows = tf.roll(shifted_windows, shift=(shift_size, shift_size), axis=(1, 2)) + else: + attention_windows = shifted_windows + + was_padded = pad_values[3] > 0 or pad_values[5] > 0 + if was_padded: + attention_windows = attention_windows[:, :height, :width, :] + + attention_windows = tf.reshape(attention_windows, (batch_size, height * width, channels)) + + hidden_states = shortcut + self.drop_path(attention_windows, training=training) + + layer_output = self.layernorm_after(hidden_states, training=training) + layer_output = self.intermediate(layer_output) + layer_output = hidden_states + self.swin_output(layer_output, training=training) + + layer_outputs = (layer_output, attention_outputs[1]) if output_attentions else (layer_output,) + return layer_outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "layernorm_before", None) is not None: + with tf.name_scope(self.layernorm_before.name): + self.layernorm_before.build([None, None, self.dim]) + if getattr(self, "attention", None) is not None: + with tf.name_scope(self.attention.name): + self.attention.build(None) + if getattr(self, "drop_path", None) is not None: + with tf.name_scope(self.drop_path.name): + self.drop_path.build(None) + if getattr(self, "layernorm_after", None) is not None: + with tf.name_scope(self.layernorm_after.name): + self.layernorm_after.build([None, None, self.dim]) + if getattr(self, "intermediate", None) is not None: + with tf.name_scope(self.intermediate.name): + self.intermediate.build(None) + if getattr(self, "swin_output", None) is not None: + with tf.name_scope(self.swin_output.name): + self.swin_output.build(None) + + +class TFSwinStage(keras.layers.Layer): + def __init__( + self, + config: SwinConfig, + dim: int, + input_resolution: Tuple[int, int], + depth: int, + num_heads: int, + drop_path: List[float], + downsample: Optional[Callable], + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.config = config + self.dim = dim + self.blocks = [ + TFSwinLayer( + config=config, + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + shift_size=0 if (i % 2 == 0) else config.window_size // 2, + drop_path_rate=drop_path[i], + name=f"blocks.{i}", + ) + for i in range(depth) + ] + + # patch merging layer + if downsample is not None: + self.downsample = downsample( + input_resolution, + dim=dim, + norm_layer=partial(keras.layers.LayerNormalization, epsilon=1e-5), + name="downsample", + ) + else: + self.downsample = None + + self.pointing = False + + def call( + self, + hidden_states: tf.Tensor, + input_dimensions: Tuple[int, int], + head_mask: tf.Tensor | None = None, + output_attentions: Optional[bool] = False, + training: bool = False, + ) -> Tuple[tf.Tensor, ...]: + height, width = input_dimensions + for i, layer_module in enumerate(self.blocks): + layer_head_mask = head_mask[i] if head_mask is not None else None + + layer_outputs = layer_module( + hidden_states, input_dimensions, layer_head_mask, output_attentions, training=training + ) + + hidden_states = layer_outputs[0] + + if self.downsample is not None: + height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2 + output_dimensions = (height, width, height_downsampled, width_downsampled) + hidden_states = self.downsample(layer_outputs[0], input_dimensions, training=training) + else: + output_dimensions = (height, width, height, width) + + stage_outputs = (hidden_states, output_dimensions) + + if output_attentions: + stage_outputs += layer_outputs[1:] + return stage_outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "downsample", None) is not None: + with tf.name_scope(self.downsample.name): + self.downsample.build(None) + if getattr(self, "blocks", None) is not None: + for layer in self.blocks: + with tf.name_scope(layer.name): + layer.build(None) + + +class TFSwinEncoder(keras.layers.Layer): + def __init__(self, config: SwinConfig, grid_size: Tuple[int, int], **kwargs): + super().__init__(**kwargs) + self.num_layers = len(config.depths) + self.config = config + dpr = list((tf.linspace(0, 1, sum(config.depths)) * config.drop_path_rate).numpy()) + self.layers = [ + TFSwinStage( + config=config, + dim=int(config.embed_dim * 2**i_layer), + input_resolution=(grid_size[0] // (2**i_layer), grid_size[1] // (2**i_layer)), + depth=config.depths[i_layer], + num_heads=config.num_heads[i_layer], + drop_path=dpr[sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])], + downsample=TFSwinPatchMerging if (i_layer < self.num_layers - 1) else None, + name=f"layers.{i_layer}", + ) + for i_layer in range(self.num_layers) + ] + + self.gradient_checkpointing = False + + def call( + self, + hidden_states: tf.Tensor, + input_dimensions: Tuple[int, int], + head_mask: tf.Tensor | None = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + training: bool = False, + ) -> Union[Tuple[tf.Tensor, ...], TFSwinEncoderOutput]: + all_input_dimensions = () + all_hidden_states = () if output_hidden_states else None + all_reshaped_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + if output_hidden_states: + batch_size, _, hidden_size = shape_list(hidden_states) + # rearrange b (h w) c -> b c h w + reshaped_hidden_state = tf.reshape(hidden_states, (batch_size, *input_dimensions, hidden_size)) + reshaped_hidden_state = tf.transpose(reshaped_hidden_state, (0, 3, 1, 2)) + all_hidden_states += (hidden_states,) + all_reshaped_hidden_states += (reshaped_hidden_state,) + + for i, layer_module in enumerate(self.layers): + layer_head_mask = head_mask[i] if head_mask is not None else None + + layer_outputs = layer_module( + hidden_states, input_dimensions, layer_head_mask, output_attentions, training=training + ) + + hidden_states = layer_outputs[0] + output_dimensions = layer_outputs[1] + + input_dimensions = (output_dimensions[-2], output_dimensions[-1]) + all_input_dimensions += (input_dimensions,) + + if output_hidden_states: + batch_size, _, hidden_size = shape_list(hidden_states) + # rearrange b (h w) c -> b c h w + reshaped_hidden_state = tf.reshape(hidden_states, (batch_size, *input_dimensions, hidden_size)) + reshaped_hidden_state = tf.transpose(reshaped_hidden_state, (0, 3, 1, 2)) + all_hidden_states += (hidden_states,) + all_reshaped_hidden_states += (reshaped_hidden_state,) + + if output_attentions: + all_self_attentions += layer_outputs[2:] + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + + return TFSwinEncoderOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + reshaped_hidden_states=all_reshaped_hidden_states, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "layers", None) is not None: + for layer in self.layers: + with tf.name_scope(layer.name): + layer.build(None) + + +class TFSwinPreTrainedModel(TFPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = SwinConfig + base_model_prefix = "swin" + main_input_name = "pixel_values" + + +SWIN_START_DOCSTRING = r""" + This model is a Tensorflow + [keras.layers.Layer](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer) sub-class. Use it as a + regular Tensorflow Module and refer to the Tensorflow documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`SwinConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +SWIN_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`] + for details. + head_mask (`tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +def normalize_data_format(value: str) -> str: + """ + From tensorflow addons + https://github.com/tensorflow/addons/blob/8cec33fcaaf1cf90aec7bdd55a0fcdbb251ce5c2/tensorflow_addons/utils/keras_utils.py#L71 + """ + if value is None: + value = keras.backend.image_data_format() + data_format = value.lower() + if data_format not in {"channels_first", "channels_last"}: + raise ValueError( + 'The `data_format` argument must be one of "channels_first", "channels_last". Received: ' + str(value) + ) + return data_format + + +class AdaptiveAveragePooling1D(keras.layers.Layer): + """ + Args: + Average 1D Pooling with adaptive kernel size. + output_size: An integer or tuple/list of a single integer, specifying pooled_features. + The new size of output channels. + data_format: A string, + one of `channels_last` (default) or `channels_first`. The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape `(batch, steps, channels)` while `channels_first` corresponds + to inputs with shape `(batch, channels, steps)`. + Input shape: + - If `data_format='channels_last'`: 3D tensor with shape `(batch, steps, channels)`. + - If `data_format='channels_first'`: 3D tensor with shape `(batch, channels, steps)`. + Output shape: + - If `data_format='channels_last'`: 3D tensor with shape `(batch_size, pooled_steps, channels)`. + - If `data_format='channels_first'`: 3D tensor with shape `(batch_size, channels, pooled_steps)`. + + Adapted from [tensorflow-addon's adaptive pooling.py]( + https://github.com/tensorflow/addons/blob/8cec33fcaaf1cf90aec7bdd55a0fcdbb251ce5c2/tensorflow_addons/layers/adaptive_pooling.py#L90-L120 + ) + """ + + def __init__( + self, + output_size: Union[int, Iterable[int]], + reduce_function: Callable = tf.reduce_mean, + data_format: Optional[str] = None, + **kwargs, + ) -> None: + self.data_format = normalize_data_format(data_format) + self.reduce_function = reduce_function + self.output_size = (output_size,) if isinstance(output_size, int) else tuple(output_size) + super().__init__(**kwargs) + + def call(self, inputs: tf.Tensor, *args) -> None: + bins = self.output_size[0] + if self.data_format == "channels_last": + splits = tf.split(inputs, bins, axis=1) + splits = tf.stack(splits, axis=1) + out_vect = self.reduce_function(splits, axis=2) + else: + splits = tf.split(inputs, bins, axis=2) + splits = tf.stack(splits, axis=2) + out_vect = self.reduce_function(splits, axis=3) + return out_vect + + def compute_output_shape(self, input_shape: Iterable[int]) -> tf.TensorShape: + input_shape = tf.TensorShape(input_shape).as_list() + if self.data_format == "channels_last": + shape = tf.TensorShape([input_shape[0], self.output_size[0], input_shape[2]]) + else: + shape = tf.TensorShape([input_shape[0], input_shape[1], self.output_size[0]]) + return shape + + def get_config(self) -> Dict[str, Any]: + config = { + "output_size": self.output_size, + "data_format": self.data_format, + } + base_config = super().get_config() + return {**base_config, **config} + + +@keras_serializable +class TFSwinMainLayer(keras.layers.Layer): + config_class = SwinConfig + + def __init__( + self, config: SwinConfig, add_pooling_layer: bool = True, use_mask_token: bool = False, **kwargs + ) -> None: + super().__init__(**kwargs) + self.config = config + self.num_layers = len(config.depths) + self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1)) + + self.embeddings = TFSwinEmbeddings(config, use_mask_token=use_mask_token, name="embeddings") + self.encoder = TFSwinEncoder(config, self.embeddings.patch_grid, name="encoder") + + self.layernorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm") + self.pooler = AdaptiveAveragePooling1D(output_size=(1,)) if add_pooling_layer else None + + def get_input_embeddings(self) -> TFSwinPatchEmbeddings: + return self.embeddings.patch_embeddings + + def _prune_heads(self, heads_to_prune: Dict[int, List]): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def get_head_mask(self, head_mask: Optional[Any]) -> List: + if head_mask is not None: + raise NotImplementedError + return [None] * len(self.config.depths) + + @unpack_inputs + def call( + self, + pixel_values: tf.Tensor | None = None, + bool_masked_pos: tf.Tensor | None = None, + head_mask: tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: bool = False, + ) -> Union[TFSwinModelOutput, Tuple[tf.Tensor, ...]]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask) + embedding_output, input_dimensions = self.embeddings( + pixel_values, bool_masked_pos=bool_masked_pos, training=training + ) + + encoder_outputs = self.encoder( + embedding_output, + input_dimensions, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output, training=training) + + pooled_output = None + if self.pooler is not None: + batch_size, _, num_features = shape_list(sequence_output) + pooled_output = self.pooler(sequence_output) + pooled_output = tf.reshape(pooled_output, (batch_size, num_features)) + + if not return_dict: + output = (sequence_output, pooled_output) + encoder_outputs[1:] + return output + + return TFSwinModelOutput( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + reshaped_hidden_states=encoder_outputs.reshaped_hidden_states, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "embeddings", None) is not None: + with tf.name_scope(self.embeddings.name): + self.embeddings.build(None) + if getattr(self, "encoder", None) is not None: + with tf.name_scope(self.encoder.name): + self.encoder.build(None) + if getattr(self, "layernorm", None) is not None: + with tf.name_scope(self.layernorm.name): + self.layernorm.build([None, None, self.num_features]) + + +@add_start_docstrings( + "The bare Swin Model transformer outputting raw hidden-states without any specific head on top.", + SWIN_START_DOCSTRING, +) +class TFSwinModel(TFSwinPreTrainedModel): + def __init__( + self, config: SwinConfig, add_pooling_layer: bool = True, use_mask_token: bool = False, **kwargs + ) -> None: + super().__init__(config, **kwargs) + self.config = config + self.swin = TFSwinMainLayer(config, name="swin") + + @add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFSwinModelOutput, + config_class=_CONFIG_FOR_DOC, + modality="vision", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + @unpack_inputs + def call( + self, + pixel_values: tf.Tensor | None = None, + bool_masked_pos: tf.Tensor | None = None, + head_mask: tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: bool = False, + ) -> Union[TFSwinModelOutput, Tuple[tf.Tensor, ...]]: + r""" + bool_masked_pos (`tf.Tensor` of shape `(batch_size, num_patches)`, *optional*): + Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + swin_outputs = self.swin( + pixel_values=pixel_values, + bool_masked_pos=bool_masked_pos, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + return swin_outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "swin", None) is not None: + with tf.name_scope(self.swin.name): + self.swin.build(None) + + +class TFSwinPixelShuffle(keras.layers.Layer): + """TF layer implementation of torch.nn.PixelShuffle""" + + def __init__(self, upscale_factor: int, **kwargs) -> None: + super().__init__(**kwargs) + if not isinstance(upscale_factor, int) or upscale_factor < 2: + raise ValueError(f"upscale_factor must be an integer value >= 2 got {upscale_factor}") + self.upscale_factor = upscale_factor + + def call(self, x: tf.Tensor) -> tf.Tensor: + hidden_states = x + batch_size, _, _, num_input_channels = shape_list(hidden_states) + block_size_squared = self.upscale_factor**2 + output_depth = int(num_input_channels / block_size_squared) + # When the number of output channels >= 2, PyTorch's PixelShuffle and + # TF's depth_to_space differ in their output as the order of channels selected for combining + # is a permutation of the other c.f. + # https://stackoverflow.com/questions/68272502/tf-depth-to-space-not-same-as-torchs-pixelshuffle-when-output-channels-1 + permutation = tf.constant( + [[i + j * block_size_squared for i in range(block_size_squared) for j in range(output_depth)]] + ) + hidden_states = tf.gather(params=hidden_states, indices=tf.tile(permutation, [batch_size, 1]), batch_dims=-1) + hidden_states = tf.nn.depth_to_space(hidden_states, block_size=self.upscale_factor, data_format="NHWC") + return hidden_states + + +class TFSwinDecoder(keras.layers.Layer): + def __init__(self, config: SwinConfig, **kwargs): + super().__init__(**kwargs) + self.conv2d = keras.layers.Conv2D( + filters=config.encoder_stride**2 * config.num_channels, kernel_size=1, strides=1, name="0" + ) + self.pixel_shuffle = TFSwinPixelShuffle(config.encoder_stride, name="1") + self.config = config + + def call(self, x: tf.Tensor) -> tf.Tensor: + hidden_states = x + # B,C,H,W -> B,H,W,C + hidden_states = tf.transpose(hidden_states, (0, 2, 3, 1)) + hidden_states = self.conv2d(hidden_states) + hidden_states = self.pixel_shuffle(hidden_states) + # B,H,W,C -> B,C,H,W + hidden_states = tf.transpose(hidden_states, (0, 3, 1, 2)) + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "conv2d", None) is not None: + with tf.name_scope(self.conv2d.name): + self.conv2d.build([None, None, None, self.config.hidden_size]) + if getattr(self, "pixel_shuffle", None) is not None: + with tf.name_scope(self.pixel_shuffle.name): + self.pixel_shuffle.build(None) + + +@add_start_docstrings( + "Swin Model with a decoder on top for masked image modeling, as proposed in" + " [SimMIM](https://arxiv.org/abs/2111.09886).", + SWIN_START_DOCSTRING, +) +class TFSwinForMaskedImageModeling(TFSwinPreTrainedModel): + def __init__(self, config: SwinConfig): + super().__init__(config) + + self.swin = TFSwinMainLayer(config, add_pooling_layer=False, use_mask_token=True, name="swin") + + self.decoder = TFSwinDecoder(config, name="decoder") + + @add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TFSwinMaskedImageModelingOutput, config_class=_CONFIG_FOR_DOC) + @unpack_inputs + def call( + self, + pixel_values: tf.Tensor | None = None, + bool_masked_pos: tf.Tensor | None = None, + head_mask: tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: bool = False, + ) -> Union[Tuple, TFSwinMaskedImageModelingOutput]: + r""" + bool_masked_pos (`tf.Tensor` of shape `(batch_size, num_patches)`): + Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). + + Returns: + + Examples: + ```python + >>> from transformers import AutoImageProcessor, TFSwinForMaskedImageModeling + >>> import tensorflow as tf + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/swin-tiny-patch4-window7-224") + >>> model = TFSwinForMaskedImageModeling.from_pretrained("microsoft/swin-tiny-patch4-window7-224") + + >>> num_patches = (model.config.image_size // model.config.patch_size) ** 2 + >>> pixel_values = image_processor(images=image, return_tensors="tf").pixel_values + >>> # create random boolean mask of shape (batch_size, num_patches) + >>> bool_masked_pos = tf.random.uniform((1, num_patches)) >= 0.5 + + >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos) + >>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction + >>> list(reconstructed_pixel_values.shape) + [1, 3, 224, 224] + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.swin( + pixel_values, + bool_masked_pos=bool_masked_pos, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + sequence_output = outputs[0] + # Reshape to (batch_size, num_channels, height, width) + sequence_output = tf.transpose(sequence_output, (0, 2, 1)) + batch_size, num_channels, sequence_length = shape_list(sequence_output) + height = width = int(sequence_length**0.5) + sequence_output = tf.reshape(sequence_output, (batch_size, num_channels, height, width)) + + # Reconstruct pixel values + reconstructed_pixel_values = self.decoder(sequence_output) + + masked_im_loss = None + if bool_masked_pos is not None: + size = self.config.image_size // self.config.patch_size + bool_masked_pos = tf.reshape(bool_masked_pos, (-1, size, size)) + mask = tf.repeat(bool_masked_pos, self.config.patch_size, 1) + mask = tf.repeat(mask, self.config.patch_size, 2) + mask = tf.expand_dims(mask, 1) + mask = tf.cast(mask, tf.float32) + + reconstruction_loss = keras.losses.mean_absolute_error( + # Swap axes as metric calculation reduces over the final dimension + tf.transpose(pixel_values, (1, 2, 3, 0)), + tf.transpose(reconstructed_pixel_values, (1, 2, 3, 0)), + ) + reconstruction_loss = tf.expand_dims(reconstruction_loss, 0) + total_loss = tf.reduce_sum(reconstruction_loss * mask) + num_masked_pixels = (tf.reduce_sum(mask) + 1e-5) * self.config.num_channels + masked_im_loss = total_loss / num_masked_pixels + masked_im_loss = tf.reshape(masked_im_loss, (1,)) + + if not return_dict: + output = (reconstructed_pixel_values,) + outputs[2:] + return ((masked_im_loss,) + output) if masked_im_loss is not None else output + + return TFSwinMaskedImageModelingOutput( + loss=masked_im_loss, + reconstruction=reconstructed_pixel_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + reshaped_hidden_states=outputs.reshaped_hidden_states, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "swin", None) is not None: + with tf.name_scope(self.swin.name): + self.swin.build(None) + if getattr(self, "decoder", None) is not None: + with tf.name_scope(self.decoder.name): + self.decoder.build(None) + + +@add_start_docstrings( + """ + Swin Model transformer with an image classification head on top (a linear layer on top of the final hidden state of + the [CLS] token) e.g. for ImageNet. + """, + SWIN_START_DOCSTRING, +) +class TFSwinForImageClassification(TFSwinPreTrainedModel, TFSequenceClassificationLoss): + def __init__(self, config: SwinConfig): + super().__init__(config) + + self.num_labels = config.num_labels + self.swin = TFSwinMainLayer(config, name="swin") + + # Classifier head + self.classifier = ( + keras.layers.Dense(config.num_labels, name="classifier") + if config.num_labels > 0 + else keras.layers.Activation("linear", name="classifier") + ) + + @add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_IMAGE_CLASS_CHECKPOINT, + output_type=TFSwinImageClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, + ) + @unpack_inputs + def call( + self, + pixel_values: tf.Tensor | None = None, + head_mask: tf.Tensor | None = None, + labels: tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: bool = False, + ) -> Union[Tuple[tf.Tensor, ...], TFSwinImageClassifierOutput]: + r""" + labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.swin( + pixel_values, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + pooled_output = outputs[1] + + logits = self.classifier(pooled_output, training=training) + + loss = None if labels is None else self.hf_compute_loss(labels, logits) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFSwinImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + reshaped_hidden_states=outputs.reshaped_hidden_states, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "swin", None) is not None: + with tf.name_scope(self.swin.name): + self.swin.build(None) + if getattr(self, "classifier", None) is not None: + if hasattr(self.classifier, "name"): + with tf.name_scope(self.classifier.name): + self.classifier.build([None, None, self.swin.num_features]) + + +__all__ = ["TFSwinForImageClassification", "TFSwinForMaskedImageModeling", "TFSwinModel", "TFSwinPreTrainedModel"] diff --git a/docs/transformers/src/transformers/models/swin2sr/__init__.py b/docs/transformers/src/transformers/models/swin2sr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cc7ea677b5d5d6495c66ec753b636115848260d8 --- /dev/null +++ b/docs/transformers/src/transformers/models/swin2sr/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_swin2sr import * + from .image_processing_swin2sr import * + from .modeling_swin2sr import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/swin2sr/configuration_swin2sr.py b/docs/transformers/src/transformers/models/swin2sr/configuration_swin2sr.py new file mode 100644 index 0000000000000000000000000000000000000000..a507d9d625136c1710f72924f62edd563d803d2d --- /dev/null +++ b/docs/transformers/src/transformers/models/swin2sr/configuration_swin2sr.py @@ -0,0 +1,154 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Swin2SR Transformer model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class Swin2SRConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Swin2SRModel`]. It is used to instantiate a Swin + Transformer v2 model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the Swin Transformer v2 + [caidas/swin2sr-classicalsr-x2-64](https://huggingface.co/caidas/swin2sr-classicalsr-x2-64) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + image_size (`int`, *optional*, defaults to 64): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 1): + The size (resolution) of each patch. + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + num_channels_out (`int`, *optional*, defaults to `num_channels`): + The number of output channels. If not set, it will be set to `num_channels`. + embed_dim (`int`, *optional*, defaults to 180): + Dimensionality of patch embedding. + depths (`list(int)`, *optional*, defaults to `[6, 6, 6, 6, 6, 6]`): + Depth of each layer in the Transformer encoder. + num_heads (`list(int)`, *optional*, defaults to `[6, 6, 6, 6, 6, 6]`): + Number of attention heads in each layer of the Transformer encoder. + window_size (`int`, *optional*, defaults to 8): + Size of windows. + mlp_ratio (`float`, *optional*, defaults to 2.0): + Ratio of MLP hidden dimensionality to embedding dimensionality. + qkv_bias (`bool`, *optional*, defaults to `True`): + Whether or not a learnable bias should be added to the queries, keys and values. + hidden_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings and encoder. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + drop_path_rate (`float`, *optional*, defaults to 0.1): + Stochastic depth rate. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder. If string, `"gelu"`, `"relu"`, + `"selu"` and `"gelu_new"` are supported. + use_absolute_embeddings (`bool`, *optional*, defaults to `False`): + Whether or not to add absolute position embeddings to the patch embeddings. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the layer normalization layers. + upscale (`int`, *optional*, defaults to 2): + The upscale factor for the image. 2/3/4/8 for image super resolution, 1 for denoising and compress artifact + reduction + img_range (`float`, *optional*, defaults to 1.0): + The range of the values of the input image. + resi_connection (`str`, *optional*, defaults to `"1conv"`): + The convolutional block to use before the residual connection in each stage. + upsampler (`str`, *optional*, defaults to `"pixelshuffle"`): + The reconstruction reconstruction module. Can be 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None. + + Example: + + ```python + >>> from transformers import Swin2SRConfig, Swin2SRModel + + >>> # Initializing a Swin2SR caidas/swin2sr-classicalsr-x2-64 style configuration + >>> configuration = Swin2SRConfig() + + >>> # Initializing a model (with random weights) from the caidas/swin2sr-classicalsr-x2-64 style configuration + >>> model = Swin2SRModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "swin2sr" + + attribute_map = { + "hidden_size": "embed_dim", + "num_attention_heads": "num_heads", + "num_hidden_layers": "num_layers", + } + + def __init__( + self, + image_size=64, + patch_size=1, + num_channels=3, + num_channels_out=None, + embed_dim=180, + depths=[6, 6, 6, 6, 6, 6], + num_heads=[6, 6, 6, 6, 6, 6], + window_size=8, + mlp_ratio=2.0, + qkv_bias=True, + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + drop_path_rate=0.1, + hidden_act="gelu", + use_absolute_embeddings=False, + initializer_range=0.02, + layer_norm_eps=1e-5, + upscale=2, + img_range=1.0, + resi_connection="1conv", + upsampler="pixelshuffle", + **kwargs, + ): + super().__init__(**kwargs) + + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_channels_out = num_channels if num_channels_out is None else num_channels_out + self.embed_dim = embed_dim + self.depths = depths + self.num_layers = len(depths) + self.num_heads = num_heads + self.window_size = window_size + self.mlp_ratio = mlp_ratio + self.qkv_bias = qkv_bias + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.drop_path_rate = drop_path_rate + self.hidden_act = hidden_act + self.use_absolute_embeddings = use_absolute_embeddings + self.layer_norm_eps = layer_norm_eps + self.initializer_range = initializer_range + self.upscale = upscale + self.img_range = img_range + self.resi_connection = resi_connection + self.upsampler = upsampler + + +__all__ = ["Swin2SRConfig"] diff --git a/docs/transformers/src/transformers/models/swin2sr/convert_swin2sr_original_to_pytorch.py b/docs/transformers/src/transformers/models/swin2sr/convert_swin2sr_original_to_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..29fe2e3e25de1ff73abbe4b3600f9df63ca7e411 --- /dev/null +++ b/docs/transformers/src/transformers/models/swin2sr/convert_swin2sr_original_to_pytorch.py @@ -0,0 +1,278 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert Swin2SR checkpoints from the original repository. URL: https://github.com/mv-lab/swin2sr""" + +import argparse + +import requests +import torch +from PIL import Image +from torchvision.transforms import Compose, Normalize, Resize, ToTensor + +from transformers import Swin2SRConfig, Swin2SRForImageSuperResolution, Swin2SRImageProcessor + + +def get_config(checkpoint_url): + config = Swin2SRConfig() + + if "Swin2SR_ClassicalSR_X4_64" in checkpoint_url: + config.upscale = 4 + elif "Swin2SR_CompressedSR_X4_48" in checkpoint_url: + config.upscale = 4 + config.image_size = 48 + config.upsampler = "pixelshuffle_aux" + elif "Swin2SR_Lightweight_X2_64" in checkpoint_url: + config.depths = [6, 6, 6, 6] + config.embed_dim = 60 + config.num_heads = [6, 6, 6, 6] + config.upsampler = "pixelshuffledirect" + elif "Swin2SR_RealworldSR_X4_64_BSRGAN_PSNR" in checkpoint_url: + config.upscale = 4 + config.upsampler = "nearest+conv" + elif "Swin2SR_Jpeg_dynamic" in checkpoint_url: + config.num_channels = 1 + config.upscale = 1 + config.image_size = 126 + config.window_size = 7 + config.img_range = 255.0 + config.upsampler = "" + + return config + + +def rename_key(name, config): + if "patch_embed.proj" in name and "layers" not in name: + name = name.replace("patch_embed.proj", "embeddings.patch_embeddings.projection") + if "patch_embed.norm" in name: + name = name.replace("patch_embed.norm", "embeddings.patch_embeddings.layernorm") + if "layers" in name: + name = name.replace("layers", "encoder.stages") + if "residual_group.blocks" in name: + name = name.replace("residual_group.blocks", "layers") + if "attn.proj" in name: + name = name.replace("attn.proj", "attention.output.dense") + if "attn" in name: + name = name.replace("attn", "attention.self") + if "norm1" in name: + name = name.replace("norm1", "layernorm_before") + if "norm2" in name: + name = name.replace("norm2", "layernorm_after") + if "mlp.fc1" in name: + name = name.replace("mlp.fc1", "intermediate.dense") + if "mlp.fc2" in name: + name = name.replace("mlp.fc2", "output.dense") + if "q_bias" in name: + name = name.replace("q_bias", "query.bias") + if "k_bias" in name: + name = name.replace("k_bias", "key.bias") + if "v_bias" in name: + name = name.replace("v_bias", "value.bias") + if "cpb_mlp" in name: + name = name.replace("cpb_mlp", "continuous_position_bias_mlp") + if "patch_embed.proj" in name: + name = name.replace("patch_embed.proj", "patch_embed.projection") + + if name == "norm.weight": + name = "layernorm.weight" + if name == "norm.bias": + name = "layernorm.bias" + + if "conv_first" in name: + name = name.replace("conv_first", "first_convolution") + + if ( + "upsample" in name + or "conv_before_upsample" in name + or "conv_bicubic" in name + or "conv_up" in name + or "conv_hr" in name + or "conv_last" in name + or "aux" in name + ): + # heads + if "conv_last" in name: + name = name.replace("conv_last", "final_convolution") + if config.upsampler in ["pixelshuffle", "pixelshuffle_aux", "nearest+conv"]: + if "conv_before_upsample.0" in name: + name = name.replace("conv_before_upsample.0", "conv_before_upsample") + if "upsample.0" in name: + name = name.replace("upsample.0", "upsample.convolution_0") + if "upsample.2" in name: + name = name.replace("upsample.2", "upsample.convolution_1") + name = "upsample." + name + elif config.upsampler == "pixelshuffledirect": + name = name.replace("upsample.0.weight", "upsample.conv.weight") + name = name.replace("upsample.0.bias", "upsample.conv.bias") + else: + pass + else: + name = "swin2sr." + name + + return name + + +def convert_state_dict(orig_state_dict, config): + for key in orig_state_dict.copy().keys(): + val = orig_state_dict.pop(key) + + if "qkv" in key: + key_split = key.split(".") + stage_num = int(key_split[1]) + block_num = int(key_split[4]) + dim = config.embed_dim + + if "weight" in key: + orig_state_dict[ + f"swin2sr.encoder.stages.{stage_num}.layers.{block_num}.attention.self.query.weight" + ] = val[:dim, :] + orig_state_dict[f"swin2sr.encoder.stages.{stage_num}.layers.{block_num}.attention.self.key.weight"] = ( + val[dim : dim * 2, :] + ) + orig_state_dict[ + f"swin2sr.encoder.stages.{stage_num}.layers.{block_num}.attention.self.value.weight" + ] = val[-dim:, :] + else: + orig_state_dict[f"swin2sr.encoder.stages.{stage_num}.layers.{block_num}.attention.self.query.bias"] = ( + val[:dim] + ) + orig_state_dict[f"swin2sr.encoder.stages.{stage_num}.layers.{block_num}.attention.self.key.bias"] = ( + val[dim : dim * 2] + ) + orig_state_dict[f"swin2sr.encoder.stages.{stage_num}.layers.{block_num}.attention.self.value.bias"] = ( + val[-dim:] + ) + pass + else: + orig_state_dict[rename_key(key, config)] = val + + return orig_state_dict + + +def convert_swin2sr_checkpoint(checkpoint_url, pytorch_dump_folder_path, push_to_hub): + config = get_config(checkpoint_url) + model = Swin2SRForImageSuperResolution(config) + model.eval() + + state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu") + new_state_dict = convert_state_dict(state_dict, config) + missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False) + + if len(missing_keys) > 0: + raise ValueError("Missing keys when converting: {}".format(missing_keys)) + for key in unexpected_keys: + if not ("relative_position_index" in key or "relative_coords_table" in key or "self_mask" in key): + raise ValueError(f"Unexpected key {key} in state_dict") + + # verify values + url = "https://github.com/mv-lab/swin2sr/blob/main/testsets/real-inputs/shanghai.jpg?raw=true" + image = Image.open(requests.get(url, stream=True).raw).convert("RGB") + processor = Swin2SRImageProcessor() + # pixel_values = processor(image, return_tensors="pt").pixel_values + + image_size = 126 if "Jpeg" in checkpoint_url else 256 + transforms = Compose( + [ + Resize((image_size, image_size)), + ToTensor(), + Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + ) + pixel_values = transforms(image).unsqueeze(0) + + if config.num_channels == 1: + pixel_values = pixel_values[:, 0, :, :].unsqueeze(1) + + outputs = model(pixel_values) + + # assert values + if "Swin2SR_ClassicalSR_X2_64" in checkpoint_url: + expected_shape = torch.Size([1, 3, 512, 512]) + expected_slice = torch.tensor( + [[-0.7087, -0.7138, -0.6721], [-0.8340, -0.8095, -0.7298], [-0.9149, -0.8414, -0.7940]] + ) + elif "Swin2SR_ClassicalSR_X4_64" in checkpoint_url: + expected_shape = torch.Size([1, 3, 1024, 1024]) + expected_slice = torch.tensor( + [[-0.7775, -0.8105, -0.8933], [-0.7764, -0.8356, -0.9225], [-0.7976, -0.8686, -0.9579]] + ) + elif "Swin2SR_CompressedSR_X4_48" in checkpoint_url: + # TODO values didn't match exactly here + expected_shape = torch.Size([1, 3, 1024, 1024]) + expected_slice = torch.tensor( + [[-0.8035, -0.7504, -0.7491], [-0.8538, -0.8124, -0.7782], [-0.8804, -0.8651, -0.8493]] + ) + elif "Swin2SR_Lightweight_X2_64" in checkpoint_url: + expected_shape = torch.Size([1, 3, 512, 512]) + expected_slice = torch.tensor( + [[-0.7669, -0.8662, -0.8767], [-0.8810, -0.9962, -0.9820], [-0.9340, -1.0322, -1.1149]] + ) + elif "Swin2SR_RealworldSR_X4_64_BSRGAN_PSNR" in checkpoint_url: + expected_shape = torch.Size([1, 3, 1024, 1024]) + expected_slice = torch.tensor( + [[-0.5238, -0.5557, -0.6321], [-0.6016, -0.5903, -0.6391], [-0.6244, -0.6334, -0.6889]] + ) + + assert outputs.reconstruction.shape == expected_shape, ( + f"Shape of reconstruction should be {expected_shape}, but is {outputs.reconstruction.shape}" + ) + assert torch.allclose(outputs.reconstruction[0, 0, :3, :3], expected_slice, atol=1e-3) + print("Looks ok!") + + url_to_name = { + "https://github.com/mv-lab/swin2sr/releases/download/v0.0.1/Swin2SR_ClassicalSR_X2_64.pth": ( + "swin2SR-classical-sr-x2-64" + ), + "https://github.com/mv-lab/swin2sr/releases/download/v0.0.1/Swin2SR_ClassicalSR_X4_64.pth": ( + "swin2SR-classical-sr-x4-64" + ), + "https://github.com/mv-lab/swin2sr/releases/download/v0.0.1/Swin2SR_CompressedSR_X4_48.pth": ( + "swin2SR-compressed-sr-x4-48" + ), + "https://github.com/mv-lab/swin2sr/releases/download/v0.0.1/Swin2SR_Lightweight_X2_64.pth": ( + "swin2SR-lightweight-x2-64" + ), + "https://github.com/mv-lab/swin2sr/releases/download/v0.0.1/Swin2SR_RealworldSR_X4_64_BSRGAN_PSNR.pth": ( + "swin2SR-realworld-sr-x4-64-bsrgan-psnr" + ), + } + model_name = url_to_name[checkpoint_url] + + if pytorch_dump_folder_path is not None: + print(f"Saving model {model_name} to {pytorch_dump_folder_path}") + model.save_pretrained(pytorch_dump_folder_path) + print(f"Saving image processor to {pytorch_dump_folder_path}") + processor.save_pretrained(pytorch_dump_folder_path) + + if push_to_hub: + model.push_to_hub(f"caidas/{model_name}") + processor.push_to_hub(f"caidas/{model_name}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--checkpoint_url", + default="https://github.com/mv-lab/swin2sr/releases/download/v0.0.1/Swin2SR_ClassicalSR_X2_64.pth", + type=str, + help="URL of the original Swin2SR checkpoint you'd like to convert.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory." + ) + parser.add_argument("--push_to_hub", action="store_true", help="Whether to push the converted model to the hub.") + + args = parser.parse_args() + convert_swin2sr_checkpoint(args.checkpoint_url, args.pytorch_dump_folder_path, args.push_to_hub) diff --git a/docs/transformers/src/transformers/models/swin2sr/image_processing_swin2sr.py b/docs/transformers/src/transformers/models/swin2sr/image_processing_swin2sr.py new file mode 100644 index 0000000000000000000000000000000000000000..2832e531602e9d1232a02830c7a5aa9ee4c6d6f1 --- /dev/null +++ b/docs/transformers/src/transformers/models/swin2sr/image_processing_swin2sr.py @@ -0,0 +1,206 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for Swin2SR.""" + +from typing import Optional, Union + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature +from ...image_transforms import get_image_size, pad, to_channel_dimension_format +from ...image_utils import ( + ChannelDimension, + ImageInput, + infer_channel_dimension_format, + is_scaled_image, + make_list_of_images, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...utils import TensorType, filter_out_non_signature_kwargs, logging + + +logger = logging.get_logger(__name__) + + +class Swin2SRImageProcessor(BaseImageProcessor): + r""" + Constructs a Swin2SR image processor. + + Args: + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale` + parameter in the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the + `preprocess` method. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_pad: bool = True, + pad_size: int = 8, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_pad = do_pad + self.pad_size = pad_size + + def pad( + self, + image: np.ndarray, + size: int, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + """ + Pad an image to make the height and width divisible by `size`. + + Args: + image (`np.ndarray`): + Image to pad. + size (`int`): + The size to make the height and width divisible by. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the output image. If unset, the channel dimension format of the input + image is used. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + input_data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + + Returns: + `np.ndarray`: The padded image. + """ + old_height, old_width = get_image_size(image, input_data_format) + pad_height = (old_height // size + 1) * size - old_height + pad_width = (old_width // size + 1) * size - old_width + + return pad( + image, + ((0, pad_height), (0, pad_width)), + mode="symmetric", + data_format=data_format, + input_data_format=input_data_format, + ) + + @filter_out_non_signature_kwargs() + def preprocess( + self, + images: ImageInput, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_pad: Optional[bool] = None, + pad_size: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image values between [0 - 1]. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_pad (`bool`, *optional*, defaults to `True`): + Whether to pad the image to make the height and width divisible by `window_size`. + pad_size (`int`, *optional*, defaults to 32): + The size of the sliding window for the local attention. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of typ, input_data_format=input_data_formate + `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_pad = do_pad if do_pad is not None else self.do_pad + pad_size = pad_size if pad_size is not None else self.pad_size + + images = make_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_pad=do_pad, + size_divisibility=pad_size, # Here the pad function simply requires pad_size. + ) + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if do_rescale and is_scaled_image(images[0]): + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + if do_rescale: + images = [ + self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + for image in images + ] + + if do_pad: + images = [self.pad(image, size=pad_size, input_data_format=input_data_format) for image in images] + + images = [ + to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images + ] + + data = {"pixel_values": images} + return BatchFeature(data=data, tensor_type=return_tensors) + + +__all__ = ["Swin2SRImageProcessor"] diff --git a/docs/transformers/src/transformers/models/swin2sr/modeling_swin2sr.py b/docs/transformers/src/transformers/models/swin2sr/modeling_swin2sr.py new file mode 100644 index 0000000000000000000000000000000000000000..6543fe5cdff610ca342be893c3c61ee4dbc513c4 --- /dev/null +++ b/docs/transformers/src/transformers/models/swin2sr/modeling_swin2sr.py @@ -0,0 +1,1183 @@ +# coding=utf-8 +# Copyright 2022 Microsoft Research and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Swin2SR Transformer model.""" + +import collections.abc +import math +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn + +from ...activations import ACT2FN +from ...modeling_outputs import BaseModelOutput, ImageSuperResolutionOutput +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer +from ...utils import ( + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from .configuration_swin2sr import Swin2SRConfig + + +logger = logging.get_logger(__name__) + +# General docstring +_CONFIG_FOR_DOC = "Swin2SRConfig" + +# Base docstring +_CHECKPOINT_FOR_DOC = "caidas/swin2SR-classical-sr-x2-64" +_EXPECTED_OUTPUT_SHAPE = [1, 180, 488, 648] + + +@dataclass +class Swin2SREncoderOutput(ModelOutput): + """ + Swin2SR encoder's outputs, with potential hidden states and attentions. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +# Copied from transformers.models.swin.modeling_swin.window_partition +def window_partition(input_feature, window_size): + """ + Partitions the given input into windows. + """ + batch_size, height, width, num_channels = input_feature.shape + input_feature = input_feature.view( + batch_size, height // window_size, window_size, width // window_size, window_size, num_channels + ) + windows = input_feature.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels) + return windows + + +# Copied from transformers.models.swin.modeling_swin.window_reverse +def window_reverse(windows, window_size, height, width): + """ + Merges windows to produce higher resolution features. + """ + num_channels = windows.shape[-1] + windows = windows.view(-1, height // window_size, width // window_size, window_size, window_size, num_channels) + windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, height, width, num_channels) + return windows + + +# Copied from transformers.models.beit.modeling_beit.drop_path +def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, + however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the + layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the + argument. + """ + if drop_prob == 0.0 or not training: + return input + keep_prob = 1 - drop_prob + shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) + random_tensor.floor_() # binarize + output = input.div(keep_prob) * random_tensor + return output + + +# Copied from transformers.models.swin.modeling_swin.SwinDropPath with Swin->Swin2SR +class Swin2SRDropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: Optional[float] = None) -> None: + super().__init__() + self.drop_prob = drop_prob + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return drop_path(hidden_states, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return "p={}".format(self.drop_prob) + + +class Swin2SREmbeddings(nn.Module): + """ + Construct the patch and optional position embeddings. + """ + + def __init__(self, config): + super().__init__() + + self.patch_embeddings = Swin2SRPatchEmbeddings(config) + num_patches = self.patch_embeddings.num_patches + + if config.use_absolute_embeddings: + self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.embed_dim)) + else: + self.position_embeddings = None + + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.window_size = config.window_size + + def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor]: + embeddings, output_dimensions = self.patch_embeddings(pixel_values) + + if self.position_embeddings is not None: + embeddings = embeddings + self.position_embeddings + + embeddings = self.dropout(embeddings) + + return embeddings, output_dimensions + + +class Swin2SRPatchEmbeddings(nn.Module): + def __init__(self, config, normalize_patches=True): + super().__init__() + num_channels = config.embed_dim + image_size, patch_size = config.image_size, config.patch_size + + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + patches_resolution = [image_size[0] // patch_size[0], image_size[1] // patch_size[1]] + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.projection = nn.Conv2d(num_channels, config.embed_dim, kernel_size=patch_size, stride=patch_size) + self.layernorm = nn.LayerNorm(config.embed_dim) if normalize_patches else None + + def forward(self, embeddings: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]: + embeddings = self.projection(embeddings) + _, _, height, width = embeddings.shape + output_dimensions = (height, width) + embeddings = embeddings.flatten(2).transpose(1, 2) + + if self.layernorm is not None: + embeddings = self.layernorm(embeddings) + + return embeddings, output_dimensions + + +class Swin2SRPatchUnEmbeddings(nn.Module): + r"""Image to Patch Unembedding""" + + def __init__(self, config): + super().__init__() + + self.embed_dim = config.embed_dim + + def forward(self, embeddings, x_size): + batch_size, height_width, num_channels = embeddings.shape + embeddings = embeddings.transpose(1, 2).view(batch_size, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C + return embeddings + + +# Copied from transformers.models.swinv2.modeling_swinv2.Swinv2PatchMerging with Swinv2->Swin2SR +class Swin2SRPatchMerging(nn.Module): + """ + Patch Merging Layer. + + Args: + input_resolution (`Tuple[int]`): + Resolution of input feature. + dim (`int`): + Number of input channels. + norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`): + Normalization layer class. + """ + + def __init__(self, input_resolution: Tuple[int], dim: int, norm_layer: nn.Module = nn.LayerNorm) -> None: + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(2 * dim) + + def maybe_pad(self, input_feature, height, width): + should_pad = (height % 2 == 1) or (width % 2 == 1) + if should_pad: + pad_values = (0, 0, 0, width % 2, 0, height % 2) + input_feature = nn.functional.pad(input_feature, pad_values) + + return input_feature + + def forward(self, input_feature: torch.Tensor, input_dimensions: Tuple[int, int]) -> torch.Tensor: + height, width = input_dimensions + # `dim` is height * width + batch_size, dim, num_channels = input_feature.shape + + input_feature = input_feature.view(batch_size, height, width, num_channels) + # pad input to be disible by width and height, if needed + input_feature = self.maybe_pad(input_feature, height, width) + # [batch_size, height/2, width/2, num_channels] + input_feature_0 = input_feature[:, 0::2, 0::2, :] + # [batch_size, height/2, width/2, num_channels] + input_feature_1 = input_feature[:, 1::2, 0::2, :] + # [batch_size, height/2, width/2, num_channels] + input_feature_2 = input_feature[:, 0::2, 1::2, :] + # [batch_size, height/2, width/2, num_channels] + input_feature_3 = input_feature[:, 1::2, 1::2, :] + # [batch_size, height/2 * width/2, 4*num_channels] + input_feature = torch.cat([input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1) + input_feature = input_feature.view(batch_size, -1, 4 * num_channels) # [batch_size, height/2 * width/2, 4*C] + + input_feature = self.reduction(input_feature) + input_feature = self.norm(input_feature) + + return input_feature + + +# Copied from transformers.models.swinv2.modeling_swinv2.Swinv2SelfAttention with Swinv2->Swin2SR +class Swin2SRSelfAttention(nn.Module): + def __init__(self, config, dim, num_heads, window_size, pretrained_window_size=[0, 0]): + super().__init__() + if dim % num_heads != 0: + raise ValueError( + f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})" + ) + + self.num_attention_heads = num_heads + self.attention_head_size = int(dim / num_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.window_size = ( + window_size if isinstance(window_size, collections.abc.Iterable) else (window_size, window_size) + ) + self.pretrained_window_size = pretrained_window_size + self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1)))) + # mlp to generate continuous relative position bias + self.continuous_position_bias_mlp = nn.Sequential( + nn.Linear(2, 512, bias=True), nn.ReLU(inplace=True), nn.Linear(512, num_heads, bias=False) + ) + + # get relative_coords_table + relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.int64).float() + relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.int64).float() + relative_coords_table = ( + torch.stack(meshgrid([relative_coords_h, relative_coords_w], indexing="ij")) + .permute(1, 2, 0) + .contiguous() + .unsqueeze(0) + ) # [1, 2*window_height - 1, 2*window_width - 1, 2] + if pretrained_window_size[0] > 0: + relative_coords_table[:, :, :, 0] /= pretrained_window_size[0] - 1 + relative_coords_table[:, :, :, 1] /= pretrained_window_size[1] - 1 + elif window_size > 1: + relative_coords_table[:, :, :, 0] /= self.window_size[0] - 1 + relative_coords_table[:, :, :, 1] /= self.window_size[1] - 1 + relative_coords_table *= 8 # normalize to -8, 8 + relative_coords_table = ( + torch.sign(relative_coords_table) * torch.log2(torch.abs(relative_coords_table) + 1.0) / math.log2(8) + ) + # set to same dtype as mlp weight + relative_coords_table = relative_coords_table.to(next(self.continuous_position_bias_mlp.parameters()).dtype) + self.register_buffer("relative_coords_table", relative_coords_table, persistent=False) + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij")) + coords_flatten = torch.flatten(coords, 1) + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + relative_coords[:, :, 0] += self.window_size[0] - 1 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) + self.register_buffer("relative_position_index", relative_position_index, persistent=False) + + self.query = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias) + self.key = nn.Linear(self.all_head_size, self.all_head_size, bias=False) + self.value = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias) + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + batch_size, dim, num_channels = hidden_states.shape + mixed_query_layer = self.query(hidden_states) + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + # cosine attention + attention_scores = nn.functional.normalize(query_layer, dim=-1) @ nn.functional.normalize( + key_layer, dim=-1 + ).transpose(-2, -1) + logit_scale = torch.clamp(self.logit_scale, max=math.log(1.0 / 0.01)).exp() + attention_scores = attention_scores * logit_scale + relative_position_bias_table = self.continuous_position_bias_mlp(self.relative_coords_table).view( + -1, self.num_attention_heads + ) + # [window_height*window_width,window_height*window_width,num_attention_heads] + relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1 + ) + # [num_attention_heads,window_height*window_width,window_height*window_width] + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + relative_position_bias = 16 * torch.sigmoid(relative_position_bias) + attention_scores = attention_scores + relative_position_bias.unsqueeze(0) + + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in Swin2SRModel forward() function) + mask_shape = attention_mask.shape[0] + attention_scores = attention_scores.view( + batch_size // mask_shape, mask_shape, self.num_attention_heads, dim, dim + ) + attention_mask.unsqueeze(1).unsqueeze(0) + attention_scores = attention_scores + attention_mask.unsqueeze(1).unsqueeze(0) + attention_scores = attention_scores.view(-1, self.num_attention_heads, dim, dim) + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +# Copied from transformers.models.swin.modeling_swin.SwinSelfOutput with Swin->Swin2SR +class Swin2SRSelfOutput(nn.Module): + def __init__(self, config, dim): + super().__init__() + self.dense = nn.Linear(dim, dim) + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + return hidden_states + + +# Copied from transformers.models.swinv2.modeling_swinv2.Swinv2Attention with Swinv2->Swin2SR +class Swin2SRAttention(nn.Module): + def __init__(self, config, dim, num_heads, window_size, pretrained_window_size=0): + super().__init__() + self.self = Swin2SRSelfAttention( + config=config, + dim=dim, + num_heads=num_heads, + window_size=window_size, + pretrained_window_size=pretrained_window_size + if isinstance(pretrained_window_size, collections.abc.Iterable) + else (pretrained_window_size, pretrained_window_size), + ) + self.output = Swin2SRSelfOutput(config, dim) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + self_outputs = self.self(hidden_states, attention_mask, head_mask, output_attentions) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.swin.modeling_swin.SwinIntermediate with Swin->Swin2SR +class Swin2SRIntermediate(nn.Module): + def __init__(self, config, dim): + super().__init__() + self.dense = nn.Linear(dim, int(config.mlp_ratio * dim)) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +# Copied from transformers.models.swin.modeling_swin.SwinOutput with Swin->Swin2SR +class Swin2SROutput(nn.Module): + def __init__(self, config, dim): + super().__init__() + self.dense = nn.Linear(int(config.mlp_ratio * dim), dim) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +# Copied from transformers.models.swinv2.modeling_swinv2.Swinv2Layer with Swinv2->Swin2SR +class Swin2SRLayer(nn.Module): + def __init__( + self, config, dim, input_resolution, num_heads, drop_path_rate=0.0, shift_size=0, pretrained_window_size=0 + ): + super().__init__() + self.input_resolution = input_resolution + window_size, shift_size = self._compute_window_shift( + (config.window_size, config.window_size), (shift_size, shift_size) + ) + self.window_size = window_size[0] + self.shift_size = shift_size[0] + self.attention = Swin2SRAttention( + config=config, + dim=dim, + num_heads=num_heads, + window_size=self.window_size, + pretrained_window_size=pretrained_window_size + if isinstance(pretrained_window_size, collections.abc.Iterable) + else (pretrained_window_size, pretrained_window_size), + ) + self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps) + self.drop_path = Swin2SRDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() + self.intermediate = Swin2SRIntermediate(config, dim) + self.output = Swin2SROutput(config, dim) + self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps) + + def _compute_window_shift(self, target_window_size, target_shift_size) -> Tuple[Tuple[int, int], Tuple[int, int]]: + window_size = [r if r <= w else w for r, w in zip(self.input_resolution, target_window_size)] + shift_size = [0 if r <= w else s for r, w, s in zip(self.input_resolution, window_size, target_shift_size)] + return window_size, shift_size + + def get_attn_mask(self, height, width, dtype): + if self.shift_size > 0: + # calculate attention mask for shifted window multihead self attention + img_mask = torch.zeros((1, height, width, 1), dtype=dtype) + height_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + width_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + count = 0 + for height_slice in height_slices: + for width_slice in width_slices: + img_mask[:, height_slice, width_slice, :] = count + count += 1 + + mask_windows = window_partition(img_mask, self.window_size) + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + else: + attn_mask = None + return attn_mask + + def maybe_pad(self, hidden_states, height, width): + pad_right = (self.window_size - width % self.window_size) % self.window_size + pad_bottom = (self.window_size - height % self.window_size) % self.window_size + pad_values = (0, 0, 0, pad_right, 0, pad_bottom) + hidden_states = nn.functional.pad(hidden_states, pad_values) + return hidden_states, pad_values + + def forward( + self, + hidden_states: torch.Tensor, + input_dimensions: Tuple[int, int], + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + height, width = input_dimensions + batch_size, _, channels = hidden_states.size() + shortcut = hidden_states + + # pad hidden_states to multiples of window size + hidden_states = hidden_states.view(batch_size, height, width, channels) + hidden_states, pad_values = self.maybe_pad(hidden_states, height, width) + _, height_pad, width_pad, _ = hidden_states.shape + # cyclic shift + if self.shift_size > 0: + shifted_hidden_states = torch.roll(hidden_states, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_hidden_states = hidden_states + + # partition windows + hidden_states_windows = window_partition(shifted_hidden_states, self.window_size) + hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels) + attn_mask = self.get_attn_mask(height_pad, width_pad, dtype=hidden_states.dtype) + if attn_mask is not None: + attn_mask = attn_mask.to(hidden_states_windows.device) + + attention_outputs = self.attention( + hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions + ) + + attention_output = attention_outputs[0] + + attention_windows = attention_output.view(-1, self.window_size, self.window_size, channels) + shifted_windows = window_reverse(attention_windows, self.window_size, height_pad, width_pad) + + # reverse cyclic shift + if self.shift_size > 0: + attention_windows = torch.roll(shifted_windows, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + attention_windows = shifted_windows + + was_padded = pad_values[3] > 0 or pad_values[5] > 0 + if was_padded: + attention_windows = attention_windows[:, :height, :width, :].contiguous() + + attention_windows = attention_windows.view(batch_size, height * width, channels) + hidden_states = self.layernorm_before(attention_windows) + hidden_states = shortcut + self.drop_path(hidden_states) + + layer_output = self.intermediate(hidden_states) + layer_output = self.output(layer_output) + layer_output = hidden_states + self.drop_path(self.layernorm_after(layer_output)) + + layer_outputs = (layer_output, attention_outputs[1]) if output_attentions else (layer_output,) + return layer_outputs + + +class Swin2SRStage(nn.Module): + """ + This corresponds to the Residual Swin Transformer Block (RSTB) in the original implementation. + """ + + def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, pretrained_window_size=0): + super().__init__() + self.config = config + self.dim = dim + self.layers = nn.ModuleList( + [ + Swin2SRLayer( + config=config, + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + shift_size=0 if (i % 2 == 0) else config.window_size // 2, + pretrained_window_size=pretrained_window_size, + ) + for i in range(depth) + ] + ) + + if config.resi_connection == "1conv": + self.conv = nn.Conv2d(dim, dim, 3, 1, 1) + elif config.resi_connection == "3conv": + # to save parameters and memory + self.conv = nn.Sequential( + nn.Conv2d(dim, dim // 4, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim, 3, 1, 1), + ) + + self.patch_embed = Swin2SRPatchEmbeddings(config, normalize_patches=False) + + self.patch_unembed = Swin2SRPatchUnEmbeddings(config) + + def forward( + self, + hidden_states: torch.Tensor, + input_dimensions: Tuple[int, int], + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + residual = hidden_states + + height, width = input_dimensions + for i, layer_module in enumerate(self.layers): + layer_head_mask = head_mask[i] if head_mask is not None else None + + layer_outputs = layer_module(hidden_states, input_dimensions, layer_head_mask, output_attentions) + + hidden_states = layer_outputs[0] + + output_dimensions = (height, width, height, width) + + hidden_states = self.patch_unembed(hidden_states, input_dimensions) + hidden_states = self.conv(hidden_states) + hidden_states, _ = self.patch_embed(hidden_states) + + hidden_states = hidden_states + residual + + stage_outputs = (hidden_states, output_dimensions) + + if output_attentions: + stage_outputs += layer_outputs[1:] + return stage_outputs + + +class Swin2SREncoder(nn.Module): + def __init__(self, config, grid_size): + super().__init__() + self.num_stages = len(config.depths) + self.config = config + dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")] + self.stages = nn.ModuleList( + [ + Swin2SRStage( + config=config, + dim=config.embed_dim, + input_resolution=(grid_size[0], grid_size[1]), + depth=config.depths[stage_idx], + num_heads=config.num_heads[stage_idx], + drop_path=dpr[sum(config.depths[:stage_idx]) : sum(config.depths[: stage_idx + 1])], + pretrained_window_size=0, + ) + for stage_idx in range(self.num_stages) + ] + ) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + input_dimensions: Tuple[int, int], + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + ) -> Union[Tuple, Swin2SREncoderOutput]: + all_input_dimensions = () + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + for i, stage_module in enumerate(self.stages): + layer_head_mask = head_mask[i] if head_mask is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + stage_module.__call__, hidden_states, input_dimensions, layer_head_mask, output_attentions + ) + else: + layer_outputs = stage_module(hidden_states, input_dimensions, layer_head_mask, output_attentions) + + hidden_states = layer_outputs[0] + output_dimensions = layer_outputs[1] + + input_dimensions = (output_dimensions[-2], output_dimensions[-1]) + all_input_dimensions += (input_dimensions,) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if output_attentions: + all_self_attentions += layer_outputs[2:] + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + + return Swin2SREncoderOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class Swin2SRPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = Swin2SRConfig + base_model_prefix = "swin2sr" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + torch.nn.init.trunc_normal_(module.weight.data, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +SWIN2SR_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use + it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`Swin2SRConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +SWIN2SR_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See + [`Swin2SRImageProcessor.__call__`] for details. + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare Swin2SR Model transformer outputting raw hidden-states without any specific head on top.", + SWIN2SR_START_DOCSTRING, +) +class Swin2SRModel(Swin2SRPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.config = config + + if config.num_channels == 3 and config.num_channels_out == 3: + mean = torch.tensor([0.4488, 0.4371, 0.4040]).view(1, 3, 1, 1) + else: + mean = torch.zeros(1, 1, 1, 1) + self.register_buffer("mean", mean, persistent=False) + + self.img_range = config.img_range + + self.first_convolution = nn.Conv2d(config.num_channels, config.embed_dim, 3, 1, 1) + self.embeddings = Swin2SREmbeddings(config) + self.encoder = Swin2SREncoder(config, grid_size=self.embeddings.patch_embeddings.patches_resolution) + + self.layernorm = nn.LayerNorm(config.embed_dim, eps=config.layer_norm_eps) + self.patch_unembed = Swin2SRPatchUnEmbeddings(config) + self.conv_after_body = nn.Conv2d(config.embed_dim, config.embed_dim, 3, 1, 1) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.patch_embeddings + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def pad_and_normalize(self, pixel_values): + _, _, height, width = pixel_values.size() + + # 1. pad + window_size = self.config.window_size + modulo_pad_height = (window_size - height % window_size) % window_size + modulo_pad_width = (window_size - width % window_size) % window_size + pixel_values = nn.functional.pad(pixel_values, (0, modulo_pad_width, 0, modulo_pad_height), "reflect") + + # 2. normalize + mean = self.mean.type_as(pixel_values) + pixel_values = (pixel_values - mean) * self.img_range + + return pixel_values + + @add_start_docstrings_to_model_forward(SWIN2SR_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutput, + config_class=_CONFIG_FOR_DOC, + modality="vision", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward( + self, + pixel_values: torch.FloatTensor, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, len(self.config.depths)) + + _, _, height, width = pixel_values.shape + + # some preprocessing: padding + normalization + pixel_values = self.pad_and_normalize(pixel_values) + + embeddings = self.first_convolution(pixel_values) + embedding_output, input_dimensions = self.embeddings(embeddings) + + encoder_outputs = self.encoder( + embedding_output, + input_dimensions, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + + sequence_output = self.patch_unembed(sequence_output, (height, width)) + sequence_output = self.conv_after_body(sequence_output) + embeddings + + if not return_dict: + output = (sequence_output,) + encoder_outputs[1:] + + return output + + return BaseModelOutput( + last_hidden_state=sequence_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class Upsample(nn.Module): + """Upsample module. + + Args: + scale (`int`): + Scale factor. Supported scales: 2^n and 3. + num_features (`int`): + Channel number of intermediate features. + """ + + def __init__(self, scale, num_features): + super().__init__() + + self.scale = scale + if (scale & (scale - 1)) == 0: + # scale = 2^n + for i in range(int(math.log(scale, 2))): + self.add_module(f"convolution_{i}", nn.Conv2d(num_features, 4 * num_features, 3, 1, 1)) + self.add_module(f"pixelshuffle_{i}", nn.PixelShuffle(2)) + elif scale == 3: + self.convolution = nn.Conv2d(num_features, 9 * num_features, 3, 1, 1) + self.pixelshuffle = nn.PixelShuffle(3) + else: + raise ValueError(f"Scale {scale} is not supported. Supported scales: 2^n and 3.") + + def forward(self, hidden_state): + if (self.scale & (self.scale - 1)) == 0: + for i in range(int(math.log(self.scale, 2))): + hidden_state = self.__getattr__(f"convolution_{i}")(hidden_state) + hidden_state = self.__getattr__(f"pixelshuffle_{i}")(hidden_state) + + elif self.scale == 3: + hidden_state = self.convolution(hidden_state) + hidden_state = self.pixelshuffle(hidden_state) + + return hidden_state + + +class UpsampleOneStep(nn.Module): + """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) + + Used in lightweight SR to save parameters. + + Args: + scale (int): + Scale factor. Supported scales: 2^n and 3. + in_channels (int): + Channel number of intermediate features. + out_channels (int): + Channel number of output features. + """ + + def __init__(self, scale, in_channels, out_channels): + super().__init__() + + self.conv = nn.Conv2d(in_channels, (scale**2) * out_channels, 3, 1, 1) + self.pixel_shuffle = nn.PixelShuffle(scale) + + def forward(self, x): + x = self.conv(x) + x = self.pixel_shuffle(x) + + return x + + +class PixelShuffleUpsampler(nn.Module): + def __init__(self, config, num_features): + super().__init__() + self.conv_before_upsample = nn.Conv2d(config.embed_dim, num_features, 3, 1, 1) + self.activation = nn.LeakyReLU(inplace=True) + self.upsample = Upsample(config.upscale, num_features) + self.final_convolution = nn.Conv2d(num_features, config.num_channels_out, 3, 1, 1) + + def forward(self, sequence_output): + x = self.conv_before_upsample(sequence_output) + x = self.activation(x) + x = self.upsample(x) + x = self.final_convolution(x) + + return x + + +class NearestConvUpsampler(nn.Module): + def __init__(self, config, num_features): + super().__init__() + if config.upscale != 4: + raise ValueError("The nearest+conv upsampler only supports an upscale factor of 4 at the moment.") + + self.conv_before_upsample = nn.Conv2d(config.embed_dim, num_features, 3, 1, 1) + self.activation = nn.LeakyReLU(inplace=True) + self.conv_up1 = nn.Conv2d(num_features, num_features, 3, 1, 1) + self.conv_up2 = nn.Conv2d(num_features, num_features, 3, 1, 1) + self.conv_hr = nn.Conv2d(num_features, num_features, 3, 1, 1) + self.final_convolution = nn.Conv2d(num_features, config.num_channels_out, 3, 1, 1) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + def forward(self, sequence_output): + sequence_output = self.conv_before_upsample(sequence_output) + sequence_output = self.activation(sequence_output) + sequence_output = self.lrelu( + self.conv_up1(torch.nn.functional.interpolate(sequence_output, scale_factor=2, mode="nearest")) + ) + sequence_output = self.lrelu( + self.conv_up2(torch.nn.functional.interpolate(sequence_output, scale_factor=2, mode="nearest")) + ) + reconstruction = self.final_convolution(self.lrelu(self.conv_hr(sequence_output))) + return reconstruction + + +class PixelShuffleAuxUpsampler(nn.Module): + def __init__(self, config, num_features): + super().__init__() + + self.upscale = config.upscale + self.conv_bicubic = nn.Conv2d(config.num_channels, num_features, 3, 1, 1) + self.conv_before_upsample = nn.Conv2d(config.embed_dim, num_features, 3, 1, 1) + self.activation = nn.LeakyReLU(inplace=True) + self.conv_aux = nn.Conv2d(num_features, config.num_channels, 3, 1, 1) + self.conv_after_aux = nn.Sequential(nn.Conv2d(3, num_features, 3, 1, 1), nn.LeakyReLU(inplace=True)) + self.upsample = Upsample(config.upscale, num_features) + self.final_convolution = nn.Conv2d(num_features, config.num_channels_out, 3, 1, 1) + + def forward(self, sequence_output, bicubic, height, width): + bicubic = self.conv_bicubic(bicubic) + sequence_output = self.conv_before_upsample(sequence_output) + sequence_output = self.activation(sequence_output) + aux = self.conv_aux(sequence_output) + sequence_output = self.conv_after_aux(aux) + sequence_output = ( + self.upsample(sequence_output)[:, :, : height * self.upscale, : width * self.upscale] + + bicubic[:, :, : height * self.upscale, : width * self.upscale] + ) + reconstruction = self.final_convolution(sequence_output) + + return reconstruction, aux + + +@add_start_docstrings( + """ + Swin2SR Model transformer with an upsampler head on top for image super resolution and restoration. + """, + SWIN2SR_START_DOCSTRING, +) +class Swin2SRForImageSuperResolution(Swin2SRPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.swin2sr = Swin2SRModel(config) + self.upsampler = config.upsampler + self.upscale = config.upscale + + # Upsampler + num_features = 64 + if self.upsampler == "pixelshuffle": + self.upsample = PixelShuffleUpsampler(config, num_features) + elif self.upsampler == "pixelshuffle_aux": + self.upsample = PixelShuffleAuxUpsampler(config, num_features) + elif self.upsampler == "pixelshuffledirect": + # for lightweight SR (to save parameters) + self.upsample = UpsampleOneStep(config.upscale, config.embed_dim, config.num_channels_out) + elif self.upsampler == "nearest+conv": + # for real-world SR (less artifacts) + self.upsample = NearestConvUpsampler(config, num_features) + else: + # for image denoising and JPEG compression artifact reduction + self.final_convolution = nn.Conv2d(config.embed_dim, config.num_channels_out, 3, 1, 1) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(SWIN2SR_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=ImageSuperResolutionOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, ImageSuperResolutionOutput]: + r""" + Returns: + + Example: + ```python + >>> import torch + >>> import numpy as np + >>> from PIL import Image + >>> import requests + + >>> from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution + + >>> processor = AutoImageProcessor.from_pretrained("caidas/swin2SR-classical-sr-x2-64") + >>> model = Swin2SRForImageSuperResolution.from_pretrained("caidas/swin2SR-classical-sr-x2-64") + + >>> url = "https://huggingface.co/spaces/jjourney1125/swin2sr/resolve/main/samples/butterfly.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> # prepare image for the model + >>> inputs = processor(image, return_tensors="pt") + + >>> # forward pass + >>> with torch.no_grad(): + ... outputs = model(**inputs) + + >>> output = outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1).numpy() + >>> output = np.moveaxis(output, source=0, destination=-1) + >>> output = (output * 255.0).round().astype(np.uint8) # float32 to uint8 + >>> # you can visualize `output` with `Image.fromarray` + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + loss = None + if labels is not None: + raise NotImplementedError("Training is not supported at the moment") + + height, width = pixel_values.shape[2:] + + if self.config.upsampler == "pixelshuffle_aux": + bicubic = nn.functional.interpolate( + pixel_values, + size=(height * self.upscale, width * self.upscale), + mode="bicubic", + align_corners=False, + ) + + outputs = self.swin2sr( + pixel_values, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + if self.upsampler in ["pixelshuffle", "pixelshuffledirect", "nearest+conv"]: + reconstruction = self.upsample(sequence_output) + elif self.upsampler == "pixelshuffle_aux": + reconstruction, aux = self.upsample(sequence_output, bicubic, height, width) + aux = aux / self.swin2sr.img_range + self.swin2sr.mean + else: + reconstruction = pixel_values + self.final_convolution(sequence_output) + + reconstruction = reconstruction / self.swin2sr.img_range + self.swin2sr.mean + reconstruction = reconstruction[:, :, : height * self.upscale, : width * self.upscale] + + if not return_dict: + output = (reconstruction,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return ImageSuperResolutionOutput( + loss=loss, + reconstruction=reconstruction, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = ["Swin2SRForImageSuperResolution", "Swin2SRModel", "Swin2SRPreTrainedModel"] diff --git a/docs/transformers/src/transformers/models/swinv2/__init__.py b/docs/transformers/src/transformers/models/swinv2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..172f1573f1eb1bca4893b2495879346c0b044b9c --- /dev/null +++ b/docs/transformers/src/transformers/models/swinv2/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_swinv2 import * + from .modeling_swinv2 import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/swinv2/configuration_swinv2.py b/docs/transformers/src/transformers/models/swinv2/configuration_swinv2.py new file mode 100644 index 0000000000000000000000000000000000000000..addb30e6a102dcdb835c90fe688045000922f21c --- /dev/null +++ b/docs/transformers/src/transformers/models/swinv2/configuration_swinv2.py @@ -0,0 +1,159 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Swinv2 Transformer model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices + + +logger = logging.get_logger(__name__) + + +class Swinv2Config(BackboneConfigMixin, PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Swinv2Model`]. It is used to instantiate a Swin + Transformer v2 model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the Swin Transformer v2 + [microsoft/swinv2-tiny-patch4-window8-256](https://huggingface.co/microsoft/swinv2-tiny-patch4-window8-256) + architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 4): + The size (resolution) of each patch. + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + embed_dim (`int`, *optional*, defaults to 96): + Dimensionality of patch embedding. + depths (`list(int)`, *optional*, defaults to `[2, 2, 6, 2]`): + Depth of each layer in the Transformer encoder. + num_heads (`list(int)`, *optional*, defaults to `[3, 6, 12, 24]`): + Number of attention heads in each layer of the Transformer encoder. + window_size (`int`, *optional*, defaults to 7): + Size of windows. + pretrained_window_sizes (`list(int)`, *optional*, defaults to `[0, 0, 0, 0]`): + Size of windows during pretraining. + mlp_ratio (`float`, *optional*, defaults to 4.0): + Ratio of MLP hidden dimensionality to embedding dimensionality. + qkv_bias (`bool`, *optional*, defaults to `True`): + Whether or not a learnable bias should be added to the queries, keys and values. + hidden_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings and encoder. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + drop_path_rate (`float`, *optional*, defaults to 0.1): + Stochastic depth rate. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder. If string, `"gelu"`, `"relu"`, + `"selu"` and `"gelu_new"` are supported. + use_absolute_embeddings (`bool`, *optional*, defaults to `False`): + Whether or not to add absolute position embeddings to the patch embeddings. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the layer normalization layers. + encoder_stride (`int`, *optional*, defaults to 32): + Factor to increase the spatial resolution by in the decoder head for masked image modeling. + out_features (`List[str]`, *optional*): + If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc. + (depending on how many stages the model has). If unset and `out_indices` is set, will default to the + corresponding stages. If unset and `out_indices` is unset, will default to the last stage. + out_indices (`List[int]`, *optional*): + If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how + many stages the model has). If unset and `out_features` is set, will default to the corresponding stages. + If unset and `out_features` is unset, will default to the last stage. + + Example: + + ```python + >>> from transformers import Swinv2Config, Swinv2Model + + >>> # Initializing a Swinv2 microsoft/swinv2-tiny-patch4-window8-256 style configuration + >>> configuration = Swinv2Config() + + >>> # Initializing a model (with random weights) from the microsoft/swinv2-tiny-patch4-window8-256 style configuration + >>> model = Swinv2Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "swinv2" + + attribute_map = { + "num_attention_heads": "num_heads", + "num_hidden_layers": "num_layers", + } + + def __init__( + self, + image_size=224, + patch_size=4, + num_channels=3, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + pretrained_window_sizes=[0, 0, 0, 0], + mlp_ratio=4.0, + qkv_bias=True, + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + drop_path_rate=0.1, + hidden_act="gelu", + use_absolute_embeddings=False, + initializer_range=0.02, + layer_norm_eps=1e-5, + encoder_stride=32, + out_features=None, + out_indices=None, + **kwargs, + ): + super().__init__(**kwargs) + + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.embed_dim = embed_dim + self.depths = depths + self.num_layers = len(depths) + self.num_heads = num_heads + self.window_size = window_size + self.pretrained_window_sizes = pretrained_window_sizes + self.mlp_ratio = mlp_ratio + self.qkv_bias = qkv_bias + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.drop_path_rate = drop_path_rate + self.hidden_act = hidden_act + self.use_absolute_embeddings = use_absolute_embeddings + self.layer_norm_eps = layer_norm_eps + self.initializer_range = initializer_range + self.encoder_stride = encoder_stride + self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)] + self._out_features, self._out_indices = get_aligned_output_features_output_indices( + out_features=out_features, out_indices=out_indices, stage_names=self.stage_names + ) + # we set the hidden_size attribute in order to make Swinv2 work with VisionEncoderDecoderModel + # this indicates the channel dimension after the last stage of the model + self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1)) + + +__all__ = ["Swinv2Config"] diff --git a/docs/transformers/src/transformers/models/swinv2/convert_swinv2_timm_to_pytorch.py b/docs/transformers/src/transformers/models/swinv2/convert_swinv2_timm_to_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..0e6e837a7e7e374cdb5ef7e7449c44c4857ab41f --- /dev/null +++ b/docs/transformers/src/transformers/models/swinv2/convert_swinv2_timm_to_pytorch.py @@ -0,0 +1,219 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert Swinv2 checkpoints from the timm library.""" + +import argparse +import json +from pathlib import Path + +import requests +import timm +import torch +from huggingface_hub import hf_hub_download +from PIL import Image + +from transformers import AutoImageProcessor, Swinv2Config, Swinv2ForImageClassification + + +def get_swinv2_config(swinv2_name): + config = Swinv2Config() + name_split = swinv2_name.split("_") + + model_size = name_split[1] + if "to" in name_split[3]: + img_size = int(name_split[3][-3:]) + else: + img_size = int(name_split[3]) + if "to" in name_split[2]: + window_size = int(name_split[2][-2:]) + else: + window_size = int(name_split[2][6:]) + + if model_size == "tiny": + embed_dim = 96 + depths = (2, 2, 6, 2) + num_heads = (3, 6, 12, 24) + elif model_size == "small": + embed_dim = 96 + depths = (2, 2, 18, 2) + num_heads = (3, 6, 12, 24) + elif model_size == "base": + embed_dim = 128 + depths = (2, 2, 18, 2) + num_heads = (4, 8, 16, 32) + else: + embed_dim = 192 + depths = (2, 2, 18, 2) + num_heads = (6, 12, 24, 48) + + if "to" in swinv2_name: + config.pretrained_window_sizes = (12, 12, 12, 6) + + if ("22k" in swinv2_name) and ("to" not in swinv2_name): + num_classes = 21841 + repo_id = "huggingface/label-files" + filename = "imagenet-22k-id2label.json" + id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r")) + id2label = {int(k): v for k, v in id2label.items()} + config.id2label = id2label + config.label2id = {v: k for k, v in id2label.items()} + + else: + num_classes = 1000 + repo_id = "huggingface/label-files" + filename = "imagenet-1k-id2label.json" + id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r")) + id2label = {int(k): v for k, v in id2label.items()} + config.id2label = id2label + config.label2id = {v: k for k, v in id2label.items()} + + config.image_size = img_size + config.num_labels = num_classes + config.embed_dim = embed_dim + config.depths = depths + config.num_heads = num_heads + config.window_size = window_size + + return config + + +def rename_key(name): + if "patch_embed.proj" in name: + name = name.replace("patch_embed.proj", "embeddings.patch_embeddings.projection") + if "patch_embed.norm" in name: + name = name.replace("patch_embed.norm", "embeddings.norm") + if "layers" in name: + name = "encoder." + name + if "attn.proj" in name: + name = name.replace("attn.proj", "attention.output.dense") + if "attn" in name: + name = name.replace("attn", "attention.self") + if "norm1" in name: + name = name.replace("norm1", "layernorm_before") + if "norm2" in name: + name = name.replace("norm2", "layernorm_after") + if "mlp.fc1" in name: + name = name.replace("mlp.fc1", "intermediate.dense") + if "mlp.fc2" in name: + name = name.replace("mlp.fc2", "output.dense") + if "q_bias" in name: + name = name.replace("q_bias", "query.bias") + if "k_bias" in name: + name = name.replace("k_bias", "key.bias") + if "v_bias" in name: + name = name.replace("v_bias", "value.bias") + if "cpb_mlp" in name: + name = name.replace("cpb_mlp", "continuous_position_bias_mlp") + if name == "norm.weight": + name = "layernorm.weight" + if name == "norm.bias": + name = "layernorm.bias" + + if "head" in name: + name = name.replace("head", "classifier") + else: + name = "swinv2." + name + + return name + + +def convert_state_dict(orig_state_dict, model): + for key in orig_state_dict.copy().keys(): + val = orig_state_dict.pop(key) + + if "mask" in key: + continue + elif "qkv" in key: + key_split = key.split(".") + layer_num = int(key_split[1]) + block_num = int(key_split[3]) + dim = model.swinv2.encoder.layers[layer_num].blocks[block_num].attention.self.all_head_size + + if "weight" in key: + orig_state_dict[ + f"swinv2.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.query.weight" + ] = val[:dim, :] + orig_state_dict[f"swinv2.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.key.weight"] = ( + val[dim : dim * 2, :] + ) + orig_state_dict[ + f"swinv2.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.value.weight" + ] = val[-dim:, :] + else: + orig_state_dict[f"swinv2.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.query.bias"] = ( + val[:dim] + ) + orig_state_dict[f"swinv2.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.key.bias"] = val[ + dim : dim * 2 + ] + orig_state_dict[f"swinv2.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.value.bias"] = ( + val[-dim:] + ) + else: + orig_state_dict[rename_key(key)] = val + + return orig_state_dict + + +def convert_swinv2_checkpoint(swinv2_name, pytorch_dump_folder_path): + timm_model = timm.create_model(swinv2_name, pretrained=True) + timm_model.eval() + + config = get_swinv2_config(swinv2_name) + model = Swinv2ForImageClassification(config) + model.eval() + + new_state_dict = convert_state_dict(timm_model.state_dict(), model) + model.load_state_dict(new_state_dict) + + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + + image_processor = AutoImageProcessor.from_pretrained("microsoft/{}".format(swinv2_name.replace("_", "-"))) + image = Image.open(requests.get(url, stream=True).raw) + inputs = image_processor(images=image, return_tensors="pt") + + timm_outs = timm_model(inputs["pixel_values"]) + hf_outs = model(**inputs).logits + + assert torch.allclose(timm_outs, hf_outs, atol=1e-3) + + print(f"Saving model {swinv2_name} to {pytorch_dump_folder_path}") + model.save_pretrained(pytorch_dump_folder_path) + + print(f"Saving image processor to {pytorch_dump_folder_path}") + image_processor.save_pretrained(pytorch_dump_folder_path) + + model.push_to_hub( + repo_path_or_name=Path(pytorch_dump_folder_path, swinv2_name), + organization="nandwalritik", + commit_message="Add model", + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--swinv2_name", + default="swinv2_tiny_patch4_window8_256", + type=str, + help="Name of the Swinv2 timm model you'd like to convert.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory." + ) + + args = parser.parse_args() + convert_swinv2_checkpoint(args.swinv2_name, args.pytorch_dump_folder_path) diff --git a/docs/transformers/src/transformers/models/swinv2/modeling_swinv2.py b/docs/transformers/src/transformers/models/swinv2/modeling_swinv2.py new file mode 100644 index 0000000000000000000000000000000000000000..52e33cf3d99baba12046a806da7c3b8007c4b1ac --- /dev/null +++ b/docs/transformers/src/transformers/models/swinv2/modeling_swinv2.py @@ -0,0 +1,1458 @@ +# coding=utf-8 +# Copyright 2022 Microsoft Research and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Swinv2 Transformer model.""" + +import collections.abc +import math +import warnings +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import Tensor, nn + +from ...activations import ACT2FN +from ...modeling_outputs import BackboneOutput +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer +from ...utils import ( + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, + torch_int, +) +from ...utils.backbone_utils import BackboneMixin +from .configuration_swinv2 import Swinv2Config + + +logger = logging.get_logger(__name__) + +# General docstring +_CONFIG_FOR_DOC = "Swinv2Config" + +# Base docstring +_CHECKPOINT_FOR_DOC = "microsoft/swinv2-tiny-patch4-window8-256" +_EXPECTED_OUTPUT_SHAPE = [1, 64, 768] + +# Image classification docstring +_IMAGE_CLASS_CHECKPOINT = "microsoft/swinv2-tiny-patch4-window8-256" +_IMAGE_CLASS_EXPECTED_OUTPUT = "Egyptian cat" + + +# drop_path, Swinv2PatchEmbeddings, Swinv2PatchMerging and Swinv2DropPath are from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/swin_transformer_v2.py. + + +@dataclass +# Copied from transformers.models.swin.modeling_swin.SwinEncoderOutput with Swin->Swinv2 +class Swinv2EncoderOutput(ModelOutput): + """ + Swinv2 encoder's outputs, with potential hidden states and attentions. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, hidden_size, height, width)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to + include the spatial dimensions. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +# Copied from transformers.models.swin.modeling_swin.SwinModelOutput with Swin->Swinv2 +class Swinv2ModelOutput(ModelOutput): + """ + Swinv2 model's outputs that also contains a pooling of the last hidden states. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*, returned when `add_pooling_layer=True` is passed): + Average pooling of the last layer hidden-state. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, hidden_size, height, width)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to + include the spatial dimensions. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + pooler_output: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +# Copied from transformers.models.swin.modeling_swin.SwinMaskedImageModelingOutput with Swin->Swinv2 +class Swinv2MaskedImageModelingOutput(ModelOutput): + """ + Swinv2 masked image model outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided): + Masked image modeling (MLM) loss. + reconstruction (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Reconstructed pixel values. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, hidden_size, height, width)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to + include the spatial dimensions. + """ + + loss: Optional[torch.FloatTensor] = None + reconstruction: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + + @property + def logits(self): + warnings.warn( + "logits attribute is deprecated and will be removed in version 5 of Transformers." + " Please use the reconstruction attribute to retrieve the final output instead.", + FutureWarning, + ) + return self.reconstruction + + +@dataclass +# Copied from transformers.models.swin.modeling_swin.SwinImageClassifierOutput with Swin->Swinv2 +class Swinv2ImageClassifierOutput(ModelOutput): + """ + Swinv2 outputs for image classification. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, hidden_size, height, width)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to + include the spatial dimensions. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + + +# Copied from transformers.models.swin.modeling_swin.window_partition +def window_partition(input_feature, window_size): + """ + Partitions the given input into windows. + """ + batch_size, height, width, num_channels = input_feature.shape + input_feature = input_feature.view( + batch_size, height // window_size, window_size, width // window_size, window_size, num_channels + ) + windows = input_feature.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels) + return windows + + +# Copied from transformers.models.swin.modeling_swin.window_reverse +def window_reverse(windows, window_size, height, width): + """ + Merges windows to produce higher resolution features. + """ + num_channels = windows.shape[-1] + windows = windows.view(-1, height // window_size, width // window_size, window_size, window_size, num_channels) + windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, height, width, num_channels) + return windows + + +# Copied from transformers.models.swin.modeling_swin.drop_path +def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, + however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the + layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the + argument. + """ + if drop_prob == 0.0 or not training: + return input + keep_prob = 1 - drop_prob + shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) + random_tensor.floor_() # binarize + output = input.div(keep_prob) * random_tensor + return output + + +# Copied from transformers.models.swin.modeling_swin.SwinDropPath with Swin->Swinv2 +class Swinv2DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: Optional[float] = None) -> None: + super().__init__() + self.drop_prob = drop_prob + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return drop_path(hidden_states, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return "p={}".format(self.drop_prob) + + +# Copied from transformers.models.swin.modeling_swin.SwinEmbeddings with Swin->Swinv2 +class Swinv2Embeddings(nn.Module): + """ + Construct the patch and position embeddings. Optionally, also the mask token. + """ + + def __init__(self, config, use_mask_token=False): + super().__init__() + + self.patch_embeddings = Swinv2PatchEmbeddings(config) + num_patches = self.patch_embeddings.num_patches + self.patch_grid = self.patch_embeddings.grid_size + self.mask_token = nn.Parameter(torch.zeros(1, 1, config.embed_dim)) if use_mask_token else None + + if config.use_absolute_embeddings: + self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.embed_dim)) + else: + self.position_embeddings = None + + self.norm = nn.LayerNorm(config.embed_dim) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.patch_size = config.patch_size + self.config = config + + # Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing. + + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 + """ + + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embeddings.shape[1] - 1 + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: + return self.position_embeddings + + class_pos_embed = self.position_embeddings[:, :1] + patch_pos_embed = self.position_embeddings[:, 1:] + + dim = embeddings.shape[-1] + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(new_height, new_width), + mode="bicubic", + align_corners=False, + ) + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) + + def forward( + self, + pixel_values: Optional[torch.FloatTensor], + bool_masked_pos: Optional[torch.BoolTensor] = None, + interpolate_pos_encoding: bool = False, + ) -> Tuple[torch.Tensor]: + _, num_channels, height, width = pixel_values.shape + embeddings, output_dimensions = self.patch_embeddings(pixel_values) + embeddings = self.norm(embeddings) + batch_size, seq_len, _ = embeddings.size() + + if bool_masked_pos is not None: + mask_tokens = self.mask_token.expand(batch_size, seq_len, -1) + # replace the masked visual tokens by mask_tokens + mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens) + embeddings = embeddings * (1.0 - mask) + mask_tokens * mask + + if self.position_embeddings is not None: + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embeddings + + embeddings = self.dropout(embeddings) + + return embeddings, output_dimensions + + +# Copied from transformers.models.swin.modeling_swin.SwinPatchEmbeddings with Swin->Swinv2 +class Swinv2PatchEmbeddings(nn.Module): + """ + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__(self, config): + super().__init__() + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.embed_dim + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + self.grid_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1]) + + self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) + + def maybe_pad(self, pixel_values, height, width): + if width % self.patch_size[1] != 0: + pad_values = (0, self.patch_size[1] - width % self.patch_size[1]) + pixel_values = nn.functional.pad(pixel_values, pad_values) + if height % self.patch_size[0] != 0: + pad_values = (0, 0, 0, self.patch_size[0] - height % self.patch_size[0]) + pixel_values = nn.functional.pad(pixel_values, pad_values) + return pixel_values + + def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]: + _, num_channels, height, width = pixel_values.shape + # pad the input to be divisible by self.patch_size, if needed + pixel_values = self.maybe_pad(pixel_values, height, width) + embeddings = self.projection(pixel_values) + _, _, height, width = embeddings.shape + output_dimensions = (height, width) + embeddings = embeddings.flatten(2).transpose(1, 2) + + return embeddings, output_dimensions + + +class Swinv2PatchMerging(nn.Module): + """ + Patch Merging Layer. + + Args: + input_resolution (`Tuple[int]`): + Resolution of input feature. + dim (`int`): + Number of input channels. + norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`): + Normalization layer class. + """ + + def __init__(self, input_resolution: Tuple[int], dim: int, norm_layer: nn.Module = nn.LayerNorm) -> None: + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(2 * dim) + + def maybe_pad(self, input_feature, height, width): + should_pad = (height % 2 == 1) or (width % 2 == 1) + if should_pad: + pad_values = (0, 0, 0, width % 2, 0, height % 2) + input_feature = nn.functional.pad(input_feature, pad_values) + + return input_feature + + def forward(self, input_feature: torch.Tensor, input_dimensions: Tuple[int, int]) -> torch.Tensor: + height, width = input_dimensions + # `dim` is height * width + batch_size, dim, num_channels = input_feature.shape + + input_feature = input_feature.view(batch_size, height, width, num_channels) + # pad input to be disible by width and height, if needed + input_feature = self.maybe_pad(input_feature, height, width) + # [batch_size, height/2, width/2, num_channels] + input_feature_0 = input_feature[:, 0::2, 0::2, :] + # [batch_size, height/2, width/2, num_channels] + input_feature_1 = input_feature[:, 1::2, 0::2, :] + # [batch_size, height/2, width/2, num_channels] + input_feature_2 = input_feature[:, 0::2, 1::2, :] + # [batch_size, height/2, width/2, num_channels] + input_feature_3 = input_feature[:, 1::2, 1::2, :] + # [batch_size, height/2 * width/2, 4*num_channels] + input_feature = torch.cat([input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1) + input_feature = input_feature.view(batch_size, -1, 4 * num_channels) # [batch_size, height/2 * width/2, 4*C] + + input_feature = self.reduction(input_feature) + input_feature = self.norm(input_feature) + + return input_feature + + +class Swinv2SelfAttention(nn.Module): + def __init__(self, config, dim, num_heads, window_size, pretrained_window_size=[0, 0]): + super().__init__() + if dim % num_heads != 0: + raise ValueError( + f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})" + ) + + self.num_attention_heads = num_heads + self.attention_head_size = int(dim / num_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.window_size = ( + window_size if isinstance(window_size, collections.abc.Iterable) else (window_size, window_size) + ) + self.pretrained_window_size = pretrained_window_size + self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1)))) + # mlp to generate continuous relative position bias + self.continuous_position_bias_mlp = nn.Sequential( + nn.Linear(2, 512, bias=True), nn.ReLU(inplace=True), nn.Linear(512, num_heads, bias=False) + ) + + # get relative_coords_table + relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.int64).float() + relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.int64).float() + relative_coords_table = ( + torch.stack(meshgrid([relative_coords_h, relative_coords_w], indexing="ij")) + .permute(1, 2, 0) + .contiguous() + .unsqueeze(0) + ) # [1, 2*window_height - 1, 2*window_width - 1, 2] + if pretrained_window_size[0] > 0: + relative_coords_table[:, :, :, 0] /= pretrained_window_size[0] - 1 + relative_coords_table[:, :, :, 1] /= pretrained_window_size[1] - 1 + elif window_size > 1: + relative_coords_table[:, :, :, 0] /= self.window_size[0] - 1 + relative_coords_table[:, :, :, 1] /= self.window_size[1] - 1 + relative_coords_table *= 8 # normalize to -8, 8 + relative_coords_table = ( + torch.sign(relative_coords_table) * torch.log2(torch.abs(relative_coords_table) + 1.0) / math.log2(8) + ) + # set to same dtype as mlp weight + relative_coords_table = relative_coords_table.to(next(self.continuous_position_bias_mlp.parameters()).dtype) + self.register_buffer("relative_coords_table", relative_coords_table, persistent=False) + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij")) + coords_flatten = torch.flatten(coords, 1) + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + relative_coords[:, :, 0] += self.window_size[0] - 1 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) + self.register_buffer("relative_position_index", relative_position_index, persistent=False) + + self.query = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias) + self.key = nn.Linear(self.all_head_size, self.all_head_size, bias=False) + self.value = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias) + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + batch_size, dim, num_channels = hidden_states.shape + mixed_query_layer = self.query(hidden_states) + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + # cosine attention + attention_scores = nn.functional.normalize(query_layer, dim=-1) @ nn.functional.normalize( + key_layer, dim=-1 + ).transpose(-2, -1) + logit_scale = torch.clamp(self.logit_scale, max=math.log(1.0 / 0.01)).exp() + attention_scores = attention_scores * logit_scale + relative_position_bias_table = self.continuous_position_bias_mlp(self.relative_coords_table).view( + -1, self.num_attention_heads + ) + # [window_height*window_width,window_height*window_width,num_attention_heads] + relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1 + ) + # [num_attention_heads,window_height*window_width,window_height*window_width] + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + relative_position_bias = 16 * torch.sigmoid(relative_position_bias) + attention_scores = attention_scores + relative_position_bias.unsqueeze(0) + + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in Swinv2Model forward() function) + mask_shape = attention_mask.shape[0] + attention_scores = attention_scores.view( + batch_size // mask_shape, mask_shape, self.num_attention_heads, dim, dim + ) + attention_mask.unsqueeze(1).unsqueeze(0) + attention_scores = attention_scores + attention_mask.unsqueeze(1).unsqueeze(0) + attention_scores = attention_scores.view(-1, self.num_attention_heads, dim, dim) + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +# Copied from transformers.models.swin.modeling_swin.SwinSelfOutput with Swin->Swinv2 +class Swinv2SelfOutput(nn.Module): + def __init__(self, config, dim): + super().__init__() + self.dense = nn.Linear(dim, dim) + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + return hidden_states + + +class Swinv2Attention(nn.Module): + def __init__(self, config, dim, num_heads, window_size, pretrained_window_size=0): + super().__init__() + self.self = Swinv2SelfAttention( + config=config, + dim=dim, + num_heads=num_heads, + window_size=window_size, + pretrained_window_size=pretrained_window_size + if isinstance(pretrained_window_size, collections.abc.Iterable) + else (pretrained_window_size, pretrained_window_size), + ) + self.output = Swinv2SelfOutput(config, dim) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + self_outputs = self.self(hidden_states, attention_mask, head_mask, output_attentions) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.swin.modeling_swin.SwinIntermediate with Swin->Swinv2 +class Swinv2Intermediate(nn.Module): + def __init__(self, config, dim): + super().__init__() + self.dense = nn.Linear(dim, int(config.mlp_ratio * dim)) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +# Copied from transformers.models.swin.modeling_swin.SwinOutput with Swin->Swinv2 +class Swinv2Output(nn.Module): + def __init__(self, config, dim): + super().__init__() + self.dense = nn.Linear(int(config.mlp_ratio * dim), dim) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +class Swinv2Layer(nn.Module): + def __init__( + self, config, dim, input_resolution, num_heads, drop_path_rate=0.0, shift_size=0, pretrained_window_size=0 + ): + super().__init__() + self.input_resolution = input_resolution + window_size, shift_size = self._compute_window_shift( + (config.window_size, config.window_size), (shift_size, shift_size) + ) + self.window_size = window_size[0] + self.shift_size = shift_size[0] + self.attention = Swinv2Attention( + config=config, + dim=dim, + num_heads=num_heads, + window_size=self.window_size, + pretrained_window_size=pretrained_window_size + if isinstance(pretrained_window_size, collections.abc.Iterable) + else (pretrained_window_size, pretrained_window_size), + ) + self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps) + self.drop_path = Swinv2DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() + self.intermediate = Swinv2Intermediate(config, dim) + self.output = Swinv2Output(config, dim) + self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps) + + def _compute_window_shift(self, target_window_size, target_shift_size) -> Tuple[Tuple[int, int], Tuple[int, int]]: + window_size = [r if r <= w else w for r, w in zip(self.input_resolution, target_window_size)] + shift_size = [0 if r <= w else s for r, w, s in zip(self.input_resolution, window_size, target_shift_size)] + return window_size, shift_size + + def get_attn_mask(self, height, width, dtype): + if self.shift_size > 0: + # calculate attention mask for shifted window multihead self attention + img_mask = torch.zeros((1, height, width, 1), dtype=dtype) + height_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + width_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + count = 0 + for height_slice in height_slices: + for width_slice in width_slices: + img_mask[:, height_slice, width_slice, :] = count + count += 1 + + mask_windows = window_partition(img_mask, self.window_size) + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + else: + attn_mask = None + return attn_mask + + def maybe_pad(self, hidden_states, height, width): + pad_right = (self.window_size - width % self.window_size) % self.window_size + pad_bottom = (self.window_size - height % self.window_size) % self.window_size + pad_values = (0, 0, 0, pad_right, 0, pad_bottom) + hidden_states = nn.functional.pad(hidden_states, pad_values) + return hidden_states, pad_values + + def forward( + self, + hidden_states: torch.Tensor, + input_dimensions: Tuple[int, int], + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + height, width = input_dimensions + batch_size, _, channels = hidden_states.size() + shortcut = hidden_states + + # pad hidden_states to multiples of window size + hidden_states = hidden_states.view(batch_size, height, width, channels) + hidden_states, pad_values = self.maybe_pad(hidden_states, height, width) + _, height_pad, width_pad, _ = hidden_states.shape + # cyclic shift + if self.shift_size > 0: + shifted_hidden_states = torch.roll(hidden_states, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_hidden_states = hidden_states + + # partition windows + hidden_states_windows = window_partition(shifted_hidden_states, self.window_size) + hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels) + attn_mask = self.get_attn_mask(height_pad, width_pad, dtype=hidden_states.dtype) + if attn_mask is not None: + attn_mask = attn_mask.to(hidden_states_windows.device) + + attention_outputs = self.attention( + hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions + ) + + attention_output = attention_outputs[0] + + attention_windows = attention_output.view(-1, self.window_size, self.window_size, channels) + shifted_windows = window_reverse(attention_windows, self.window_size, height_pad, width_pad) + + # reverse cyclic shift + if self.shift_size > 0: + attention_windows = torch.roll(shifted_windows, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + attention_windows = shifted_windows + + was_padded = pad_values[3] > 0 or pad_values[5] > 0 + if was_padded: + attention_windows = attention_windows[:, :height, :width, :].contiguous() + + attention_windows = attention_windows.view(batch_size, height * width, channels) + hidden_states = self.layernorm_before(attention_windows) + hidden_states = shortcut + self.drop_path(hidden_states) + + layer_output = self.intermediate(hidden_states) + layer_output = self.output(layer_output) + layer_output = hidden_states + self.drop_path(self.layernorm_after(layer_output)) + + layer_outputs = (layer_output, attention_outputs[1]) if output_attentions else (layer_output,) + return layer_outputs + + +class Swinv2Stage(nn.Module): + def __init__( + self, config, dim, input_resolution, depth, num_heads, drop_path, downsample, pretrained_window_size=0 + ): + super().__init__() + self.config = config + self.dim = dim + blocks = [] + for i in range(depth): + block = Swinv2Layer( + config=config, + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + drop_path_rate=drop_path[i], + shift_size=0 if (i % 2 == 0) else config.window_size // 2, + pretrained_window_size=pretrained_window_size, + ) + blocks.append(block) + self.blocks = nn.ModuleList(blocks) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=nn.LayerNorm) + else: + self.downsample = None + + self.pointing = False + + def forward( + self, + hidden_states: torch.Tensor, + input_dimensions: Tuple[int, int], + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + height, width = input_dimensions + for i, layer_module in enumerate(self.blocks): + layer_head_mask = head_mask[i] if head_mask is not None else None + + layer_outputs = layer_module( + hidden_states, + input_dimensions, + layer_head_mask, + output_attentions, + ) + + hidden_states = layer_outputs[0] + + hidden_states_before_downsampling = hidden_states + if self.downsample is not None: + height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2 + output_dimensions = (height, width, height_downsampled, width_downsampled) + hidden_states = self.downsample(hidden_states_before_downsampling, input_dimensions) + else: + output_dimensions = (height, width, height, width) + + stage_outputs = (hidden_states, hidden_states_before_downsampling, output_dimensions) + + if output_attentions: + stage_outputs += layer_outputs[1:] + return stage_outputs + + +class Swinv2Encoder(nn.Module): + def __init__(self, config, grid_size, pretrained_window_sizes=(0, 0, 0, 0)): + super().__init__() + self.num_layers = len(config.depths) + self.config = config + if self.config.pretrained_window_sizes is not None: + pretrained_window_sizes = config.pretrained_window_sizes + dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")] + + layers = [] + for i_layer in range(self.num_layers): + stage = Swinv2Stage( + config=config, + dim=int(config.embed_dim * 2**i_layer), + input_resolution=(grid_size[0] // (2**i_layer), grid_size[1] // (2**i_layer)), + depth=config.depths[i_layer], + num_heads=config.num_heads[i_layer], + drop_path=dpr[sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])], + downsample=Swinv2PatchMerging if (i_layer < self.num_layers - 1) else None, + pretrained_window_size=pretrained_window_sizes[i_layer], + ) + layers.append(stage) + self.layers = nn.ModuleList(layers) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + input_dimensions: Tuple[int, int], + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + output_hidden_states_before_downsampling: Optional[bool] = False, + return_dict: Optional[bool] = True, + ) -> Union[Tuple, Swinv2EncoderOutput]: + all_hidden_states = () if output_hidden_states else None + all_reshaped_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + if output_hidden_states: + batch_size, _, hidden_size = hidden_states.shape + # rearrange b (h w) c -> b c h w + reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size) + reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2) + all_hidden_states += (hidden_states,) + all_reshaped_hidden_states += (reshaped_hidden_state,) + + for i, layer_module in enumerate(self.layers): + layer_head_mask = head_mask[i] if head_mask is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.__call__, hidden_states, input_dimensions, layer_head_mask + ) + else: + layer_outputs = layer_module( + hidden_states, + input_dimensions, + layer_head_mask, + output_attentions, + ) + + hidden_states = layer_outputs[0] + hidden_states_before_downsampling = layer_outputs[1] + output_dimensions = layer_outputs[2] + + input_dimensions = (output_dimensions[-2], output_dimensions[-1]) + + if output_hidden_states and output_hidden_states_before_downsampling: + batch_size, _, hidden_size = hidden_states_before_downsampling.shape + # rearrange b (h w) c -> b c h w + # here we use the original (not downsampled) height and width + reshaped_hidden_state = hidden_states_before_downsampling.view( + batch_size, *(output_dimensions[0], output_dimensions[1]), hidden_size + ) + reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2) + all_hidden_states += (hidden_states_before_downsampling,) + all_reshaped_hidden_states += (reshaped_hidden_state,) + elif output_hidden_states and not output_hidden_states_before_downsampling: + batch_size, _, hidden_size = hidden_states.shape + # rearrange b (h w) c -> b c h w + reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size) + reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2) + all_hidden_states += (hidden_states,) + all_reshaped_hidden_states += (reshaped_hidden_state,) + + if output_attentions: + all_self_attentions += layer_outputs[3:] + + if not return_dict: + return tuple( + v + for v in [hidden_states, all_hidden_states, all_self_attentions, all_reshaped_hidden_states] + if v is not None + ) + + return Swinv2EncoderOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + reshaped_hidden_states=all_reshaped_hidden_states, + ) + + +class Swinv2PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = Swinv2Config + base_model_prefix = "swinv2" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + _no_split_modules = ["Swinv2Stage"] + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, Swinv2Embeddings): + if module.mask_token is not None: + module.mask_token.data.zero_() + if module.position_embeddings is not None: + module.position_embeddings.data.zero_() + elif isinstance(module, Swinv2SelfAttention): + module.logit_scale.data.fill_(math.log(10)) + + +SWINV2_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use + it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`Swinv2Config`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +SWINV2_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`] + for details. + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + interpolate_pos_encoding (`bool`, *optional*, default `False`): + Whether to interpolate the pre-trained position encodings. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare Swinv2 Model transformer outputting raw hidden-states without any specific head on top.", + SWINV2_START_DOCSTRING, +) +# Copied from transformers.models.swin.modeling_swin.SwinModel with SWIN->SWINV2,Swin->Swinv2 +class Swinv2Model(Swinv2PreTrainedModel): + def __init__(self, config, add_pooling_layer=True, use_mask_token=False): + super().__init__(config) + self.config = config + self.num_layers = len(config.depths) + self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1)) + + self.embeddings = Swinv2Embeddings(config, use_mask_token=use_mask_token) + self.encoder = Swinv2Encoder(config, self.embeddings.patch_grid) + + self.layernorm = nn.LayerNorm(self.num_features, eps=config.layer_norm_eps) + self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.patch_embeddings + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(SWINV2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=Swinv2ModelOutput, + config_class=_CONFIG_FOR_DOC, + modality="vision", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, Swinv2ModelOutput]: + r""" + bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*): + Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, len(self.config.depths)) + + embedding_output, input_dimensions = self.embeddings( + pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding + ) + + encoder_outputs = self.encoder( + embedding_output, + input_dimensions, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + + pooled_output = None + if self.pooler is not None: + pooled_output = self.pooler(sequence_output.transpose(1, 2)) + pooled_output = torch.flatten(pooled_output, 1) + + if not return_dict: + output = (sequence_output, pooled_output) + encoder_outputs[1:] + + return output + + return Swinv2ModelOutput( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + reshaped_hidden_states=encoder_outputs.reshaped_hidden_states, + ) + + +@add_start_docstrings( + """Swinv2 Model with a decoder on top for masked image modeling, as proposed in +[SimMIM](https://arxiv.org/abs/2111.09886). + + + + Note that we provide a script to pre-train this model on custom data in our [examples + directory](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-pretraining). + + + """, + SWINV2_START_DOCSTRING, +) +# Copied from transformers.models.swin.modeling_swin.SwinForMaskedImageModeling with swin->swinv2, base-simmim-window6-192->tiny-patch4-window8-256,SWIN->SWINV2,Swin->Swinv2,192->256 +class Swinv2ForMaskedImageModeling(Swinv2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.swinv2 = Swinv2Model(config, add_pooling_layer=False, use_mask_token=True) + + num_features = int(config.embed_dim * 2 ** (config.num_layers - 1)) + self.decoder = nn.Sequential( + nn.Conv2d( + in_channels=num_features, out_channels=config.encoder_stride**2 * config.num_channels, kernel_size=1 + ), + nn.PixelShuffle(config.encoder_stride), + ) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(SWINV2_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Swinv2MaskedImageModelingOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, Swinv2MaskedImageModelingOutput]: + r""" + bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`): + Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). + + Returns: + + Examples: + ```python + >>> from transformers import AutoImageProcessor, Swinv2ForMaskedImageModeling + >>> import torch + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/swinv2-tiny-patch4-window8-256") + >>> model = Swinv2ForMaskedImageModeling.from_pretrained("microsoft/swinv2-tiny-patch4-window8-256") + + >>> num_patches = (model.config.image_size // model.config.patch_size) ** 2 + >>> pixel_values = image_processor(images=image, return_tensors="pt").pixel_values + >>> # create random boolean mask of shape (batch_size, num_patches) + >>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool() + + >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos) + >>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction + >>> list(reconstructed_pixel_values.shape) + [1, 3, 256, 256] + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.swinv2( + pixel_values, + bool_masked_pos=bool_masked_pos, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + # Reshape to (batch_size, num_channels, height, width) + sequence_output = sequence_output.transpose(1, 2) + batch_size, num_channels, sequence_length = sequence_output.shape + height = width = math.floor(sequence_length**0.5) + sequence_output = sequence_output.reshape(batch_size, num_channels, height, width) + + # Reconstruct pixel values + reconstructed_pixel_values = self.decoder(sequence_output) + + masked_im_loss = None + if bool_masked_pos is not None: + size = self.config.image_size // self.config.patch_size + bool_masked_pos = bool_masked_pos.reshape(-1, size, size) + mask = ( + bool_masked_pos.repeat_interleave(self.config.patch_size, 1) + .repeat_interleave(self.config.patch_size, 2) + .unsqueeze(1) + .contiguous() + ) + reconstruction_loss = nn.functional.l1_loss(pixel_values, reconstructed_pixel_values, reduction="none") + masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels + + if not return_dict: + output = (reconstructed_pixel_values,) + outputs[2:] + return ((masked_im_loss,) + output) if masked_im_loss is not None else output + + return Swinv2MaskedImageModelingOutput( + loss=masked_im_loss, + reconstruction=reconstructed_pixel_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + reshaped_hidden_states=outputs.reshaped_hidden_states, + ) + + +@add_start_docstrings( + """ + Swinv2 Model transformer with an image classification head on top (a linear layer on top of the final hidden state + of the [CLS] token) e.g. for ImageNet. + + + + Note that it's possible to fine-tune SwinV2 on higher resolution images than the ones it has been trained on, by + setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained + position embeddings to the higher resolution. + + + """, + SWINV2_START_DOCSTRING, +) +# Copied from transformers.models.swin.modeling_swin.SwinForImageClassification with SWIN->SWINV2,Swin->Swinv2,swin->swinv2 +class Swinv2ForImageClassification(Swinv2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.num_labels = config.num_labels + self.swinv2 = Swinv2Model(config) + + # Classifier head + self.classifier = ( + nn.Linear(self.swinv2.num_features, config.num_labels) if config.num_labels > 0 else nn.Identity() + ) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(SWINV2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_IMAGE_CLASS_CHECKPOINT, + output_type=Swinv2ImageClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, + ) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, Swinv2ImageClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.swinv2( + pixel_values, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, pooled_logits=logits, config=self.config) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return Swinv2ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + reshaped_hidden_states=outputs.reshaped_hidden_states, + ) + + +@add_start_docstrings( + """ + Swinv2 backbone, to be used with frameworks like DETR and MaskFormer. + """, + SWINV2_START_DOCSTRING, +) +class Swinv2Backbone(Swinv2PreTrainedModel, BackboneMixin): + def __init__(self, config): + super().__init__(config) + super()._init_backbone(config) + + self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))] + self.embeddings = Swinv2Embeddings(config) + self.encoder = Swinv2Encoder(config, self.embeddings.patch_grid) + + # initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.patch_embeddings + + @add_start_docstrings_to_model_forward(SWINV2_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: Tensor, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> BackboneOutput: + """ + Returns: + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, AutoBackbone + >>> import torch + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> processor = AutoImageProcessor.from_pretrained("microsoft/swinv2-tiny-patch4-window8-256") + >>> model = AutoBackbone.from_pretrained( + ... "microsoft/swinv2-tiny-patch4-window8-256", out_features=["stage1", "stage2", "stage3", "stage4"] + ... ) + + >>> inputs = processor(image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> feature_maps = outputs.feature_maps + >>> list(feature_maps[-1].shape) + [1, 2048, 7, 7] + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + embedding_output, input_dimensions = self.embeddings(pixel_values) + + outputs = self.encoder( + embedding_output, + input_dimensions, + head_mask=None, + output_attentions=output_attentions, + output_hidden_states=True, + output_hidden_states_before_downsampling=True, + return_dict=return_dict, + ) + + hidden_states = outputs.reshaped_hidden_states if return_dict else outputs[-1] + + feature_maps = () + for stage, hidden_state in zip(self.stage_names, hidden_states): + if stage in self.out_features: + feature_maps += (hidden_state,) + + if not return_dict: + output = (feature_maps,) + if output_hidden_states: + output += (outputs[1],) + if output_attentions: + output += (outputs[2],) + return output + + return BackboneOutput( + feature_maps=feature_maps, + hidden_states=outputs.hidden_states if output_hidden_states else None, + attentions=outputs.attentions, + ) + + +__all__ = [ + "Swinv2ForImageClassification", + "Swinv2ForMaskedImageModeling", + "Swinv2Model", + "Swinv2PreTrainedModel", + "Swinv2Backbone", +] diff --git a/docs/transformers/src/transformers/models/switch_transformers/__init__.py b/docs/transformers/src/transformers/models/switch_transformers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..70afa7006d73dd2ecf9c421edd78e230cae4c731 --- /dev/null +++ b/docs/transformers/src/transformers/models/switch_transformers/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_switch_transformers import * + from .modeling_switch_transformers import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/switch_transformers/configuration_switch_transformers.py b/docs/transformers/src/transformers/models/switch_transformers/configuration_switch_transformers.py new file mode 100644 index 0000000000000000000000000000000000000000..093148f601046bc47800ed80834bb7f545dc810b --- /dev/null +++ b/docs/transformers/src/transformers/models/switch_transformers/configuration_switch_transformers.py @@ -0,0 +1,185 @@ +# coding=utf-8 +# Copyright 2022, Google and HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Switch Transformers model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class SwitchTransformersConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`SwitchTransformersModel`]. It is used to + instantiate a SwitchTransformers model according to the specified arguments, defining the model architecture. + Instantiating a configuration with the defaults will yield a similar configuration to that of the + SwitchTransformers [google/switch-base-8](https://huggingface.co/google/switch-base-8) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Arguments: + vocab_size (`int`, *optional*, defaults to 32128): + Vocabulary size of the SwitchTransformers model. Defines the number of different tokens that can be + represented by the `inputs_ids` passed when calling [`SwitchTransformersModel`]. + d_model (`int`, *optional*, defaults to 768): + Size of the encoder layers and the pooler layer. + d_kv (`int`, *optional*, defaults to 64): + Size of the key, query, value projections per attention head. `d_kv` has to be equal to `d_model // + num_heads`. + d_ff (`int`, *optional*, defaults to 2048): + Size of the intermediate feed forward layer in each `SwitchTransformersBlock`. + expert_capacity (`int`, *optional*, defaults to 64): + Number of tokens that can be stored in each expert. If set to 1, the model will behave like a regular + Transformer. + num_layers (`int`, *optional*, defaults to 12): + Number of dense hidden layers in the Transformer encoder layer. + num_sparse_encoder_layers (`int`, *optional*, defaults to 3): + Number of sparse (MoE) dense hidden layers in the Transformer encoder layer. + num_decoder_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer decoder. Will use the same value as `num_layers` if not set. + num_sparse_decoder_layers (`int`, *optional*, defaults to 3): + Number of sparse (MoE) dense hidden layers in the Transformer decoder layer. + num_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + num_experts (`int`, *optional*, defaults to 8): + Number of experts for each SwitchTransformer layer. + router_bias (`bool`, *optional*, defaults to `False`): + Whether to add a bias to the router. + router_jitter_noise (`float`, *optional*, defaults to 0.01): + Amount of noise to add to the router. + router_dtype (`str`, *optional*, default to `"float32"`): + The `dtype` used for the routers. It is preferable to keep the `dtype` to `"float32"` as specified in the + *selective precision* discussion in [the paper](https://arxiv.org/abs/2101.03961). + router_ignore_padding_tokens (`bool`, *optional*, defaults to `False`): + Whether to ignore padding tokens when routing. + relative_attention_num_buckets (`int`, *optional*, defaults to 32): + The number of buckets to use for each attention layer. + relative_attention_max_distance (`int`, *optional*, defaults to 128): + The maximum distance of the longer sequences for the bucket separation. + dropout_rate (`float`, *optional*, defaults to 0.1): + The ratio for all dropout layers. + layer_norm_eps (`float`, *optional*, defaults to 1e-6): + The epsilon used by the layer normalization layers. + router_z_loss_coef (`float`, *optional*, defaults to 0.001): + The z loss factor for the total loss. + router_aux_loss_coef (`float`, *optional*, defaults to 0.001): + The aux loss factor for the total loss. + initializer_factor (`float`, *optional*, defaults to 1.0): + A factor for initializing all weight matrices (should be kept to 1, used internally for initialization + testing). + dense_act_fn (`string`, *optional*, defaults to `"relu"`): + Type of feed forward layer to be used. Should be one of `"relu"` or `"gated-gelu"`. SwitchTransformersv1.1 + uses the `"gated-gelu"` feed forward projection. Original SwitchTransformers uses `"relu"`. + add_router_probs (`bool`, *optional*, defaults to `False`): + Whether to output router probabilities to compute router auxiliary loss. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + """ + + model_type = "switch_transformers" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"} + + def __init__( + self, + vocab_size=32128, + d_model=768, + d_kv=64, + d_ff=2048, + expert_capacity=64, + num_layers=12, + num_sparse_encoder_layers=3, + num_decoder_layers=12, + num_sparse_decoder_layers=3, + num_heads=12, + num_experts=8, + router_bias=False, + router_jitter_noise=0.01, + router_dtype="float32", + router_ignore_padding_tokens=False, + relative_attention_num_buckets=32, + relative_attention_max_distance=128, + dropout_rate=0.1, + layer_norm_epsilon=1e-6, + router_z_loss_coef=0.001, + router_aux_loss_coef=0.001, + initializer_factor=1.0, + dense_act_fn="relu", + is_encoder_decoder=True, + add_router_probs=False, + use_cache=True, + pad_token_id=0, + eos_token_id=1, + **kwargs, + ): + self.vocab_size = vocab_size + self.d_model = d_model + self.d_kv = d_kv + self.d_ff = d_ff + + self.num_sparse_encoder_layers = num_sparse_encoder_layers + + self.num_layers = num_layers + self.num_decoder_layers = ( + num_decoder_layers if num_decoder_layers is not None else self.num_layers + ) # default = symmetry + self.num_sparse_decoder_layers = num_sparse_decoder_layers + + # This tells us, each how many encoder layer we'll have to set a sparse layer. + if self.num_sparse_encoder_layers > 0: + self.encoder_sparse_step = self.num_layers // self.num_sparse_encoder_layers + else: + self.encoder_sparse_step = self.num_layers # HACK: this will create 0 sparse layers + + # This tells us, each how many encoder layer we'll have to set a sparse layer. + if self.num_sparse_decoder_layers > 0: + self.decoder_sparse_step = self.num_decoder_layers // self.num_sparse_decoder_layers + else: + self.decoder_sparse_step = self.num_decoder_layers # HACK: this will create 0 sparse layers + + self.num_heads = num_heads + self.num_experts = num_experts + self.expert_capacity = expert_capacity + self.router_bias = router_bias + self.router_jitter_noise = router_jitter_noise + if router_dtype not in ["float32", "float16", "bfloat16"]: + raise ValueError(f"`router_dtype` must be one of 'float32', 'float16' or 'bfloat16', got {router_dtype}") + self.router_dtype = router_dtype + + self.router_ignore_padding_tokens = router_ignore_padding_tokens + self.relative_attention_num_buckets = relative_attention_num_buckets + self.relative_attention_max_distance = relative_attention_max_distance + + self.dropout_rate = dropout_rate + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_factor = initializer_factor + self.use_cache = use_cache + self.add_router_probs = add_router_probs + + self.router_z_loss_coef = router_z_loss_coef + self.router_aux_loss_coef = router_aux_loss_coef + self.dense_act_fn = dense_act_fn + + super().__init__( + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + is_encoder_decoder=is_encoder_decoder, + **kwargs, + ) + + +__all__ = ["SwitchTransformersConfig"] diff --git a/docs/transformers/src/transformers/models/switch_transformers/convert_big_switch.py b/docs/transformers/src/transformers/models/switch_transformers/convert_big_switch.py new file mode 100644 index 0000000000000000000000000000000000000000..6f422439fc7e5c0cd09b97affd99e74ba4c741c6 --- /dev/null +++ b/docs/transformers/src/transformers/models/switch_transformers/convert_big_switch.py @@ -0,0 +1,194 @@ +import argparse +import json +import os + +import tensorstore as ts +import torch +from flax import serialization +from flax.traverse_util import flatten_dict, unflatten_dict +from tensorflow.io import gfile + +from transformers.models.switch_transformers.convert_switch_transformers_original_flax_checkpoint_to_pytorch import ( + rename_keys, +) +from transformers.utils import WEIGHTS_INDEX_NAME, WEIGHTS_NAME +from transformers.utils.hub import convert_file_size_to_int + + +def rename_base_flax_keys(flax_key_tuple, flax_tensor): + """ + Post renaming of basic JAX keys to pytorch. + """ + if flax_key_tuple[-1] == "kernel" and flax_tensor.ndim == 3: + # expert layer + flax_key_tuple = flax_key_tuple[:-1] + ("weight",) + flax_tensor = torch.permute(flax_tensor, (0, 2, 1)) + elif flax_key_tuple[-1] == "kernel" and ".".join(flax_key_tuple): + # linear layer + flax_key_tuple = flax_key_tuple[:-1] + ("weight",) + flax_tensor = flax_tensor.T + elif flax_key_tuple[-1] in ["scale", "embedding"]: + flax_key_tuple = flax_key_tuple[:-1] + ("weight",) + + return flax_key_tuple, flax_tensor + + +def get_key_and_tensorstore_dict(layer, checkpoint_info, switch_checkpoint_path): + if "metadata" in layer: + split_layer = layer.split("metadata") + curr_real_layer_name = "".join(split_layer[0])[:-1] + split_layer = [tuple(("metadata" + split_layer[1]).split("/"))] + elif "kvstore" in layer: + split_layer = layer.split("kvstore") + curr_real_layer_name = "".join(split_layer[0])[:-1] + split_layer = [tuple(("kvstore" + split_layer[1]).split("/"))] + + else: + split_layer = layer.split("/") + curr_real_layer_name = "/".join(split_layer[:-1]) + split_layer[-1] = (split_layer[-1],) + + if "kvstore/path" in layer: + content = f"{switch_checkpoint_path}/{checkpoint_info[layer]}" + elif "kvstore/driver" in layer: + content = "file" + else: + content = checkpoint_info[layer] + + return curr_real_layer_name, split_layer, content + + +def rename_and_save_block(current_block, save_path): + current_block = rename_keys(current_block) + new_current_block = {} + for k, v in current_block.items(): + new_current_block[k.replace("/", ".")] = v + current_block = new_current_block + torch.save(current_block, save_path) + + +def shard_on_the_fly(switch_checkpoint_path, dump_path, max_shard_size, dtype, weights_name: str = WEIGHTS_NAME): + max_shard_size = convert_file_size_to_int(max_shard_size) + sharded_state_dicts = [] + current_block = {} + current_block_size = 0 + total_size = 0 + + os.makedirs(dump_path, exist_ok=True) + with gfile.GFile(switch_checkpoint_path + "/checkpoint", "rb") as fp: + checkpoint_info = serialization.msgpack_restore(fp.read())["optimizer"]["target"] + checkpoint_info = flatten_dict(checkpoint_info, sep="/") + + all_layers = {} + for layer in checkpoint_info.keys(): + curr_real_layer_name, split_layer, content = get_key_and_tensorstore_dict( + layer, checkpoint_info, switch_checkpoint_path + ) + if curr_real_layer_name in all_layers: + all_layers[curr_real_layer_name][split_layer[-1]] = content + else: + all_layers[curr_real_layer_name] = {split_layer[-1]: content} + + for key in all_layers.keys(): + # open tensorstore file + raw_weights = ts.open(unflatten_dict(all_layers[key])).result().read().result() + raw_weights = torch.tensor(raw_weights) + weight_size = raw_weights.numel() * raw_weights.element_size() + + # use the renaming pattern from the small conversion scripts + key, raw_weights = rename_base_flax_keys(tuple(key.split("/")), raw_weights) + key = "/".join(key) + + # If this weight is going to tip up over the maximal size, we split. + if current_block_size + weight_size > max_shard_size: + save_path = os.path.join( + dump_path, weights_name.replace(".bin", f"-{len(sharded_state_dicts) + 1:05d}-of-???.bin") + ) + rename_and_save_block(current_block, save_path) + sharded_state_dicts.append(current_block.keys()) + del current_block + current_block = {} + current_block_size = 0 + + current_block[key] = raw_weights.to(getattr(torch, dtype)) + current_block_size += weight_size + total_size += weight_size + + # Add the last block + save_path = os.path.join( + dump_path, weights_name.replace(".bin", f"-{len(sharded_state_dicts) + 1:05d}-of-???.bin") + ) + rename_and_save_block(current_block, save_path) + sharded_state_dicts.append(current_block.keys()) + + # If we only have one shard, we return it + if len(sharded_state_dicts) == 1: + return {weights_name: sharded_state_dicts[0]}, None + + # Otherwise, let's build the index + weight_map = {} + shards = {} + for idx, shard in enumerate(sharded_state_dicts): + shard_file = weights_name.replace( + ".bin", f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.bin" + ) # len(sharded_state_dicts):05d} + temp_filename = os.path.join(dump_path, weights_name.replace(".bin", f"-{idx + 1:05d}-of-???.bin")) + os.rename(temp_filename, os.path.join(dump_path, shard_file)) + shards[shard_file] = shard + for key in shard: + weight_map[key] = shard_file + + # Add the metadata + metadata = {"total_size": total_size} + index = {"metadata": metadata, "weight_map": weight_map} + + with open(os.path.join(dump_path, WEIGHTS_INDEX_NAME), "w", encoding="utf-8") as f: + content = json.dumps(index, indent=2, sort_keys=True) + "\n" + f.write(content) + + return metadata, index + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--switch_t5x_checkpoint_path", + default="/mnt/disks/disk_switch/original_checkpoints/switch-xxl-128/checkpoint_634600", + type=str, + required=False, + help="Path to a directory containing a folder per layer. Follows the original Google format.", + ) + parser.add_argument("--max_shard_size", default="10GB", required=False, help="Max shard size") + parser.add_argument("--dtype", default="bfloat16", type=str, required=False, help="dtype of the saved model") + parser.add_argument( + "--pytorch_dump_folder_path", + default="/mnt/disks/disk_switch/original_checkpoints/switch-xxl-128-converted", + type=str, + required=False, + help="Path to the output pytorch model.", + ) + args = parser.parse_args() + shard_on_the_fly( + args.switch_t5x_checkpoint_path, + args.pytorch_dump_folder_path, + args.max_shard_size, + args.dtype, + ) + + +def sanity_check(): + from transformers import SwitchTransformersConfig, SwitchTransformersForConditionalGeneration, T5Tokenizer + + config = SwitchTransformersConfig.from_pretrained("google/switch-base-8") + config.save_pretrained("/home/arthur_huggingface_co/transformers/switch_converted") + model = SwitchTransformersForConditionalGeneration.from_pretrained( + "/home/arthur_huggingface_co/transformers/switch_converted", device_map="auto" + ) + + tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-small") + text = "A walks into a bar a orders a with pinch of ." + + input_ids = tokenizer(text, return_tensors="pt").input_ids + out = model.generate(input_ids, decoder_start_token_id=0) + print(tokenizer.decode(out[0])) diff --git a/docs/transformers/src/transformers/models/switch_transformers/convert_switch_transformers_original_flax_checkpoint_to_pytorch.py b/docs/transformers/src/transformers/models/switch_transformers/convert_switch_transformers_original_flax_checkpoint_to_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..71d304ea96c6619c9b62b5981b885efa398fd0d3 --- /dev/null +++ b/docs/transformers/src/transformers/models/switch_transformers/convert_switch_transformers_original_flax_checkpoint_to_pytorch.py @@ -0,0 +1,203 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Convert SwitchTransformersX checkpoints from the original repository to JAX/FLAX model.""" + +import argparse +import re + +from flax.traverse_util import flatten_dict, unflatten_dict +from t5x import checkpoints + +from transformers import SwitchTransformersConfig, SwitchTransformersForConditionalGeneration +from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model +from transformers.utils import logging + + +logging.set_verbosity_info() + + +# should not include what is already done by the `from_pt` argument +MOE_LAYER_NAME_MAPPING = { + "/attention/": "/0/SelfAttention/", + "/self_attention/": "/0/SelfAttention/", + "/encoder_decoder_attention/": "/1/EncDecAttention/", + "value": "v", + "query": "q", + "key": "k", + "out": "o", + "pre_self_attention_layer_norm": "0/layer_norm", + "pre_cross_attention_layer_norm": "1/layer_norm", + "pre_attention_layer_norm": "0/layer_norm", # previously 1, but seems wrong + "token_embedder": "shared", + "encoder_norm": "final_layer_norm", + "decoder_norm": "final_layer_norm", + "relpos_bias/rel_embedding": "block/0/layer/0/SelfAttention/relative_attention_bias/weight", + "router/router_weights/w/": "router/classifier/", + "roer/roer_weights/w/": "router/classifier/", + "logits_dense": "lm_head", +} + + +def rename_keys(s_dict): + # 1. in HF T5, we have block.{x}.layer.{y}. which corresponds to layer.{x} in + # the original model + keys = list(s_dict.keys()) + for key in keys: + layer_to_block_of_layer = r".*/layers_(\d+)" + new_key = key + if re.match(layer_to_block_of_layer, key): + new_key = re.sub(r"layers_(\d+)", r"block/\1/layer", new_key) + + layer_to_block_of_layer = r"(encoder|decoder)\/" + + if re.match(layer_to_block_of_layer, key): + groups = re.match(layer_to_block_of_layer, new_key).groups() + if groups[0] == "encoder": + new_key = re.sub(r"/mlp/", r"/1/mlp/", new_key) + new_key = re.sub(r"/pre_mlp_layer_norm/", r"/1/layer_norm/", new_key) + + elif groups[0] == "decoder": + new_key = re.sub(r"/mlp/", r"/2/mlp/", new_key) + new_key = re.sub(r"/pre_mlp_layer_norm/", r"/2/layer_norm/", new_key) + + # 2. Convert other classic mappings + for old_key, temp_key in MOE_LAYER_NAME_MAPPING.items(): + if old_key in new_key: + new_key = new_key.replace(old_key, temp_key) + + print(f"{key} -> {new_key}") + s_dict[new_key] = s_dict.pop(key) + + if "encoder/block/0/layer/0/SelfAttention/relative_attention_bias/weight" in s_dict: + s_dict["encoder/block/0/layer/0/SelfAttention/relative_attention_bias/weight"] = s_dict[ + "encoder/block/0/layer/0/SelfAttention/relative_attention_bias/weight" + ].T + if "decoder/block/0/layer/0/SelfAttention/relative_attention_bias/weight" in s_dict: + s_dict["decoder/block/0/layer/0/SelfAttention/relative_attention_bias/weight"] = s_dict[ + "decoder/block/0/layer/0/SelfAttention/relative_attention_bias/weight" + ].T + + # 3. Take extra care of the EXPERTS layer + for key in list(s_dict.keys()): + if "expert" in key: + num_experts = s_dict[key].shape[0] + expert_weihts = s_dict[key] + for idx in range(num_experts): + s_dict[key.replace("expert/", f"experts/expert_{idx}/")] = expert_weihts[idx] + print(f"{key} -> {key.replace('expert/', f'experts/expert_{idx}/')}") + + s_dict.pop(key) + + return s_dict + + +GIN_TO_CONFIG_MAPPING = { + "NUM_ENCODER_LAYERS": "num_layers", + "NUM_DECODER_LAYERS": "num_decoder_layers", + "NUM_HEADS": "num_heads", + "HEAD_DIM": "d_kv", + "EMBED_DIM": "d_model", + "MLP_DIM": "d_ff", + "NUM_SELECTED_EXPERTS": "num_selected_experts", + "NUM_ENCODER_SPARSE_LAYERS": "num_sparse_encoder_layers", + "NUM_DECODER_SPARSE_LAYERS": "num_sparse_decoder_layers", + "dense.MlpBlock.activations": "feed_forward_proj", +} + + +def convert_gin_to_config(gin_file, num_experts): + # Convert a google style config to the hugging face format + import regex as re + + with open(gin_file, "r") as f: + raw_gin = f.read() + + regex_match = re.findall(r"(.*) = ([0-9.]*)", raw_gin) + args = {} + for param, value in regex_match: + if param in GIN_TO_CONFIG_MAPPING and value != "": + args[GIN_TO_CONFIG_MAPPING[param]] = float(value) if "." in value else int(value) + + activation = re.findall(r"(.*activations) = \(\'(.*)\',\)", raw_gin)[0] + args[GIN_TO_CONFIG_MAPPING[activation[0]]] = str(activation[1]) + + args["num_experts"] = num_experts + config = SwitchTransformersConfig(**args) + return config + + +def convert_flax_checkpoint_to_pytorch( + flax_checkpoint_path, config_file, gin_file=None, pytorch_dump_path="./", num_experts=8 +): + # Initialise PyTorch model + + print(f"Loading flax weights from : {flax_checkpoint_path}") + flax_params = checkpoints.load_t5x_checkpoint(flax_checkpoint_path) + + if gin_file is not None: + config = convert_gin_to_config(gin_file, num_experts) + else: + config = SwitchTransformersConfig.from_pretrained(config_file) + + pt_model = SwitchTransformersForConditionalGeneration(config) + + flax_params = flax_params["target"] + flax_params = flatten_dict(flax_params, sep="/") + flax_params = rename_keys(flax_params) + flax_params = unflatten_dict(flax_params, sep="/") + + # Load the flax params in the PT model + load_flax_weights_in_pytorch_model(pt_model, flax_params) + + print(f"Save PyTorch model to {pytorch_dump_path}") + pt_model.save_pretrained(pytorch_dump_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--switch_t5x_checkpoint_path", + default=None, + type=str, + required=True, + help=( + "The config json file corresponding to the pre-trained SwitchTransformers model. \nThis specifies the" + " model architecture. If not provided, a `gin_file` has to be provided." + ), + ) + parser.add_argument( + "--gin_file", + default=None, + type=str, + required=False, + help="Path to the gin config file. If not provided, a `config_file` has to be passed ", + ) + parser.add_argument( + "--config_name", default=None, type=str, required=False, help="Config name of SwitchTransformers model." + ) + parser.add_argument( + "--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output pytorch model." + ) + parser.add_argument("--num_experts", default=8, type=int, required=False, help="Number of experts") + args = parser.parse_args() + convert_flax_checkpoint_to_pytorch( + args.switch_t5x_checkpoint_path, + args.config_name, + args.gin_file, + args.pytorch_dump_folder_path, + args.num_experts, + ) diff --git a/docs/transformers/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/docs/transformers/src/transformers/models/switch_transformers/modeling_switch_transformers.py new file mode 100644 index 0000000000000000000000000000000000000000..3dd8a139875bb180a3e1568b5e0bcd4d85d03d27 --- /dev/null +++ b/docs/transformers/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -0,0 +1,1997 @@ +# coding=utf-8 +# Copyright 2022 SwitchTransformers Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch SwitchTransformers model.""" + +import copy +import math +import warnings +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +from torch.nn import CrossEntropyLoss + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_outputs import ( + MoEModelOutput, + MoEModelOutputWithPastAndCrossAttentions, + Seq2SeqMoEModelOutput, + Seq2SeqMoEOutput, +) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import ALL_LAYERNORM_LAYERS, find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import ( + DUMMY_INPUTS, + DUMMY_MASK, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_torch_flex_attn_available, + is_torch_fx_proxy, + is_torchdynamo_compiling, + logging, + replace_return_docstrings, +) +from .configuration_switch_transformers import SwitchTransformersConfig + + +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "SwitchTransformersConfig" +_CHECKPOINT_FOR_DOC = "google/switch-base-8" + +#################################################### +# This dict contains ids and associated url +# for the pretrained weights provided with the models +#################################################### + + +def router_z_loss_func(router_logits: torch.Tensor) -> float: + r""" + Compute the router z-loss implemented in PyTorch. + + The router z-loss was introduced in [Designing Effective Sparse Expert Models](https://arxiv.org/abs/2202.08906). + It encourages router logits to remain small in an effort to improve stability. + + Args: + router_logits (`float`): + Input logits of shape [batch_size, sequence_length, num_experts] + + Returns: + Scalar router z-loss. + """ + num_groups, tokens_per_group, _ = router_logits.shape + log_z = torch.logsumexp(router_logits, dim=-1) + z_loss = log_z**2 + return torch.sum(z_loss) / (num_groups * tokens_per_group) + + +def load_balancing_loss_func(router_probs: torch.Tensor, expert_indices: torch.Tensor) -> float: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + router_probs (`torch.Tensor`): + Probability assigned to each expert per token. Shape: [batch_size, seqeunce_length, num_experts]. + expert_indices (`torch.Tensor`): + Indices tensor of shape [batch_size, seqeunce_length] identifying the selected expert for a given token. + + Returns: + The auxiliary loss. + """ + num_experts = router_probs.shape[-1] + + # cast the expert indices to int64, otherwise one-hot encoding will fail + if expert_indices.dtype != torch.int64: + expert_indices = expert_indices.to(torch.int64) + + if len(expert_indices.shape) == 2: + expert_indices = expert_indices.unsqueeze(2) + + expert_mask = torch.nn.functional.one_hot(expert_indices, num_experts) + + # For a given token, determine if it was routed to a given expert. + expert_mask = torch.max(expert_mask, axis=-2).values + + # cast to float32 otherwise mean will fail + expert_mask = expert_mask.to(torch.float32) + tokens_per_group_and_expert = torch.mean(expert_mask, axis=-2) + + router_prob_per_group_and_expert = torch.mean(router_probs, axis=-2) + return torch.mean(tokens_per_group_and_expert * router_prob_per_group_and_expert) * (num_experts**2) + + +class SwitchTransformersTop1Router(nn.Module): + """ + Router using tokens choose top-1 experts assignment. + + This router uses the same mechanism as in Switch Transformer (https://arxiv.org/abs/2101.03961) and V-MoE + (https://arxiv.org/abs/2106.05974): tokens choose their top experts. Items are sorted by router_probs and then + routed to their choice of expert until the expert's expert_capacity is reached. **There is no guarantee that each + token is processed by an expert**, or that each expert receives at least one token. + + """ + + def __init__(self, config: SwitchTransformersConfig): + super().__init__() + self.num_experts = config.num_experts + self.expert_capacity = config.expert_capacity + self.classifier = nn.Linear(config.hidden_size, self.num_experts, bias=config.router_bias) + self.jitter_noise = config.router_jitter_noise + self.ignore_padding_tokens = config.router_ignore_padding_tokens + self.dtype = getattr(torch, config.router_dtype) + + def _compute_router_probabilities(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Computes router probabilities from input hidden states. + + Args: + hidden_states (`torch.Tensor`): + (batch_size, sequence_length, hidden_dim) from which router probabilities are computed. + Returns: + router_probabilities (`torch.Tensor`): + Tensor of shape (batch_size, sequence_length, num_experts) corresponding to the probabilities for each + token and expert. Used for routing tokens to experts. + router_logits (`torch.Tensor`): + Logits tensor of shape (batch_size, sequence_length, num_experts) corresponding to raw router logits. + This is used later for computing router z-loss. + """ + # float32 is used to ensure stability. See the discussion of "selective precision" in + # https://arxiv.org/abs/2101.03961. + # We also store the previous dtype to cast back the output to the previous dtype + self.input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(self.dtype) + + if self.training and self.jitter_noise > 0: + # Multiply the token inputs by the uniform distribution - adding some noise + hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise) + + # Shape: [num_groups, tokens_per_group, num_experts] + self._cast_classifier() + router_logits = self.classifier(hidden_states) + + # Apply Softmax and cast back to the original `dtype` + router_probabilities = nn.functional.softmax(router_logits, dim=-1, dtype=self.dtype).to(self.input_dtype) + return router_probabilities, router_logits + + def _cast_classifier(self): + r""" + `bitsandbytes` `Linear8bitLt` layers does not support manual casting Therefore we need to check if they are an + instance of the `Linear8bitLt` class by checking special attributes. + """ + if not (hasattr(self.classifier, "SCB") or hasattr(self.classifier, "CB")): + self.classifier = self.classifier.to(self.dtype) + + def forward(self, hidden_states: torch.Tensor) -> Tuple: + r""" + Generic forward function for every Router class. Each Router expects to have the same input hidden states + (`hidden_states`) corresponding to the hidden states for each token, the `expert_capacity` corresponding to the + number of tokens the Router will send to each expert, some Routers can send up to few tokens to each expert. + + Each Router works as the following: it expects the hidden states for each token, gets the `router_probs` and + `router_logits` from the `router_weights`. This will assign for each token, the raw probability to be assigned + to an expert. Then each Router class will have to define its own `_compute_routing_instructions`. + + Args: + hidden_states (`torch.Tensor`) : + [num_groups, tokens_per_group, hidden_dim] inputs to send to experts. + Returns: + Tuple[`torch.Tensor`, `torch.Tensor`, `torch.Tensor`] Tuple containing the expert index, the router probs + and the router logits. The router probabilities and logits are required to compute the loss. + """ + router_probs, router_logits = self._compute_router_probabilities(hidden_states) + + expert_index = torch.argmax(router_probs, dim=-1) + expert_index = torch.nn.functional.one_hot(expert_index, num_classes=self.num_experts) + + # Mask tokens outside expert capacity. Sum over each sequence + token_priority = torch.cumsum(expert_index, dim=-2) + # mask if the token routed to to the expert will overflow + expert_capacity_mask = token_priority <= self.expert_capacity + expert_index = expert_index * expert_capacity_mask + + router_probs = torch.max(router_probs, dim=-1).values.unsqueeze(-1) + return expert_index, router_probs, router_logits + + +# Copied from transformers.models.t5.modeling_t5.T5LayerNorm with T5->SwitchTransformers +class SwitchTransformersLayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Construct a layernorm module in the SwitchTransformers style. No bias and no subtraction of mean. + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + # SwitchTransformers uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean + # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated + # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for + # half-precision inputs is done in fp32 + + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states + + +ALL_LAYERNORM_LAYERS.append(SwitchTransformersLayerNorm) + + +# Copied from transformers.models.t5.modeling_t5.T5DenseActDense with T5->SwitchTransformers +class SwitchTransformersDenseActDense(nn.Module): + def __init__(self, config: SwitchTransformersConfig): + super().__init__() + self.wi = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) + self.dropout = nn.Dropout(config.dropout_rate) + self.act = ACT2FN[config.dense_act_fn] + + def forward(self, hidden_states): + hidden_states = self.wi(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.dropout(hidden_states) + if ( + isinstance(self.wo.weight, torch.Tensor) + and hidden_states.dtype != self.wo.weight.dtype + and self.wo.weight.dtype != torch.int8 + ): + hidden_states = hidden_states.to(self.wo.weight.dtype) + hidden_states = self.wo(hidden_states) + return hidden_states + + +class SwitchTransformersSparseMLP(nn.Module): + r""" + Implementation of the Switch Transformers Sparse MLP module. + """ + + def __init__(self, config: SwitchTransformersConfig, expert_class: nn.Module = SwitchTransformersDenseActDense): + super().__init__() + # Step 1: Get the correct router according to its class + self.router = SwitchTransformersTop1Router(config) + + # Step 2: Get the experts + self.experts = nn.ModuleDict() + for idx in range(config.num_experts): + self.experts[f"expert_{idx}"] = expert_class(config) + + def forward(self, hidden_states): + r""" + Hold on, this will be slightly tricky to understand In the correct order, a MoE layer does the following: + + 1- Gets the `router_mask` from the router. The shape of the mask is `(batch_size, sequence_length, num_expert)` + and corresponds to the argmax of the `router_probs`. The probabilities are needed in the computation of the + hidden states : they are broadcasted to the hidden states values (can be interpreted as a scaling factor). + + 2- Dispatch the tokens to its associated experts. We do a classic for loop over the experts and assign for each + expert the corresponding hidden states. + + """ + # Step 1: Get the router_mask from the router as well as the probabilities + router_mask, router_probs, router_logits = self.router(hidden_states) + expert_index = torch.argmax(router_mask, dim=-1) + + # If a token gets dropped, we just set it to zero such that it does not get updated. + next_states = torch.zeros(hidden_states.shape, device=hidden_states.device, dtype=hidden_states.dtype) + + router_mask = router_mask.bool() + batch_size, seq_len, num_experts = router_mask.shape + idx_mask = router_mask.reshape(batch_size * seq_len, num_experts).sum(dim=0) + idx_mask = torch.nonzero(idx_mask, as_tuple=True)[ + 0 + ].tolist() # length: number of "activated" expert / value: index + for idx in idx_mask: + next_states[router_mask[:, :, idx]] = getattr(self.experts, "expert_{}".format(idx))( + hidden_states[router_mask[:, :, idx]] + ) + + hidden_states = router_probs * next_states + return hidden_states, (router_logits, expert_index) + + +class SwitchTransformersLayerFF(nn.Module): + r""" + Switch Transformers Feed Forward layer module. This is a wrapper around the Mixture of Experts module. + + Parameters: + config : ([`SwitchTransformersConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. + is_sparse (`bool`): + Whether the MLP layer is a `Sparse` layer (contains a Mixture of Experts) or not + """ + + def __init__(self, config: SwitchTransformersConfig, is_sparse=False): + super().__init__() + self.is_sparse = is_sparse + + # Check if it is a sparse layer, if not then it is a dense layer + if not self.is_sparse: + self.mlp = SwitchTransformersDenseActDense(config) + else: + self.mlp = SwitchTransformersSparseMLP(config) + + self.layer_norm = SwitchTransformersLayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward(self, hidden_states, output_router_logits): + forwarded_states = self.layer_norm(hidden_states) + forwarded_states = self.mlp(forwarded_states) + + if isinstance(forwarded_states, tuple): + forwarded_states, router_tuple = forwarded_states + else: + router_tuple = None + + output = hidden_states + self.dropout(forwarded_states) + + if output_router_logits and router_tuple is not None: + output = (output, router_tuple) + + return output + + +# Copied from transformers.models.t5.modeling_t5.T5Attention with T5->SwitchTransformers +class SwitchTransformersAttention(nn.Module): + def __init__( + self, + config: SwitchTransformersConfig, + has_relative_attention_bias=False, + layer_idx: Optional[int] = None, + ): + super().__init__() + self.is_decoder = config.is_decoder + self.has_relative_attention_bias = has_relative_attention_bias + self.relative_attention_num_buckets = config.relative_attention_num_buckets + self.relative_attention_max_distance = config.relative_attention_max_distance + self.d_model = config.d_model + self.key_value_proj_dim = config.d_kv + self.n_heads = config.num_heads + self.dropout = config.dropout_rate + self.inner_dim = self.n_heads * self.key_value_proj_dim + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + # Mesh TensorFlow initialization to avoid scaling before softmax + self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.k = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.v = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.o = nn.Linear(self.inner_dim, self.d_model, bias=False) + + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads) + self.pruned_heads = set() + self.gradient_checkpointing = False + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads + ) + # Prune linear layers + self.q = prune_linear_layer(self.q, index) + self.k = prune_linear_layer(self.k, index) + self.v = prune_linear_layer(self.v, index) + self.o = prune_linear_layer(self.o, index, dim=1) + # Update hyper params + self.n_heads = self.n_heads - len(heads) + self.inner_dim = self.key_value_proj_dim * self.n_heads + self.pruned_heads = self.pruned_heads.union(heads) + + @staticmethod + def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_position_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_position_if_large = torch.min( + relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) + return relative_buckets + + def compute_bias(self, query_length, key_length, device=None, cache_position=None): + """Compute binned relative position bias""" + if device is None: + device = self.relative_attention_bias.weight.device + if cache_position is None: + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + else: + context_position = cache_position[:, None].to(device) + memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] + relative_position = memory_position - context_position # shape (query_length, key_length) + relative_position_bucket = self._relative_position_bucket( + relative_position, # shape (query_length, key_length) + bidirectional=(not self.is_decoder), + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads) + values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) + return values + + def forward( + self, + hidden_states, + mask=None, + key_value_states=None, + position_bias=None, + past_key_value=None, + layer_head_mask=None, + query_length=None, + use_cache=False, + output_attentions=False, + cache_position=None, + ): + """ + Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). + """ + # Input is (batch_size, seq_length, dim) + # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder) + batch_size, seq_length = hidden_states.shape[:2] + + # if key_value_states are provided this layer is used as a cross-attention layer for the decoder + is_cross_attention = key_value_states is not None + + query_states = self.q(hidden_states) + query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + + if past_key_value is not None: + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: + # reuse k,v, cross_attentions + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] + else: + key_states = self.k(current_states) + value_states = self.v(current_states) + key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True + + # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + scores = torch.matmul(query_states, key_states.transpose(3, 2)) + + if position_bias is None: + key_length = key_states.shape[-2] + # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past) + real_seq_length = query_length if query_length is not None else cache_position[-1] + 1 + if not self.has_relative_attention_bias: + position_bias = torch.zeros( + (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype + ) + if self.gradient_checkpointing and self.training: + position_bias.requires_grad = True + else: + position_bias = self.compute_bias( + real_seq_length, key_length, device=scores.device, cache_position=cache_position + ) + position_bias = position_bias[:, :, -seq_length:, :] + + if mask is not None: + causal_mask = mask[:, :, :, : key_states.shape[-2]] + position_bias = position_bias + causal_mask + + if self.pruned_heads: + mask = torch.ones(position_bias.shape[1]) + mask[list(self.pruned_heads)] = 0 + position_bias_masked = position_bias[:, mask.bool()] + else: + position_bias_masked = position_bias + + scores += position_bias_masked + + # (batch_size, n_heads, seq_length, key_length) + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + # Mask heads if we want to + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask + + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(batch_size, -1, self.inner_dim) + attn_output = self.o(attn_output) + + outputs = (attn_output, past_key_value, position_bias) + + if output_attentions: + outputs = outputs + (attn_weights,) + return outputs + + +# Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5->SwitchTransformers +class SwitchTransformersLayerSelfAttention(nn.Module): + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): + super().__init__() + self.SelfAttention = SwitchTransformersAttention( + config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx + ) + self.layer_norm = SwitchTransformersLayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + use_cache=False, + output_attentions=False, + cache_position=None, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.SelfAttention( + normed_hidden_states, + mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + ) + hidden_states = hidden_states + self.dropout(attention_output[0]) + outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5->SwitchTransformers +class SwitchTransformersLayerCrossAttention(nn.Module): + def __init__(self, config, layer_idx: Optional[int] = None): + super().__init__() + self.EncDecAttention = SwitchTransformersAttention( + config, has_relative_attention_bias=False, layer_idx=layer_idx + ) + self.layer_norm = SwitchTransformersLayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + hidden_states, + key_value_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + use_cache=False, + query_length=None, + output_attentions=False, + cache_position=None, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.EncDecAttention( + normed_hidden_states, + mask=attention_mask, + key_value_states=key_value_states, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + query_length=query_length, + output_attentions=output_attentions, + cache_position=cache_position, + ) + layer_output = hidden_states + self.dropout(attention_output[0]) + outputs = (layer_output,) + attention_output[1:] # add attentions if we output them + return outputs + + +class SwitchTransformersBlock(nn.Module): + def __init__(self, config, has_relative_attention_bias=False, is_sparse=False, layer_idx: Optional[int] = None): + super().__init__() + self.is_decoder = config.is_decoder + self.is_sparse = is_sparse + self.layer = nn.ModuleList() + self.layer.append( + SwitchTransformersLayerSelfAttention( + config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx + ) + ) + if self.is_decoder: + self.layer.append(SwitchTransformersLayerCrossAttention(config, layer_idx=layer_idx)) + + self.layer.append(SwitchTransformersLayerFF(config, is_sparse=self.is_sparse)) + + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + encoder_decoder_position_bias=None, + layer_head_mask=None, + cross_attn_layer_head_mask=None, + past_key_value=None, + use_cache=False, + output_attentions=False, + output_router_logits=True, + return_dict=True, + cache_position=None, + ): + self_attention_outputs = self.layer[0]( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + ) + hidden_states, past_key_value = self_attention_outputs[:2] + attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + do_cross_attention = self.is_decoder and encoder_hidden_states is not None + if do_cross_attention: + cross_attention_outputs = self.layer[1]( + hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + position_bias=encoder_decoder_position_bias, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=past_key_value, + query_length=cache_position[-1] + 1, + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + ) + hidden_states, past_key_value = cross_attention_outputs[:2] + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + # Keep cross-attention outputs and relative position weights + attention_outputs = attention_outputs + cross_attention_outputs[2:] + + # Apply Feed Forward layer + hidden_states = self.layer[-1](hidden_states, output_router_logits) + + if isinstance(hidden_states, tuple): + hidden_states, router_tuple = hidden_states + else: + router_tuple = (torch.zeros((1,), device=hidden_states.device, dtype=torch.int64),) + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if use_cache: + outputs = outputs + (past_key_value,) + attention_outputs + (router_tuple,) + else: + outputs = outputs + attention_outputs + (router_tuple,) + + return outputs # hidden-states, past_key_value, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights), (router_tuple) + + +class SwitchTransformersPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = SwitchTransformersConfig + base_model_prefix = "switch_transformers" + supports_gradient_checkpointing = True + _supports_cache_class = True + _supports_static_cache = False + _no_split_modules = ["SwitchTransformersBlock"] + + @property + def dummy_inputs(self): + input_ids = torch.tensor(DUMMY_INPUTS) + input_mask = torch.tensor(DUMMY_MASK) + dummy_inputs = { + "decoder_input_ids": input_ids, + "input_ids": input_ids, + "decoder_attention_mask": input_mask, + } + return dummy_inputs + + def _init_weights(self, module): + """Initialize the weights""" + factor = self.config.initializer_factor # Used for testing weights initialization + if isinstance(module, SwitchTransformersLayerNorm): + module.weight.data.fill_(factor * 1.0) + elif isinstance( + module, + (SwitchTransformersModel, SwitchTransformersForConditionalGeneration, SwitchTransformersEncoderModel), + ): + # Mesh TensorFlow embeddings initialization + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624 + module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0) + if hasattr(module, "lm_head") and not self.config.tie_word_embeddings: + module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0) + elif isinstance(module, SwitchTransformersDenseActDense): + # Mesh TensorFlow FF initialization + # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56 + # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89 + module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.wi, "bias") and module.wi.bias is not None: + module.wi.bias.data.zero_() + module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + if hasattr(module.wo, "bias") and module.wo.bias is not None: + module.wo.bias.data.zero_() + elif isinstance(module, SwitchTransformersAttention): + # Mesh TensorFlow attention initialization to avoid scaling before softmax + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136 + d_model = self.config.d_model + key_value_proj_dim = self.config.d_kv + n_heads = self.config.num_heads + module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) + module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) + module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) + module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) + if module.has_relative_attention_bias: + module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) + elif isinstance(module, SwitchTransformersSparseMLP): + # Mesh TensorFlow attention initialization to avoid scaling before softmax + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136 + d_model = self.config.d_model + key_value_proj_dim = self.config.d_kv + n_heads = self.config.num_heads + module.router.classifier.weight.data.normal_(mean=0.0, std=factor * 1) + for idx in range(self.config.num_experts): + module.experts[f"expert_{idx}"].wi.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) + module.experts[f"expert_{idx}"].wo.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) + + def _shift_right(self, input_ids): + decoder_start_token_id = self.config.decoder_start_token_id + pad_token_id = self.config.pad_token_id + + if decoder_start_token_id is None: + raise ValueError( + "self.model.config.decoder_start_token_id has to be defined. In SwitchTransformers it is usually set" + " to the pad_token_id. See SwitchTransformers docs for more information" + ) + + # shift inputs to the right + if is_torch_fx_proxy(input_ids): + # Item assignment is not supported natively for proxies. + shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id) + shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1) + else: + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() + shifted_input_ids[..., 0] = decoder_start_token_id + + if pad_token_id is None: + raise ValueError("self.model.config.pad_token_id has to be defined.") + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +class SwitchTransformersStack(SwitchTransformersPreTrainedModel): + def __init__(self, config, embed_tokens=None): + super().__init__(config) + + self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model) + + if embed_tokens is not None: + self.embed_tokens.weight = embed_tokens.weight + + self.is_decoder = config.is_decoder + + sparse_step = config.decoder_sparse_step if self.is_decoder else config.encoder_sparse_step + config.num_layers = config.num_decoder_layers if self.is_decoder else config.num_layers + self.block = nn.ModuleList() + for i in range(config.num_layers): + is_sparse = (i % sparse_step == 1 or sparse_step == 1) if sparse_step > 0 else False + + self.block.append( + SwitchTransformersBlock( + config, has_relative_attention_bias=bool(i == 0), is_sparse=is_sparse, layer_idx=i + ) + ) + + self.final_layer_norm = SwitchTransformersLayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + # Initialize weights and apply final processing + self.post_init() + + self.device_map = None + self.gradient_checkpointing = False + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, new_embeddings): + self.embed_tokens = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + inputs_embeds=None, + head_mask=None, + cross_attn_head_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + output_router_logits=True, + return_dict=None, + cache_position=None, + ): + use_cache = use_cache if use_cache is not None else self.config.use_cache + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError( + f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time" + ) + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if inputs_embeds is None: + if self.embed_tokens is None: + raise ValueError("You have to initialize the model with valid token embeddings") + inputs_embeds = self.embed_tokens(input_ids) + + batch_size, seq_length = input_shape + + if use_cache is True: + if not self.is_decoder: + raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder") + + # initialize past_key_values + return_legacy_cache = False + return_self_attention_cache = False + if self.is_decoder and (use_cache or past_key_values is not None): + if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache): + return_self_attention_cache = True + past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) + elif not isinstance(past_key_values, EncoderDecoderCache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + elif past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + elif not self.is_decoder: + # do not pass cache object down the line for encoder stack + # it messes indexing later in decoder-stack because cache object is modified in-place + past_key_values = None + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device + ) + + if attention_mask is None and not is_torchdynamo_compiling(): + # required mask seq length can be calculated via length of past cache + mask_seq_length = past_key_values_length + seq_length + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + + if self.config.is_decoder: + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + past_key_values.self_attention_cache if past_key_values is not None else None, + output_attentions, + ) + else: + causal_mask = attention_mask[:, None, None, :] + causal_mask = causal_mask.to(dtype=inputs_embeds.dtype) + causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + head_mask = self.get_head_mask(head_mask, self.config.num_layers) + cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_router_probs = () if output_router_logits else None + all_cross_attentions = () if (output_attentions and self.is_decoder) else None + position_bias = None + encoder_decoder_position_bias = None + + hidden_states = self.dropout(inputs_embeds) + + for i, layer_module in enumerate(self.block): + layer_head_mask = head_mask[i] + cross_attn_layer_head_mask = cross_attn_head_mask[i] + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.forward, + hidden_states, + causal_mask, + position_bias, + encoder_hidden_states, + encoder_extended_attention_mask, + encoder_decoder_position_bias, + layer_head_mask, + cross_attn_layer_head_mask, + None, # past_key_value is always None with gradient checkpointing + use_cache, + output_attentions, + output_router_logits, + return_dict, + cache_position, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask=causal_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + layer_head_mask=layer_head_mask, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, + past_key_value=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + return_dict=return_dict, + cache_position=cache_position, + ) + + router_probs = layer_outputs[-1] + layer_outputs = layer_outputs[:-1] + + # layer_outputs is a tuple with: + # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + if use_cache is False: + layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] + + hidden_states, next_decoder_cache = layer_outputs[:2] + + # We share the position biases between the layers - the first layer store them + # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), + # (cross-attention position bias), (cross-attention weights) + position_bias = layer_outputs[2] + if self.is_decoder and encoder_hidden_states is not None: + encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[3],) + if self.is_decoder: + all_cross_attentions = all_cross_attentions + (layer_outputs[5],) + + if output_router_logits: + all_router_probs = all_router_probs + (router_probs,) + + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_self_attention_cache: + next_cache = past_key_values.self_attention_cache + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_cache, + all_hidden_states, + all_attentions, + all_cross_attentions, + all_router_probs, + ] + if v is not None + ) + return MoEModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + router_probs=all_router_probs, + ) + + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, "BlockMask"], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to place the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +SWITCH_TRANSFORMERS_START_DOCSTRING = r""" + + The SWITCH_TRANSFORMERS model was proposed in [Switch Transformers: Scaling to Trillion Parameter Models with + Simple and Efficient Sparsity](https://arxiv.org/abs/2101.03961) by [William + Fedus](https://arxiv.org/search/cs?searchtype=author&query=Fedus%2C+W), [Barret + Zoph](https://arxiv.org/search/cs?searchtype=author&query=Zoph%2C+B), and [Noam + Shazeer](https://arxiv.org/search/cs?searchtype=author&query=Shazeer%2C+N). It's an encoder-decoder T5-like model + with sparse Feed Forward that stands for Mixture of Experts (MoE) architecture. + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`SwitchTransformersConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +SWITCH_TRANSFORMERS_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. SWITCH_TRANSFORMERS is a model with relative position + embeddings so you should be able to pad the inputs on both the right and the left. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for detail. + + [What are input IDs?](../glossary#input-ids) + + To know more on how to prepare `input_ids` for pretraining take a look a [SWITCH_TRANSFORMERS + Training](./switch_transformers#training). + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + SWITCH_TRANSFORMERS uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If + `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + To know more on how to prepare `decoder_input_ids` for pretraining take a look at [SWITCH_TRANSFORMERS + Training](./switch_transformers#training). + decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in `[0, + 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0, + 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in + `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): + Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, `optional`: *attentions*) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a sequence of hidden states at + the output of the last layer of the encoder. Used in the cross-attention of the decoder. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded + representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be + input (see `past_key_values`). This is useful if you want more control over how to convert + `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. + + If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value + of `inputs_embeds`. + + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, and + should not be returned during inference. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. +""" + +SWITCH_TRANSFORMERS_ENCODER_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. SWITCH_TRANSFORMERS is a model with relative position + embeddings so you should be able to pad the inputs on both the right and the left. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for detail. + + To know more on how to prepare `input_ids` for pretraining take a look a [SWITCH_TRANSFORMERS + Training](./switch_transformers#training). + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, and + should not be returned during inference. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +# Warning message for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask +__HEAD_MASK_WARNING_MSG = """ +The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently, +`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions. +If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers, +num_heads)`. +""" + + +@add_start_docstrings( + "The bare SWITCH_TRANSFORMERS Model transformer outputting raw hidden-states without any specific head on top.", + SWITCH_TRANSFORMERS_START_DOCSTRING, +) +class SwitchTransformersModel(SwitchTransformersPreTrainedModel): + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + + def __init__(self, config: SwitchTransformersConfig): + super().__init__(config) + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = SwitchTransformersStack(encoder_config, self.shared) + + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.is_encoder_decoder = False + self.decoder = SwitchTransformersStack(decoder_config, self.shared) + + # Initialize weights and apply final processing + self.post_init() + + # Model parallel + self.device_map = None + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + self.decoder.set_input_embeddings(new_embeddings) + + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(SWITCH_TRANSFORMERS_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqMoEModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + decoder_inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple[torch.FloatTensor], Seq2SeqMoEModelOutput]: + r""" + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, SwitchTransformersModel + + >>> tokenizer = AutoTokenizer.from_pretrained("google/switch-base-8") + >>> model = SwitchTransformersModel.from_pretrained("google/switch-base-8") + + >>> input_ids = tokenizer( + ... "Studies have been shown that owning a dog is good for you", return_tensors="pt" + ... ).input_ids # Batch size 1 + >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1 + + >>> # preprocess: Prepend decoder_input_ids with start token which is pad token for SwitchTransformersModel. + >>> # This is not needed for torch's SwitchTransformersForConditionalGeneration as it does this internally using labels arg. + >>> decoder_input_ids = model._shift_right(decoder_input_ids) + + >>> # forward pass + >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) + >>> last_hidden_states = outputs.last_hidden_state + ```""" + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + if self.config.num_layers == self.config.num_decoder_layers: + warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) + decoder_head_mask = head_mask + + if ( + output_router_logits + and self.config.num_sparse_encoder_layers == 0 + and self.config.num_sparse_encoder_layers == 0 + ): + raise ValueError( + "You asked to return `output_router_logits` but the transformer in dense, and does " + " not contain any sparse MLP Layers. Set `output_router_logits = False` and restart" + ) + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + return_dict=return_dict, + ) + elif return_dict and not isinstance(encoder_outputs, MoEModelOutput): + encoder_outputs = MoEModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + router_probs=encoder_outputs[3] if len(encoder_outputs) > 3 else None, + ) + + hidden_states = encoder_outputs[0] + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + return_dict=return_dict, + cache_position=cache_position, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqMoEModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + decoder_router_logits=decoder_outputs.router_probs, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + encoder_router_logits=encoder_outputs.router_probs, + ) + + +@add_start_docstrings( + """SWITCH_TRANSFORMERS Model with a `language modeling` head on top.""", SWITCH_TRANSFORMERS_START_DOCSTRING +) +class SwitchTransformersForConditionalGeneration(SwitchTransformersPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] + + def __init__(self, config: SwitchTransformersConfig): + super().__init__(config) + self.model_dim = config.d_model + + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = SwitchTransformersStack(encoder_config, self.shared) + + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.is_encoder_decoder = False + decoder_config.num_layers = config.num_decoder_layers + self.decoder = SwitchTransformersStack(decoder_config, self.shared) + + self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) + + self.router_z_loss_coef = config.router_z_loss_coef + self.router_aux_loss_coef = config.router_aux_loss_coef + + # Initialize weights and apply final processing + self.post_init() + + # Model parallel + self.device_map = None + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + self.decoder.set_input_embeddings(new_embeddings) + + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_output_embeddings(self): + return self.lm_head + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + @add_start_docstrings_to_model_forward(SWITCH_TRANSFORMERS_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqMoEOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = True, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple[torch.FloatTensor], Seq2SeqMoEOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ..., + config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for + labels in `[0, ..., config.vocab_size]` + + Returns: + + Examples: + + ```python + >>> from transformers import AutoTokenizer, SwitchTransformersForConditionalGeneration + + >>> tokenizer = AutoTokenizer.from_pretrained("google/switch-base-8") + >>> model = SwitchTransformersForConditionalGeneration.from_pretrained("google/switch-base-8") + + >>> # training + >>> input_ids = tokenizer("The walks in park", return_tensors="pt").input_ids + >>> labels = tokenizer(" cute dog the ", return_tensors="pt").input_ids + >>> outputs = model(input_ids=input_ids, labels=labels) + >>> loss = outputs.loss + >>> logits = outputs.logits + + >>> # inference + >>> input_ids = tokenizer( + ... "summarize: studies have shown that owning a dog is good for you", return_tensors="pt" + ... ).input_ids # Batch size 1 + >>> outputs = model.generate(input_ids) + >>> # . To, let’s say you have a dog. To summarize: + >>> # Since the model has been trained on MLM, this will output gibberish + ```""" + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + if self.config.num_layers == self.config.num_decoder_layers: + warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) + decoder_head_mask = head_mask + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + # Convert encoder inputs in embeddings if needed + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + return_dict=return_dict, + ) + elif return_dict and not isinstance(encoder_outputs, MoEModelOutput): + encoder_outputs = MoEModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + router_probs=encoder_outputs[3] if len(encoder_outputs) > 3 else None, + ) + + hidden_states = encoder_outputs[0] + + if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: + # get decoder inputs from shifting lm labels to the right + decoder_input_ids = self._shift_right(labels) + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + return_dict=return_dict, + cache_position=cache_position, + ) + + sequence_output = decoder_outputs[0] + + if self.config.tie_word_embeddings: + # Rescale output before projecting on vocab + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + sequence_output = sequence_output * (self.model_dim**-0.5) + + lm_logits = self.lm_head(sequence_output) + + loss = None + encoder_z_loss = None + encoder_aux_loss = None + decoder_z_loss = None + decoder_aux_loss = None + + if output_router_logits: + # Compute the router loss (z_loss + auxiliary loss) for each router in the encoder and decoder + if self.encoder.config.encoder_sparse_step > 1: + encoder_router_logits, encoder_expert_indexes = self._unpack_router_logits(encoder_outputs[-1]) + encoder_z_loss = router_z_loss_func(encoder_router_logits) + encoder_router_probs = nn.Softmax(dim=-1)(encoder_router_logits) + encoder_aux_loss = load_balancing_loss_func(encoder_router_probs, encoder_expert_indexes) + else: + encoder_z_loss = 0 + encoder_aux_loss = 0 + + if self.decoder.config.decoder_sparse_step > 1: + decoder_router_logits, decoder_expert_indexes = self._unpack_router_logits(decoder_outputs[-1]) + decoder_z_loss = router_z_loss_func(decoder_router_logits) + decoder_router_probs = nn.Softmax(dim=-1)(decoder_router_logits) + decoder_aux_loss = load_balancing_loss_func(decoder_router_probs, decoder_expert_indexes) + else: + decoder_z_loss = 0 + decoder_aux_loss = 0 + + if labels is not None: + loss_fct = CrossEntropyLoss(ignore_index=-100) + # move labels to correct device to enable PP + labels = labels.to(lm_logits.device) + loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) + + if output_router_logits: + z_loss = self.router_z_loss_coef * (encoder_z_loss + decoder_z_loss) + aux_loss = self.router_aux_loss_coef * (encoder_aux_loss + decoder_aux_loss) + loss = loss + z_loss + aux_loss + + if not return_dict: + output = (lm_logits,) + if output_router_logits: + output += (encoder_z_loss, encoder_aux_loss, decoder_z_loss, decoder_aux_loss) + output += (*decoder_outputs[1:], *encoder_outputs) + + return ((loss,) + output) if loss is not None else output + + return Seq2SeqMoEOutput( + loss=loss, + logits=lm_logits, + encoder_z_loss=encoder_z_loss, + encoder_aux_loss=encoder_aux_loss, + decoder_z_loss=decoder_z_loss, + decoder_aux_loss=decoder_aux_loss, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + decoder_router_logits=decoder_outputs.router_probs, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + encoder_router_logits=encoder_outputs.router_probs, + ) + + def _unpack_router_logits(self, router_outputs): + total_router_logits = [] + total_expert_indexes = [] + for router_output in router_outputs: + if len(router_output[0].shape) > 1: + router_logits, expert_indexes = router_output + total_router_logits.append(router_logits) + total_expert_indexes.append(expert_indexes) + return torch.cat(total_router_logits, dim=1), torch.cat(total_expert_indexes, dim=1) + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return self._shift_right(labels) + + def _reorder_cache(self, past_key_values, beam_idx): + # if decoder past is not included in output + # speedy decoding is disabled and no need to reorder + if past_key_values is None: + logger.warning("You might want to consider setting `use_cache=True` to speed up decoding") + return past_key_values + + reordered_decoder_past = () + for layer_past_states in past_key_values: + # get the correct batch idx from layer past batch dim + # batch dim of `past` is at 2nd position + reordered_layer_past_states = () + for layer_past_state in layer_past_states: + # need to set correct `past` for each of the four key / value states + reordered_layer_past_states = reordered_layer_past_states + ( + layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)), + ) + + if reordered_layer_past_states[0].shape != layer_past_states[0].shape: + raise ValueError( + "expected reordered_layer_past_states to have the same shape than layer_past_states, " + f"but got {reordered_layer_past_states[0].shape} and {layer_past_states[0].shape}" + ) + if len(reordered_layer_past_states) != len(layer_past_states): + raise ValueError( + "expected layer_past_states to have the same length as reordered_layer_past_states, " + f"but got {len(layer_past_states)} and {len(reordered_layer_past_states)}" + ) + + reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,) + return reordered_decoder_past + + +@add_start_docstrings( + "The bare SWITCH_TRANSFORMERS Model transformer outputting encoder's raw hidden-states without any specific head" + " on top.", + SWITCH_TRANSFORMERS_START_DOCSTRING, +) +class SwitchTransformersEncoderModel(SwitchTransformersPreTrainedModel): + _tied_weights_keys = ["encoder.embed_tokens.weight"] + + def __init__(self, config: SwitchTransformersConfig): + super().__init__(config) + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = SwitchTransformersStack(encoder_config, self.shared) + + # Initialize weights and apply final processing + self.post_init() + + # Model parallel + self.device_map = None + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + + def get_encoder(self): + return self.encoder + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.block[layer].layer[0].SelfAttention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(SWITCH_TRANSFORMERS_ENCODER_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=MoEModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = True, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.FloatTensor], MoEModelOutput]: + r""" + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, SwitchTransformersEncoderModel + + >>> tokenizer = AutoTokenizer.from_pretrained("google/switch-base-8") + >>> model = SwitchTransformersEncoderModel.from_pretrained("google/switch-base-8") + >>> input_ids = tokenizer( + ... "Studies have been shown that owning a dog is good for you", return_tensors="pt" + ... ).input_ids # Batch size 1 + >>> outputs = model(input_ids=input_ids) + >>> last_hidden_states = outputs.last_hidden_state + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + return_dict=return_dict, + ) + + return encoder_outputs + + +__all__ = [ + "SwitchTransformersEncoderModel", + "SwitchTransformersForConditionalGeneration", + "SwitchTransformersModel", + "SwitchTransformersPreTrainedModel", + "SwitchTransformersTop1Router", + "SwitchTransformersSparseMLP", +] diff --git a/docs/transformers/src/transformers/models/t5/__init__.py b/docs/transformers/src/transformers/models/t5/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..366eab10826b19cd75e87b013f58f55976969abe --- /dev/null +++ b/docs/transformers/src/transformers/models/t5/__init__.py @@ -0,0 +1,31 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_t5 import * + from .modeling_flax_t5 import * + from .modeling_t5 import * + from .modeling_tf_t5 import * + from .tokenization_t5 import * + from .tokenization_t5_fast import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/t5/configuration_t5.py b/docs/transformers/src/transformers/models/t5/configuration_t5.py new file mode 100644 index 0000000000000000000000000000000000000000..a22286449a27ddb1472cf4a28144ffe31313bb92 --- /dev/null +++ b/docs/transformers/src/transformers/models/t5/configuration_t5.py @@ -0,0 +1,171 @@ +# coding=utf-8 +# Copyright 2020, The T5 Authors and HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""T5 model configuration""" + +from typing import Mapping + +from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxSeq2SeqConfigWithPast +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class T5Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`T5Model`] or a [`TFT5Model`]. It is used to + instantiate a T5 model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the T5 + [google-t5/t5-small](https://huggingface.co/google-t5/t5-small) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Arguments: + vocab_size (`int`, *optional*, defaults to 32128): + Vocabulary size of the T5 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`T5Model`] or [`TFT5Model`]. + d_model (`int`, *optional*, defaults to 512): + Size of the encoder layers and the pooler layer. + d_kv (`int`, *optional*, defaults to 64): + Size of the key, query, value projections per attention head. The `inner_dim` of the projection layer will + be defined as `num_heads * d_kv`. + d_ff (`int`, *optional*, defaults to 2048): + Size of the intermediate feed forward layer in each `T5Block`. + num_layers (`int`, *optional*, defaults to 6): + Number of hidden layers in the Transformer encoder. + num_decoder_layers (`int`, *optional*): + Number of hidden layers in the Transformer decoder. Will use the same value as `num_layers` if not set. + num_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer encoder. + relative_attention_num_buckets (`int`, *optional*, defaults to 32): + The number of buckets to use for each attention layer. + relative_attention_max_distance (`int`, *optional*, defaults to 128): + The maximum distance of the longer sequences for the bucket separation. + dropout_rate (`float`, *optional*, defaults to 0.1): + The ratio for all dropout layers. + classifier_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for classifier. + layer_norm_eps (`float`, *optional*, defaults to 1e-6): + The epsilon used by the layer normalization layers. + initializer_factor (`float`, *optional*, defaults to 1): + A factor for initializing all weight matrices (should be kept to 1, used internally for initialization + testing). + feed_forward_proj (`string`, *optional*, defaults to `"relu"`): + Type of feed forward layer to be used. Should be one of `"relu"` or `"gated-gelu"`. T5v1.1 uses the + `"gated-gelu"` feed forward projection. Original T5 uses `"relu"`. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + """ + + model_type = "t5" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = { + "hidden_size": "d_model", + "num_attention_heads": "num_heads", + "num_hidden_layers": "num_layers", + "head_dim": "d_kv", + } + + def __init__( + self, + vocab_size=32128, + d_model=512, + d_kv=64, + d_ff=2048, + num_layers=6, + num_decoder_layers=None, + num_heads=8, + relative_attention_num_buckets=32, + relative_attention_max_distance=128, + dropout_rate=0.1, + layer_norm_epsilon=1e-6, + initializer_factor=1.0, + feed_forward_proj="relu", + is_encoder_decoder=True, + use_cache=True, + pad_token_id=0, + eos_token_id=1, + classifier_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.d_model = d_model + self.d_kv = d_kv + self.d_ff = d_ff + self.num_layers = num_layers + self.num_decoder_layers = ( + num_decoder_layers if num_decoder_layers is not None else self.num_layers + ) # default = symmetry + self.num_heads = num_heads + self.relative_attention_num_buckets = relative_attention_num_buckets + self.relative_attention_max_distance = relative_attention_max_distance + self.dropout_rate = dropout_rate + self.classifier_dropout = classifier_dropout + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_factor = initializer_factor + self.feed_forward_proj = feed_forward_proj + self.use_cache = use_cache + + act_info = self.feed_forward_proj.split("-") + self.dense_act_fn = act_info[-1] + self.is_gated_act = act_info[0] == "gated" + + if len(act_info) > 1 and act_info[0] != "gated" or len(act_info) > 2: + raise ValueError( + f"`feed_forward_proj`: {feed_forward_proj} is not a valid activation function of the dense layer. " + "Please make sure `feed_forward_proj` is of the format `gated-{ACT_FN}` or `{ACT_FN}`, e.g. " + "'gated-gelu' or 'relu'" + ) + + # for backwards compatibility + if feed_forward_proj == "gated-gelu": + self.dense_act_fn = "gelu_new" + + super().__init__( + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + is_encoder_decoder=is_encoder_decoder, + **kwargs, + ) + + +class T5OnnxConfig(OnnxSeq2SeqConfigWithPast): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + common_inputs = { + "input_ids": {0: "batch", 1: "encoder_sequence"}, + "attention_mask": {0: "batch", 1: "encoder_sequence"}, + } + if self.use_past: + common_inputs["attention_mask"][1] = "past_encoder_sequence + sequence" + common_inputs["decoder_input_ids"] = {0: "batch"} + common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"} + else: + common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"} + common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"} + + if self.use_past: + self.fill_with_past_key_values_(common_inputs, direction="inputs") + + return common_inputs + + @property + def default_onnx_opset(self) -> int: + return 13 + + +__all__ = ["T5Config", "T5OnnxConfig"] diff --git a/docs/transformers/src/transformers/models/t5/convert_t5_original_tf_checkpoint_to_pytorch.py b/docs/transformers/src/transformers/models/t5/convert_t5_original_tf_checkpoint_to_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..9b1b15857ceaa1f523eca5e1e542fe48e63d6651 --- /dev/null +++ b/docs/transformers/src/transformers/models/t5/convert_t5_original_tf_checkpoint_to_pytorch.py @@ -0,0 +1,59 @@ +# coding=utf-8 +# Copyright 2018 The T5 authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert T5 checkpoint.""" + +import argparse + +from transformers import T5Config, T5ForConditionalGeneration, load_tf_weights_in_t5 +from transformers.utils import logging + + +logging.set_verbosity_info() + + +def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path): + # Initialise PyTorch model + config = T5Config.from_json_file(config_file) + print(f"Building PyTorch model from configuration: {config}") + model = T5ForConditionalGeneration(config) + + # Load weights from tf checkpoint + load_tf_weights_in_t5(model, config, tf_checkpoint_path) + + # Save pytorch-model + print(f"Save PyTorch model to {pytorch_dump_path}") + model.save_pretrained(pytorch_dump_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." + ) + parser.add_argument( + "--config_file", + default=None, + type=str, + required=True, + help=( + "The config json file corresponding to the pre-trained T5 model. \nThis specifies the model architecture." + ), + ) + parser.add_argument( + "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." + ) + args = parser.parse_args() + convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path) diff --git a/docs/transformers/src/transformers/models/t5/convert_t5x_checkpoint_to_flax.py b/docs/transformers/src/transformers/models/t5/convert_t5x_checkpoint_to_flax.py new file mode 100644 index 0000000000000000000000000000000000000000..12498359d21b30f05c09d639a277586b5daa3682 --- /dev/null +++ b/docs/transformers/src/transformers/models/t5/convert_t5x_checkpoint_to_flax.py @@ -0,0 +1,235 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Convert T5X checkpoints from the original repository to JAX/FLAX model.""" + +import argparse + +from t5x import checkpoints + +from transformers import FlaxT5ForConditionalGeneration, T5Config + + +def convert_t5x_checkpoint_to_flax(t5x_checkpoint_path, config_name, flax_dump_folder_path): + config = T5Config.from_pretrained(config_name) + flax_model = FlaxT5ForConditionalGeneration(config=config) + t5x_model = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path) + + split_mlp_wi = "wi_0" in t5x_model["target"]["encoder"]["layers_0"]["mlp"] + + # Encoder + for layer_index in range(config.num_layers): + layer_name = f"layers_{str(layer_index)}" + + # Self-Attention + t5x_attention_key = t5x_model["target"]["encoder"][layer_name]["attention"]["key"]["kernel"] + t5x_attention_out = t5x_model["target"]["encoder"][layer_name]["attention"]["out"]["kernel"] + t5x_attention_query = t5x_model["target"]["encoder"][layer_name]["attention"]["query"]["kernel"] + t5x_attention_value = t5x_model["target"]["encoder"][layer_name]["attention"]["value"]["kernel"] + + # Layer Normalization + t5x_attention_layer_norm = t5x_model["target"]["encoder"][layer_name]["pre_attention_layer_norm"]["scale"] + + if split_mlp_wi: + t5x_mlp_wi_0 = t5x_model["target"]["encoder"][layer_name]["mlp"]["wi_0"]["kernel"] + t5x_mlp_wi_1 = t5x_model["target"]["encoder"][layer_name]["mlp"]["wi_1"]["kernel"] + else: + t5x_mlp_wi = t5x_model["target"]["encoder"][layer_name]["mlp"]["wi"]["kernel"] + + t5x_mlp_wo = t5x_model["target"]["encoder"][layer_name]["mlp"]["wo"]["kernel"] + + # Layer Normalization + t5x_mlp_layer_norm = t5x_model["target"]["encoder"][layer_name]["pre_mlp_layer_norm"]["scale"] + + # Assigning + flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["k"]["kernel"] = ( + t5x_attention_key + ) + flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["o"]["kernel"] = ( + t5x_attention_out + ) + flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["q"]["kernel"] = ( + t5x_attention_query + ) + flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["v"]["kernel"] = ( + t5x_attention_value + ) + + flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["layer_norm"]["weight"] = ( + t5x_attention_layer_norm + ) + + if split_mlp_wi: + flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wi_0"][ + "kernel" + ] = t5x_mlp_wi_0 + flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wi_1"][ + "kernel" + ] = t5x_mlp_wi_1 + else: + flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wi"]["kernel"] = ( + t5x_mlp_wi + ) + + flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wo"]["kernel"] = ( + t5x_mlp_wo + ) + flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["layer_norm"]["weight"] = ( + t5x_mlp_layer_norm + ) + + # Only for layer 0: + t5x_encoder_rel_embedding = t5x_model["target"]["encoder"]["relpos_bias"]["rel_embedding"].T + flax_model.params["encoder"]["block"]["0"]["layer"]["0"]["SelfAttention"]["relative_attention_bias"][ + "embedding" + ] = t5x_encoder_rel_embedding + + # Assigning + t5x_encoder_norm = t5x_model["target"]["encoder"]["encoder_norm"]["scale"] + flax_model.params["encoder"]["final_layer_norm"]["weight"] = t5x_encoder_norm + + # Decoder + for layer_index in range(config.num_decoder_layers): + layer_name = f"layers_{str(layer_index)}" + + # Self-Attention + t5x_attention_key = t5x_model["target"]["decoder"][layer_name]["self_attention"]["key"]["kernel"] + t5x_attention_out = t5x_model["target"]["decoder"][layer_name]["self_attention"]["out"]["kernel"] + t5x_attention_query = t5x_model["target"]["decoder"][layer_name]["self_attention"]["query"]["kernel"] + t5x_attention_value = t5x_model["target"]["decoder"][layer_name]["self_attention"]["value"]["kernel"] + + # Layer Normalization + t5x_pre_attention_layer_norm = t5x_model["target"]["decoder"][layer_name]["pre_self_attention_layer_norm"][ + "scale" + ] + + # Encoder-Decoder-Attention + t5x_enc_dec_attention_key = t5x_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["key"][ + "kernel" + ] + t5x_enc_dec_attention_out = t5x_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["out"][ + "kernel" + ] + t5x_enc_dec_attention_query = t5x_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["query"][ + "kernel" + ] + t5x_enc_dec_attention_value = t5x_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["value"][ + "kernel" + ] + + # Layer Normalization + t5x_cross_layer_norm = t5x_model["target"]["decoder"][layer_name]["pre_cross_attention_layer_norm"]["scale"] + + # MLP + if split_mlp_wi: + t5x_mlp_wi_0 = t5x_model["target"]["decoder"][layer_name]["mlp"]["wi_0"]["kernel"] + t5x_mlp_wi_1 = t5x_model["target"]["decoder"][layer_name]["mlp"]["wi_1"]["kernel"] + else: + t5x_mlp_wi = t5x_model["target"]["decoder"][layer_name]["mlp"]["wi"]["kernel"] + + t5x_mlp_wo = t5x_model["target"]["decoder"][layer_name]["mlp"]["wo"]["kernel"] + + # Layer Normalization + tx5_mlp_layer_norm = t5x_model["target"]["decoder"][layer_name]["pre_mlp_layer_norm"]["scale"] + + # Assigning + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["k"]["kernel"] = ( + t5x_attention_key + ) + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["o"]["kernel"] = ( + t5x_attention_out + ) + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["q"]["kernel"] = ( + t5x_attention_query + ) + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["v"]["kernel"] = ( + t5x_attention_value + ) + + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["layer_norm"]["weight"] = ( + t5x_pre_attention_layer_norm + ) + + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["k"]["kernel"] = ( + t5x_enc_dec_attention_key + ) + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["o"]["kernel"] = ( + t5x_enc_dec_attention_out + ) + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["q"]["kernel"] = ( + t5x_enc_dec_attention_query + ) + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["v"]["kernel"] = ( + t5x_enc_dec_attention_value + ) + + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["layer_norm"]["weight"] = ( + t5x_cross_layer_norm + ) + + if split_mlp_wi: + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wi_0"][ + "kernel" + ] = t5x_mlp_wi_0 + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wi_1"][ + "kernel" + ] = t5x_mlp_wi_1 + else: + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wi"]["kernel"] = ( + t5x_mlp_wi + ) + + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wo"]["kernel"] = ( + t5x_mlp_wo + ) + + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["layer_norm"]["weight"] = ( + tx5_mlp_layer_norm + ) + + # Decoder Normalization + tx5_decoder_norm = t5x_model["target"]["decoder"]["decoder_norm"]["scale"] + flax_model.params["decoder"]["final_layer_norm"]["weight"] = tx5_decoder_norm + + # Only for layer 0: + t5x_decoder_rel_embedding = t5x_model["target"]["decoder"]["relpos_bias"]["rel_embedding"].T + flax_model.params["decoder"]["block"]["0"]["layer"]["0"]["SelfAttention"]["relative_attention_bias"][ + "embedding" + ] = t5x_decoder_rel_embedding + + # Token Embeddings + tx5_token_embeddings = t5x_model["target"]["token_embedder"]["embedding"] + flax_model.params["shared"]["embedding"] = tx5_token_embeddings + + # LM Head (only in v1.1 checkpoints) + if "logits_dense" in t5x_model["target"]["decoder"]: + flax_model.params["lm_head"]["kernel"] = t5x_model["target"]["decoder"]["logits_dense"]["kernel"] + + flax_model.save_pretrained(flax_dump_folder_path) + print("T5X Model was successfully converted!") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--t5x_checkpoint_path", default=None, type=str, required=True, help="Path the TX5 checkpoint." + ) + parser.add_argument("--config_name", default=None, type=str, required=True, help="Config name of T5 model.") + parser.add_argument( + "--flax_dump_folder_path", default=None, type=str, required=True, help="Path to the output FLAX model." + ) + args = parser.parse_args() + convert_t5x_checkpoint_to_flax(args.t5x_checkpoint_path, args.config_name, args.flax_dump_folder_path) diff --git a/docs/transformers/src/transformers/models/t5/convert_t5x_checkpoint_to_pytorch.py b/docs/transformers/src/transformers/models/t5/convert_t5x_checkpoint_to_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..5e7d9ef33d3e8a6c40a726983beab5b3ec6b67f4 --- /dev/null +++ b/docs/transformers/src/transformers/models/t5/convert_t5x_checkpoint_to_pytorch.py @@ -0,0 +1,238 @@ +# coding=utf-8 +# Copyright 2022 Google LLC and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Convert T5X checkpoint to PyTorch + +Steps: +- Install gsutil according to https://cloud.google.com/storage/docs/gsutil_install +- Get a T5X checkpoint at https://github.com/google-research/t5x/blob/main/docs/models.md#t5-11-checkpoints Example: + `gsutil -m cp -r gs://t5-data/pretrained_models/t5x/t5_1_1_small $HOME/` +- Create or download a corresponding config for the downloaded model. E.g. for T5 v1.1 small, you can use + https://huggingface.co/google/t5-v1_1-small/blob/main/config.json +- Convert: + ``` + python3 convert_t5x_checkpoint_to_pytorch.py --t5x_checkpoint_path=$HOME/t5_1_1_small --config_file=config.json\ + --pytorch_dump_path=$HOME/t5_1_1_small_pt + ``` +""" + +import argparse +import collections + +import torch +from flax import traverse_util +from t5x import checkpoints + +from transformers import T5Config, T5EncoderModel, T5ForConditionalGeneration +from transformers.utils import logging + + +logging.set_verbosity_info() + + +def t5x_attention_lookup(params, i, prefix, layer_name="attention"): + """Returns the KOQV parameters of (self-)attention. Does not transpose.""" + k = params[f"{prefix}/layers_{i}/{layer_name}/key/kernel"] + o = params[f"{prefix}/layers_{i}/{layer_name}/out/kernel"] + q = params[f"{prefix}/layers_{i}/{layer_name}/query/kernel"] + v = params[f"{prefix}/layers_{i}/{layer_name}/value/kernel"] + return k, o, q, v + + +def t5x_mlp_lookup(params, i, prefix, split_mlp_wi=False): + """Returns the MLP parameters of a layer. Does not transpose.""" + if split_mlp_wi: + wi_0 = params[f"{prefix}/layers_{i}/mlp/wi_0/kernel"] + wi_1 = params[f"{prefix}/layers_{i}/mlp/wi_1/kernel"] + wi = (wi_0, wi_1) + else: + wi = params[f"{prefix}/layers_{i}/mlp/wi/kernel"] + + wo = params[f"{prefix}/layers_{i}/mlp/wo/kernel"] + return wi, wo + + +def t5x_layer_norm_lookup(params, i, prefix, layer_name): + """Returns the layer norm param of a layer.""" + return params[f"{prefix}/layers_{i}/{layer_name}/scale"] + + +def convert_t5x_to_pytorch(variables: dict, *, num_layers: int, num_decoder_layers: int, is_encoder_only: bool): + """Converts the parameters from T5X-Flax to Transformers-PyTorch.""" + old = traverse_util.flatten_dict(variables["target"]) + old = {"/".join(k): v for k, v in old.items()} + + # v1.1 models have a gated GeLU with wi_0 and wi_1 instead of wi + split_mlp_wi = "encoder/layers_0/mlp/wi_0/kernel" in old + print("Split MLP:", split_mlp_wi) + + new = collections.OrderedDict() + + # Shared embeddings. + new["shared.weight"] = old["token_embedder/embedding"] + + # Encoder. + for i in range(num_layers): + # Block i, layer 0 (Self Attention). + layer_norm = t5x_layer_norm_lookup(old, i, "encoder", "pre_attention_layer_norm") + k, o, q, v = t5x_attention_lookup(old, i, "encoder", "attention") + new[f"encoder.block.{i}.layer.0.layer_norm.weight"] = layer_norm + new[f"encoder.block.{i}.layer.0.SelfAttention.k.weight"] = k.T + new[f"encoder.block.{i}.layer.0.SelfAttention.o.weight"] = o.T + new[f"encoder.block.{i}.layer.0.SelfAttention.q.weight"] = q.T + new[f"encoder.block.{i}.layer.0.SelfAttention.v.weight"] = v.T + + # Block i, layer 1 (MLP). + layer_norm = t5x_layer_norm_lookup(old, i, "encoder", "pre_mlp_layer_norm") + wi, wo = t5x_mlp_lookup(old, i, "encoder", split_mlp_wi) + new[f"encoder.block.{i}.layer.1.layer_norm.weight"] = layer_norm + if split_mlp_wi: + new[f"encoder.block.{i}.layer.1.DenseReluDense.wi_0.weight"] = wi[0].T + new[f"encoder.block.{i}.layer.1.DenseReluDense.wi_1.weight"] = wi[1].T + else: + new[f"encoder.block.{i}.layer.1.DenseReluDense.wi.weight"] = wi.T + new[f"encoder.block.{i}.layer.1.DenseReluDense.wo.weight"] = wo.T + + new["encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"] = old[ + "encoder/relpos_bias/rel_embedding" + ].T + new["encoder.final_layer_norm.weight"] = old["encoder/encoder_norm/scale"] + + if not is_encoder_only: + # Decoder. + for i in range(num_decoder_layers): + # Block i, layer 0 (Self Attention). + layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_self_attention_layer_norm") + k, o, q, v = t5x_attention_lookup(old, i, "decoder", "self_attention") + new[f"decoder.block.{i}.layer.0.layer_norm.weight"] = layer_norm + new[f"decoder.block.{i}.layer.0.SelfAttention.k.weight"] = k.T + new[f"decoder.block.{i}.layer.0.SelfAttention.o.weight"] = o.T + new[f"decoder.block.{i}.layer.0.SelfAttention.q.weight"] = q.T + new[f"decoder.block.{i}.layer.0.SelfAttention.v.weight"] = v.T + + # Block i, layer 1 (Cross Attention). + layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_cross_attention_layer_norm") + k, o, q, v = t5x_attention_lookup(old, i, "decoder", "encoder_decoder_attention") + new[f"decoder.block.{i}.layer.1.layer_norm.weight"] = layer_norm + new[f"decoder.block.{i}.layer.1.EncDecAttention.k.weight"] = k.T + new[f"decoder.block.{i}.layer.1.EncDecAttention.o.weight"] = o.T + new[f"decoder.block.{i}.layer.1.EncDecAttention.q.weight"] = q.T + new[f"decoder.block.{i}.layer.1.EncDecAttention.v.weight"] = v.T + + # Block i, layer 2 (MLP). + layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_mlp_layer_norm") + wi, wo = t5x_mlp_lookup(old, i, "decoder", split_mlp_wi) + new[f"decoder.block.{i}.layer.2.layer_norm.weight"] = layer_norm + if split_mlp_wi: + new[f"decoder.block.{i}.layer.2.DenseReluDense.wi_0.weight"] = wi[0].T + new[f"decoder.block.{i}.layer.2.DenseReluDense.wi_1.weight"] = wi[1].T + else: + new[f"decoder.block.{i}.layer.2.DenseReluDense.wi.weight"] = wi.T + new[f"decoder.block.{i}.layer.2.DenseReluDense.wo.weight"] = wo.T + + new["decoder.final_layer_norm.weight"] = old["decoder/decoder_norm/scale"] + new["decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"] = old[ + "decoder/relpos_bias/rel_embedding" + ].T + + # LM Head (only in v1.1 checkpoints, in v1.0 embeddings are used instead) + if "decoder/logits_dense/kernel" in old: + new["lm_head.weight"] = old["decoder/logits_dense/kernel"].T + + return new + + +def make_state_dict(converted_params, is_encoder_only: bool): + """Prepares a state dict for the PyTorch model.""" + # Make a state dict with torch tensors. + state_dict = collections.OrderedDict([(k, torch.from_numpy(v.copy())) for (k, v) in converted_params.items()]) + + # Add what is missing. + if "encoder.embed_tokens.weight" not in state_dict: + state_dict["encoder.embed_tokens.weight"] = state_dict["shared.weight"] + + if not is_encoder_only: + if "decoder.embed_tokens.weight" not in state_dict: + state_dict["decoder.embed_tokens.weight"] = state_dict["shared.weight"] + + if "lm_head.weight" not in state_dict: # For old 1.0 models. + print("Using shared word embeddings as lm_head.") + state_dict["lm_head.weight"] = state_dict["shared.weight"] + + return state_dict + + +def load_t5x_weights_in_t5(model, config, t5x_checkpoint_path, is_encoder_only): + """Replaces the params in model witht the T5X converted params.""" + variables = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path) + converted = convert_t5x_to_pytorch( + variables, + num_layers=config.num_layers, + num_decoder_layers=config.num_decoder_layers, + is_encoder_only=is_encoder_only, + ) + state_dict = make_state_dict(converted, is_encoder_only) + model.load_state_dict(state_dict, strict=True) + + +def convert_t5x_checkpoint_to_pytorch( + t5x_checkpoint_path, config_file, pytorch_dump_path, is_encoder_only: bool = False +): + """Loads the config and model, converts the T5X checkpoint, and saves a PyTorch checkpoint.""" + # Initialise PyTorch model + config = T5Config.from_json_file(config_file) + print(f"Building PyTorch model from configuration: {config}") + # Non-v1.1 checkpoints could also use T5Model, but this works for all. + # The v1.0 checkpoints will simply have an LM head that is the word embeddings. + if is_encoder_only: + model = T5EncoderModel(config) + else: + model = T5ForConditionalGeneration(config) + + # Load weights from tf checkpoint + load_t5x_weights_in_t5(model, config, t5x_checkpoint_path, is_encoder_only) + + # Save pytorch-model + print(f"Save PyTorch model to {pytorch_dump_path}") + model.save_pretrained(pytorch_dump_path) + + # Verify that we can load the checkpoint. + model.from_pretrained(pytorch_dump_path) + print("Done") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Converts a native T5X checkpoint into a PyTorch checkpoint.") + # Required parameters + parser.add_argument( + "--t5x_checkpoint_path", default=None, type=str, required=True, help="Path to the T5X checkpoint." + ) + parser.add_argument( + "--config_file", + default=None, + type=str, + required=True, + help="The config json file corresponding to the pre-trained T5 model.\nThis specifies the model architecture.", + ) + parser.add_argument( + "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." + ) + parser.add_argument( + "--is_encoder_only", action="store_true", help="Check if the model is encoder-decoder model", default=False + ) + args = parser.parse_args() + convert_t5x_checkpoint_to_pytorch( + args.t5x_checkpoint_path, args.config_file, args.pytorch_dump_path, args.is_encoder_only + ) diff --git a/docs/transformers/src/transformers/models/t5/download_from_gcp.sh b/docs/transformers/src/transformers/models/t5/download_from_gcp.sh new file mode 100644 index 0000000000000000000000000000000000000000..fece45c5187cb9cada4fff18f014e3a7cebcd94a --- /dev/null +++ b/docs/transformers/src/transformers/models/t5/download_from_gcp.sh @@ -0,0 +1,35 @@ +#!/usr/bin/env bash +# Use this script as follows ./download_from_gcp.sh /path/to/folder/to/store/downloads +folder_to_store_downloads=${1} + +# Replace by gcp_path to T5 cloud bucket folder here +# To download the official `t5-small` model of https://github.com/google-research/text-to-text-transfer-transformer#released-model-checkpoints: +gcp_path="gs://t5-data/pretrained_models/small" + +# Number of files the checkpoint is split into +num_of_checks=16 + +# Create dir if not exist +mkdir -p ${folder_to_store_downloads} + +# Copy all meta information files +gsutil cp "${gcp_path}/operative_config.gin" ${folder_to_store_downloads} +gsutil cp "${gcp_path}/checkpoint" ${folder_to_store_downloads} +gsutil cp "${gcp_path}/model.ckpt-1000000.index" ${folder_to_store_downloads} +gsutil cp "${gcp_path}/model.ckpt-1000000.meta" ${folder_to_store_downloads} + +# Copy all model weights +# single digit num checkpoitns +for ((i = 0 ; i < ${num_of_checks} ; i++)); do + gsutil cp "${gcp_path}/model.ckpt-1000000.data-0000${i}-of-000${num_of_checks}" ${folder_to_store_downloads} +done + +# double digit num checkpoints +for ((i = 0 ; i < ${num_of_checks} ; i++)); do + gsutil cp "${gcp_path}/model.ckpt-1000000.data-000${i}-of-000${num_of_checks}" ${folder_to_store_downloads} +done + + +# Having run this script, you should create a suitable config.json, *e.g.* by +# looking at `https://huggingface.co/t5-small`. +# Then you can run `python convert_t5_original_tf_checkpoint_to_pytorch.py --tf_checkpoint_path "${folder_to_store_downloads}" --config_file "config.json" --pytorch_dump_path "/path/to/store/pytorch/weights" diff --git a/docs/transformers/src/transformers/models/t5/modeling_flax_t5.py b/docs/transformers/src/transformers/models/t5/modeling_flax_t5.py new file mode 100644 index 0000000000000000000000000000000000000000..be76fe1b7724c0901f1f43661cd37e31a59a345a --- /dev/null +++ b/docs/transformers/src/transformers/models/t5/modeling_flax_t5.py @@ -0,0 +1,1801 @@ +# coding=utf-8 +# Copyright 2021 T5 Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Flax T5 model.""" + +import copy +from typing import Callable, Optional, Tuple + +import flax.linen as nn +import jax +import jax.numpy as jnp +import numpy as np +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze +from flax.linen import combine_masks, make_causal_mask +from flax.linen import partitioning as nn_partitioning +from flax.linen.attention import dot_product_attention_weights +from flax.traverse_util import flatten_dict, unflatten_dict +from jax.random import PRNGKey + +from ...modeling_flax_outputs import ( + FlaxBaseModelOutput, + FlaxBaseModelOutputWithPastAndCrossAttentions, + FlaxCausalLMOutputWithCrossAttentions, + FlaxSeq2SeqLMOutput, + FlaxSeq2SeqModelOutput, +) +from ...modeling_flax_utils import ( + ACT2FN, + FlaxPreTrainedModel, + append_call_sample_docstring, + append_replace_return_docstrings, + overwrite_call_docstring, +) +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from .configuration_t5 import T5Config + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "google-t5/t5-small" +_CONFIG_FOR_DOC = "T5Config" + +remat = nn_partitioning.remat + + +# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right +def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: + """ + Shift input ids one token to the right. + """ + shifted_input_ids = jnp.zeros_like(input_ids) + shifted_input_ids = shifted_input_ids.at[:, 1:].set(input_ids[:, :-1]) + shifted_input_ids = shifted_input_ids.at[:, 0].set(decoder_start_token_id) + + shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids) + return shifted_input_ids + + +class FlaxT5LayerNorm(nn.Module): + hidden_size: int + dtype: jnp.dtype = jnp.float32 + eps: float = 1e-6 + weight_init: Callable[..., np.ndarray] = jax.nn.initializers.ones + + def setup(self): + self.weight = self.param("weight", self.weight_init, (self.hidden_size,)) + + def __call__(self, hidden_states): + """ + Construct a layernorm module in the T5 style; No bias and no subtraction of mean. + """ + # layer norm should always be calculated in float32 + variance = jnp.power(hidden_states.astype("f4"), 2).mean(axis=-1, keepdims=True) + hidden_states = hidden_states / jnp.sqrt(variance + self.eps) + + return self.weight * hidden_states + + +class FlaxT5DenseActDense(nn.Module): + config: T5Config + dtype: jnp.dtype = jnp.float32 + + def setup(self): + wi_init_std = self.config.initializer_factor * (self.config.d_model**-0.5) + wo_init_std = self.config.initializer_factor * (self.config.d_ff**-0.5) + + self.wi = nn.Dense( + self.config.d_ff, + use_bias=False, + kernel_init=jax.nn.initializers.normal(wi_init_std), + dtype=self.dtype, + ) + self.wo = nn.Dense( + self.config.d_model, + use_bias=False, + kernel_init=jax.nn.initializers.normal(wo_init_std), + dtype=self.dtype, + ) + self.dropout = nn.Dropout(self.config.dropout_rate) + self.act = ACT2FN[self.config.dense_act_fn] + + def __call__(self, hidden_states, deterministic=True): + hidden_states = self.wi(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + hidden_states = self.wo(hidden_states) + return hidden_states + + +class FlaxT5DenseGatedActDense(nn.Module): + config: T5Config + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + wi_init_std = self.config.initializer_factor * (self.config.d_model**-0.5) + wo_init_std = self.config.initializer_factor * (self.config.d_ff**-0.5) + + self.wi_0 = nn.Dense( + self.config.d_ff, + use_bias=False, + kernel_init=jax.nn.initializers.normal(wi_init_std), + dtype=self.dtype, + ) + self.wi_1 = nn.Dense( + self.config.d_ff, + use_bias=False, + kernel_init=jax.nn.initializers.normal(wi_init_std), + dtype=self.dtype, + ) + self.wo = nn.Dense( + self.config.d_model, + use_bias=False, + kernel_init=jax.nn.initializers.normal(wo_init_std), + dtype=self.dtype, + ) + self.dropout = nn.Dropout(self.config.dropout_rate) + self.act = ACT2FN[self.config.dense_act_fn] + + def __call__(self, hidden_states, deterministic): + hidden_gelu = self.act(self.wi_0(hidden_states)) + hidden_linear = self.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + hidden_states = self.wo(hidden_states) + return hidden_states + + +class FlaxT5LayerFF(nn.Module): + config: T5Config + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + if self.config.is_gated_act: + self.DenseReluDense = FlaxT5DenseGatedActDense(self.config, dtype=self.dtype) + else: + self.DenseReluDense = FlaxT5DenseActDense(self.config, dtype=self.dtype) + + self.layer_norm = FlaxT5LayerNorm(self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype) + self.dropout = nn.Dropout(self.config.dropout_rate) + + def __call__(self, hidden_states, deterministic=True): + forwarded_states = self.layer_norm(hidden_states) + forwarded_states = self.DenseReluDense(forwarded_states, deterministic=deterministic) + hidden_states = hidden_states + self.dropout(forwarded_states, deterministic=deterministic) + return hidden_states + + +class FlaxT5Attention(nn.Module): + config: T5Config + has_relative_attention_bias: bool = False + causal: bool = False + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.relative_attention_num_buckets = self.config.relative_attention_num_buckets + self.relative_attention_max_distance = self.config.relative_attention_max_distance + self.d_model = self.config.d_model + self.key_value_proj_dim = self.config.d_kv + self.n_heads = self.config.num_heads + self.dropout = self.config.dropout_rate + self.inner_dim = self.n_heads * self.key_value_proj_dim + + q_init_std = self.config.initializer_factor * ((self.inner_dim * self.key_value_proj_dim) ** -0.5) + kv_init_std = self.config.initializer_factor * (self.inner_dim**-0.5) + o_init_std = self.config.initializer_factor * (self.inner_dim**-0.5) + + self.q = nn.Dense( + self.inner_dim, + use_bias=False, + kernel_init=jax.nn.initializers.normal(q_init_std), + dtype=self.dtype, + ) + self.k = nn.Dense( + self.inner_dim, + use_bias=False, + kernel_init=jax.nn.initializers.normal(kv_init_std), + dtype=self.dtype, + ) + self.v = nn.Dense( + self.inner_dim, + use_bias=False, + kernel_init=jax.nn.initializers.normal(kv_init_std), + dtype=self.dtype, + ) + self.o = nn.Dense( + self.d_model, + use_bias=False, + kernel_init=jax.nn.initializers.normal(o_init_std), + dtype=self.dtype, + ) + + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embed( + self.relative_attention_num_buckets, + self.n_heads, + embedding_init=jax.nn.initializers.normal(kv_init_std), + dtype=self.dtype, + ) + + @staticmethod + def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0) * num_buckets + relative_position = jnp.abs(relative_position) + else: + relative_position = -jnp.clip(relative_position, a_max=0) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_position_if_large = max_exact + ( + jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact) + ) + relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1) + + relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large) + + return relative_buckets.astype("i4") + + def compute_bias(self, query_length, key_length): + """Compute binned relative position bias""" + context_position = jnp.arange(query_length, dtype="i4")[:, None] + memory_position = jnp.arange(key_length, dtype="i4")[None, :] + + relative_position = memory_position - context_position + relative_position_bucket = self._relative_position_bucket( + relative_position, + bidirectional=(not self.causal), + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + + values = self.relative_attention_bias(relative_position_bucket) + values = values.transpose((2, 0, 1))[None, :, :, :] + return values + + def _split_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.n_heads, self.key_value_proj_dim)) + + def _merge_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.inner_dim,)) + + @nn.compact + def _concatenate_to_cache(self, key, value, query, attention_mask): + """ + This function takes projected key, value states from a single input token and concatenates the states to cached + states from previous steps. This function is slightly adapted from the official Flax repository: + https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 + """ + # detect if we're initializing by absence of existing cache data. + is_initialized = self.has_variable("cache", "cached_key") + cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) + cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) + cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) + + if is_initialized: + *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape + # update key, value caches with our new 1d spatial slices + cur_index = cache_index.value + indices = (0,) * len(batch_dims) + (cur_index, 0, 0) + key = jax.lax.dynamic_update_slice(cached_key.value, key, indices) + value = jax.lax.dynamic_update_slice(cached_value.value, value, indices) + cached_key.value = key + cached_value.value = value + num_updated_cache_vectors = query.shape[1] + cache_index.value = cache_index.value + num_updated_cache_vectors + # causal mask for cached decoder self-attention: our single query position should only attend to those key positions + # that have already been generated and cached, not the remaining zero elements. + pad_mask = jnp.broadcast_to( + jnp.arange(max_length) < cur_index + num_updated_cache_vectors, + tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), + ) + attention_mask = combine_masks(pad_mask, attention_mask) + return key, value, attention_mask + + def _create_position_bias( + self, key_states, query_states, attention_mask, init_cache, seq_length, causal_attention_mask_shift + ): + cache_is_filled = self.causal and self.has_variable("cache", "cached_key") and (not init_cache) + key_length = key_states.shape[1] + query_length = key_length if cache_is_filled else query_states.shape[1] + + if self.has_relative_attention_bias: + position_bias = self.compute_bias(query_length, key_length) + elif attention_mask is not None: + position_bias = jnp.zeros_like(attention_mask) + else: + position_bias = jnp.zeros((1, self.n_heads, query_length, key_length), dtype=self.dtype) + + # if key and values are already calculated, only the last query position bias should be taken + if cache_is_filled: + max_decoder_length = self.variables["cache"]["cached_key"].shape[1] + position_bias = jax.lax.dynamic_slice( + position_bias, + (0, 0, causal_attention_mask_shift, 0), + (1, self.n_heads, seq_length, max_decoder_length), + ) + return position_bias + + def __call__( + self, + hidden_states, + attention_mask=None, + key_value_states=None, + position_bias=None, + use_cache=False, + output_attentions=False, + deterministic=True, + init_cache=False, + ): + """ + Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). + """ + batch_size, seq_length = hidden_states.shape[:2] + + # q, k, v projections + query_states = self.q(hidden_states) # (batch_size, n_heads, seq_length, dim_per_head) + key_states = self.k(hidden_states) if key_value_states is None else self.k(key_value_states) + value_states = self.v(hidden_states) if key_value_states is None else self.v(key_value_states) + + # reshape to (batch_size, seq_length, n_heads, head_dim) + query_states = self._split_heads(query_states) + key_states = self._split_heads(key_states) + value_states = self._split_heads(value_states) + + # counter-act scaling in dot_product_attention_weights function + query_states *= jnp.sqrt(query_states.shape[-1]) + + # for fast decoding causal attention mask should be shifted + causal_attention_mask_shift = ( + self.variables["cache"]["cache_index"] if (self.has_variable("cache", "cached_key") and self.causal) else 0 + ) + # create causal attention_mask; attention_mask has to be defined when model is causal + if self.causal: + causal_attention_mask = make_causal_mask(attention_mask, dtype="bool") + + # fast decoding for generate requires special attention_mask + if self.has_variable("cache", "cached_key"): + max_decoder_length = self.variables["cache"]["cached_key"].shape[1] + causal_attention_mask = jax.lax.dynamic_slice( + causal_attention_mask, + (0, 0, causal_attention_mask_shift, 0), + (1, 1, seq_length, max_decoder_length), + ) + + # broadcast causal attention mask & attention mask to fit for merge + causal_attention_mask = jnp.broadcast_to( + causal_attention_mask, (batch_size,) + causal_attention_mask.shape[1:] + ) + attention_mask = jnp.broadcast_to( + jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_attention_mask.shape + ) + attention_mask = combine_masks(attention_mask, causal_attention_mask) + elif attention_mask is not None: + attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) + + # During fast autoregressive decoding, we feed one position at a time, + # and cache the keys and values step by step. + if self.causal and (self.has_variable("cache", "cached_key") or init_cache): + key_states, value_states, attention_mask = self._concatenate_to_cache( + key_states, value_states, query_states, attention_mask + ) + + # replace masked positions with -10_000 + if attention_mask is not None: + mask_value = jnp.finfo(self.dtype).min + attention_mask = jax.lax.select( + attention_mask > 0, + jnp.full(attention_mask.shape, 0.0).astype(self.dtype), + jnp.full(attention_mask.shape, mask_value).astype(self.dtype), + ) + + if position_bias is None: + # compute position bias (only for first layer) + position_bias = self._create_position_bias( + key_states, query_states, attention_mask, init_cache, seq_length, causal_attention_mask_shift + ) + + if attention_mask is not None: + position_bias = position_bias + attention_mask + + # create dropout rng + dropout_rng = None + if not deterministic and self.dropout > 0.0: + dropout_rng = self.make_rng("dropout") + + # Softmax(QK^T) + attn_weights = dot_product_attention_weights( + query_states, + key_states, + bias=position_bias, + dropout_rng=dropout_rng, + dropout_rate=self.dropout, + broadcast_dropout=True, + deterministic=deterministic, + dtype=self.dtype, + ) + + # multiply with value states + attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) + + # bring back to (batch_size, seq_length, d_model) + attn_output = self._merge_heads(attn_output) + + # apply output matrix + attn_output = self.o(attn_output) + + outputs = (attn_output, position_bias) + + if output_attentions: + outputs = outputs + (attn_weights,) + + return outputs + + +class FlaxT5LayerSelfAttention(nn.Module): + config: T5Config + has_relative_attention_bias: bool = False + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.SelfAttention = FlaxT5Attention( + self.config, + has_relative_attention_bias=self.has_relative_attention_bias, + causal=self.config.causal, + dtype=self.dtype, + ) + self.layer_norm = FlaxT5LayerNorm(self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype) + self.dropout = nn.Dropout(self.config.dropout_rate) + + def __call__( + self, + hidden_states, + attention_mask=None, + position_bias=None, + output_attentions=False, + deterministic=True, + init_cache=False, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.SelfAttention( + normed_hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + output_attentions=output_attentions, + deterministic=deterministic, + init_cache=init_cache, + ) + hidden_states = hidden_states + self.dropout(attention_output[0], deterministic=deterministic) + outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them + return outputs + + +class FlaxT5LayerCrossAttention(nn.Module): + config: T5Config + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.EncDecAttention = FlaxT5Attention( + self.config, has_relative_attention_bias=False, causal=False, dtype=self.dtype + ) + self.layer_norm = FlaxT5LayerNorm(self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype) + self.dropout = nn.Dropout(self.config.dropout_rate) + + def __call__( + self, + hidden_states, + key_value_states, + attention_mask=None, + position_bias=None, + output_attentions=False, + deterministic=True, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.EncDecAttention( + normed_hidden_states, + attention_mask=attention_mask, + key_value_states=key_value_states, + position_bias=position_bias, + output_attentions=output_attentions, + ) + hidden_states = hidden_states + self.dropout(attention_output[0], deterministic=deterministic) + outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them + return outputs + + +class FlaxT5Block(nn.Module): + config: T5Config + has_relative_attention_bias: bool = False + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.causal = self.config.causal + self.layer = ( + FlaxT5LayerSelfAttention( + self.config, + has_relative_attention_bias=self.has_relative_attention_bias, + name=str(0), + dtype=self.dtype, + ), + ) + feed_forward_index = 1 + if self.causal: + self.layer += (FlaxT5LayerCrossAttention(self.config, name=str(1), dtype=self.dtype),) + feed_forward_index += 1 + + self.layer += (FlaxT5LayerFF(self.config, name=str(feed_forward_index), dtype=self.dtype),) + + def __call__( + self, + hidden_states, + attention_mask=None, + position_bias=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + encoder_decoder_position_bias=None, + output_attentions=False, + return_dict=True, + deterministic=True, + init_cache=False, + ): + self_attention_outputs = self.layer[0]( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + output_attentions=output_attentions, + deterministic=deterministic, + init_cache=init_cache, + ) + hidden_states = self_attention_outputs[0] + attention_outputs = self_attention_outputs[1:] # Keep self-attention outputs and relative position weights + + do_cross_attention = self.causal and encoder_hidden_states is not None + if do_cross_attention: + cross_attention_outputs = self.layer[1]( + hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + position_bias=encoder_decoder_position_bias, + output_attentions=output_attentions, + deterministic=deterministic, + ) + hidden_states = cross_attention_outputs[0] + + # Keep cross-attention outputs and relative position weights + attention_outputs = attention_outputs + cross_attention_outputs[1:] + + # Apply Feed Forward layer + hidden_states = self.layer[-1](hidden_states, deterministic=deterministic) + + outputs = (hidden_states,) + + outputs = outputs + attention_outputs + + # returns hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), + # (cross-attention position bias), (cross-attention weights) + return outputs + + +class FlaxT5LayerCollection(nn.Module): + config: T5Config + has_relative_attention_bias: bool + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.layer = FlaxT5Block( + self.config, has_relative_attention_bias=self.has_relative_attention_bias, dtype=self.dtype + ) + + def __call__( + self, + hidden_states, + attention_mask=None, + position_bias=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + encoder_decoder_position_bias=None, + output_attentions=False, + deterministic=True, + init_cache=False, + ): + return self.layer( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + output_attentions=output_attentions, + deterministic=deterministic, + init_cache=init_cache, + ) + + +class FlaxT5BlockCollection(nn.Module): + config: T5Config + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False + + def setup(self): + self.causal = self.config.causal + if self.gradient_checkpointing: + FlaxT5CheckpointLayer = remat(FlaxT5LayerCollection, static_argnums=(6, 7, 8)) + self.blocks = [ + FlaxT5CheckpointLayer( + self.config, + has_relative_attention_bias=(i == 0), + dtype=self.dtype, + name=str(i), + ) + for i in range(self.config.num_layers) + ] + else: + self.blocks = [ + FlaxT5LayerCollection( + self.config, + has_relative_attention_bias=(i == 0), + dtype=self.dtype, + name=str(i), + ) + for i in range(self.config.num_layers) + ] + + def __call__( + self, + hidden_states=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + output_attentions: bool = False, + output_hidden_states: bool = False, + deterministic: bool = True, + init_cache: bool = False, + ): + # Prepare head mask if needed + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_cross_attentions = () if (output_attentions and self.causal) else None + position_bias = None + encoder_decoder_position_bias = None + + for i, layer_module in enumerate(self.blocks): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_outputs = layer_module( + hidden_states, + attention_mask, + position_bias, + encoder_hidden_states, + encoder_attention_mask, + encoder_decoder_position_bias, + output_attentions, + deterministic, + init_cache, + ) + + hidden_states = layer_outputs[0] + + # We share the position biases between the layers - the first layer store them + # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), + # (cross-attention position bias), (cross-attention weights) + position_bias = layer_outputs[1] + + if self.causal and encoder_hidden_states is not None: + encoder_decoder_position_bias = layer_outputs[3 if output_attentions else 2] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[2],) + if self.causal: + all_cross_attentions = all_cross_attentions + (layer_outputs[4],) + + return FlaxBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ) + + +class FlaxT5Stack(nn.Module): + config: T5Config + embed_tokens: nn.Embed + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False + + def setup(self): + self.causal = self.config.causal + + self.block = FlaxT5BlockCollection( + self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing + ) + self.final_layer_norm = FlaxT5LayerNorm( + self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype + ) + self.dropout = nn.Dropout(self.config.dropout_rate) + + def __call__( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + init_cache: bool = False, + ): + hidden_states = self.embed_tokens(input_ids) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + + outputs = self.block( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + deterministic=deterministic, + init_cache=init_cache, + ) + + hidden_states = outputs[0] + + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + + # Add last layer + all_hidden_states = None + + if output_hidden_states: + all_hidden_states = outputs.hidden_states + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + if output_hidden_states: + return ( + hidden_states, + all_hidden_states, + ) + outputs[2:] + return (hidden_states,) + outputs[1:] + + return FlaxBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +T5_ENCODE_INPUTS_DOCSTRING = r""" + Args: + input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you + should be able to pad the inputs on both the right and the left. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for detail. + + To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training). + attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +T5_DECODE_INPUTS_DOCSTRING = r""" + Args: + decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + For training, `decoder_input_ids` should be provided. + encoder_outputs (`tuple(tuple(jnp.ndarray)`): + Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of + hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. + encoder_attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + + If you want to change padding behavior, you should modify to your needs. See diagram 1 in [the + paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy. + past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`): + Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast + auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +T5_INPUTS_DOCSTRING = r""" + Args: + input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you + should be able to pad the inputs on both the right and the left. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for detail. + + [What are input IDs?](../glossary#input-ids) + + To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training). + attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + T5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values` + is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). + + To know more on how to prepare `decoder_input_ids` for pretraining take a look at [T5 + Training](./t5#training). + decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + encoder_outputs (`tuple(tuple(jnp.ndarray)`, *optional*): + Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, `optional`: *attentions*) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a sequence of hidden states at + the output of the last layer of the encoder. Used in the cross-attention of the decoder. + past_key_values (`tuple(tuple(jnp.ndarray))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class FlaxT5PreTrainedModel(FlaxPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = T5Config + base_model_prefix = "transformer" + module_class: nn.Module = None + + def __init__( + self, + config: T5Config, + input_shape: Tuple[int] = (1, 1), + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, + gradient_checkpointing: bool = False, + **kwargs, + ): + module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + + def enable_gradient_checkpointing(self): + self._module = self.module_class( + config=self.config, + dtype=self.dtype, + gradient_checkpointing=True, + ) + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: + # init input tensors + input_ids = jnp.zeros(input_shape, dtype="i4") + + attention_mask = jnp.ones_like(input_ids) + args = [input_ids, attention_mask] + if self.module_class not in [FlaxT5EncoderModule]: + decoder_input_ids = jnp.ones_like(input_ids) + decoder_attention_mask = jnp.ones_like(input_ids) + args.extend([decoder_input_ids, decoder_attention_mask]) + + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + random_params = self.module.init( + rngs, + *args, + )["params"] + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + + @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING) + def __call__( + self, + input_ids: jnp.ndarray, + attention_mask: Optional[jnp.ndarray] = None, + decoder_input_ids: jnp.ndarray = None, + decoder_attention_mask: Optional[jnp.ndarray] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: dict = None, + dropout_rng: PRNGKey = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + if decoder_input_ids is None: + raise ValueError( + "Make sure to provide both `input_ids` and `decoder_input_ids`. `decoder_input_ids` is not passed" + " here." + ) + + # prepare encoder inputs + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + + # prepare decoder inputs + if decoder_attention_mask is None: + decoder_attention_mask = jnp.ones_like(decoder_input_ids) + + # Handle any PRNG if needed + rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} + + return self.module.apply( + {"params": params or self.params}, + input_ids=jnp.array(input_ids, dtype="i4"), + attention_mask=jnp.array(attention_mask, dtype="i4"), + decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), + decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + ) + + def init_cache(self, batch_size, max_length, encoder_outputs): + r""" + Args: + batch_size (`int`): + batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. + max_length (`int`): + maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized + cache. + encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`): + `encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: + `attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) + is a sequence of hidden-states at the output of the last layer of the encoder. Used in the + cross-attention of the decoder. + """ + # init input variables to retrieve cache + decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4") + decoder_attention_mask = jnp.ones_like(decoder_input_ids) + + def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs): + decoder_module = module._get_decoder_module() + return decoder_module( + decoder_input_ids, + decoder_attention_mask, + **kwargs, + ) + + init_variables = self.module.init( + jax.random.PRNGKey(0), + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + init_cache=True, + method=_decoder_forward, # we only need to call the decoder to init the cache + ) + return unfreeze(init_variables["cache"]) + + @add_start_docstrings(T5_ENCODE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=T5Config) + def encode( + self, + input_ids: jnp.ndarray, + attention_mask: Optional[jnp.ndarray] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: dict = None, + dropout_rng: PRNGKey = None, + ): + r""" + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, FlaxT5ForConditionalGeneration + + >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small") + >>> model = FlaxT5ForConditionalGeneration.from_pretrained("google-t5/t5-small") + + >>> text = "My friends are cool but they eat too many carbs." + >>> inputs = tokenizer(text, return_tensors="np") + >>> encoder_outputs = model.encode(**inputs) + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + def _encoder_forward(module, input_ids, attention_mask, **kwargs): + encode_module = module._get_encoder_module() + return encode_module(input_ids, attention_mask, **kwargs) + + return self.module.apply( + {"params": params or self.params}, + input_ids=jnp.array(input_ids, dtype="i4"), + attention_mask=jnp.array(attention_mask, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + method=_encoder_forward, + ) + + @add_start_docstrings(T5_DECODE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=FlaxBaseModelOutputWithPastAndCrossAttentions, config_class=T5Config) + def decode( + self, + decoder_input_ids, + encoder_outputs, + encoder_attention_mask: Optional[jnp.ndarray] = None, + decoder_attention_mask: Optional[jnp.ndarray] = None, + past_key_values: dict = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: dict = None, + dropout_rng: PRNGKey = None, + ): + r""" + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, FlaxT5ForConditionalGeneration + >>> import jax.numpy as jnp + + >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small") + >>> model = FlaxT5ForConditionalGeneration.from_pretrained("google-t5/t5-small") + + >>> text = "My friends are cool but they eat too many carbs." + >>> inputs = tokenizer(text, return_tensors="np") + >>> encoder_outputs = model.encode(**inputs) + + >>> decoder_start_token_id = model.config.decoder_start_token_id + >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id + + >>> outputs = model.decode(decoder_input_ids, encoder_outputs) + >>> logits = outputs.logits + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + encoder_hidden_states = encoder_outputs[0] + if encoder_attention_mask is None: + batch_size, sequence_length = encoder_hidden_states.shape[:2] + encoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + batch_size, sequence_length = decoder_input_ids.shape + if decoder_attention_mask is None: + decoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + inputs = {"params": params or self.params} + + # if past_key_values are passed then cache is already initialized a private flag init_cache has to be + # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that + # it can be changed by FlaxT5Attention module + if past_key_values: + inputs["cache"] = past_key_values + mutable = ["cache"] + else: + mutable = False + + def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs): + decoder_module = module._get_decoder_module() + return decoder_module( + decoder_input_ids, + decoder_attention_mask, + **kwargs, + ) + + outputs = self.module.apply( + inputs, + decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), + decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + mutable=mutable, + method=_decoder_forward, + ) + + # add updated cache to model output + if past_key_values is not None and return_dict: + outputs, past = outputs + outputs["past_key_values"] = unfreeze(past["cache"]) + return outputs + elif past_key_values is not None and not return_dict: + outputs, past = outputs + outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] + + return outputs + + +T5_START_DOCSTRING = r""" + The T5 model was proposed in [Exploring the Limits of Transfer Learning with a Unified Text-to-Text + Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan + Narang, Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder transformer pre-trained in a + text-to-text denoising generative setting. + + This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a Flax Linen + [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a + regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior. + + Finally, this model supports inherent JAX features such as: + + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + + Parameters: + config ([`T5Config`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. + dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): + The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and + `jax.numpy.bfloat16` (on TPUs). + + This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If + specified all the computation will be performed with the given `dtype`. + + **Note that this only specifies the dtype of the computation and does not influence the dtype of model + parameters.** + + If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and + [`~FlaxPreTrainedModel.to_bf16`]. +""" + + +@add_start_docstrings( + "The bare T5 Model transformer outputting raw hidden-stateswithout any specific head on top.", + T5_START_DOCSTRING, +) +class FlaxT5Module(nn.Module): + config: T5Config + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False + + def _get_encoder_module(self): + return self.encoder + + def _get_decoder_module(self): + return self.decoder + + def setup(self): + self.shared = nn.Embed( + self.config.vocab_size, + self.config.d_model, + embedding_init=jax.nn.initializers.normal(self.config.initializer_factor * 1.0), + dtype=self.dtype, + ) + + encoder_config = copy.deepcopy(self.config) + encoder_config.causal = False + self.encoder = FlaxT5Stack( + encoder_config, + embed_tokens=self.shared, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) + + decoder_config = copy.deepcopy(self.config) + decoder_config.causal = True + decoder_config.num_layers = self.config.num_decoder_layers + self.decoder = FlaxT5Stack( + decoder_config, + embed_tokens=self.shared, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) + + def __call__( + self, + input_ids=None, + attention_mask=None, + decoder_input_ids=None, + decoder_attention_mask=None, + encoder_outputs=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + deterministic: bool = True, + ): + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Encode if needed (training, first prediction pass) + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return FlaxSeq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +class FlaxT5Model(FlaxT5PreTrainedModel): + module_class = FlaxT5Module + + +append_call_sample_docstring(FlaxT5Model, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC) + +FLAX_T5_MODEL_DOCSTRING = """ + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, FlaxT5Model + + >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small") + >>> model = FlaxT5Model.from_pretrained("google-t5/t5-small") + + >>> input_ids = tokenizer( + ... "Studies have been shown that owning a dog is good for you", return_tensors="np" + ... ).input_ids + >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="np").input_ids + + >>> # preprocess: Prepend decoder_input_ids with start token which is pad token for T5Model. + >>> # This is not needed for torch's T5ForConditionalGeneration as it does this internally using labels arg. + >>> decoder_input_ids = model._shift_right(decoder_input_ids) + + >>> # forward pass + >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) + >>> last_hidden_states = outputs.last_hidden_state + ``` +""" + + +overwrite_call_docstring(FlaxT5Model, T5_INPUTS_DOCSTRING + FLAX_T5_MODEL_DOCSTRING) +append_replace_return_docstrings(FlaxT5Model, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + + +@add_start_docstrings( + "The bare T5 Model transformer outputting encoder's raw hidden-states without any specific head on top.", + T5_START_DOCSTRING, +) +class FlaxT5EncoderModule(nn.Module): + config: T5Config + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False + + def setup(self): + self.shared = nn.Embed( + self.config.vocab_size, + self.config.d_model, + embedding_init=jax.nn.initializers.normal(self.config.initializer_factor * 1.0), + dtype=self.dtype, + ) + + encoder_config = copy.deepcopy(self.config) + encoder_config.is_decoder = False + encoder_config.is_encoder_decoder = False + encoder_config.causal = False + self.encoder = FlaxT5Stack( + encoder_config, + embed_tokens=self.shared, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) + + def __call__( + self, + input_ids=None, + attention_mask=None, + output_attentions=False, + output_hidden_states=False, + return_dict: bool = True, + deterministic: bool = True, + ): + # Encode if needed (training, first prediction pass) + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + return encoder_outputs + + +class FlaxT5EncoderModel(FlaxT5PreTrainedModel): + module_class = FlaxT5EncoderModule + + @add_start_docstrings_to_model_forward(T5_ENCODE_INPUTS_DOCSTRING) + def __call__( + self, + input_ids: jnp.ndarray, + attention_mask: Optional[jnp.ndarray] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: dict = None, + dropout_rng: PRNGKey = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + # prepare encoder inputs + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + + # Handle any PRNG if needed + rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} + + return self.module.apply( + {"params": params or self.params}, + input_ids=jnp.array(input_ids, dtype="i4"), + attention_mask=jnp.array(attention_mask, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + ) + + +@add_start_docstrings("""T5 Model with a `language modeling` head on top.""", T5_START_DOCSTRING) +class FlaxT5ForConditionalGenerationModule(nn.Module): + config: T5Config + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False + + def _get_encoder_module(self): + return self.encoder + + def _get_decoder_module(self): + return self.decoder + + def setup(self): + self.model_dim = self.config.d_model + + self.shared = nn.Embed( + self.config.vocab_size, + self.config.d_model, + embedding_init=jax.nn.initializers.normal(self.config.initializer_factor), + dtype=self.dtype, + ) + + encoder_config = copy.deepcopy(self.config) + encoder_config.causal = False + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = FlaxT5Stack( + encoder_config, self.shared, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing + ) + + decoder_config = copy.deepcopy(self.config) + decoder_config.causal = True + decoder_config.is_encoder_decoder = False + decoder_config.num_layers = self.config.num_decoder_layers + self.decoder = FlaxT5Stack( + decoder_config, self.shared, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing + ) + + self.lm_head = nn.Dense( + self.config.vocab_size, + use_bias=False, + kernel_init=jax.nn.initializers.normal(self.config.initializer_factor), + dtype=self.dtype, + ) + + def __call__( + self, + input_ids=None, + attention_mask=None, + decoder_input_ids=None, + decoder_attention_mask=None, + encoder_outputs=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + deterministic: bool = True, + ): + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Encode + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + hidden_states = encoder_outputs[0] + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + sequence_output = decoder_outputs[0] + + if self.config.tie_word_embeddings: + # Rescale output before projecting on vocab + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + sequence_output = sequence_output * (self.model_dim**-0.5) + + if self.config.tie_word_embeddings: + shared_embedding = self.shared.variables["params"]["embedding"] + lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, sequence_output) + else: + lm_logits = self.lm_head(sequence_output) + + if not return_dict: + return (lm_logits,) + decoder_outputs[1:] + encoder_outputs + + return FlaxSeq2SeqLMOutput( + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +class FlaxT5ForConditionalGeneration(FlaxT5PreTrainedModel): + module_class = FlaxT5ForConditionalGenerationModule + + @add_start_docstrings(T5_DECODE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=T5Config) + def decode( + self, + decoder_input_ids, + encoder_outputs, + encoder_attention_mask: Optional[jnp.ndarray] = None, + decoder_attention_mask: Optional[jnp.ndarray] = None, + past_key_values: dict = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: dict = None, + dropout_rng: PRNGKey = None, + ): + r""" + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, FlaxT5ForConditionalGeneration + >>> import jax.numpy as jnp + + >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small") + >>> model = FlaxT5ForConditionalGeneration.from_pretrained("google-t5/t5-small") + + >>> text = "summarize: My friends are cool but they eat too many carbs." + >>> inputs = tokenizer(text, return_tensors="np") + >>> encoder_outputs = model.encode(**inputs) + + >>> decoder_start_token_id = model.config.decoder_start_token_id + >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id + + >>> outputs = model.decode(decoder_input_ids, encoder_outputs) + >>> logits = outputs.logits + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + encoder_hidden_states = encoder_outputs[0] + if encoder_attention_mask is None: + batch_size, sequence_length = encoder_hidden_states.shape[:2] + encoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + batch_size, sequence_length = decoder_input_ids.shape + if decoder_attention_mask is None: + decoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + inputs = {"params": params or self.params} + + # if past_key_values are passed then cache is already initialized a private flag init_cache has to be + # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that + # it can be changed by FlaxT5Attention module + if past_key_values: + inputs["cache"] = past_key_values + mutable = ["cache"] + else: + mutable = False + + def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs): + decoder_module = module._get_decoder_module() + decoder_outputs = decoder_module( + decoder_input_ids, + decoder_attention_mask, + **kwargs, + ) + + sequence_output = decoder_outputs[0] + + if self.config.tie_word_embeddings: + # Rescale output before projecting on vocab + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + sequence_output = sequence_output * (self.config.d_model**-0.5) + + if self.config.tie_word_embeddings: + shared_embedding = module.shared.variables["params"]["embedding"] + lm_logits = module.lm_head.apply({"params": {"kernel": shared_embedding.T}}, sequence_output) + else: + lm_logits = module.lm_head(sequence_output) + + return lm_logits, decoder_outputs + + outputs = self.module.apply( + inputs, + decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), + decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + mutable=mutable, + method=_decoder_forward, + ) + + if past_key_values is None: + lm_logits, decoder_outputs = outputs + else: + (lm_logits, decoder_outputs), past = outputs + + if return_dict: + outputs = FlaxCausalLMOutputWithCrossAttentions( + logits=lm_logits, + hidden_states=decoder_outputs.hidden_states, + attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + ) + else: + outputs = (lm_logits,) + decoder_outputs[1:] + + # add updated cache to model output + if past_key_values is not None and return_dict: + outputs["past_key_values"] = unfreeze(past["cache"]) + return outputs + elif past_key_values is not None and not return_dict: + outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] + + return outputs + + def prepare_inputs_for_generation( + self, + decoder_input_ids, + max_length, + attention_mask: Optional[jax.Array] = None, + decoder_attention_mask: Optional[jax.Array] = None, + encoder_outputs=None, + **kwargs, + ): + # initializing the cache + batch_size, seq_length = decoder_input_ids.shape + + past_key_values = self.init_cache(batch_size, max_length, encoder_outputs) + # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. + # But since the decoder uses a causal mask, those positions are masked anyways. + # Thus we can create a single static attention_mask here, which is more efficient for compilation + extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") + if decoder_attention_mask is not None: + extended_attention_mask = jax.lax.dynamic_update_slice( + extended_attention_mask, decoder_attention_mask, (0, 0) + ) + + return { + "past_key_values": past_key_values, + "encoder_outputs": encoder_outputs, + "encoder_attention_mask": attention_mask, + "decoder_attention_mask": extended_attention_mask, + } + + def update_inputs_for_generation(self, model_outputs, model_kwargs): + model_kwargs["past_key_values"] = model_outputs.past_key_values + return model_kwargs + + +FLAX_T5_CONDITIONAL_GENERATION_DOCSTRING = """ + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, FlaxT5ForConditionalGeneration + + >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small") + >>> model = FlaxT5ForConditionalGeneration.from_pretrained("google-t5/t5-small") + + >>> ARTICLE_TO_SUMMARIZE = "summarize: My friends are cool but they eat too many carbs." + >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], return_tensors="np") + + >>> # Generate Summary + >>> summary_ids = model.generate(inputs["input_ids"]).sequences + >>> print(tokenizer.decode(summary_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)) + ``` +""" + + +overwrite_call_docstring( + FlaxT5ForConditionalGeneration, T5_INPUTS_DOCSTRING + FLAX_T5_CONDITIONAL_GENERATION_DOCSTRING +) +append_replace_return_docstrings( + FlaxT5ForConditionalGeneration, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC +) + + +__all__ = ["FlaxT5EncoderModel", "FlaxT5ForConditionalGeneration", "FlaxT5Model", "FlaxT5PreTrainedModel"] diff --git a/docs/transformers/src/transformers/models/t5/modeling_t5.py b/docs/transformers/src/transformers/models/t5/modeling_t5.py new file mode 100644 index 0000000000000000000000000000000000000000..e8085a8f45651e1537d8561d9d5526c7462924ca --- /dev/null +++ b/docs/transformers/src/transformers/models/t5/modeling_t5.py @@ -0,0 +1,2518 @@ +# coding=utf-8 +# Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch T5 model.""" + +import copy +import math +import os +import warnings +from typing import List, Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, + Seq2SeqQuestionAnsweringModelOutput, + Seq2SeqSequenceClassifierOutput, + TokenClassifierOutput, +) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import ALL_LAYERNORM_LAYERS, find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import ( + DUMMY_INPUTS, + DUMMY_MASK, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_torch_flex_attn_available, + is_torch_fx_proxy, + is_torchdynamo_compiling, + logging, + replace_return_docstrings, +) +from ...utils.model_parallel_utils import assert_device_map, get_device_map +from .configuration_t5 import T5Config + + +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "T5Config" +_CHECKPOINT_FOR_DOC = "google-t5/t5-small" + +#################################################### +# This dict contains ids and associated url +# for the pretrained weights provided with the models +#################################################### + + +#################################################### +# This is a conversion method from TF 1.0 to PyTorch +# More details: https://medium.com/huggingface/from-tensorflow-to-pytorch-265f40ef2a28 +#################################################### +def load_tf_weights_in_t5(model, config, tf_checkpoint_path): + """Load tf checkpoints in a pytorch model.""" + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + tf_weights = {} + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + tf_weights[name] = array + + for txt_name in names: + name = txt_name.split("/") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any( + n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] + for n in name + ): + logger.info(f"Skipping {'/'.join(name)}") + tf_weights.pop(txt_name, None) + continue + if "_slot_" in name[-1]: + logger.info(f"Skipping {'/'.join(name)}") + tf_weights.pop(txt_name, None) + continue + pointer = model + array = tf_weights[txt_name] + + for m_name in name: + if re.fullmatch(r"[A-Za-z]+_\d+", m_name): + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] in ["kernel", "scale", "embedding"]: + pointer = getattr(pointer, "weight") + elif scope_names[0] == "self_attention": + pointer = getattr(pointer, "layer") + pointer = pointer[0] + elif scope_names[0] == "enc_dec_attention": + pointer = getattr(pointer, "layer") + pointer = pointer[1] + elif scope_names[0] == "dense_relu_dense": + pointer = getattr(pointer, "layer") + pointer = pointer[2] + elif scope_names[0] == "rms_norm": + if hasattr(pointer, "layer_norm"): + pointer = getattr(pointer, "layer_norm") + elif hasattr(pointer, "final_layer_norm"): + pointer = getattr(pointer, "final_layer_norm") + elif scope_names[0] == "scale": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "output_bias" or scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "squad": + pointer = getattr(pointer, "classifier") + elif scope_names[0] == "decoder" and name[1] == "logits": + continue + elif scope_names[0] == "logits": + pointer = getattr(pointer, "lm_head") + elif scope_names[0] == "wi" and len(scope_names) > 1 and scope_names[1].isdigit(): + pointer = getattr(pointer, f"wi_{scope_names[1]}") + continue + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info(f"Skipping {'/'.join(name)}") + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if scope_names[0] not in ["kernel", "scale", "embedding"]: + pointer = getattr(pointer, "weight") + if scope_names[0] != "embedding": + logger.info(f"Transposing numpy weight of shape {array.shape} for {name}") + array = np.transpose(array) + try: + if pointer.shape != array.shape: + raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array.astype(np.float32)) + tf_weights.pop(txt_name, None) + + logger.info(f"Weights not copied to PyTorch model: {', '.join(tf_weights.keys())}.") + return model + + +#################################################### +# PyTorch Models are constructed by sub-classing +# - torch.nn.Module for the layers and +# - PreTrainedModel for the models (it-self a sub-class of nn.Module) +#################################################### +PARALLELIZE_DOCSTRING = r""" + This is an experimental feature and is a subject to change at a moment's notice. + + Uses a device map to distribute attention modules of the model across several devices. If no device map is given, + it will evenly distribute blocks across all devices. + + Args: + device_map (`Dict[int, list]`, *optional*): + A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always + automatically mapped to the first device (for esoteric reasons). That means that the first device should + have fewer attention modules mapped to it than other devices. For reference, the t5 models have the + following number of attention modules: + + - google-t5/t5-small: 6 + - google-t5/t5-base: 12 + - google-t5/t5-large: 24 + - google-t5/t5-3b: 24 + - google-t5/t5-11b: 24 + + Example: + + ```python + # Here is an example of a device map on a machine with 4 GPUs using google-t5/t5-3b, which has a total of 24 attention modules: + model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-3b") + device_map = { + 0: [0, 1, 2], + 1: [3, 4, 5, 6, 7, 8, 9], + 2: [10, 11, 12, 13, 14, 15, 16], + 3: [17, 18, 19, 20, 21, 22, 23], + } + model.parallelize(device_map) + ``` +""" +DEPARALLELIZE_DOCSTRING = r""" + Moves the model to cpu from a model parallel state. + + Example: + + ```python + # On a 4 GPU machine with google-t5/t5-3b: + model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-3b") + device_map = { + 0: [0, 1, 2], + 1: [3, 4, 5, 6, 7, 8, 9], + 2: [10, 11, 12, 13, 14, 15, 16], + 3: [17, 18, 19, 20, 21, 22, 23], + } + model.parallelize(device_map) # Splits the model across several devices + model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache() + ``` +""" + + +class T5LayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Construct a layernorm module in the T5 style. No bias and no subtraction of mean. + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean + # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated + # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for + # half-precision inputs is done in fp32 + + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states + + +try: + from apex.normalization import FusedRMSNorm + + T5LayerNorm = FusedRMSNorm # noqa + + logger.info("Discovered apex.normalization.FusedRMSNorm - will use it instead of T5LayerNorm") +except ImportError: + # using the normal T5LayerNorm + pass +except Exception: + logger.warning("discovered apex but it failed to load, falling back to T5LayerNorm") + pass + +ALL_LAYERNORM_LAYERS.append(T5LayerNorm) + + +class T5DenseActDense(nn.Module): + def __init__(self, config: T5Config): + super().__init__() + self.wi = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) + self.dropout = nn.Dropout(config.dropout_rate) + self.act = ACT2FN[config.dense_act_fn] + + def forward(self, hidden_states): + hidden_states = self.wi(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.dropout(hidden_states) + if ( + isinstance(self.wo.weight, torch.Tensor) + and hidden_states.dtype != self.wo.weight.dtype + and self.wo.weight.dtype != torch.int8 + ): + hidden_states = hidden_states.to(self.wo.weight.dtype) + hidden_states = self.wo(hidden_states) + return hidden_states + + +class T5DenseGatedActDense(nn.Module): + def __init__(self, config: T5Config): + super().__init__() + self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) + self.dropout = nn.Dropout(config.dropout_rate) + self.act = ACT2FN[config.dense_act_fn] + + def forward(self, hidden_states): + hidden_gelu = self.act(self.wi_0(hidden_states)) + hidden_linear = self.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = self.dropout(hidden_states) + + # To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32. + # See https://github.com/huggingface/transformers/issues/20287 + # we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None`` + if ( + isinstance(self.wo.weight, torch.Tensor) + and hidden_states.dtype != self.wo.weight.dtype + and self.wo.weight.dtype != torch.int8 + ): + hidden_states = hidden_states.to(self.wo.weight.dtype) + + hidden_states = self.wo(hidden_states) + return hidden_states + + +class T5LayerFF(nn.Module): + def __init__(self, config: T5Config): + super().__init__() + if config.is_gated_act: + self.DenseReluDense = T5DenseGatedActDense(config) + else: + self.DenseReluDense = T5DenseActDense(config) + + self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward(self, hidden_states): + forwarded_states = self.layer_norm(hidden_states) + forwarded_states = self.DenseReluDense(forwarded_states) + hidden_states = hidden_states + self.dropout(forwarded_states) + return hidden_states + + +class T5Attention(nn.Module): + def __init__( + self, + config: T5Config, + has_relative_attention_bias=False, + layer_idx: Optional[int] = None, + ): + super().__init__() + self.is_decoder = config.is_decoder + self.has_relative_attention_bias = has_relative_attention_bias + self.relative_attention_num_buckets = config.relative_attention_num_buckets + self.relative_attention_max_distance = config.relative_attention_max_distance + self.d_model = config.d_model + self.key_value_proj_dim = config.d_kv + self.n_heads = config.num_heads + self.dropout = config.dropout_rate + self.inner_dim = self.n_heads * self.key_value_proj_dim + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + # Mesh TensorFlow initialization to avoid scaling before softmax + self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.k = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.v = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.o = nn.Linear(self.inner_dim, self.d_model, bias=False) + + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads) + self.pruned_heads = set() + self.gradient_checkpointing = False + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads + ) + # Prune linear layers + self.q = prune_linear_layer(self.q, index) + self.k = prune_linear_layer(self.k, index) + self.v = prune_linear_layer(self.v, index) + self.o = prune_linear_layer(self.o, index, dim=1) + # Update hyper params + self.n_heads = self.n_heads - len(heads) + self.inner_dim = self.key_value_proj_dim * self.n_heads + self.pruned_heads = self.pruned_heads.union(heads) + + @staticmethod + def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_position_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_position_if_large = torch.min( + relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) + return relative_buckets + + def compute_bias(self, query_length, key_length, device=None, cache_position=None): + """Compute binned relative position bias""" + if device is None: + device = self.relative_attention_bias.weight.device + if cache_position is None: + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + else: + context_position = cache_position[:, None].to(device) + memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] + relative_position = memory_position - context_position # shape (query_length, key_length) + relative_position_bucket = self._relative_position_bucket( + relative_position, # shape (query_length, key_length) + bidirectional=(not self.is_decoder), + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads) + values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) + return values + + def forward( + self, + hidden_states, + mask=None, + key_value_states=None, + position_bias=None, + past_key_value=None, + layer_head_mask=None, + query_length=None, + use_cache=False, + output_attentions=False, + cache_position=None, + ): + """ + Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). + """ + # Input is (batch_size, seq_length, dim) + # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder) + batch_size, seq_length = hidden_states.shape[:2] + + # if key_value_states are provided this layer is used as a cross-attention layer for the decoder + is_cross_attention = key_value_states is not None + + query_states = self.q(hidden_states) + query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + + if past_key_value is not None: + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: + # reuse k,v, cross_attentions + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] + else: + key_states = self.k(current_states) + value_states = self.v(current_states) + key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True + + # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + scores = torch.matmul(query_states, key_states.transpose(3, 2)) + + if position_bias is None: + key_length = key_states.shape[-2] + # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past) + real_seq_length = query_length if query_length is not None else cache_position[-1] + 1 + if not self.has_relative_attention_bias: + position_bias = torch.zeros( + (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype + ) + if self.gradient_checkpointing and self.training: + position_bias.requires_grad = True + else: + position_bias = self.compute_bias( + real_seq_length, key_length, device=scores.device, cache_position=cache_position + ) + position_bias = position_bias[:, :, -seq_length:, :] + + if mask is not None: + causal_mask = mask[:, :, :, : key_states.shape[-2]] + position_bias = position_bias + causal_mask + + if self.pruned_heads: + mask = torch.ones(position_bias.shape[1]) + mask[list(self.pruned_heads)] = 0 + position_bias_masked = position_bias[:, mask.bool()] + else: + position_bias_masked = position_bias + + scores += position_bias_masked + + # (batch_size, n_heads, seq_length, key_length) + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + # Mask heads if we want to + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask + + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(batch_size, -1, self.inner_dim) + attn_output = self.o(attn_output) + + outputs = (attn_output, past_key_value, position_bias) + + if output_attentions: + outputs = outputs + (attn_weights,) + return outputs + + +class T5LayerSelfAttention(nn.Module): + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): + super().__init__() + self.SelfAttention = T5Attention( + config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx + ) + self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + use_cache=False, + output_attentions=False, + cache_position=None, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.SelfAttention( + normed_hidden_states, + mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + ) + hidden_states = hidden_states + self.dropout(attention_output[0]) + outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them + return outputs + + +class T5LayerCrossAttention(nn.Module): + def __init__(self, config, layer_idx: Optional[int] = None): + super().__init__() + self.EncDecAttention = T5Attention(config, has_relative_attention_bias=False, layer_idx=layer_idx) + self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + hidden_states, + key_value_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + use_cache=False, + query_length=None, + output_attentions=False, + cache_position=None, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.EncDecAttention( + normed_hidden_states, + mask=attention_mask, + key_value_states=key_value_states, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + query_length=query_length, + output_attentions=output_attentions, + cache_position=cache_position, + ) + layer_output = hidden_states + self.dropout(attention_output[0]) + outputs = (layer_output,) + attention_output[1:] # add attentions if we output them + return outputs + + +class T5Block(nn.Module): + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): + super().__init__() + self.is_decoder = config.is_decoder + self.layer = nn.ModuleList() + self.layer.append( + T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx) + ) + if self.is_decoder: + self.layer.append(T5LayerCrossAttention(config, layer_idx=layer_idx)) + + self.layer.append(T5LayerFF(config)) + + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + encoder_decoder_position_bias=None, + layer_head_mask=None, + cross_attn_layer_head_mask=None, + past_key_value=None, + use_cache=False, + output_attentions=False, + return_dict=True, + cache_position=None, + ): + self_attention_outputs = self.layer[0]( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + ) + hidden_states, past_key_value = self_attention_outputs[:2] + attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(hidden_states).any(), + torch.finfo(hidden_states.dtype).max - 1000, + torch.finfo(hidden_states.dtype).max, + ) + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + do_cross_attention = self.is_decoder and encoder_hidden_states is not None + if do_cross_attention: + cross_attention_outputs = self.layer[1]( + hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + position_bias=encoder_decoder_position_bias, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=past_key_value, + query_length=cache_position[-1] + 1, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states, past_key_value = cross_attention_outputs[:2] + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(hidden_states).any(), + torch.finfo(hidden_states.dtype).max - 1000, + torch.finfo(hidden_states.dtype).max, + ) + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + # Keep cross-attention outputs and relative position weights + attention_outputs = attention_outputs + cross_attention_outputs[2:] + + # Apply Feed Forward layer + hidden_states = self.layer[-1](hidden_states) + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(hidden_states).any(), + torch.finfo(hidden_states.dtype).max - 1000, + torch.finfo(hidden_states.dtype).max, + ) + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if use_cache: + outputs = outputs + (past_key_value,) + attention_outputs + else: + outputs = outputs + attention_outputs + + return outputs # hidden-states, past_key_value, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + + +class T5ClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__(self, config: T5Config): + super().__init__() + self.dense = nn.Linear(config.d_model, config.d_model) + self.dropout = nn.Dropout(p=config.classifier_dropout) + self.out_proj = nn.Linear(config.d_model, config.num_labels) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dropout(hidden_states) + hidden_states = self.dense(hidden_states) + hidden_states = torch.tanh(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.out_proj(hidden_states) + return hidden_states + + +class T5PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = T5Config + load_tf_weights = load_tf_weights_in_t5 + base_model_prefix = "transformer" + is_parallelizable = True + supports_gradient_checkpointing = True + _supports_quantized_cache = False # enc-dec models don't support yet + _supports_static_cache = True + _supports_cache_class = True + _no_split_modules = ["T5Block"] + _keep_in_fp32_modules = ["wo"] + + @property + def dummy_inputs(self): + input_ids = torch.tensor(DUMMY_INPUTS) + input_mask = torch.tensor(DUMMY_MASK) + dummy_inputs = { + "decoder_input_ids": input_ids, + "input_ids": input_ids, + "decoder_attention_mask": input_mask, + } + return dummy_inputs + + def _init_weights(self, module): + """Initialize the weights""" + factor = self.config.initializer_factor # Used for testing weights initialization + if isinstance(module, T5LayerNorm): + module.weight.data.fill_(factor * 1.0) + elif isinstance( + module, + (T5Model, T5ForConditionalGeneration, T5EncoderModel, T5ForQuestionAnswering), + ): + # Mesh TensorFlow embeddings initialization + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624 + module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0) + if hasattr(module, "lm_head") and not self.config.tie_word_embeddings: + module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0) + if hasattr(module, "qa_outputs"): + module.qa_outputs.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + module.qa_outputs.bias.data.zero_() + elif isinstance(module, T5ForTokenClassification): + if hasattr(module, "classifier"): + module.classifier.weight.data.normal_(mean=0.0, std=factor * 1.0) + module.classifier.bias.data.zero_() + elif isinstance(module, T5ClassificationHead): + module.dense.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.dense, "bias") and module.dense.bias is not None: + module.dense.bias.data.zero_() + module.out_proj.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.out_proj, "bias") and module.out_proj.bias is not None: + module.out_proj.bias.data.zero_() + elif isinstance(module, T5DenseActDense): + # Mesh TensorFlow FF initialization + # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56 + # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89 + module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.wi, "bias") and module.wi.bias is not None: + module.wi.bias.data.zero_() + module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + if hasattr(module.wo, "bias") and module.wo.bias is not None: + module.wo.bias.data.zero_() + elif isinstance(module, T5DenseGatedActDense): + module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None: + module.wi_0.bias.data.zero_() + module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None: + module.wi_1.bias.data.zero_() + module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + if hasattr(module.wo, "bias") and module.wo.bias is not None: + module.wo.bias.data.zero_() + elif isinstance(module, T5Attention): + # Mesh TensorFlow attention initialization to avoid scaling before softmax + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136 + d_model = self.config.d_model + key_value_proj_dim = self.config.d_kv + n_heads = self.config.num_heads + module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) + module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) + module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) + module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) + if module.has_relative_attention_bias: + module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) + + def _shift_right(self, input_ids): + decoder_start_token_id = self.config.decoder_start_token_id + pad_token_id = self.config.pad_token_id + + if decoder_start_token_id is None: + raise ValueError( + "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. " + "See T5 docs for more information." + ) + + # shift inputs to the right + if is_torch_fx_proxy(input_ids): + # Item assignment is not supported natively for proxies. + shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id) + shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1) + else: + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() + shifted_input_ids[..., 0] = decoder_start_token_id + + if pad_token_id is None: + raise ValueError("self.model.config.pad_token_id has to be defined.") + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +class T5Stack(T5PreTrainedModel): + def __init__(self, config, embed_tokens=None): + super().__init__(config) + + self.embed_tokens = embed_tokens + self.is_decoder = config.is_decoder + + self.block = nn.ModuleList( + [T5Block(config, has_relative_attention_bias=bool(i == 0), layer_idx=i) for i in range(config.num_layers)] + ) + self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + # Initialize weights and apply final processing + self.post_init() + # Model parallel + self.model_parallel = False + self.device_map = None + self.gradient_checkpointing = False + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + warnings.warn( + "`T5Stack.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your model" + " with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own" + " `device_map` but it needs to be a dictionary module_name to device, so for instance {'block.0': 0," + " 'block.1': 1, ...}", + FutureWarning, + ) + # Check validity of device_map + self.device_map = ( + get_device_map(len(self.block), range(torch.cuda.device_count())) if device_map is None else device_map + ) + assert_device_map(self.device_map, len(self.block)) + self.model_parallel = True + self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys())) + self.last_device = "cuda:" + str(max(self.device_map.keys())) + # Load onto devices + for k, v in self.device_map.items(): + for layer in v: + cuda_device = "cuda:" + str(k) + self.block[layer] = self.block[layer].to(cuda_device) + + # Set embed_tokens to first layer + self.embed_tokens = self.embed_tokens.to(self.first_device) + # Set final layer norm to last device + self.final_layer_norm = self.final_layer_norm.to(self.last_device) + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + warnings.warn( + "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + self.model_parallel = False + self.device_map = None + self.first_device = "cpu" + self.last_device = "cpu" + for i in range(len(self.block)): + self.block[i] = self.block[i].to("cpu") + self.embed_tokens = self.embed_tokens.to("cpu") + self.final_layer_norm = self.final_layer_norm.to("cpu") + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, new_embeddings): + self.embed_tokens = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + inputs_embeds=None, + head_mask=None, + cross_attn_head_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + cache_position=None, + ): + # Model parallel + if self.model_parallel: + torch.cuda.set_device(self.first_device) + self.embed_tokens = self.embed_tokens.to(self.first_device) + use_cache = use_cache if use_cache is not None else self.config.use_cache + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError( + f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time" + ) + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if inputs_embeds is None: + if self.embed_tokens is None: + raise ValueError("You have to initialize the model with valid token embeddings") + inputs_embeds = self.embed_tokens(input_ids) + + batch_size, seq_length = input_shape + + if use_cache is True: + if not self.is_decoder: + raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder") + + # initialize past_key_values + return_legacy_cache = False + return_self_attention_cache = False + if self.is_decoder and (use_cache or past_key_values is not None): + if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache): + return_self_attention_cache = True + past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) + elif not isinstance(past_key_values, EncoderDecoderCache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + elif past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + elif not self.is_decoder: + # do not pass cache object down the line for encoder stack + # it messes indexing later in decoder-stack because cache object is modified in-place + past_key_values = None + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device + ) + + if attention_mask is None and not is_torchdynamo_compiling(): + # required mask seq length can be calculated via length of past cache + mask_seq_length = past_key_values_length + seq_length + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + + if self.config.is_decoder: + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + past_key_values.self_attention_cache if past_key_values is not None else None, + output_attentions, + ) + elif attention_mask is not None: + causal_mask = attention_mask[:, None, None, :] + causal_mask = causal_mask.to(dtype=inputs_embeds.dtype) + causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min + else: + causal_mask = None + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones( + encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long + ) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + head_mask = self.get_head_mask(head_mask, self.config.num_layers) + cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_cross_attentions = () if (output_attentions and self.is_decoder) else None + position_bias = None + encoder_decoder_position_bias = None + + hidden_states = self.dropout(inputs_embeds) + + for i, layer_module in enumerate(self.block): + layer_head_mask = head_mask[i] + cross_attn_layer_head_mask = cross_attn_head_mask[i] + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure that attention_mask is always on the same device as hidden_states + if causal_mask is not None: + causal_mask = causal_mask.to(hidden_states.device) + if position_bias is not None: + position_bias = position_bias.to(hidden_states.device) + if encoder_hidden_states is not None: + encoder_hidden_states = encoder_hidden_states.to(hidden_states.device) + if encoder_extended_attention_mask is not None: + encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device) + if encoder_decoder_position_bias is not None: + encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device) + if layer_head_mask is not None: + layer_head_mask = layer_head_mask.to(hidden_states.device) + if cross_attn_layer_head_mask is not None: + cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.forward, + hidden_states, + causal_mask, + position_bias, + encoder_hidden_states, + encoder_extended_attention_mask, + encoder_decoder_position_bias, + layer_head_mask, + cross_attn_layer_head_mask, + None, # past_key_value is always None with gradient checkpointing + use_cache, + output_attentions, + return_dict, + cache_position, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask=causal_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + layer_head_mask=layer_head_mask, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, + past_key_value=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + return_dict=return_dict, + cache_position=cache_position, + ) + + # layer_outputs is a tuple with: + # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + if use_cache is False: + layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] + + hidden_states, next_decoder_cache = layer_outputs[:2] + + # We share the position biases between the layers - the first layer store them + # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), + # (cross-attention position bias), (cross-attention weights) + position_bias = layer_outputs[2] + if self.is_decoder and encoder_hidden_states is not None: + encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[3],) + if self.is_decoder: + all_cross_attentions = all_cross_attentions + (layer_outputs[5],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_self_attention_cache: + next_cache = past_key_values.self_attention_cache + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_cache, + all_hidden_states, + all_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ) + + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, "BlockMask"], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to place the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +T5_START_DOCSTRING = r""" + + The T5 model was proposed in [Exploring the Limits of Transfer Learning with a Unified Text-to-Text + Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan + Narang, Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder transformer pre-trained in a + text-to-text denoising generative setting. + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`T5Config`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +T5_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you + should be able to pad the inputs on both the right and the left. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for detail. + + [What are input IDs?](../glossary#input-ids) + + To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training). + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + T5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values` + is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). + + To know more on how to prepare `decoder_input_ids` for pretraining take a look at [T5 + Training](./t5#training). + decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in `[0, + 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0, + 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in + `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): + Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, `optional`: *attentions*) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a sequence of hidden states at + the output of the last layer of the encoder. Used in the cross-attention of the decoder. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded + representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be + input (see `past_key_values`). This is useful if you want more control over how to convert + `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. + + If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value + of `inputs_embeds`. + + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. +""" + +T5_ENCODER_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you + should be able to pad the inputs on both the right and the left. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for detail. + + To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training). + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +# Warning message for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask +__HEAD_MASK_WARNING_MSG = """ +The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently, +`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions. +If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers, +num_heads)`. +""" + + +@add_start_docstrings( + "The bare T5 Model transformer outputting raw hidden-states without any specific head on top.", + T5_START_DOCSTRING, +) +class T5Model(T5PreTrainedModel): + _keys_to_ignore_on_load_unexpected = [ + "decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight", + ] + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + + def __init__(self, config: T5Config): + super().__init__(config) + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = T5Stack(encoder_config, self.shared) + + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.is_encoder_decoder = False + decoder_config.num_layers = config.num_decoder_layers + self.decoder = T5Stack(decoder_config, self.shared) + + # Initialize weights and apply final processing + self.post_init() + + # Model parallel + self.model_parallel = False + self.device_map = None + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + warnings.warn( + "`T5Model.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your model" + " with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own" + " `device_map` but it needs to be a dictionary module_name to device, so for instance {'encoder.block.0':" + " 0, 'encoder.block.1': 1, ...}", + FutureWarning, + ) + self.device_map = ( + get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.encoder.block)) + self.encoder.parallelize(self.device_map) + self.decoder.parallelize(self.device_map) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + warnings.warn( + "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + self.encoder.deparallelize() + self.decoder.deparallelize() + self.encoder = self.encoder.to("cpu") + self.decoder = self.decoder.to("cpu") + self.model_parallel = False + self.device_map = None + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + self.decoder.set_input_embeddings(new_embeddings) + + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + decoder_inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]: + r""" + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, T5Model + + >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small") + >>> model = T5Model.from_pretrained("google-t5/t5-small") + + >>> input_ids = tokenizer( + ... "Studies have been shown that owning a dog is good for you", return_tensors="pt" + ... ).input_ids # Batch size 1 + >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1 + + >>> # preprocess: Prepend decoder_input_ids with start token which is pad token for T5Model. + >>> # This is not needed for torch's T5ForConditionalGeneration as it does this internally using labels arg. + >>> decoder_input_ids = model._shift_right(decoder_input_ids) + + >>> # forward pass + >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) + >>> last_hidden_states = outputs.last_hidden_state + ```""" + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + if self.config.num_layers == self.config.num_decoder_layers: + warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) + decoder_head_mask = head_mask + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + hidden_states = encoder_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + hidden_states = hidden_states.to(self.decoder.first_device) + if decoder_input_ids is not None: + decoder_input_ids = decoder_input_ids.to(self.decoder.first_device) + if attention_mask is not None: + attention_mask = attention_mask.to(self.decoder.first_device) + if decoder_attention_mask is not None: + decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device) + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings("""T5 Model with a `language modeling` head on top.""", T5_START_DOCSTRING) +class T5ForConditionalGeneration(T5PreTrainedModel, GenerationMixin): + _keys_to_ignore_on_load_unexpected = [ + "decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight", + ] + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] + + def __init__(self, config: T5Config): + super().__init__(config) + self.model_dim = config.d_model + + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = T5Stack(encoder_config, self.shared) + + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.is_encoder_decoder = False + decoder_config.num_layers = config.num_decoder_layers + self.decoder = T5Stack(decoder_config, self.shared) + + self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + # Model parallel + self.model_parallel = False + self.device_map = None + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + warnings.warn( + "`T5ForConditionalGeneration.parallelize` is deprecated and will be removed in v5 of Transformers, you" + " should load your model with `device_map='balanced'` in the call to `from_pretrained`. You can also" + " provide your own `device_map` but it needs to be a dictionary module_name to device, so for instance" + " {'encoder.block.0': 0, 'encoder.block.1': 1, ...}", + FutureWarning, + ) + self.device_map = ( + get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.encoder.block)) + self.encoder.parallelize(self.device_map) + self.decoder.parallelize(self.device_map) + self.lm_head = self.lm_head.to(self.decoder.first_device) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + warnings.warn( + "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + self.encoder.deparallelize() + self.decoder.deparallelize() + self.encoder = self.encoder.to("cpu") + self.decoder = self.decoder.to("cpu") + self.lm_head = self.lm_head.to("cpu") + self.model_parallel = False + self.device_map = None + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + self.decoder.set_input_embeddings(new_embeddings) + + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_output_embeddings(self): + return self.lm_head + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ..., + config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for + labels in `[0, ..., config.vocab_size]` + + Returns: + + Examples: + + ```python + >>> from transformers import AutoTokenizer, T5ForConditionalGeneration + + >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small") + >>> model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-small") + + >>> # training + >>> input_ids = tokenizer("The walks in park", return_tensors="pt").input_ids + >>> labels = tokenizer(" cute dog the ", return_tensors="pt").input_ids + >>> outputs = model(input_ids=input_ids, labels=labels) + >>> loss = outputs.loss + >>> logits = outputs.logits + + >>> # inference + >>> input_ids = tokenizer( + ... "summarize: studies have shown that owning a dog is good for you", return_tensors="pt" + ... ).input_ids # Batch size 1 + >>> outputs = model.generate(input_ids) + >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True)) + >>> # studies have shown that owning a dog is good for you. + ```""" + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + if self.config.num_layers == self.config.num_decoder_layers: + warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) + decoder_head_mask = head_mask + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + # Convert encoder inputs in embeddings if needed + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + hidden_states = encoder_outputs[0] + + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + + if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: + # get decoder inputs from shifting lm labels to the right + decoder_input_ids = self._shift_right(labels) + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + hidden_states = hidden_states.to(self.decoder.first_device) + if decoder_input_ids is not None: + decoder_input_ids = decoder_input_ids.to(self.decoder.first_device) + if attention_mask is not None: + attention_mask = attention_mask.to(self.decoder.first_device) + if decoder_attention_mask is not None: + decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device) + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + sequence_output = decoder_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.encoder.first_device) + self.lm_head = self.lm_head.to(self.encoder.first_device) + sequence_output = sequence_output.to(self.lm_head.weight.device) + + if self.config.tie_word_embeddings: + # Rescale output before projecting on vocab + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + sequence_output = sequence_output * (self.model_dim**-0.5) + + lm_logits = self.lm_head(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss(ignore_index=-100) + # move labels to correct device to enable PP + labels = labels.to(lm_logits.device) + loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) + # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 + + if not return_dict: + output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs + return ((loss,) + output) if loss is not None else output + + return Seq2SeqLMOutput( + loss=loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return self._shift_right(labels) + + def _reorder_cache(self, past_key_values, beam_idx): + # if decoder past is not included in output + # speedy decoding is disabled and no need to reorder + if past_key_values is None: + logger.warning("You might want to consider setting `use_cache=True` to speed up decoding") + return past_key_values + + reordered_decoder_past = () + for layer_past_states in past_key_values: + # get the correct batch idx from layer past batch dim + # batch dim of `past` is at 2nd position + reordered_layer_past_states = () + for layer_past_state in layer_past_states: + # need to set correct `past` for each of the four key / value states + reordered_layer_past_states = reordered_layer_past_states + ( + layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)), + ) + + if reordered_layer_past_states[0].shape != layer_past_states[0].shape: + raise ValueError( + f"reordered_layer_past_states[0] shape {reordered_layer_past_states[0].shape} and layer_past_states[0] shape {layer_past_states[0].shape} mismatched" + ) + if len(reordered_layer_past_states) != len(layer_past_states): + raise ValueError( + f"length of reordered_layer_past_states {len(reordered_layer_past_states)} and length of layer_past_states {len(layer_past_states)} mismatched" + ) + + reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,) + return reordered_decoder_past + + +@add_start_docstrings( + "The bare T5 Model transformer outputting encoder's raw hidden-states without any specific head on top.", + T5_START_DOCSTRING, +) +class T5EncoderModel(T5PreTrainedModel): + _tied_weights_keys = ["encoder.embed_tokens.weight"] + _keys_to_ignore_on_load_unexpected = [r"decoder"] + + def __init__(self, config: T5Config): + super().__init__(config) + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = T5Stack(encoder_config, self.shared) + + # Initialize weights and apply final processing + self.post_init() + + # Model parallel + self.model_parallel = False + self.device_map = None + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + warnings.warn( + "`T5EncoderModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load" + " your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own" + " `device_map` but it needs to be a dictionary module_name to device, so for instance {'block.0': 0," + " 'block.1': 1, ...}", + FutureWarning, + ) + self.device_map = ( + get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.encoder.block)) + self.encoder.parallelize(self.device_map) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + warnings.warn( + "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + self.encoder.deparallelize() + self.encoder = self.encoder.to("cpu") + self.model_parallel = False + self.device_map = None + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + + def get_encoder(self): + return self.encoder + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.block[layer].layer[0].SelfAttention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(T5_ENCODER_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]: + r""" + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, T5EncoderModel + + >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small") + >>> model = T5EncoderModel.from_pretrained("google-t5/t5-small") + >>> input_ids = tokenizer( + ... "Studies have been shown that owning a dog is good for you", return_tensors="pt" + ... ).input_ids # Batch size 1 + >>> outputs = model(input_ids=input_ids) + >>> last_hidden_states = outputs.last_hidden_state + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + return encoder_outputs + + +@add_start_docstrings( + """ + T5 model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE + tasks. + """, + T5_START_DOCSTRING, +) +class T5ForSequenceClassification(T5PreTrainedModel): + _keys_to_ignore_on_load_unexpected = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"] + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + + def __init__(self, config: T5Config): + super().__init__(config) + self.transformer = T5Model(config) + self.classification_head = T5ClassificationHead(config) + + # Initialize weights and apply final processing + self.post_init() + + self.model_parallel = False + + @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + Returns: + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if labels is not None: + use_cache = False + + if input_ids is None and inputs_embeds is not None: + raise NotImplementedError( + f"Passing input embeddings is currently not supported for {self.__class__.__name__}" + ) + + # Copied from models.bart.modeling_bart.BartModel.forward different to other models, T5 automatically creates + # decoder_input_ids from input_ids if no decoder_input_ids are provided + if decoder_input_ids is None and decoder_inputs_embeds is None: + if input_ids is None: + raise ValueError( + "If no `decoder_input_ids` or `decoder_inputs_embeds` are " + "passed, `input_ids` cannot be `None`. Please pass either " + "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`." + ) + decoder_input_ids = self._shift_right(input_ids) + + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + + eos_mask = input_ids.eq(self.config.eos_token_id).to(sequence_output.device) + + if len(torch.unique_consecutive(eos_mask.sum(1))) > 1: + raise ValueError("All examples must have the same number of tokens.") + batch_size, _, hidden_size = sequence_output.shape + sentence_representation = sequence_output[eos_mask, :].view(batch_size, -1, hidden_size)[:, -1, :] + logits = self.classification_head(sentence_representation) + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.config.num_labels == 1: + self.config.problem_type = "regression" + elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.config.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return Seq2SeqSequenceClassifierOutput( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + +@add_start_docstrings( + """ + T5 Encoder Model with a token classification head on top (a linear layer on top of the hidden-states output) + e.g. for Named-Entity-Recognition (NER) tasks. + """, + T5_START_DOCSTRING, +) +class T5ForTokenClassification(T5PreTrainedModel): + _tied_weights_keys = ["transformer.encoder.embed_tokens.weight"] + + def __init__(self, config: T5Config): + super().__init__(config) + self.num_labels = config.num_labels + + self.transformer = T5EncoderModel(config) + self.dropout = nn.Dropout(config.classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TokenClassifierOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + Returns: + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + hidden_states = self.dropout(hidden_states) + logits = self.classifier(hidden_states) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits, outputs[2:-1]) + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + T5 Model with a span classification head on top for extractive question-answering tasks like SQuAD (linear layers + on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + T5_START_DOCSTRING, +) +class T5ForQuestionAnswering(T5PreTrainedModel): + _keys_to_ignore_on_load_unexpected = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"] + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + + def __init__(self, config: T5Config): + super().__init__(config) + self.model_dim = config.d_model + + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = T5Stack(encoder_config, self.shared) + + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.is_encoder_decoder = False + decoder_config.num_layers = config.num_decoder_layers + self.decoder = T5Stack(decoder_config, self.shared) + + self.num_labels = config.num_labels + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + self.model_parallel = False + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + self.decoder.set_input_embeddings(new_embeddings) + + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqQuestionAnsweringModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.FloatTensor], Seq2SeqQuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence + are not taken into account for computing the loss. + Returns: + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + use_cache = use_cache if use_cache is not None else self.config.use_cache + if start_positions is not None and end_positions is not None: + use_cache = False + + # Copied from models.bart.modeling_bart.BartModel.forward + # different to other models, T5 automatically creates decoder_input_ids from + # input_ids if no decoder_input_ids are provided + if decoder_input_ids is None and decoder_inputs_embeds is None: + if input_ids is None: + raise ValueError( + "If no `decoder_input_ids` or `decoder_inputs_embeds` are " + "passed, `input_ids` cannot be `None`. Please pass either " + "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`." + ) + decoder_input_ids = self._shift_right(input_ids) + + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + if self.config.num_layers == self.config.num_decoder_layers: + warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) + decoder_head_mask = head_mask + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + hidden_states = encoder_outputs[0] + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=None, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = decoder_outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1).to(start_logits.device) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1).to(end_logits.device) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + decoder_outputs[1:] + encoder_outputs + return ((total_loss,) + output) if total_loss is not None else output + + return Seq2SeqQuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +__all__ = [ + "T5EncoderModel", + "T5ForConditionalGeneration", + "T5Model", + "T5PreTrainedModel", + "load_tf_weights_in_t5", + "T5ForQuestionAnswering", + "T5ForSequenceClassification", + "T5ForTokenClassification", +] diff --git a/docs/transformers/src/transformers/models/t5/modeling_tf_t5.py b/docs/transformers/src/transformers/models/t5/modeling_tf_t5.py new file mode 100644 index 0000000000000000000000000000000000000000..84f5b2a636129ab69f58341faf30df135128d13e --- /dev/null +++ b/docs/transformers/src/transformers/models/t5/modeling_tf_t5.py @@ -0,0 +1,1683 @@ +# coding=utf-8 +# Copyright 2020 T5 Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TF 2.0 T5 model.""" + +from __future__ import annotations + +import copy +import itertools +import math +import warnings +from typing import Optional, Tuple, Union + +import numpy as np +import tensorflow as tf +from tensorflow.compiler.tf2xla.python.xla import dynamic_slice + +from ...activations_tf import get_tf_activation +from ...modeling_tf_outputs import ( + TFBaseModelOutput, + TFBaseModelOutputWithPastAndCrossAttentions, + TFSeq2SeqLMOutput, + TFSeq2SeqModelOutput, +) +from ...modeling_tf_utils import ( + TFCausalLanguageModelingLoss, + TFModelInputType, + TFPreTrainedModel, + get_initializer, + keras, + keras_serializable, + unpack_inputs, +) +from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from .configuration_t5 import T5Config + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "T5Config" + + +#################################################### +# TF 2.0 Models are constructed using Keras imperative API by sub-classing +# - keras.layers.Layer for the layers and +# - TFPreTrainedModel for the models (it-self a sub-class of keras.Model) +#################################################### + + +class TFT5LayerNorm(keras.layers.Layer): + def __init__(self, hidden_size, epsilon=1e-6, **kwargs): + """ + Construct a layernorm module in the T5 style No bias and no subtraction of mean. + """ + super().__init__(**kwargs) + self.variance_epsilon = epsilon + self.hidden_size = hidden_size + + def build(self, input_shape): + """Build shared word embedding layer""" + self.weight = self.add_weight("weight", shape=(self.hidden_size,), initializer="ones") + super().build(input_shape) + + def call(self, hidden_states): + variance = tf.math.reduce_mean(tf.math.square(hidden_states), axis=-1, keepdims=True) + hidden_states = hidden_states * tf.math.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states + + +class TFT5DenseActDense(keras.layers.Layer): + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + wi_initializer = keras.initializers.RandomNormal( + mean=0, stddev=config.initializer_factor * (config.d_model**-0.5) + ) + wo_initializer = keras.initializers.RandomNormal( + mean=0, stddev=config.initializer_factor * (config.d_ff**-0.5) + ) + self.wi = keras.layers.Dense( + config.d_ff, use_bias=False, name="wi", kernel_initializer=wi_initializer + ) # Update init weights as in flax + self.wo = keras.layers.Dense( + config.d_model, use_bias=False, name="wo", kernel_initializer=wo_initializer + ) # Update init weights as in flax + self.dropout = keras.layers.Dropout(config.dropout_rate) + self.act = get_tf_activation(config.dense_act_fn) + self.config = config + + def call(self, hidden_states, training=False): + hidden_states = self.wi(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.dropout(hidden_states, training=training) + hidden_states = self.wo(hidden_states) + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "wi", None) is not None: + with tf.name_scope(self.wi.name): + self.wi.build([None, None, self.config.d_model]) + if getattr(self, "wo", None) is not None: + with tf.name_scope(self.wo.name): + self.wo.build([None, None, self.config.d_ff]) + + +class TFT5DenseGatedActDense(keras.layers.Layer): + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + wi_initializer = keras.initializers.RandomNormal( + mean=0, stddev=config.initializer_factor * (config.d_model**-0.5) + ) + wo_initializer = keras.initializers.RandomNormal( + mean=0, stddev=config.initializer_factor * (config.d_ff**-0.5) + ) + self.wi_0 = keras.layers.Dense( + config.d_ff, use_bias=False, name="wi_0", kernel_initializer=wi_initializer + ) # Update init weights as in flax + self.wi_1 = keras.layers.Dense( + config.d_ff, use_bias=False, name="wi_1", kernel_initializer=wi_initializer + ) # Update init weights as in flax + self.wo = keras.layers.Dense( + config.d_model, use_bias=False, name="wo", kernel_initializer=wo_initializer + ) # Update init weights as in flax + self.dropout = keras.layers.Dropout(config.dropout_rate) + self.act = get_tf_activation(config.dense_act_fn) + self.config = config + + def call(self, hidden_states, training=False): + hidden_gelu = self.act(self.wi_0(hidden_states)) + hidden_linear = self.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = self.dropout(hidden_states, training=training) + hidden_states = self.wo(hidden_states) + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "wi_0", None) is not None: + with tf.name_scope(self.wi_0.name): + self.wi_0.build([None, None, self.config.d_model]) + if getattr(self, "wi_1", None) is not None: + with tf.name_scope(self.wi_1.name): + self.wi_1.build([None, None, self.config.d_model]) + if getattr(self, "wo", None) is not None: + with tf.name_scope(self.wo.name): + self.wo.build([None, None, self.config.d_ff]) + + +class TFT5LayerFF(keras.layers.Layer): + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + if config.is_gated_act: + self.DenseReluDense = TFT5DenseGatedActDense(config, name="DenseReluDense") + else: + self.DenseReluDense = TFT5DenseActDense(config, name="DenseReluDense") + + self.layer_norm = TFT5LayerNorm(config.d_model, epsilon=config.layer_norm_epsilon, name="layer_norm") + self.dropout = keras.layers.Dropout(config.dropout_rate) + + def call(self, hidden_states, training=False): + normed_hidden_states = self.layer_norm(hidden_states) + dense_output = self.DenseReluDense(normed_hidden_states, training=training) + hidden_states = hidden_states + self.dropout(dense_output, training=training) + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "layer_norm", None) is not None: + with tf.name_scope(self.layer_norm.name): + self.layer_norm.build(None) + if getattr(self, "DenseReluDense", None) is not None: + with tf.name_scope(self.DenseReluDense.name): + self.DenseReluDense.build(None) + + +class TFT5Attention(keras.layers.Layer): + NEW_ID = itertools.count() + + def __init__(self, config, has_relative_attention_bias=False, **kwargs): + super().__init__(**kwargs) + self.layer_id = next(TFT5Attention.NEW_ID) + self.is_decoder = config.is_decoder + self.use_cache = config.use_cache + self.has_relative_attention_bias = has_relative_attention_bias + self.output_attentions = config.output_attentions + + self.relative_attention_num_buckets = config.relative_attention_num_buckets + self.relative_attention_max_distance = config.relative_attention_max_distance + self.d_model = config.d_model + self.key_value_proj_dim = config.d_kv + self.n_heads = config.num_heads + self.inner_dim = self.n_heads * self.key_value_proj_dim + + # Mesh TensorFlow initialization to avoid scaling before softmax + q_initializer = keras.initializers.RandomNormal( + mean=0, stddev=config.initializer_factor * ((self.inner_dim * self.key_value_proj_dim) ** -0.5) + ) + k_initializer = keras.initializers.RandomNormal( + mean=0, stddev=config.initializer_factor * (self.inner_dim**-0.5) + ) + v_initializer = keras.initializers.RandomNormal( + mean=0, stddev=config.initializer_factor * (self.inner_dim**-0.5) + ) + o_initializer = keras.initializers.RandomNormal( + mean=0, stddev=config.initializer_factor * (self.inner_dim**-0.5) + ) + self.relative_attention_bias_initializer = keras.initializers.RandomNormal( + mean=0, stddev=config.initializer_factor * (self.inner_dim**-0.5) + ) + + self.q = keras.layers.Dense( + self.inner_dim, use_bias=False, name="q", kernel_initializer=q_initializer + ) # Update init weights as in flax + self.k = keras.layers.Dense( + self.inner_dim, use_bias=False, name="k", kernel_initializer=k_initializer + ) # Update init weights as in flax + self.v = keras.layers.Dense( + self.inner_dim, use_bias=False, name="v", kernel_initializer=v_initializer + ) # Update init weights as in flax + self.o = keras.layers.Dense( + self.d_model, use_bias=False, name="o", kernel_initializer=o_initializer + ) # Update init weights as in flax + self.dropout = keras.layers.Dropout(config.dropout_rate) + + self.pruned_heads = set() + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if self.has_relative_attention_bias: + with tf.name_scope("relative_attention_bias"): + self.relative_attention_bias = self.add_weight( + name="embeddings", + shape=[self.relative_attention_num_buckets, self.n_heads], + initializer=self.relative_attention_bias_initializer, # Add initializer + ) + if getattr(self, "q", None) is not None: + with tf.name_scope(self.q.name): + self.q.build([None, None, self.d_model]) + if getattr(self, "k", None) is not None: + with tf.name_scope(self.k.name): + self.k.build([None, None, self.d_model]) + if getattr(self, "v", None) is not None: + with tf.name_scope(self.v.name): + self.v.build([None, None, self.d_model]) + if getattr(self, "o", None) is not None: + with tf.name_scope(self.o.name): + self.o.build([None, None, self.inner_dim]) + + def prune_heads(self, heads): + raise NotImplementedError + + @staticmethod + def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + # n = -relative_position + if bidirectional: + num_buckets //= 2 + relative_buckets += ( + tf.cast(tf.math.greater(relative_position, 0), dtype=relative_position.dtype) * num_buckets + ) + relative_position = tf.math.abs(relative_position) + else: + relative_position = -tf.math.minimum(relative_position, 0) + # now n is in the range [0, inf) + max_exact = num_buckets // 2 + is_small = tf.math.less(relative_position, max_exact) + relative_position_if_large = max_exact + tf.cast( + tf.math.log(tf.cast(relative_position, tf.float32) / tf.cast(max_exact, tf.float32)) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact), + dtype=relative_position.dtype, + ) + relative_position_if_large = tf.math.minimum(relative_position_if_large, num_buckets - 1) + relative_buckets += tf.where(is_small, relative_position, relative_position_if_large) + return relative_buckets + + def compute_bias(self, query_length, key_length): + """Compute binned relative position bias""" + context_position = tf.range(query_length)[:, None] + memory_position = tf.range(key_length)[None, :] + relative_position = memory_position - context_position # shape (query_length, key_length) + relative_position_bucket = self._relative_position_bucket( + relative_position, + bidirectional=(not self.is_decoder), + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + values = tf.gather( + self.relative_attention_bias, relative_position_bucket + ) # shape (query_length, key_length, num_heads) + values = tf.expand_dims( + tf.transpose(values, [2, 0, 1]), axis=0 + ) # shape (1, num_heads, query_length, key_length) + return values + + def call( + self, + hidden_states, + mask=None, + key_value_states=None, + position_bias=None, + past_key_value=None, + layer_head_mask=None, + query_length=None, + use_cache=False, + training=False, + output_attentions=False, + ): + """ + Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). + """ + # Input is (batch_size, query_length, dim) + # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) + # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) + batch_size, seq_length = shape_list(hidden_states)[:2] + + real_seq_length = seq_length + + if past_key_value is not None: + assert len(past_key_value) == 2, ( + f"past_key_value should have 2 past states: keys and values. Got {len(past_key_value)} past states" + ) + real_seq_length += shape_list(past_key_value[0])[2] if query_length is None else query_length + + key_length = real_seq_length if key_value_states is None else shape_list(key_value_states)[1] + + def shape(hidden_states): + """projection""" + return tf.transpose( + tf.reshape(hidden_states, (batch_size, -1, self.n_heads, self.key_value_proj_dim)), perm=(0, 2, 1, 3) + ) + + def unshape(hidden_states): + """compute context""" + return tf.reshape(tf.transpose(hidden_states, perm=(0, 2, 1, 3)), (batch_size, -1, self.inner_dim)) + + def project(hidden_states, proj_layer, key_value_states, past_key_value): + """projects hidden states correctly to key/query states""" + if key_value_states is None: + # self-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(hidden_states)) + elif past_key_value is None: + # cross-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(key_value_states)) + + if past_key_value is not None: + if key_value_states is None: + # self-attn + # (batch_size, n_heads, key_length, dim_per_head) + hidden_states = tf.concat([past_key_value, hidden_states], axis=2) + else: + # cross-attn + hidden_states = past_key_value + return hidden_states + + # get query + query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, query_length, dim_per_head) + + # get key/value + key_states = project( + hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None + ) + value_states = project( + hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None + ) + + # to cope with keras serialization + if self.is_decoder and use_cache: + present_key_value_state = (key_states, value_states) + else: + present_key_value_state = None + + scores = tf.einsum( + "bnqd,bnkd->bnqk", query_states, key_states + ) # (batch_size, n_heads, query_length, key_length) + + if position_bias is None: + if not self.has_relative_attention_bias: + position_bias = tf.zeros((1, self.n_heads, real_seq_length, key_length)) + else: + position_bias = self.compute_bias(real_seq_length, key_length) + + # if key and values are already calculated we want only the last query position bias + if past_key_value is not None: + if not self.has_relative_attention_bias: + position_bias = position_bias[:, :, -seq_length:, :] + else: + # we might have a padded past structure, in which case we want to fetch the position bias slice + # right after the most recently filled past index + most_recently_filled_past_index = tf.reduce_max(tf.where(past_key_value[0][0, 0, :, 0] != 0.0)) + position_bias = dynamic_slice( + position_bias, + (0, 0, most_recently_filled_past_index + 1, 0), + (1, self.n_heads, seq_length, real_seq_length), + ) + + if mask is not None: + position_bias = tf.cast(position_bias, dtype=mask.dtype) + position_bias = position_bias + mask # (batch_size, n_heads, query_length, key_length) + + scores += position_bias + weights = stable_softmax(scores, axis=-1) # (batch_size, n_heads, query_length, key_length) + weights = self.dropout(weights, training=training) # (batch_size, n_heads, query_length, key_length) + + # Mask heads if we want to + if layer_head_mask is not None: + tf.debugging.assert_equal( + shape_list(layer_head_mask), + [self.n_heads], + message=( + f"Head mask for a single layer should be of size {(self.n_heads)}, but is" + f" {shape_list(layer_head_mask)}" + ), + ) + weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * weights + + attn_output = tf.matmul(weights, value_states) # (batch_size, n_heads, query_length, dim_per_head) + + attn_output = self.o(unshape(attn_output)) + + outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) + + if output_attentions: + outputs = outputs + (weights,) + + return outputs + + +class TFT5LayerSelfAttention(keras.layers.Layer): + def __init__(self, config, has_relative_attention_bias=False, **kwargs): + super().__init__(**kwargs) + self.SelfAttention = TFT5Attention( + config, + has_relative_attention_bias=has_relative_attention_bias, + name="SelfAttention", + ) + self.layer_norm = TFT5LayerNorm(config.d_model, epsilon=config.layer_norm_epsilon, name="layer_norm") + self.dropout = keras.layers.Dropout(config.dropout_rate) + + def call( + self, + hidden_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + use_cache=False, + output_attentions=False, + training=False, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.SelfAttention( + normed_hidden_states, + mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + training=training, + ) + hidden_states = hidden_states + self.dropout(attention_output[0], training=training) + outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "SelfAttention", None) is not None: + with tf.name_scope(self.SelfAttention.name): + self.SelfAttention.build(None) + if getattr(self, "layer_norm", None) is not None: + with tf.name_scope(self.layer_norm.name): + self.layer_norm.build(None) + + +class TFT5LayerCrossAttention(keras.layers.Layer): + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + self.EncDecAttention = TFT5Attention( + config, + has_relative_attention_bias=False, + name="EncDecAttention", + ) + self.layer_norm = TFT5LayerNorm(config.d_model, epsilon=config.layer_norm_epsilon, name="layer_norm") + self.dropout = keras.layers.Dropout(config.dropout_rate) + + def call( + self, + hidden_states, + key_value_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + query_length=None, + use_cache=False, + output_attentions=False, + training=False, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.EncDecAttention( + normed_hidden_states, + mask=attention_mask, + key_value_states=key_value_states, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + query_length=query_length, + use_cache=use_cache, + output_attentions=output_attentions, + training=training, + ) + hidden_states = hidden_states + self.dropout(attention_output[0], training=training) + outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "EncDecAttention", None) is not None: + with tf.name_scope(self.EncDecAttention.name): + self.EncDecAttention.build(None) + if getattr(self, "layer_norm", None) is not None: + with tf.name_scope(self.layer_norm.name): + self.layer_norm.build(None) + + +class TFT5Block(keras.layers.Layer): + def __init__(self, config, has_relative_attention_bias=False, **kwargs): + super().__init__(**kwargs) + self.is_decoder = config.is_decoder + self.layer = [] + self.layer.append( + TFT5LayerSelfAttention( + config, + has_relative_attention_bias=has_relative_attention_bias, + name="layer_._0", + ) + ) + if self.is_decoder: + self.layer.append( + TFT5LayerCrossAttention( + config, + name="layer_._1", + ) + ) + + self.layer.append(TFT5LayerFF(config, name=f"layer_._{len(self.layer)}")) + + def call( + self, + hidden_states, + attention_mask=None, + position_bias=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + encoder_decoder_position_bias=None, + layer_head_mask=None, + encoder_layer_head_mask=None, + past_key_value=None, + use_cache=False, + output_attentions=False, + training=False, + ): + if past_key_value is not None: + assert self.is_decoder, "Only decoder can use `past_key_values`" + expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 + + if len(past_key_value) != expected_num_past_key_values: + raise ValueError( + f"There should be {expected_num_past_key_values} past states. " + f"{'2 (key / value) for cross attention' if expected_num_past_key_values == 4 else ''}. " + f"Got {len(past_key_value)} past key / value states" + ) + + self_attn_past_key_value = past_key_value[:2] + cross_attn_past_key_value = past_key_value[2:] + else: + self_attn_past_key_value, cross_attn_past_key_value = None, None + + self_attention_outputs = self.layer[0]( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=self_attn_past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + training=training, + ) + hidden_states, present_key_value_state = self_attention_outputs[:2] + attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights + + if self.is_decoder and encoder_hidden_states is not None: + # the actual query length is unknown for cross attention + # if using past key value states. Need to inject it here + if present_key_value_state is not None: + query_length = shape_list(present_key_value_state[0])[2] + else: + query_length = None + + cross_attention_outputs = self.layer[1]( + hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + position_bias=encoder_decoder_position_bias, + layer_head_mask=encoder_layer_head_mask, + past_key_value=cross_attn_past_key_value, + query_length=query_length, + use_cache=use_cache, + output_attentions=output_attentions, + training=training, + ) + hidden_states = cross_attention_outputs[0] + # Combine self attn and cross attn key value states + if present_key_value_state is not None: + present_key_value_state = present_key_value_state + cross_attention_outputs[1] + + # Keep cross-attention outputs and relative position weights + attention_outputs = attention_outputs + cross_attention_outputs[2:] + + # Apply Feed Forward layer + hidden_states = self.layer[-1](hidden_states, training=training) + outputs = (hidden_states,) + + # Add attentions if we output them + outputs = outputs + (present_key_value_state,) + attention_outputs + return outputs # hidden-states, present_key_value_states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + for layer_module in self.layer: + if hasattr(layer_module, "name"): + with tf.name_scope(layer_module.name): + layer_module.build(None) + + +#################################################### +# The full model without a specific pretrained or finetuning head is +# provided as a keras.layers.Layer usually called "TFT5MainLayer" +#################################################### +@keras_serializable +class TFT5MainLayer(keras.layers.Layer): + config_class = T5Config + + def __init__(self, config, embed_tokens=None, **kwargs): + super().__init__(**kwargs) + + self.config = config + self.output_hidden_states = config.output_hidden_states + self.output_attentions = config.output_attentions + self.use_cache = config.use_cache + + self.embed_tokens = embed_tokens + self.is_decoder = config.is_decoder + + self.config = config + self.num_hidden_layers = config.num_layers + + self.block = [ + TFT5Block(config, has_relative_attention_bias=bool(i == 0), name=f"block_._{i}") + for i in range(config.num_layers) + ] + self.final_layer_norm = TFT5LayerNorm( + config.d_model, epsilon=config.layer_norm_epsilon, name="final_layer_norm" + ) + self.dropout = keras.layers.Dropout(config.dropout_rate) + + def _prune_heads(self, heads_to_prune): + raise NotImplementedError # Not implemented yet in the library fr TF 2.0 models + + @unpack_inputs + def call( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + inputs_embeds=None, + head_mask=None, + encoder_head_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + training=False, + ) -> Tuple: + if input_ids is not None and inputs_embeds is not None: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError( + f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time" + ) + elif input_ids is not None: + input_shape = shape_list(input_ids) + input_ids = tf.reshape(input_ids, (-1, input_shape[-1])) + elif inputs_embeds is not None: + input_shape = shape_list(inputs_embeds)[:-1] + else: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") + + if inputs_embeds is None: + assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings" + check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim) + inputs_embeds = self.embed_tokens(input_ids) + + batch_size, seq_length = input_shape + + # required mask seq length can be calculated via length of past + mask_seq_length = ( + shape_list(past_key_values[0][0])[2] + seq_length if past_key_values is not None else seq_length + ) + + if attention_mask is None: + attention_mask = tf.fill((batch_size, mask_seq_length), 1) + if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: + encoder_seq_length = shape_list(encoder_hidden_states)[1] + encoder_attention_mask = tf.fill((batch_size, encoder_seq_length), 1) + + # initialize past_key_values with `None` if past does not exist + if past_key_values is None: + past_key_values = [None] * len(self.block) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + attention_mask = tf.cast(attention_mask, dtype=inputs_embeds.dtype) + num_dims_attention_mask = len(shape_list(attention_mask)) + if num_dims_attention_mask == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif num_dims_attention_mask == 2: + # Provided a padding mask of dimensions [batch_size, mask_seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] + if self.is_decoder: + seq_ids = tf.range(mask_seq_length) + causal_mask = tf.less_equal( + tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)), + seq_ids[None, :, None], + ) + causal_mask = tf.cast(causal_mask, dtype=attention_mask.dtype) + extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] + if past_key_values[0] is not None: + extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :] + else: + extended_attention_mask = attention_mask[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -1e9 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + + # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition + # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270 + # extended_attention_mask = tf.math.equal(extended_attention_mask, + # tf.transpose(extended_attention_mask, perm=(-1, -2))) + + extended_attention_mask = (1.0 - extended_attention_mask) * -1e9 + + if self.is_decoder and encoder_attention_mask is not None: + # If a 2D ou 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + encoder_attention_mask = tf.cast(encoder_attention_mask, dtype=extended_attention_mask.dtype) + num_dims_encoder_attention_mask = len(shape_list(encoder_attention_mask)) + if num_dims_encoder_attention_mask == 3: + encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] + if num_dims_encoder_attention_mask == 2: + encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] + + # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition + # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270 + # encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask, + # tf.transpose(encoder_extended_attention_mask, perm=(-1, -2))) + + encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e9 + else: + encoder_extended_attention_mask = None + + present_key_value_states = () if use_cache and self.is_decoder else None + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_cross_attentions = () if (output_attentions and self.is_decoder) else None + position_bias = None + encoder_decoder_position_bias = None + + hidden_states = self.dropout(inputs_embeds, training=training) + + for idx, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + layer_outputs = layer_module( + hidden_states, + attention_mask=extended_attention_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + layer_head_mask=head_mask[idx] if head_mask is not None else None, + encoder_layer_head_mask=encoder_head_mask[idx] if encoder_head_mask is not None else None, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + training=training, + ) + + # layer_outputs is a tuple with: + # hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) + hidden_states, present_key_value_state = layer_outputs[:2] + + # We share the position biases between the layers - the first layer store them + # layer_outputs = hidden-states, past_key_values, (self-attention weights), + # (self-attention position bias), (cross-attention position bias), (cross-attention weights), + position_bias = layer_outputs[2] + + if self.is_decoder and encoder_hidden_states is not None: + encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] + + # append next layer key value states + if present_key_value_state is not None and use_cache and self.is_decoder: + present_key_value_states = present_key_value_states + (present_key_value_state,) + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[3],) + if self.is_decoder: + all_cross_attentions = all_cross_attentions + (layer_outputs[5],) + + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states, training=training) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + outputs = (hidden_states,) + # need to check if is decoder here as well for special cases when using keras compile + if use_cache and self.is_decoder: + outputs = outputs + (present_key_value_states,) + if output_hidden_states: + outputs = outputs + (all_hidden_states,) + if output_attentions: + outputs = outputs + (all_attentions,) + if self.is_decoder: + outputs + (all_cross_attentions,) + return outputs # last-layer hidden state, (past_key_values), (all hidden states), (all attentions), (all_cross_attentions) + + if self.is_decoder: + return TFBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=present_key_value_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ) + else: + return TFBaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "final_layer_norm", None) is not None: + with tf.name_scope(self.final_layer_norm.name): + self.final_layer_norm.build(None) + if getattr(self, "block", None) is not None: + for layer in self.block: + with tf.name_scope(layer.name): + layer.build(None) + + +#################################################### +# TFT5PreTrainedModel is a sub-class of keras.Model +# which take care of loading and saving pretrained weights +# and various common utilities. +# Here you just need to specify a few (self-explanatory) +# pointers for your model. +#################################################### +class TFT5PreTrainedModel(TFPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = T5Config + base_model_prefix = "transformer" + # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model + _keys_to_ignore_on_load_unexpected = [r"decoder\Wblock[\W_0]+layer[\W_1]+EncDecAttention\Wrelative_attention_bias"] + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, value): + self.shared = value + self.encoder.embed_tokens = self.shared + if hasattr(self, "decoder"): + self.decoder.embed_tokens = self.shared + + def _shift_right(self, input_ids): + decoder_start_token_id = self.config.decoder_start_token_id + pad_token_id = self.config.pad_token_id + + assert decoder_start_token_id is not None, ( + "self.model.config.decoder_start_token_id has to be defined. In TF T5 it is usually set to the" + " pad_token_id. See T5 docs for more information" + ) + + start_tokens = tf.fill((shape_list(input_ids)[0], 1), decoder_start_token_id) + start_tokens = tf.cast(start_tokens, input_ids.dtype) # Ensure compatible dtypes for concatenation + shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1) + + assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids = tf.where( + shifted_input_ids == -100, + tf.cast(tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids.dtype), + shifted_input_ids, + ) + + # "Verify that `labels` has only positive values and -100" + assert_gte0 = tf.debugging.assert_greater_equal( + shifted_input_ids, tf.constant(0, dtype=shifted_input_ids.dtype) + ) + + # Make sure the assertion op is called by wrapping the result in an identity no-op + with tf.control_dependencies([assert_gte0]): + shifted_input_ids = tf.identity(shifted_input_ids) + + return shifted_input_ids + + +T5_START_DOCSTRING = r""" + + The T5 model was proposed in [Exploring the Limits of Transfer Learning with a Unified Text-to-Text + Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan + Narang, Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder transformer pre-trained in a + text-to-text denoising generative setting. + + This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it + as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and + behavior. + + + + TensorFlow models and layers in `transformers` accept two formats as input: + + - having all inputs as keyword arguments (like PyTorch models), or + - having all inputs as a list, tuple or dict in the first positional argument. + + The reason the second format is supported is that Keras methods prefer this format when passing inputs to models + and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just + pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second + format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with + the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first + positional argument: + + - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` + - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: + `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` + - a dictionary with one or several input Tensors associated to the input names given in the docstring: + `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` + + Note that when creating models and layers with + [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry + about any of this, as you can just pass inputs like you would to any other Python function! + + + + Parameters: + config ([`T5Config`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +T5_INPUTS_DOCSTRING = r""" + Args: + input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you + should be able to pad the inputs on the right or the left. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and + [`PreTrainedTokenizer.encode`] for details. + + [What are input IDs?](../glossary#input-ids) + + To know more on how to prepare `inputs` for pretraining take a look at [T5 Training](./t5#training). + decoder_input_ids (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Provide for sequence to sequence training. T5 uses the `pad_token_id` as the starting token for + `decoder_input_ids` generation. If `past_key_values` is used, optionally only the last `decoder_input_ids` + have to be input (see `past_key_values`). + + To know more on how to prepare `decoder_input_ids` for pretraining take a look at [T5 + Training](./t5#training). + attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_attention_mask (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + head_mask (`tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in `[0, + 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + decoder_head_mask (`tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0, + 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + encoder_outputs (`tuple(tuple(tf.FloatTensor)`, *optional*): + Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, `optional`: *attentions*) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a sequence of hidden states at + the output of the last layer of the encoder. Used in the cross-attention of the decoder. + past_key_values (`tuple(tuple(tf.Tensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + decoder_inputs_embeds (`tf.Tensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded + representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be + input (see `past_key_values`). This is useful if you want more control over how to convert + `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. + + If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value + of `inputs_embeds`. + use_cache (`bool`, *optional*, defaults to `True`): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the + config will be used instead. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. This argument can be used only in eager mode, in graph mode the value in the config will be + used instead. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in + eager mode, in graph mode the value will always be set to True. + training (`bool`, *optional*, defaults to `False`): + Whether or not to use the model in training mode (some modules like dropout modules have different + behaviors between training and evaluation). +""" + +T5_ENCODER_INPUTS_DOCSTRING = r""" + Args: + inputs (`tf.Tensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you + should be able to pad the inputs on the right or the left. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and + [`PreTrainedTokenizer.encode`] for details. + + To know more on how to prepare `inputs` for pre-training take a look at [T5 Training](./t5#training). + attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + head_mask (`tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + training (`bool`, *optional*, defaults to `False`): + Whether or not to use the model in training mode (some modules like dropout modules have different + behaviors between training and evaluation). +""" + +_HEAD_MASK_WARNING_MSG = """ +The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently, +`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions. +If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = tf.ones((num_layers, +num_heads))`. +""" + + +@add_start_docstrings( + "The bare T5 Model transformer outputting raw hidden-stateswithout any specific head on top.", + T5_START_DOCSTRING, +) +class TFT5Model(TFT5PreTrainedModel): + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.shared = keras.layers.Embedding( + input_dim=config.vocab_size, + output_dim=config.d_model, + embeddings_initializer=keras.initializers.TruncatedNormal(self.config.initializer_factor), + name="shared", + ) + # Additional attribute to specify the expected name scope of the layer (for loading/storing weights) + self.shared.load_weight_prefix = "shared" + + encoder_config = copy.deepcopy(config) + encoder_config.use_cache = False + self.encoder = TFT5MainLayer(encoder_config, self.shared, name="encoder") + + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.num_layers = config.num_decoder_layers + self.decoder = TFT5MainLayer(decoder_config, self.shared, name="decoder") + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + @unpack_inputs + @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TFSeq2SeqModelOutput, config_class=_CONFIG_FOR_DOC) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + decoder_input_ids: np.ndarray | tf.Tensor | None = None, + decoder_attention_mask: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + decoder_head_mask: np.ndarray | tf.Tensor | None = None, + encoder_outputs: np.ndarray | tf.Tensor | None = None, + past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + decoder_inputs_embeds: np.ndarray | tf.Tensor | None = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: Optional[bool] = False, + ) -> Union[Tuple, TFSeq2SeqModelOutput]: + r""" + Returns: + + Examples: + + ```python + >>> from transformers import AutoTokenizer, TFT5Model + + >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small") + >>> model = TFT5Model.from_pretrained("google-t5/t5-small") + + >>> input_ids = tokenizer( + ... "Studies have been shown that owning a dog is good for you", return_tensors="tf" + ... ).input_ids # Batch size 1 + >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="tf").input_ids # Batch size 1 + + >>> # preprocess: Prepend decoder_input_ids with start token which is pad token for T5Model. + >>> # This is not needed for torch's T5ForConditionalGeneration as it does this internally using labels arg. + >>> decoder_input_ids = model._shift_right(decoder_input_ids) + + >>> # forward pass + >>> outputs = model(input_ids, decoder_input_ids=decoder_input_ids) + >>> last_hidden_states = outputs.last_hidden_state + ```""" + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + warnings.warn(_HEAD_MASK_WARNING_MSG, FutureWarning) + decoder_head_mask = head_mask + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids, + attention_mask=attention_mask, + encoder_hidden_states=None, + encoder_attention_mask=None, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + past_key_values=None, + use_cache=False, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + hidden_states = encoder_outputs[0] + + # Decode + decoder_outputs = self.decoder( + decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + inputs_embeds=decoder_inputs_embeds, + head_mask=decoder_head_mask, + encoder_head_mask=head_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + past = decoder_outputs[1] if use_cache else None + + if not return_dict: + if past_key_values is not None: + decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:] + return decoder_outputs + encoder_outputs + + return TFSeq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=past, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + # The shared/tied weights expect to be in the model base namespace + # Adding "/" to the end (not the start!) of a tf.name_scope puts it in the root namespace rather than + # the current one. + with tf.name_scope(self.shared.load_weight_prefix + "/" + self.shared.name + "/"): + self.shared.build(None) + if getattr(self, "encoder", None) is not None: + with tf.name_scope(self.encoder.name): + self.encoder.build(None) + if getattr(self, "decoder", None) is not None: + with tf.name_scope(self.decoder.name): + self.decoder.build(None) + + +@add_start_docstrings("""T5 Model with a `language modeling` head on top.""", T5_START_DOCSTRING) +class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModelingLoss): + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.model_dim = config.d_model + self.shared = keras.layers.Embedding( + config.vocab_size, + config.d_model, + name="shared", + embeddings_initializer=get_initializer(self.config.initializer_factor), + ) + # Additional attribute to specify the expected name scope of the layer (for loading/storing weights) + self.shared.load_weight_prefix = "shared" + + encoder_config = copy.deepcopy(config) + encoder_config.use_cache = False + self.encoder = TFT5MainLayer(encoder_config, self.shared, name="encoder") + + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.num_layers = config.num_decoder_layers + self.decoder = TFT5MainLayer(decoder_config, self.shared, name="decoder") + + if not config.tie_word_embeddings: + lm_head_initializer = keras.initializers.RandomNormal(mean=0, stddev=config.initializer_factor) + self.lm_head = keras.layers.Dense( + config.vocab_size, use_bias=False, name="lm_head", kernel_initializer=lm_head_initializer + ) # Update init weights as in flax + self.config = config + + def get_output_embeddings(self): + if self.config.tie_word_embeddings: + return self.get_input_embeddings() + else: + # in a dense layer the kernel has a shape (last_dim, units), for us (dim, num_tokens) + # value has a shape (num_tokens, dim) then needs to be transposed + return tf.transpose(self.lm_head.kernel) + + def set_output_embeddings(self, value): + if self.config.tie_word_embeddings: + self.set_input_embeddings(value) + else: + lm_head_initializer = keras.initializers.RandomNormal(mean=0, stddev=self.config.initializer_factor) + self.lm_head = keras.layers.Dense( + shape_list(value)[0], use_bias=False, name="lm_head", kernel_initializer=lm_head_initializer + ) # Update init weights as in flax + # in a dense layer the kernel has a shape (last_dim, units), for us (dim, num_tokens) + # value has a shape (num_tokens, dim) then needs to be transposed + transposed_value = tf.transpose(value) + self.lm_head.kernel = transposed_value + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + @unpack_inputs + @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + decoder_input_ids: np.ndarray | tf.Tensor | None = None, + decoder_attention_mask: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + decoder_head_mask: np.ndarray | tf.Tensor | None = None, + encoder_outputs: np.ndarray | tf.Tensor | None = None, + past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + decoder_inputs_embeds: np.ndarray | tf.Tensor | None = None, + labels: np.ndarray | tf.Tensor | None = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: Optional[bool] = False, + ) -> Union[Tuple, TFSeq2SeqLMOutput]: + r""" + labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the cross entropy classification loss. Indices should be in `[0, ..., + config.vocab_size - 1]`. + + Returns: + + Examples: + + ```python + >>> from transformers import AutoTokenizer, TFT5ForConditionalGeneration + + >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small") + >>> model = TFT5ForConditionalGeneration.from_pretrained("google-t5/t5-small") + + >>> # training + >>> inputs = tokenizer("The walks in park", return_tensors="tf").input_ids + >>> labels = tokenizer(" cute dog the ", return_tensors="tf").input_ids + >>> outputs = model(inputs, labels=labels) + >>> loss = outputs.loss + >>> logits = outputs.logits + + >>> # inference + >>> inputs = tokenizer( + ... "summarize: studies have shown that owning a dog is good for you", return_tensors="tf" + ... ).input_ids # Batch size 1 + >>> outputs = model.generate(inputs) + >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True)) + >>> # studies have shown that owning a dog is good for you + ```""" + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + warnings.warn(_HEAD_MASK_WARNING_MSG, FutureWarning) + decoder_head_mask = head_mask + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + hidden_states = encoder_outputs[0] + + if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: + # get decoder inputs from shifting lm labels to the right + decoder_input_ids = self._shift_right(labels) + + # Decode + decoder_outputs = self.decoder( + decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + inputs_embeds=decoder_inputs_embeds, + head_mask=decoder_head_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + sequence_output = decoder_outputs[0] + + # T5v1.1 does not tie output word embeddings and thus does not require downscaling + if self.config.tie_word_embeddings: + sequence_output = sequence_output * (self.model_dim**-0.5) + logits = tf.matmul(sequence_output, self.shared.weights, transpose_b=True) + else: + logits = self.lm_head(sequence_output) + + logits = tf.cast(logits, tf.float32) + + loss = None if labels is None else self.hf_compute_loss(labels, logits) + + past = decoder_outputs[1] if use_cache else None + if not return_dict: + if past_key_values is not None: + decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:] + output = (logits,) + decoder_outputs[1:] + encoder_outputs + return ((loss,) + output) if loss is not None else output + + # If the user passed a tuple for encoder_outputs, we wrap it in a TFBaseModelOutput when return_dict=True + elif isinstance(encoder_outputs, tuple): + last_hidden_state = encoder_outputs[0] + hidden_states = None + attentions = None + idx = 0 + if output_hidden_states: + idx += 1 + hidden_states = encoder_outputs[idx] + if output_attentions: + idx += 1 + attentions = encoder_outputs[idx] + + encoder_outputs = TFBaseModelOutput( + last_hidden_state=last_hidden_state, + hidden_states=hidden_states, + attentions=attentions, + ) + + return TFSeq2SeqLMOutput( + loss=loss, + logits=logits, + past_key_values=past, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + def serving_output(self, output): + pkv = tf.convert_to_tensor(output.past_key_values[1:]) if self.config.use_cache else None + dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None + dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None + cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None + enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None + enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None + + return TFSeq2SeqLMOutput( + logits=output.logits, + past_key_values=pkv, + decoder_hidden_states=dec_hs, + decoder_attentions=dec_attns, + cross_attentions=cross_attns, + encoder_last_hidden_state=output.encoder_last_hidden_state, + encoder_hidden_states=enc_hs, + encoder_attentions=enc_attns, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs, + ): + # cut decoder_input_ids if past is used + if past_key_values is not None: + input_ids = input_ids[:, -1:] + + return { + "input_ids": None, # needs to be passed to make Keras.layer.__call__ happy + "decoder_input_ids": input_ids, + "past_key_values": past_key_values, + "encoder_outputs": encoder_outputs, + "attention_mask": attention_mask, + "decoder_attention_mask": decoder_attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "use_cache": use_cache, + } + + def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor): + return self._shift_right(labels) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + # The shared/tied weights expect to be in the model base namespace + # Adding "/" to the end (not the start!) of a tf.name_scope puts it in the root namespace rather than + # the current one. + with tf.name_scope(self.shared.load_weight_prefix + "/" + self.shared.name + "/"): + self.shared.build(None) + if getattr(self, "encoder", None) is not None: + with tf.name_scope(self.encoder.name): + self.encoder.build(None) + if getattr(self, "decoder", None) is not None: + with tf.name_scope(self.decoder.name): + self.decoder.build(None) + if getattr(self, "lm_head", None) is not None: + with tf.name_scope(self.lm_head.name): + self.lm_head.build([None, None, self.config.d_model]) + + +@add_start_docstrings( + "The bare T5 Model transformer outputting encoder's raw hidden-stateswithout any specific head on top.", + T5_START_DOCSTRING, +) +class TFT5EncoderModel(TFT5PreTrainedModel): + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.shared = keras.layers.Embedding( + config.vocab_size, + config.d_model, + name="shared", + embeddings_initializer=get_initializer(self.config.initializer_factor), + ) + # Additional attribute to specify the expected name scope of the layer (for loading/storing weights) + self.shared.load_weight_prefix = "shared" + + encoder_config = copy.deepcopy(config) + encoder_config.use_cache = False + self.encoder = TFT5MainLayer(encoder_config, self.shared, name="encoder") + + def get_encoder(self): + return self.encoder + + @unpack_inputs + @add_start_docstrings_to_model_forward(T5_ENCODER_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TFBaseModelOutput, config_class=_CONFIG_FOR_DOC) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: Optional[bool] = False, + ) -> Union[Tuple, TFBaseModelOutput]: + r""" + Returns: + + Examples: + + ```python + >>> from transformers import AutoTokenizer, TFT5EncoderModel + + >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small") + >>> model = TFT5EncoderModel.from_pretrained("google-t5/t5-small") + + >>> input_ids = tokenizer( + ... "Studies have been shown that owning a dog is good for you", return_tensors="tf" + ... ).input_ids # Batch size 1 + >>> outputs = model(input_ids) + ```""" + + encoder_outputs = self.encoder( + input_ids, + attention_mask=attention_mask, + encoder_hidden_states=None, + encoder_attention_mask=None, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + past_key_values=None, + use_cache=False, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + if not return_dict: + return encoder_outputs + + return TFBaseModelOutput( + last_hidden_state=encoder_outputs.last_hidden_state, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + # The shared/tied weights expect to be in the model base namespace + # Adding "/" to the end (not the start!) of a tf.name_scope puts it in the root namespace rather than + # the current one. + with tf.name_scope(self.shared.load_weight_prefix + "/" + self.shared.name + "/"): + self.shared.build(None) + if getattr(self, "encoder", None) is not None: + with tf.name_scope(self.encoder.name): + self.encoder.build(None) + + +__all__ = ["TFT5EncoderModel", "TFT5ForConditionalGeneration", "TFT5Model", "TFT5PreTrainedModel"] diff --git a/docs/transformers/src/transformers/models/t5/tokenization_t5.py b/docs/transformers/src/transformers/models/t5/tokenization_t5.py new file mode 100644 index 0000000000000000000000000000000000000000..01447232b890e4dfa126564dbca7de3c44902bc1 --- /dev/null +++ b/docs/transformers/src/transformers/models/t5/tokenization_t5.py @@ -0,0 +1,450 @@ +# coding=utf-8 +# Copyright 2018 T5 Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization class for model T5.""" + +import os +import re +import warnings +from shutil import copyfile +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple + +import sentencepiece as spm + +from ...convert_slow_tokenizer import import_protobuf +from ...tokenization_utils import PreTrainedTokenizer +from ...tokenization_utils_base import AddedToken + + +if TYPE_CHECKING: + from ...tokenization_utils_base import TextInput +from ...utils import logging +from ...utils.import_utils import requires + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"} + + +SPIECE_UNDERLINE = "▁" + + +@requires(backends=("sentencepiece",)) +class T5Tokenizer(PreTrainedTokenizer): + """ + Construct a T5 tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece). + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that + contains the vocabulary necessary to instantiate a tokenizer. + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + extra_ids (`int`, *optional*, defaults to 100): + Add a number of extra ids added to the vocabulary for use as sentinels. These tokens are + accessible as "" where "{%d}" is a number between 0 and extra_ids-1. These tokens can be + retrieved by calling get_sentinel_tokens method and token ids can be by calling get_sentinel_token_ids + method + additional_special_tokens (`List[str]`, *optional*): + Additional special tokens used by the tokenizer. + sp_model_kwargs (`dict`, *optional*): + Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for + SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things, + to set: + + - `enable_sampling`: Enable subword regularization. + - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout. + + - `nbest_size = {0,1}`: No sampling is performed. + - `nbest_size > 1`: samples from the nbest_size results. + - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice) + using forward-filtering-and-backward-sampling algorithm. + + - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for + BPE-dropout. + legacy (`bool`, *optional*): + Whether or not the `legacy` behaviour of the tokenizer should be used. Legacy is before the merge of #24622 + and #25224 which includes fixes to properly handle tokens that appear after special tokens. A simple + example: + + - `legacy=True`: + ```python + >>> from transformers import T5Tokenizer + + >>> tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-base", legacy=True) + >>> tokenizer.encode("Hello .") + [8774, 32099, 3, 5, 1] + ``` + - `legacy=False`: + ```python + >>> from transformers import T5Tokenizer + + >>> tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-base", legacy=False) + >>> tokenizer.encode("Hello .") # the extra space `[3]` is no longer here + [8774, 32099, 5, 1] + ``` + Checkout the [pull request](https://github.com/huggingface/transformers/pull/24565) for more details. + add_prefix_space (`bool`, *optional*, defaults to `False`): + Whether or not to add an initial space to the input. This allows to treat the leading word just as any + other word. + + Attributes: + sp_model (`SentencePieceProcessor`): + The *SentencePiece* processor that is used for every conversion (string, tokens and IDs). + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + eos_token="", + unk_token="", + pad_token="", + extra_ids=100, + additional_special_tokens=None, + sp_model_kwargs: Optional[Dict[str, Any]] = None, + legacy=None, + add_prefix_space=True, + **kwargs, + ) -> None: + pad_token = AddedToken(pad_token, special=True) if isinstance(pad_token, str) else pad_token + unk_token = AddedToken(unk_token, special=True) if isinstance(unk_token, str) else unk_token + eos_token = AddedToken(eos_token, special=True) if isinstance(eos_token, str) else eos_token + + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + + self.vocab_file = vocab_file + self._extra_ids = extra_ids + + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(vocab_file) + + if additional_special_tokens is not None: + extra_tokens = [x for x in additional_special_tokens if "" for i in range(extra_ids)] + elif extra_ids > 0 and extra_ids != len(extra_tokens): + raise ValueError( + f"Both extra_ids ({extra_ids}) and additional_special_tokens ({additional_special_tokens}) are" + " provided to T5Tokenizer. In this case the additional_special_tokens must include the extra_ids" + " tokens" + ) + else: + extra_tokens = [f"" for i in range(extra_ids)] + additional_special_tokens = extra_tokens + + # for legacy purpose, we keep this. Will be removed and tests updated. (when `added_tokens_decoder` is not passed as kwargs) + self._added_tokens_decoder = {} + for i in range(len(extra_tokens)): + self._added_tokens_decoder[len(self.sp_model) - 1 + extra_ids - i] = AddedToken( + f"", single_word=False, lstrip=True, rstrip=True, special=True, normalized=False + ) + + if legacy is None: + logger.warning_once( + f"You are using the default legacy behaviour of the {self.__class__}. This is" + " expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you." + " If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it" + " means, and thoroughly read the reason why this was added as explained in" + " https://github.com/huggingface/transformers/pull/24565" + ) + legacy = True + + self.legacy = legacy + self.sp_model = self.get_spm_processor(kwargs.pop("from_slow", False)) + self.vocab_file = vocab_file + self._extra_ids = extra_ids + self.add_prefix_space = add_prefix_space + + super().__init__( + eos_token=eos_token, + unk_token=unk_token, + pad_token=pad_token, + extra_ids=extra_ids, + additional_special_tokens=additional_special_tokens, + sp_model_kwargs=self.sp_model_kwargs, + legacy=legacy, + add_prefix_space=add_prefix_space, + **kwargs, + ) + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_spm_processor + def get_spm_processor(self, from_slow=False): + tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs) + if self.legacy or from_slow: # no dependency on protobuf + tokenizer.Load(self.vocab_file) + return tokenizer + + with open(self.vocab_file, "rb") as f: + sp_model = f.read() + model_pb2 = import_protobuf(f"The new behaviour of {self.__class__.__name__} (with `self.legacy = False`)") + model = model_pb2.ModelProto.FromString(sp_model) + normalizer_spec = model_pb2.NormalizerSpec() + normalizer_spec.add_dummy_prefix = False + model.normalizer_spec.MergeFrom(normalizer_spec) + sp_model = model.SerializeToString() + tokenizer.LoadFromSerializedProto(sp_model) + return tokenizer + + @staticmethod + def _eventually_correct_t5_max_length(pretrained_model_name_or_path, max_model_length, init_max_model_length): + if pretrained_model_name_or_path in T5Tokenizer.max_model_input_sizes: + deprecated_max_model_length = T5Tokenizer.max_model_input_sizes[pretrained_model_name_or_path] + if init_max_model_length is not None and init_max_model_length != max_model_length: + return init_max_model_length + elif init_max_model_length is None: + warnings.warn( + "This tokenizer was incorrectly instantiated with a model max length of" + f" {deprecated_max_model_length} which will be corrected in Transformers v5.\nFor now, this" + " behavior is kept to avoid breaking backwards compatibility when padding/encoding with" + " `truncation is True`.\n- Be aware that you SHOULD NOT rely on" + f" {pretrained_model_name_or_path} automatically truncating your input to" + f" {deprecated_max_model_length} when padding/encoding.\n- If you want to encode/pad to sequences" + f" longer than {deprecated_max_model_length} you can either instantiate this tokenizer with" + " `model_max_length` or pass `max_length` when encoding/padding.\n- To avoid this warning, please" + " instantiate this tokenizer with `model_max_length` set to your preferred value.", + FutureWarning, + ) + + return max_model_length + + @property + def vocab_size(self): + return self.sp_model.get_piece_size() + + def get_vocab(self): + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + # normal case: some special tokens + if token_ids_1 is None: + return ([0] * len(token_ids_0)) + [1] + return ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] + + def get_sentinel_tokens(self): + return list( + set(filter(lambda x: bool(re.search(r"", x)) is not None, self.additional_special_tokens)) + ) + + def get_sentinel_token_ids(self): + return [self.convert_tokens_to_ids(token) for token in self.get_sentinel_tokens()] + + def _add_eos_if_not_present(self, token_ids: List[int]) -> List[int]: + """Do not add eos again if user already added it.""" + if len(token_ids) > 0 and token_ids[-1] == self.eos_token_id: + warnings.warn( + f"This sequence already has {self.eos_token}. In future versions this behavior may lead to duplicated" + " eos tokens being added." + ) + return token_ids + else: + return token_ids + [self.eos_token_id] + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make + use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of zeros. + """ + eos = [self.eos_token_id] + + if token_ids_1 is None: + return len(token_ids_0 + eos) * [0] + return len(token_ids_0 + eos + token_ids_1 + eos) * [0] + + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A sequence has the following format: + + - single sequence: `X ` + - pair of sequences: `A B ` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + token_ids_0 = self._add_eos_if_not_present(token_ids_0) + if token_ids_1 is None: + return token_ids_0 + else: + token_ids_1 = self._add_eos_if_not_present(token_ids_1) + return token_ids_0 + token_ids_1 + + def __getstate__(self): + state = self.__dict__.copy() + state["sp_model"] = None + return state + + def __setstate__(self, d): + self.__dict__ = d + + # for backward compatibility + if not hasattr(self, "sp_model_kwargs"): + self.sp_model_kwargs = {} + + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(self.vocab_file) + + def tokenize(self, text: "TextInput", **kwargs) -> List[str]: + """ + Converts a string to a list of tokens. If `self.legacy` is set to `False`, a prefix token is added unless the + first token is special. + """ + if self.legacy or len(text) == 0: + return super().tokenize(text, **kwargs) + + text = text.replace(SPIECE_UNDERLINE, " ") + if self.add_prefix_space: + text = SPIECE_UNDERLINE + text + + tokens = super().tokenize(text, **kwargs) + + if len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens: + tokens = tokens[1:] + return tokens + + @property + def unk_token_length(self): + return len(self.sp_model.encode(str(self.unk_token))) + + def _tokenize(self, text, **kwargs): + """ + Returns a tokenized string. + + We de-activated the `add_dummy_prefix` option, thus the sentencepiece internals will always strip any + SPIECE_UNDERLINE. For example: `self.sp_model.encode(f"{SPIECE_UNDERLINE}Hey", out_type = str)` will give + `['H', 'e', 'y']` instead of `['▁He', 'y']`. Thus we always encode `f"{unk_token}text"` and strip the + `unk_token`. Here is an example with `unk_token = ""` and `unk_token_length = 4`. + `self.tokenizer.sp_model.encode(" Hey", out_type = str)[4:]`. + """ + if self.legacy or not text.startswith((SPIECE_UNDERLINE, " ")): + return self.sp_model.encode(text, out_type=str) + + # 1. Encode string + prefix ex: " Hey" + tokens = self.sp_model.encode(self.unk_token + text, out_type=str) + # 2. Remove self.unk_token from ['<','unk','>', '▁Hey'] + return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.sp_model.piece_to_id(token) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + token = self.sp_model.IdToPiece(index) + return token + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + # since we manually add the prefix space, we have to remove it when decoding + if tokens[0].startswith(SPIECE_UNDERLINE) and self.add_prefix_space: + tokens[0] = tokens[0][1:] + + current_sub_tokens = [] + out_string = "" + prev_is_special = False + for token in tokens: + # make sure that special tokens are not decoded using sentencepiece model + if token in self.all_special_tokens: + if not prev_is_special: + out_string += " " + out_string += self.sp_model.decode(current_sub_tokens) + token + prev_is_special = True + current_sub_tokens = [] + else: + current_sub_tokens.append(token) + prev_is_special = False + out_string += self.sp_model.decode(current_sub_tokens) + return out_string.strip() + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): + copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + + return (out_vocab_file,) + + +__all__ = ["T5Tokenizer"] diff --git a/docs/transformers/src/transformers/models/t5/tokenization_t5_fast.py b/docs/transformers/src/transformers/models/t5/tokenization_t5_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..8eb652728bfac9bed797679df9adedfdacf0cefe --- /dev/null +++ b/docs/transformers/src/transformers/models/t5/tokenization_t5_fast.py @@ -0,0 +1,237 @@ +# coding=utf-8 +# Copyright 2018 T5 Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization class for model T5.""" + +import os +import re +import warnings +from shutil import copyfile +from typing import List, Optional, Tuple + +from ...tokenization_utils_fast import PreTrainedTokenizerFast +from ...utils import is_sentencepiece_available, logging + + +if is_sentencepiece_available(): + from .tokenization_t5 import T5Tokenizer +else: + T5Tokenizer = None + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "spiece.model", "tokenizer_file": "tokenizer.json"} + + +# TODO(PVP) - this should be removed in Transformers v5 + + +class T5TokenizerFast(PreTrainedTokenizerFast): + """ + Construct a "fast" T5 tokenizer (backed by HuggingFace's *tokenizers* library). Based on + [Unigram](https://huggingface.co/docs/tokenizers/python/latest/components.html?highlight=unigram#models). + + This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should + refer to this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that + contains the vocabulary necessary to instantiate a tokenizer. + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + extra_ids (`int`, *optional*, defaults to 100): + Add a number of extra ids added to the vocabulary for use as sentinels. These tokens are accessible as + "" where "{%d}" is a number between 0 and extra_ids-1. These tokens can be retrieved by + calling get_sentinel_tokens method and token ids can be by calling get_sentinel_token_ids method + additional_special_tokens (`List[str]`, *optional*): + Additional special tokens used by the tokenizer. + add_prefix_space (`bool`, *optional*): + Whether or not the tokenizer should automatically add a prefix space + from_slow (`book`, *optional*, defaults to `False`): + Whether or not the tokenizer should be converted from a slow one. If `add_prefix_space` is set, this will be set to `True`. + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + slow_tokenizer_class = T5Tokenizer + + prefix_tokens: List[int] = [] + + def __init__( + self, + vocab_file=None, + tokenizer_file=None, + eos_token="", + unk_token="", + pad_token="", + extra_ids=100, + additional_special_tokens=None, + add_prefix_space=None, + **kwargs, + ): + # Add extra_ids to the special token list + if additional_special_tokens is not None: + extra_tokens = [x for x in additional_special_tokens if "" for i in range(extra_ids)] + elif extra_ids > 0 and extra_ids != len(extra_tokens): + raise ValueError( + f"Both extra_ids ({extra_ids}) and additional_special_tokens ({additional_special_tokens}) are" + " provided to T5Tokenizer. In this case the additional_special_tokens must include the extra_ids" + " tokens" + ) + else: + extra_tokens = [f"" for i in range(extra_ids)] + additional_special_tokens = extra_tokens + + if add_prefix_space is not None: + logger.warning_once( + "You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers" + ) + kwargs["from_slow"] = True + + super().__init__( + vocab_file=vocab_file, + tokenizer_file=tokenizer_file, + eos_token=eos_token, + unk_token=unk_token, + pad_token=pad_token, + extra_ids=extra_ids, + additional_special_tokens=additional_special_tokens, + add_prefix_space=add_prefix_space, + **kwargs, + ) + + self.vocab_file = vocab_file + self._extra_ids = extra_ids + + @property + def can_save_slow_tokenizer(self) -> bool: + return os.path.isfile(self.vocab_file) if self.vocab_file else False + + @staticmethod + def _eventually_correct_t5_max_length(pretrained_model_name_or_path, max_model_length, init_max_model_length): + if pretrained_model_name_or_path in T5TokenizerFast.max_model_input_sizes: + deprecated_max_model_length = T5TokenizerFast.max_model_input_sizes[pretrained_model_name_or_path] + if init_max_model_length is not None and init_max_model_length != max_model_length: + return init_max_model_length + elif init_max_model_length is None: + warnings.warn( + "This tokenizer was incorrectly instantiated with a model max length of" + f" {deprecated_max_model_length} which will be corrected in Transformers v5.\nFor now, this" + " behavior is kept to avoid breaking backwards compatibility when padding/encoding with" + " `truncation is True`.\n- Be aware that you SHOULD NOT rely on" + f" {pretrained_model_name_or_path} automatically truncating your input to" + f" {deprecated_max_model_length} when padding/encoding.\n- If you want to encode/pad to sequences" + f" longer than {deprecated_max_model_length} you can either instantiate this tokenizer with" + " `model_max_length` or pass `max_length` when encoding/padding.\n- To avoid this warning, please" + " instantiate this tokenizer with `model_max_length` set to your preferred value.", + FutureWarning, + ) + + return max_model_length + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + if not self.can_save_slow_tokenizer: + raise ValueError( + "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow " + "tokenizer." + ) + + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): + copyfile(self.vocab_file, out_vocab_file) + logger.info(f"Copy vocab file to {out_vocab_file}") + + return (out_vocab_file,) + + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A sequence has the following format: + + - single sequence: `X ` + - pair of sequences: `A B ` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + token_ids_0 = token_ids_0 + [self.eos_token_id] + if token_ids_1 is None: + return self.prefix_tokens + token_ids_0 + else: + token_ids_1 = token_ids_1 + [self.eos_token_id] + return self.prefix_tokens + token_ids_0 + token_ids_1 + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make + use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of zeros. + """ + eos = [self.eos_token_id] + + if token_ids_1 is None: + return len(token_ids_0 + eos) * [0] + return len(token_ids_0 + eos + token_ids_1 + eos) * [0] + + def get_sentinel_tokens(self): + return list( + set(filter(lambda x: bool(re.search(r"", x)) is not None, self.additional_special_tokens)) + ) + + def get_sentinel_token_ids(self): + return [self.convert_tokens_to_ids(token) for token in self.get_sentinel_tokens()] + + +__all__ = ["T5TokenizerFast"] diff --git a/docs/transformers/src/transformers/models/table_transformer/__init__.py b/docs/transformers/src/transformers/models/table_transformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8ebd1da6fdaeea49c47590e03c86474733265baa --- /dev/null +++ b/docs/transformers/src/transformers/models/table_transformer/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_table_transformer import * + from .modeling_table_transformer import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/table_transformer/configuration_table_transformer.py b/docs/transformers/src/transformers/models/table_transformer/configuration_table_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..458be0eea3106b00f2821a6e802fcfa4684f79f5 --- /dev/null +++ b/docs/transformers/src/transformers/models/table_transformer/configuration_table_transformer.py @@ -0,0 +1,279 @@ +# coding=utf-8 +# Copyright The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Table Transformer model configuration""" + +from collections import OrderedDict +from typing import Mapping + +from packaging import version + +from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig +from ...utils import logging +from ...utils.backbone_utils import verify_backbone_config_arguments +from ..auto import CONFIG_MAPPING + + +logger = logging.get_logger(__name__) + + +class TableTransformerConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`TableTransformerModel`]. It is used to + instantiate a Table Transformer model according to the specified arguments, defining the model architecture. + Instantiating a configuration with the defaults will yield a similar configuration to that of the Table Transformer + [microsoft/table-transformer-detection](https://huggingface.co/microsoft/table-transformer-detection) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + use_timm_backbone (`bool`, *optional*, defaults to `True`): + Whether or not to use the `timm` library for the backbone. If set to `False`, will use the [`AutoBackbone`] + API. + backbone_config (`PretrainedConfig` or `dict`, *optional*): + The configuration of the backbone model. Only used in case `use_timm_backbone` is set to `False` in which + case it will default to `ResNetConfig()`. + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + num_queries (`int`, *optional*, defaults to 100): + Number of object queries, i.e. detection slots. This is the maximal number of objects + [`TableTransformerModel`] can detect in a single image. For COCO, we recommend 100 queries. + d_model (`int`, *optional*, defaults to 256): + Dimension of the layers. + encoder_layers (`int`, *optional*, defaults to 6): + Number of encoder layers. + decoder_layers (`int`, *optional*, defaults to 6): + Number of decoder layers. + encoder_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer encoder. + decoder_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer decoder. + decoder_ffn_dim (`int`, *optional*, defaults to 2048): + Dimension of the "intermediate" (often named feed-forward) layer in decoder. + encoder_ffn_dim (`int`, *optional*, defaults to 2048): + Dimension of the "intermediate" (often named feed-forward) layer in decoder. + activation_function (`str` or `function`, *optional*, defaults to `"relu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + activation_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for activations inside the fully connected layer. + init_std (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + init_xavier_std (`float`, *optional*, defaults to 1): + The scaling factor used for the Xavier initialization gain in the HM Attention map module. + encoder_layerdrop (`float`, *optional*, defaults to 0.0): + The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) + for more details. + decoder_layerdrop (`float`, *optional*, defaults to 0.0): + The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) + for more details. + auxiliary_loss (`bool`, *optional*, defaults to `False`): + Whether auxiliary decoding losses (loss at each decoder layer) are to be used. + position_embedding_type (`str`, *optional*, defaults to `"sine"`): + Type of position embeddings to be used on top of the image features. One of `"sine"` or `"learned"`. + backbone (`str`, *optional*): + Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this + will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone` + is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights. + use_pretrained_backbone (`bool`, *optional*, `True`): + Whether to use pretrained weights for the backbone. + backbone_kwargs (`dict`, *optional*): + Keyword arguments to be passed to AutoBackbone when loading from a checkpoint + e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set. + dilation (`bool`, *optional*, defaults to `False`): + Whether to replace stride with dilation in the last convolutional block (DC5). Only supported when + `use_timm_backbone` = `True`. + class_cost (`float`, *optional*, defaults to 1): + Relative weight of the classification error in the Hungarian matching cost. + bbox_cost (`float`, *optional*, defaults to 5): + Relative weight of the L1 error of the bounding box coordinates in the Hungarian matching cost. + giou_cost (`float`, *optional*, defaults to 2): + Relative weight of the generalized IoU loss of the bounding box in the Hungarian matching cost. + mask_loss_coefficient (`float`, *optional*, defaults to 1): + Relative weight of the Focal loss in the panoptic segmentation loss. + dice_loss_coefficient (`float`, *optional*, defaults to 1): + Relative weight of the DICE/F-1 loss in the panoptic segmentation loss. + bbox_loss_coefficient (`float`, *optional*, defaults to 5): + Relative weight of the L1 bounding box loss in the object detection loss. + giou_loss_coefficient (`float`, *optional*, defaults to 2): + Relative weight of the generalized IoU loss in the object detection loss. + eos_coefficient (`float`, *optional*, defaults to 0.1): + Relative classification weight of the 'no-object' class in the object detection loss. + + Examples: + + ```python + >>> from transformers import TableTransformerModel, TableTransformerConfig + + >>> # Initializing a Table Transformer microsoft/table-transformer-detection style configuration + >>> configuration = TableTransformerConfig() + + >>> # Initializing a model from the microsoft/table-transformer-detection style configuration + >>> model = TableTransformerModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "table-transformer" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = { + "hidden_size": "d_model", + "num_attention_heads": "encoder_attention_heads", + } + + # Copied from transformers.models.detr.configuration_detr.DetrConfig.__init__ + def __init__( + self, + use_timm_backbone=True, + backbone_config=None, + num_channels=3, + num_queries=100, + encoder_layers=6, + encoder_ffn_dim=2048, + encoder_attention_heads=8, + decoder_layers=6, + decoder_ffn_dim=2048, + decoder_attention_heads=8, + encoder_layerdrop=0.0, + decoder_layerdrop=0.0, + is_encoder_decoder=True, + activation_function="relu", + d_model=256, + dropout=0.1, + attention_dropout=0.0, + activation_dropout=0.0, + init_std=0.02, + init_xavier_std=1.0, + auxiliary_loss=False, + position_embedding_type="sine", + backbone="resnet50", + use_pretrained_backbone=True, + backbone_kwargs=None, + dilation=False, + class_cost=1, + bbox_cost=5, + giou_cost=2, + mask_loss_coefficient=1, + dice_loss_coefficient=1, + bbox_loss_coefficient=5, + giou_loss_coefficient=2, + eos_coefficient=0.1, + **kwargs, + ): + # We default to values which were previously hard-coded in the model. This enables configurability of the config + # while keeping the default behavior the same. + if use_timm_backbone and backbone_kwargs is None: + backbone_kwargs = {} + if dilation: + backbone_kwargs["output_stride"] = 16 + backbone_kwargs["out_indices"] = [1, 2, 3, 4] + backbone_kwargs["in_chans"] = num_channels + # Backwards compatibility + elif not use_timm_backbone and backbone in (None, "resnet50"): + if backbone_config is None: + logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.") + backbone_config = CONFIG_MAPPING["resnet"](out_features=["stage4"]) + elif isinstance(backbone_config, dict): + backbone_model_type = backbone_config.get("model_type") + config_class = CONFIG_MAPPING[backbone_model_type] + backbone_config = config_class.from_dict(backbone_config) + backbone = None + # set timm attributes to None + dilation = None + + verify_backbone_config_arguments( + use_timm_backbone=use_timm_backbone, + use_pretrained_backbone=use_pretrained_backbone, + backbone=backbone, + backbone_config=backbone_config, + backbone_kwargs=backbone_kwargs, + ) + + self.use_timm_backbone = use_timm_backbone + self.backbone_config = backbone_config + self.num_channels = num_channels + self.num_queries = num_queries + self.d_model = d_model + self.encoder_ffn_dim = encoder_ffn_dim + self.encoder_layers = encoder_layers + self.encoder_attention_heads = encoder_attention_heads + self.decoder_ffn_dim = decoder_ffn_dim + self.decoder_layers = decoder_layers + self.decoder_attention_heads = decoder_attention_heads + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.activation_function = activation_function + self.init_std = init_std + self.init_xavier_std = init_xavier_std + self.encoder_layerdrop = encoder_layerdrop + self.decoder_layerdrop = decoder_layerdrop + self.num_hidden_layers = encoder_layers + self.auxiliary_loss = auxiliary_loss + self.position_embedding_type = position_embedding_type + self.backbone = backbone + self.use_pretrained_backbone = use_pretrained_backbone + self.backbone_kwargs = backbone_kwargs + self.dilation = dilation + # Hungarian matcher + self.class_cost = class_cost + self.bbox_cost = bbox_cost + self.giou_cost = giou_cost + # Loss coefficients + self.mask_loss_coefficient = mask_loss_coefficient + self.dice_loss_coefficient = dice_loss_coefficient + self.bbox_loss_coefficient = bbox_loss_coefficient + self.giou_loss_coefficient = giou_loss_coefficient + self.eos_coefficient = eos_coefficient + super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs) + + @property + def num_attention_heads(self) -> int: + return self.encoder_attention_heads + + @property + def hidden_size(self) -> int: + return self.d_model + + +# Copied from transformers.models.detr.configuration_detr.DetrOnnxConfig +class TableTransformerOnnxConfig(OnnxConfig): + torch_onnx_minimum_version = version.parse("1.11") + + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict( + [ + ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}), + ("pixel_mask", {0: "batch"}), + ] + ) + + @property + def atol_for_validation(self) -> float: + return 1e-5 + + @property + def default_onnx_opset(self) -> int: + return 12 + + +__all__ = ["TableTransformerConfig", "TableTransformerOnnxConfig"] diff --git a/docs/transformers/src/transformers/models/table_transformer/convert_table_transformer_to_hf.py b/docs/transformers/src/transformers/models/table_transformer/convert_table_transformer_to_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..487cdc481992867d1318ff4db706c872657480d6 --- /dev/null +++ b/docs/transformers/src/transformers/models/table_transformer/convert_table_transformer_to_hf.py @@ -0,0 +1,317 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert Table Transformer checkpoints with timm-backbone. + +URL: https://github.com/microsoft/table-transformer +""" + +import argparse +from collections import OrderedDict +from pathlib import Path + +import torch +from huggingface_hub import hf_hub_download +from PIL import Image +from torchvision.transforms import functional as F + +from transformers import DetrImageProcessor, TableTransformerConfig, TableTransformerForObjectDetection +from transformers.utils import logging + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + +# here we list all keys to be renamed (original name on the left, our name on the right) +rename_keys = [] +for i in range(6): + # encoder layers: output projection, 2 feedforward neural networks and 2 layernorms + rename_keys.append( + (f"transformer.encoder.layers.{i}.self_attn.out_proj.weight", f"encoder.layers.{i}.self_attn.out_proj.weight") + ) + rename_keys.append( + (f"transformer.encoder.layers.{i}.self_attn.out_proj.bias", f"encoder.layers.{i}.self_attn.out_proj.bias") + ) + rename_keys.append((f"transformer.encoder.layers.{i}.linear1.weight", f"encoder.layers.{i}.fc1.weight")) + rename_keys.append((f"transformer.encoder.layers.{i}.linear1.bias", f"encoder.layers.{i}.fc1.bias")) + rename_keys.append((f"transformer.encoder.layers.{i}.linear2.weight", f"encoder.layers.{i}.fc2.weight")) + rename_keys.append((f"transformer.encoder.layers.{i}.linear2.bias", f"encoder.layers.{i}.fc2.bias")) + rename_keys.append( + (f"transformer.encoder.layers.{i}.norm1.weight", f"encoder.layers.{i}.self_attn_layer_norm.weight") + ) + rename_keys.append((f"transformer.encoder.layers.{i}.norm1.bias", f"encoder.layers.{i}.self_attn_layer_norm.bias")) + rename_keys.append((f"transformer.encoder.layers.{i}.norm2.weight", f"encoder.layers.{i}.final_layer_norm.weight")) + rename_keys.append((f"transformer.encoder.layers.{i}.norm2.bias", f"encoder.layers.{i}.final_layer_norm.bias")) + # decoder layers: 2 times output projection, 2 feedforward neural networks and 3 layernorms + rename_keys.append( + (f"transformer.decoder.layers.{i}.self_attn.out_proj.weight", f"decoder.layers.{i}.self_attn.out_proj.weight") + ) + rename_keys.append( + (f"transformer.decoder.layers.{i}.self_attn.out_proj.bias", f"decoder.layers.{i}.self_attn.out_proj.bias") + ) + rename_keys.append( + ( + f"transformer.decoder.layers.{i}.multihead_attn.out_proj.weight", + f"decoder.layers.{i}.encoder_attn.out_proj.weight", + ) + ) + rename_keys.append( + ( + f"transformer.decoder.layers.{i}.multihead_attn.out_proj.bias", + f"decoder.layers.{i}.encoder_attn.out_proj.bias", + ) + ) + rename_keys.append((f"transformer.decoder.layers.{i}.linear1.weight", f"decoder.layers.{i}.fc1.weight")) + rename_keys.append((f"transformer.decoder.layers.{i}.linear1.bias", f"decoder.layers.{i}.fc1.bias")) + rename_keys.append((f"transformer.decoder.layers.{i}.linear2.weight", f"decoder.layers.{i}.fc2.weight")) + rename_keys.append((f"transformer.decoder.layers.{i}.linear2.bias", f"decoder.layers.{i}.fc2.bias")) + rename_keys.append( + (f"transformer.decoder.layers.{i}.norm1.weight", f"decoder.layers.{i}.self_attn_layer_norm.weight") + ) + rename_keys.append((f"transformer.decoder.layers.{i}.norm1.bias", f"decoder.layers.{i}.self_attn_layer_norm.bias")) + rename_keys.append( + (f"transformer.decoder.layers.{i}.norm2.weight", f"decoder.layers.{i}.encoder_attn_layer_norm.weight") + ) + rename_keys.append( + (f"transformer.decoder.layers.{i}.norm2.bias", f"decoder.layers.{i}.encoder_attn_layer_norm.bias") + ) + rename_keys.append((f"transformer.decoder.layers.{i}.norm3.weight", f"decoder.layers.{i}.final_layer_norm.weight")) + rename_keys.append((f"transformer.decoder.layers.{i}.norm3.bias", f"decoder.layers.{i}.final_layer_norm.bias")) + +# convolutional projection + query embeddings + layernorm of encoder + layernorm of decoder + class and bounding box heads +rename_keys.extend( + [ + ("input_proj.weight", "input_projection.weight"), + ("input_proj.bias", "input_projection.bias"), + ("query_embed.weight", "query_position_embeddings.weight"), + ("transformer.encoder.norm.weight", "encoder.layernorm.weight"), + ("transformer.encoder.norm.bias", "encoder.layernorm.bias"), + ("transformer.decoder.norm.weight", "decoder.layernorm.weight"), + ("transformer.decoder.norm.bias", "decoder.layernorm.bias"), + ("class_embed.weight", "class_labels_classifier.weight"), + ("class_embed.bias", "class_labels_classifier.bias"), + ("bbox_embed.layers.0.weight", "bbox_predictor.layers.0.weight"), + ("bbox_embed.layers.0.bias", "bbox_predictor.layers.0.bias"), + ("bbox_embed.layers.1.weight", "bbox_predictor.layers.1.weight"), + ("bbox_embed.layers.1.bias", "bbox_predictor.layers.1.bias"), + ("bbox_embed.layers.2.weight", "bbox_predictor.layers.2.weight"), + ("bbox_embed.layers.2.bias", "bbox_predictor.layers.2.bias"), + ] +) + + +def rename_key(state_dict, old, new): + val = state_dict.pop(old) + state_dict[new] = val + + +def rename_backbone_keys(state_dict): + new_state_dict = OrderedDict() + for key, value in state_dict.items(): + if "backbone.0.body" in key: + new_key = key.replace("backbone.0.body", "backbone.conv_encoder.model") + new_state_dict[new_key] = value + else: + new_state_dict[key] = value + + return new_state_dict + + +def read_in_q_k_v(state_dict): + prefix = "" + + # first: transformer encoder + for i in range(6): + # read in weights + bias of input projection layer (in PyTorch's MultiHeadAttention, this is a single matrix + bias) + in_proj_weight = state_dict.pop(f"{prefix}transformer.encoder.layers.{i}.self_attn.in_proj_weight") + in_proj_bias = state_dict.pop(f"{prefix}transformer.encoder.layers.{i}.self_attn.in_proj_bias") + # next, add query, keys and values (in that order) to the state dict + state_dict[f"encoder.layers.{i}.self_attn.q_proj.weight"] = in_proj_weight[:256, :] + state_dict[f"encoder.layers.{i}.self_attn.q_proj.bias"] = in_proj_bias[:256] + state_dict[f"encoder.layers.{i}.self_attn.k_proj.weight"] = in_proj_weight[256:512, :] + state_dict[f"encoder.layers.{i}.self_attn.k_proj.bias"] = in_proj_bias[256:512] + state_dict[f"encoder.layers.{i}.self_attn.v_proj.weight"] = in_proj_weight[-256:, :] + state_dict[f"encoder.layers.{i}.self_attn.v_proj.bias"] = in_proj_bias[-256:] + # next: transformer decoder (which is a bit more complex because it also includes cross-attention) + for i in range(6): + # read in weights + bias of input projection layer of self-attention + in_proj_weight = state_dict.pop(f"{prefix}transformer.decoder.layers.{i}.self_attn.in_proj_weight") + in_proj_bias = state_dict.pop(f"{prefix}transformer.decoder.layers.{i}.self_attn.in_proj_bias") + # next, add query, keys and values (in that order) to the state dict + state_dict[f"decoder.layers.{i}.self_attn.q_proj.weight"] = in_proj_weight[:256, :] + state_dict[f"decoder.layers.{i}.self_attn.q_proj.bias"] = in_proj_bias[:256] + state_dict[f"decoder.layers.{i}.self_attn.k_proj.weight"] = in_proj_weight[256:512, :] + state_dict[f"decoder.layers.{i}.self_attn.k_proj.bias"] = in_proj_bias[256:512] + state_dict[f"decoder.layers.{i}.self_attn.v_proj.weight"] = in_proj_weight[-256:, :] + state_dict[f"decoder.layers.{i}.self_attn.v_proj.bias"] = in_proj_bias[-256:] + # read in weights + bias of input projection layer of cross-attention + in_proj_weight_cross_attn = state_dict.pop( + f"{prefix}transformer.decoder.layers.{i}.multihead_attn.in_proj_weight" + ) + in_proj_bias_cross_attn = state_dict.pop(f"{prefix}transformer.decoder.layers.{i}.multihead_attn.in_proj_bias") + # next, add query, keys and values (in that order) of cross-attention to the state dict + state_dict[f"decoder.layers.{i}.encoder_attn.q_proj.weight"] = in_proj_weight_cross_attn[:256, :] + state_dict[f"decoder.layers.{i}.encoder_attn.q_proj.bias"] = in_proj_bias_cross_attn[:256] + state_dict[f"decoder.layers.{i}.encoder_attn.k_proj.weight"] = in_proj_weight_cross_attn[256:512, :] + state_dict[f"decoder.layers.{i}.encoder_attn.k_proj.bias"] = in_proj_bias_cross_attn[256:512] + state_dict[f"decoder.layers.{i}.encoder_attn.v_proj.weight"] = in_proj_weight_cross_attn[-256:, :] + state_dict[f"decoder.layers.{i}.encoder_attn.v_proj.bias"] = in_proj_bias_cross_attn[-256:] + + +def resize(image, checkpoint_url): + width, height = image.size + current_max_size = max(width, height) + target_max_size = 800 if "detection" in checkpoint_url else 1000 + scale = target_max_size / current_max_size + resized_image = image.resize((int(round(scale * width)), int(round(scale * height)))) + + return resized_image + + +def normalize(image): + image = F.to_tensor(image) + image = F.normalize(image, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + return image + + +@torch.no_grad() +def convert_table_transformer_checkpoint(checkpoint_url, pytorch_dump_folder_path, push_to_hub): + """ + Copy/paste/tweak model's weights to our DETR structure. + """ + + logger.info("Converting model...") + + # load original state dict + state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu") + # rename keys + for src, dest in rename_keys: + rename_key(state_dict, src, dest) + state_dict = rename_backbone_keys(state_dict) + # query, key and value matrices need special treatment + read_in_q_k_v(state_dict) + # important: we need to prepend a prefix to each of the base model keys as the head models use different attributes for them + prefix = "model." + for key in state_dict.copy().keys(): + if not key.startswith("class_labels_classifier") and not key.startswith("bbox_predictor"): + val = state_dict.pop(key) + state_dict[prefix + key] = val + # create HuggingFace model and load state dict + config = TableTransformerConfig( + backbone="resnet18", + mask_loss_coefficient=1, + dice_loss_coefficient=1, + ce_loss_coefficient=1, + bbox_loss_coefficient=5, + giou_loss_coefficient=2, + eos_coefficient=0.4, + class_cost=1, + bbox_cost=5, + giou_cost=2, + ) + + if "detection" in checkpoint_url: + config.num_queries = 15 + config.num_labels = 2 + id2label = {0: "table", 1: "table rotated"} + config.id2label = id2label + config.label2id = {v: k for k, v in id2label.items()} + else: + config.num_queries = 125 + config.num_labels = 6 + id2label = { + 0: "table", + 1: "table column", + 2: "table row", + 3: "table column header", + 4: "table projected row header", + 5: "table spanning cell", + } + config.id2label = id2label + config.label2id = {v: k for k, v in id2label.items()} + + image_processor = DetrImageProcessor( + format="coco_detection", max_size=800 if "detection" in checkpoint_url else 1000 + ) + model = TableTransformerForObjectDetection(config) + model.load_state_dict(state_dict) + model.eval() + + # verify our conversion + filename = "example_pdf.png" if "detection" in checkpoint_url else "example_table.png" + file_path = hf_hub_download(repo_id="nielsr/example-pdf", repo_type="dataset", filename=filename) + image = Image.open(file_path).convert("RGB") + pixel_values = normalize(resize(image, checkpoint_url)).unsqueeze(0) + + outputs = model(pixel_values) + + if "detection" in checkpoint_url: + expected_shape = (1, 15, 3) + expected_logits = torch.tensor( + [[-6.7897, -16.9985, 6.7937], [-8.0186, -22.2192, 6.9677], [-7.3117, -21.0708, 7.4055]] + ) + expected_boxes = torch.tensor([[0.4867, 0.1767, 0.6732], [0.6718, 0.4479, 0.3830], [0.4716, 0.1760, 0.6364]]) + + else: + expected_shape = (1, 125, 7) + expected_logits = torch.tensor( + [[-18.1430, -8.3214, 4.8274], [-18.4685, -7.1361, -4.2667], [-26.3693, -9.3429, -4.9962]] + ) + expected_boxes = torch.tensor([[0.4983, 0.5595, 0.9440], [0.4916, 0.6315, 0.5954], [0.6108, 0.8637, 0.1135]]) + + assert outputs.logits.shape == expected_shape + assert torch.allclose(outputs.logits[0, :3, :3], expected_logits, atol=1e-4) + assert torch.allclose(outputs.pred_boxes[0, :3, :3], expected_boxes, atol=1e-4) + print("Looks ok!") + + if pytorch_dump_folder_path is not None: + # Save model and image processor + logger.info(f"Saving PyTorch model and image processor to {pytorch_dump_folder_path}...") + Path(pytorch_dump_folder_path).mkdir(exist_ok=True) + model.save_pretrained(pytorch_dump_folder_path) + image_processor.save_pretrained(pytorch_dump_folder_path) + + if push_to_hub: + # Push model to HF hub + logger.info("Pushing model to the hub...") + model_name = ( + "microsoft/table-transformer-detection" + if "detection" in checkpoint_url + else "microsoft/table-transformer-structure-recognition" + ) + model.push_to_hub(model_name) + image_processor.push_to_hub(model_name) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--checkpoint_url", + default="https://pubtables1m.blob.core.windows.net/model/pubtables1m_detection_detr_r18.pth", + type=str, + choices=[ + "https://pubtables1m.blob.core.windows.net/model/pubtables1m_detection_detr_r18.pth", + "https://pubtables1m.blob.core.windows.net/model/pubtables1m_structure_detr_r18.pth", + ], + help="URL of the Table Transformer checkpoint you'd like to convert.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", default=None, type=str, help="Path to the folder to output PyTorch model." + ) + parser.add_argument( + "--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub." + ) + args = parser.parse_args() + convert_table_transformer_checkpoint(args.checkpoint_url, args.pytorch_dump_folder_path, args.push_to_hub) diff --git a/docs/transformers/src/transformers/models/table_transformer/convert_table_transformer_to_hf_no_timm.py b/docs/transformers/src/transformers/models/table_transformer/convert_table_transformer_to_hf_no_timm.py new file mode 100644 index 0000000000000000000000000000000000000000..1073d48877431006c06ce16a18a19d905d2aa30a --- /dev/null +++ b/docs/transformers/src/transformers/models/table_transformer/convert_table_transformer_to_hf_no_timm.py @@ -0,0 +1,434 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert Table Transformer checkpoints with native (Transformers) backbone. + +URL: https://github.com/microsoft/table-transformer +""" + +import argparse +from pathlib import Path + +import torch +from huggingface_hub import hf_hub_download +from PIL import Image +from torchvision.transforms import functional as F + +from transformers import DetrImageProcessor, ResNetConfig, TableTransformerConfig, TableTransformerForObjectDetection +from transformers.utils import logging + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + + +def create_rename_keys(config): + # here we list all keys to be renamed (original name on the left, our name on the right) + rename_keys = [] + + # stem + # fmt: off + rename_keys.append(("backbone.0.body.conv1.weight", "backbone.conv_encoder.model.embedder.embedder.convolution.weight")) + rename_keys.append(("backbone.0.body.bn1.weight", "backbone.conv_encoder.model.embedder.embedder.normalization.weight")) + rename_keys.append(("backbone.0.body.bn1.bias", "backbone.conv_encoder.model.embedder.embedder.normalization.bias")) + rename_keys.append(("backbone.0.body.bn1.running_mean", "backbone.conv_encoder.model.embedder.embedder.normalization.running_mean")) + rename_keys.append(("backbone.0.body.bn1.running_var", "backbone.conv_encoder.model.embedder.embedder.normalization.running_var")) + # stages + for stage_idx in range(len(config.backbone_config.depths)): + for layer_idx in range(config.backbone_config.depths[stage_idx]): + rename_keys.append( + ( + f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.conv1.weight", + f"backbone.conv_encoder.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.0.convolution.weight", + ) + ) + rename_keys.append( + ( + f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.bn1.weight", + f"backbone.conv_encoder.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.0.normalization.weight", + ) + ) + rename_keys.append( + ( + f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.bn1.bias", + f"backbone.conv_encoder.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.0.normalization.bias", + ) + ) + rename_keys.append( + ( + f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.bn1.running_mean", + f"backbone.conv_encoder.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.0.normalization.running_mean", + ) + ) + rename_keys.append( + ( + f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.bn1.running_var", + f"backbone.conv_encoder.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.0.normalization.running_var", + ) + ) + rename_keys.append( + ( + f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.conv2.weight", + f"backbone.conv_encoder.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.1.convolution.weight", + ) + ) + rename_keys.append( + ( + f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.bn2.weight", + f"backbone.conv_encoder.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.1.normalization.weight", + ) + ) + rename_keys.append( + ( + f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.bn2.bias", + f"backbone.conv_encoder.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.1.normalization.bias", + ) + ) + rename_keys.append( + ( + f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.bn2.running_mean", + f"backbone.conv_encoder.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.1.normalization.running_mean", + ) + ) + rename_keys.append( + ( + f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.bn2.running_var", + f"backbone.conv_encoder.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.1.normalization.running_var", + ) + ) + # all ResNet stages except the first one have a downsample as first layer + if stage_idx != 0 and layer_idx == 0: + rename_keys.append( + ( + f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.downsample.0.weight", + f"backbone.conv_encoder.model.encoder.stages.{stage_idx}.layers.{layer_idx}.shortcut.convolution.weight", + ) + ) + rename_keys.append( + ( + f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.downsample.1.weight", + f"backbone.conv_encoder.model.encoder.stages.{stage_idx}.layers.{layer_idx}.shortcut.normalization.weight", + ) + ) + rename_keys.append( + ( + f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.downsample.1.bias", + f"backbone.conv_encoder.model.encoder.stages.{stage_idx}.layers.{layer_idx}.shortcut.normalization.bias", + ) + ) + rename_keys.append( + ( + f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.downsample.1.running_mean", + f"backbone.conv_encoder.model.encoder.stages.{stage_idx}.layers.{layer_idx}.shortcut.normalization.running_mean", + ) + ) + rename_keys.append( + ( + # "backbone.conv_encoder.model.encoder.stages.3.layers.0.shortcut.normalization.running_var" + f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.downsample.1.running_var", + f"backbone.conv_encoder.model.encoder.stages.{stage_idx}.layers.{layer_idx}.shortcut.normalization.running_var", + ) + ) + # fmt: on + + for i in range(config.encoder_layers): + # encoder layers: output projection, 2 feedforward neural networks and 2 layernorms + rename_keys.append( + ( + f"transformer.encoder.layers.{i}.self_attn.out_proj.weight", + f"encoder.layers.{i}.self_attn.out_proj.weight", + ) + ) + rename_keys.append( + (f"transformer.encoder.layers.{i}.self_attn.out_proj.bias", f"encoder.layers.{i}.self_attn.out_proj.bias") + ) + rename_keys.append((f"transformer.encoder.layers.{i}.linear1.weight", f"encoder.layers.{i}.fc1.weight")) + rename_keys.append((f"transformer.encoder.layers.{i}.linear1.bias", f"encoder.layers.{i}.fc1.bias")) + rename_keys.append((f"transformer.encoder.layers.{i}.linear2.weight", f"encoder.layers.{i}.fc2.weight")) + rename_keys.append((f"transformer.encoder.layers.{i}.linear2.bias", f"encoder.layers.{i}.fc2.bias")) + rename_keys.append( + (f"transformer.encoder.layers.{i}.norm1.weight", f"encoder.layers.{i}.self_attn_layer_norm.weight") + ) + rename_keys.append( + (f"transformer.encoder.layers.{i}.norm1.bias", f"encoder.layers.{i}.self_attn_layer_norm.bias") + ) + rename_keys.append( + (f"transformer.encoder.layers.{i}.norm2.weight", f"encoder.layers.{i}.final_layer_norm.weight") + ) + rename_keys.append((f"transformer.encoder.layers.{i}.norm2.bias", f"encoder.layers.{i}.final_layer_norm.bias")) + # decoder layers: 2 times output projection, 2 feedforward neural networks and 3 layernorms + rename_keys.append( + ( + f"transformer.decoder.layers.{i}.self_attn.out_proj.weight", + f"decoder.layers.{i}.self_attn.out_proj.weight", + ) + ) + rename_keys.append( + (f"transformer.decoder.layers.{i}.self_attn.out_proj.bias", f"decoder.layers.{i}.self_attn.out_proj.bias") + ) + rename_keys.append( + ( + f"transformer.decoder.layers.{i}.multihead_attn.out_proj.weight", + f"decoder.layers.{i}.encoder_attn.out_proj.weight", + ) + ) + rename_keys.append( + ( + f"transformer.decoder.layers.{i}.multihead_attn.out_proj.bias", + f"decoder.layers.{i}.encoder_attn.out_proj.bias", + ) + ) + rename_keys.append((f"transformer.decoder.layers.{i}.linear1.weight", f"decoder.layers.{i}.fc1.weight")) + rename_keys.append((f"transformer.decoder.layers.{i}.linear1.bias", f"decoder.layers.{i}.fc1.bias")) + rename_keys.append((f"transformer.decoder.layers.{i}.linear2.weight", f"decoder.layers.{i}.fc2.weight")) + rename_keys.append((f"transformer.decoder.layers.{i}.linear2.bias", f"decoder.layers.{i}.fc2.bias")) + rename_keys.append( + (f"transformer.decoder.layers.{i}.norm1.weight", f"decoder.layers.{i}.self_attn_layer_norm.weight") + ) + rename_keys.append( + (f"transformer.decoder.layers.{i}.norm1.bias", f"decoder.layers.{i}.self_attn_layer_norm.bias") + ) + rename_keys.append( + (f"transformer.decoder.layers.{i}.norm2.weight", f"decoder.layers.{i}.encoder_attn_layer_norm.weight") + ) + rename_keys.append( + (f"transformer.decoder.layers.{i}.norm2.bias", f"decoder.layers.{i}.encoder_attn_layer_norm.bias") + ) + rename_keys.append( + (f"transformer.decoder.layers.{i}.norm3.weight", f"decoder.layers.{i}.final_layer_norm.weight") + ) + rename_keys.append((f"transformer.decoder.layers.{i}.norm3.bias", f"decoder.layers.{i}.final_layer_norm.bias")) + + # convolutional projection + query embeddings + layernorm of decoder + class and bounding box heads + rename_keys.extend( + [ + ("input_proj.weight", "input_projection.weight"), + ("input_proj.bias", "input_projection.bias"), + ("query_embed.weight", "query_position_embeddings.weight"), + ("transformer.decoder.norm.weight", "decoder.layernorm.weight"), + ("transformer.decoder.norm.bias", "decoder.layernorm.bias"), + ("class_embed.weight", "class_labels_classifier.weight"), + ("class_embed.bias", "class_labels_classifier.bias"), + ("bbox_embed.layers.0.weight", "bbox_predictor.layers.0.weight"), + ("bbox_embed.layers.0.bias", "bbox_predictor.layers.0.bias"), + ("bbox_embed.layers.1.weight", "bbox_predictor.layers.1.weight"), + ("bbox_embed.layers.1.bias", "bbox_predictor.layers.1.bias"), + ("bbox_embed.layers.2.weight", "bbox_predictor.layers.2.weight"), + ("bbox_embed.layers.2.bias", "bbox_predictor.layers.2.bias"), + ("transformer.encoder.norm.weight", "encoder.layernorm.weight"), + ("transformer.encoder.norm.bias", "encoder.layernorm.bias"), + ] + ) + + return rename_keys + + +def rename_key(state_dict, old, new): + val = state_dict.pop(old) + state_dict[new] = val + + +def read_in_q_k_v(state_dict, is_panoptic=False): + prefix = "" + if is_panoptic: + prefix = "detr." + + # first: transformer encoder + for i in range(6): + # read in weights + bias of input projection layer (in PyTorch's MultiHeadAttention, this is a single matrix + bias) + in_proj_weight = state_dict.pop(f"{prefix}transformer.encoder.layers.{i}.self_attn.in_proj_weight") + in_proj_bias = state_dict.pop(f"{prefix}transformer.encoder.layers.{i}.self_attn.in_proj_bias") + # next, add query, keys and values (in that order) to the state dict + state_dict[f"encoder.layers.{i}.self_attn.q_proj.weight"] = in_proj_weight[:256, :] + state_dict[f"encoder.layers.{i}.self_attn.q_proj.bias"] = in_proj_bias[:256] + state_dict[f"encoder.layers.{i}.self_attn.k_proj.weight"] = in_proj_weight[256:512, :] + state_dict[f"encoder.layers.{i}.self_attn.k_proj.bias"] = in_proj_bias[256:512] + state_dict[f"encoder.layers.{i}.self_attn.v_proj.weight"] = in_proj_weight[-256:, :] + state_dict[f"encoder.layers.{i}.self_attn.v_proj.bias"] = in_proj_bias[-256:] + # next: transformer decoder (which is a bit more complex because it also includes cross-attention) + for i in range(6): + # read in weights + bias of input projection layer of self-attention + in_proj_weight = state_dict.pop(f"{prefix}transformer.decoder.layers.{i}.self_attn.in_proj_weight") + in_proj_bias = state_dict.pop(f"{prefix}transformer.decoder.layers.{i}.self_attn.in_proj_bias") + # next, add query, keys and values (in that order) to the state dict + state_dict[f"decoder.layers.{i}.self_attn.q_proj.weight"] = in_proj_weight[:256, :] + state_dict[f"decoder.layers.{i}.self_attn.q_proj.bias"] = in_proj_bias[:256] + state_dict[f"decoder.layers.{i}.self_attn.k_proj.weight"] = in_proj_weight[256:512, :] + state_dict[f"decoder.layers.{i}.self_attn.k_proj.bias"] = in_proj_bias[256:512] + state_dict[f"decoder.layers.{i}.self_attn.v_proj.weight"] = in_proj_weight[-256:, :] + state_dict[f"decoder.layers.{i}.self_attn.v_proj.bias"] = in_proj_bias[-256:] + # read in weights + bias of input projection layer of cross-attention + in_proj_weight_cross_attn = state_dict.pop( + f"{prefix}transformer.decoder.layers.{i}.multihead_attn.in_proj_weight" + ) + in_proj_bias_cross_attn = state_dict.pop(f"{prefix}transformer.decoder.layers.{i}.multihead_attn.in_proj_bias") + # next, add query, keys and values (in that order) of cross-attention to the state dict + state_dict[f"decoder.layers.{i}.encoder_attn.q_proj.weight"] = in_proj_weight_cross_attn[:256, :] + state_dict[f"decoder.layers.{i}.encoder_attn.q_proj.bias"] = in_proj_bias_cross_attn[:256] + state_dict[f"decoder.layers.{i}.encoder_attn.k_proj.weight"] = in_proj_weight_cross_attn[256:512, :] + state_dict[f"decoder.layers.{i}.encoder_attn.k_proj.bias"] = in_proj_bias_cross_attn[256:512] + state_dict[f"decoder.layers.{i}.encoder_attn.v_proj.weight"] = in_proj_weight_cross_attn[-256:, :] + state_dict[f"decoder.layers.{i}.encoder_attn.v_proj.bias"] = in_proj_bias_cross_attn[-256:] + + +def resize(image, checkpoint_url): + width, height = image.size + current_max_size = max(width, height) + target_max_size = 800 if "detection" in checkpoint_url else 1000 + scale = target_max_size / current_max_size + resized_image = image.resize((int(round(scale * width)), int(round(scale * height)))) + + return resized_image + + +def normalize(image): + image = F.to_tensor(image) + image = F.normalize(image, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + return image + + +@torch.no_grad() +def convert_table_transformer_checkpoint(checkpoint_url, pytorch_dump_folder_path, push_to_hub): + """ + Copy/paste/tweak model's weights to our DETR structure. + """ + + logger.info("Converting model...") + + # create HuggingFace model and load state dict + backbone_config = ResNetConfig.from_pretrained( + "microsoft/resnet-18", out_features=["stage1", "stage2", "stage3", "stage4"] + ) + + config = TableTransformerConfig( + backbone_config=backbone_config, + use_timm_backbone=False, + mask_loss_coefficient=1, + dice_loss_coefficient=1, + ce_loss_coefficient=1, + bbox_loss_coefficient=5, + giou_loss_coefficient=2, + eos_coefficient=0.4, + class_cost=1, + bbox_cost=5, + giou_cost=2, + ) + + # load original state dict + state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu") + + # rename keys + for src, dest in create_rename_keys(config): + rename_key(state_dict, src, dest) + # query, key and value matrices need special treatment + read_in_q_k_v(state_dict) + # important: we need to prepend a prefix to each of the base model keys as the head models use different attributes for them + prefix = "model." + for key in state_dict.copy().keys(): + if not key.startswith("class_labels_classifier") and not key.startswith("bbox_predictor"): + val = state_dict.pop(key) + state_dict[prefix + key] = val + + if "detection" in checkpoint_url: + config.num_queries = 15 + config.num_labels = 2 + id2label = {0: "table", 1: "table rotated"} + config.id2label = id2label + config.label2id = {v: k for k, v in id2label.items()} + else: + config.num_queries = 125 + config.num_labels = 6 + id2label = { + 0: "table", + 1: "table column", + 2: "table row", + 3: "table column header", + 4: "table projected row header", + 5: "table spanning cell", + } + config.id2label = id2label + config.label2id = {v: k for k, v in id2label.items()} + + image_processor = DetrImageProcessor(format="coco_detection", size={"longest_edge": 800}) + model = TableTransformerForObjectDetection(config) + model.load_state_dict(state_dict) + model.eval() + + # verify our conversion + filename = "example_pdf.png" if "detection" in checkpoint_url else "example_table.png" + file_path = hf_hub_download(repo_id="nielsr/example-pdf", repo_type="dataset", filename=filename) + image = Image.open(file_path).convert("RGB") + pixel_values = normalize(resize(image, checkpoint_url)).unsqueeze(0) + + outputs = model(pixel_values) + + if "detection" in checkpoint_url: + expected_shape = (1, 15, 3) + expected_logits = torch.tensor( + [[-6.7897, -16.9985, 6.7937], [-8.0186, -22.2192, 6.9677], [-7.3117, -21.0708, 7.4055]] + ) + expected_boxes = torch.tensor([[0.4867, 0.1767, 0.6732], [0.6718, 0.4479, 0.3830], [0.4716, 0.1760, 0.6364]]) + + else: + expected_shape = (1, 125, 7) + expected_logits = torch.tensor( + [[-18.1430, -8.3214, 4.8274], [-18.4685, -7.1361, -4.2667], [-26.3693, -9.3429, -4.9962]] + ) + expected_boxes = torch.tensor([[0.4983, 0.5595, 0.9440], [0.4916, 0.6315, 0.5954], [0.6108, 0.8637, 0.1135]]) + + assert outputs.logits.shape == expected_shape + assert torch.allclose(outputs.logits[0, :3, :3], expected_logits, atol=1e-4) + assert torch.allclose(outputs.pred_boxes[0, :3, :3], expected_boxes, atol=1e-4) + print("Looks ok!") + + if pytorch_dump_folder_path is not None: + # Save model and image processor + logger.info(f"Saving PyTorch model and image processor to {pytorch_dump_folder_path}...") + Path(pytorch_dump_folder_path).mkdir(exist_ok=True) + model.save_pretrained(pytorch_dump_folder_path) + image_processor.save_pretrained(pytorch_dump_folder_path) + + if push_to_hub: + # Push model to HF hub + logger.info("Pushing model to the hub...") + model_name = ( + "microsoft/table-transformer-detection" + if "detection" in checkpoint_url + else "microsoft/table-transformer-structure-recognition" + ) + model.push_to_hub(model_name, revision="no_timm") + image_processor.push_to_hub(model_name, revision="no_timm") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--checkpoint_url", + default="https://pubtables1m.blob.core.windows.net/model/pubtables1m_detection_detr_r18.pth", + type=str, + choices=[ + "https://pubtables1m.blob.core.windows.net/model/pubtables1m_detection_detr_r18.pth", + "https://pubtables1m.blob.core.windows.net/model/pubtables1m_structure_detr_r18.pth", + ], + help="URL of the Table Transformer checkpoint you'd like to convert.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", default=None, type=str, help="Path to the folder to output PyTorch model." + ) + parser.add_argument( + "--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub." + ) + args = parser.parse_args() + convert_table_transformer_checkpoint(args.checkpoint_url, args.pytorch_dump_folder_path, args.push_to_hub) diff --git a/docs/transformers/src/transformers/models/table_transformer/modeling_table_transformer.py b/docs/transformers/src/transformers/models/table_transformer/modeling_table_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..25246f3382014f993554e96db988273ce22650f6 --- /dev/null +++ b/docs/transformers/src/transformers/models/table_transformer/modeling_table_transformer.py @@ -0,0 +1,1435 @@ +# coding=utf-8 +# Copyright 2022 Microsoft Research and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Table Transformer model.""" + +import math +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple, Union + +import torch +from torch import Tensor, nn + +from ...activations import ACT2FN +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput +from ...modeling_utils import PreTrainedModel +from ...utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_timm_available, + logging, + replace_return_docstrings, + requires_backends, +) +from ...utils.backbone_utils import load_backbone +from .configuration_table_transformer import TableTransformerConfig + + +if is_timm_available(): + from timm import create_model + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "TableTransformerConfig" +_CHECKPOINT_FOR_DOC = "microsoft/table-transformer-detection" + + +@dataclass +# Copied from transformers.models.detr.modeling_detr.DetrDecoderOutput with DETR->TABLE_TRANSFORMER,Detr->TableTransformer +class TableTransformerDecoderOutput(BaseModelOutputWithCrossAttentions): + """ + Base class for outputs of the TABLE_TRANSFORMER decoder. This class adds one attribute to BaseModelOutputWithCrossAttentions, + namely an optional stack of intermediate decoder activations, i.e. the output of each decoder layer, each of them + gone through a layernorm. This is useful when training the model with auxiliary decoding losses. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer + plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in + the self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax, + used to compute the weighted average in the cross-attention heads. + intermediate_hidden_states (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, hidden_size)`, *optional*, returned when `config.auxiliary_loss=True`): + Intermediate decoder activations, i.e. the output of each decoder layer, each of them gone through a + layernorm. + """ + + intermediate_hidden_states: Optional[torch.FloatTensor] = None + + +@dataclass +# Copied from transformers.models.detr.modeling_detr.DetrModelOutput with DETR->TABLE_TRANSFORMER,Detr->TableTransformer +class TableTransformerModelOutput(Seq2SeqModelOutput): + """ + Base class for outputs of the TABLE_TRANSFORMER encoder-decoder model. This class adds one attribute to Seq2SeqModelOutput, + namely an optional stack of intermediate decoder activations, i.e. the output of each decoder layer, each of them + gone through a layernorm. This is useful when training the model with auxiliary decoding losses. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the decoder of the model. + decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the decoder at the output of each + layer plus the initial embedding outputs. + decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights of the decoder, after the attention softmax, used to compute the + weighted average in the self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax, + used to compute the weighted average in the cross-attention heads. + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the encoder at the output of each + layer plus the initial embedding outputs. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights of the encoder, after the attention softmax, used to compute the + weighted average in the self-attention heads. + intermediate_hidden_states (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, sequence_length, hidden_size)`, *optional*, returned when `config.auxiliary_loss=True`): + Intermediate decoder activations, i.e. the output of each decoder layer, each of them gone through a + layernorm. + """ + + intermediate_hidden_states: Optional[torch.FloatTensor] = None + + +@dataclass +# Copied from transformers.models.detr.modeling_detr.DetrObjectDetectionOutput with Detr->TableTransformer,DetrImageProcessor->DetrImageProcessor +class TableTransformerObjectDetectionOutput(ModelOutput): + """ + Output type of [`TableTransformerForObjectDetection`]. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)): + Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a + bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized + scale-invariant IoU loss. + loss_dict (`Dict`, *optional*): + A dictionary containing the individual losses. Useful for logging. + logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes + 1)`): + Classification logits (including no-object) for all queries. + pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): + Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These + values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding + possible padding). You can use [`~TableTransformerImageProcessor.post_process_object_detection`] to retrieve the + unnormalized bounding boxes. + auxiliary_outputs (`list[Dict]`, *optional*): + Optional, only returned when auxilary losses are activated (i.e. `config.auxiliary_loss` is set to `True`) + and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and + `pred_boxes`) for each decoder layer. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the decoder of the model. + decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the decoder at the output of each + layer plus the initial embedding outputs. + decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights of the decoder, after the attention softmax, used to compute the + weighted average in the self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax, + used to compute the weighted average in the cross-attention heads. + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the encoder at the output of each + layer plus the initial embedding outputs. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights of the encoder, after the attention softmax, used to compute the + weighted average in the self-attention heads. + """ + + loss: Optional[torch.FloatTensor] = None + loss_dict: Optional[Dict] = None + logits: Optional[torch.FloatTensor] = None + pred_boxes: Optional[torch.FloatTensor] = None + auxiliary_outputs: Optional[List[Dict]] = None + last_hidden_state: Optional[torch.FloatTensor] = None + decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + + +# Copied from transformers.models.detr.modeling_detr.DetrFrozenBatchNorm2d with Detr->TableTransformer +class TableTransformerFrozenBatchNorm2d(nn.Module): + """ + BatchNorm2d where the batch statistics and the affine parameters are fixed. + + Copy-paste from torchvision.misc.ops with added eps before rqsrt, without which any other models than + torchvision.models.resnet[18,34,50,101] produce nans. + """ + + def __init__(self, n): + super().__init__() + self.register_buffer("weight", torch.ones(n)) + self.register_buffer("bias", torch.zeros(n)) + self.register_buffer("running_mean", torch.zeros(n)) + self.register_buffer("running_var", torch.ones(n)) + + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + num_batches_tracked_key = prefix + "num_batches_tracked" + if num_batches_tracked_key in state_dict: + del state_dict[num_batches_tracked_key] + + super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) + + def forward(self, x): + # move reshapes to the beginning + # to make it user-friendly + weight = self.weight.reshape(1, -1, 1, 1) + bias = self.bias.reshape(1, -1, 1, 1) + running_var = self.running_var.reshape(1, -1, 1, 1) + running_mean = self.running_mean.reshape(1, -1, 1, 1) + epsilon = 1e-5 + scale = weight * (running_var + epsilon).rsqrt() + bias = bias - running_mean * scale + return x * scale + bias + + +# Copied from transformers.models.detr.modeling_detr.replace_batch_norm with Detr->TableTransformer +def replace_batch_norm(model): + r""" + Recursively replace all `torch.nn.BatchNorm2d` with `TableTransformerFrozenBatchNorm2d`. + + Args: + model (torch.nn.Module): + input model + """ + for name, module in model.named_children(): + if isinstance(module, nn.BatchNorm2d): + new_module = TableTransformerFrozenBatchNorm2d(module.num_features) + + if not module.weight.device == torch.device("meta"): + new_module.weight.data.copy_(module.weight) + new_module.bias.data.copy_(module.bias) + new_module.running_mean.data.copy_(module.running_mean) + new_module.running_var.data.copy_(module.running_var) + + model._modules[name] = new_module + + if len(list(module.children())) > 0: + replace_batch_norm(module) + + +# Copied from transformers.models.detr.modeling_detr.DetrConvEncoder with Detr->TableTransformer +class TableTransformerConvEncoder(nn.Module): + """ + Convolutional backbone, using either the AutoBackbone API or one from the timm library. + + nn.BatchNorm2d layers are replaced by TableTransformerFrozenBatchNorm2d as defined above. + + """ + + def __init__(self, config): + super().__init__() + + self.config = config + + # For backwards compatibility we have to use the timm library directly instead of the AutoBackbone API + if config.use_timm_backbone: + # We default to values which were previously hard-coded. This enables configurability from the config + # using backbone arguments, while keeping the default behavior the same. + requires_backends(self, ["timm"]) + kwargs = getattr(config, "backbone_kwargs", {}) + kwargs = {} if kwargs is None else kwargs.copy() + out_indices = kwargs.pop("out_indices", (1, 2, 3, 4)) + num_channels = kwargs.pop("in_chans", config.num_channels) + if config.dilation: + kwargs["output_stride"] = kwargs.get("output_stride", 16) + backbone = create_model( + config.backbone, + pretrained=config.use_pretrained_backbone, + features_only=True, + out_indices=out_indices, + in_chans=num_channels, + **kwargs, + ) + else: + backbone = load_backbone(config) + + # replace batch norm by frozen batch norm + with torch.no_grad(): + replace_batch_norm(backbone) + self.model = backbone + self.intermediate_channel_sizes = ( + self.model.feature_info.channels() if config.use_timm_backbone else self.model.channels + ) + + backbone_model_type = None + if config.backbone is not None: + backbone_model_type = config.backbone + elif config.backbone_config is not None: + backbone_model_type = config.backbone_config.model_type + else: + raise ValueError("Either `backbone` or `backbone_config` should be provided in the config") + + if "resnet" in backbone_model_type: + for name, parameter in self.model.named_parameters(): + if config.use_timm_backbone: + if "layer2" not in name and "layer3" not in name and "layer4" not in name: + parameter.requires_grad_(False) + else: + if "stage.1" not in name and "stage.2" not in name and "stage.3" not in name: + parameter.requires_grad_(False) + + def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor): + # send pixel_values through the model to get list of feature maps + features = self.model(pixel_values) if self.config.use_timm_backbone else self.model(pixel_values).feature_maps + + out = [] + for feature_map in features: + # downsample pixel_mask to match shape of corresponding feature_map + mask = nn.functional.interpolate(pixel_mask[None].float(), size=feature_map.shape[-2:]).to(torch.bool)[0] + out.append((feature_map, mask)) + return out + + +# Copied from transformers.models.detr.modeling_detr.DetrConvModel with Detr->TableTransformer +class TableTransformerConvModel(nn.Module): + """ + This module adds 2D position embeddings to all intermediate feature maps of the convolutional encoder. + """ + + def __init__(self, conv_encoder, position_embedding): + super().__init__() + self.conv_encoder = conv_encoder + self.position_embedding = position_embedding + + def forward(self, pixel_values, pixel_mask): + # send pixel_values and pixel_mask through backbone to get list of (feature_map, pixel_mask) tuples + out = self.conv_encoder(pixel_values, pixel_mask) + pos = [] + for feature_map, mask in out: + # position encoding + pos.append(self.position_embedding(feature_map, mask).to(feature_map.dtype)) + + return out, pos + + +# Copied from transformers.models.detr.modeling_detr.DetrSinePositionEmbedding with Detr->TableTransformer +class TableTransformerSinePositionEmbedding(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one used by the Attention is all you + need paper, generalized to work on images. + """ + + def __init__(self, embedding_dim=64, temperature=10000, normalize=False, scale=None): + super().__init__() + self.embedding_dim = embedding_dim + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, pixel_values, pixel_mask): + if pixel_mask is None: + raise ValueError("No pixel mask provided") + y_embed = pixel_mask.cumsum(1, dtype=torch.float32) + x_embed = pixel_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + y_embed = y_embed / (y_embed[:, -1:, :] + 1e-6) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + 1e-6) * self.scale + + dim_t = torch.arange(self.embedding_dim, dtype=torch.int64, device=pixel_values.device).float() + dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.embedding_dim) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + +# Copied from transformers.models.detr.modeling_detr.DetrLearnedPositionEmbedding with Detr->TableTransformer +class TableTransformerLearnedPositionEmbedding(nn.Module): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, embedding_dim=256): + super().__init__() + self.row_embeddings = nn.Embedding(50, embedding_dim) + self.column_embeddings = nn.Embedding(50, embedding_dim) + + def forward(self, pixel_values, pixel_mask=None): + height, width = pixel_values.shape[-2:] + width_values = torch.arange(width, device=pixel_values.device) + height_values = torch.arange(height, device=pixel_values.device) + x_emb = self.column_embeddings(width_values) + y_emb = self.row_embeddings(height_values) + pos = torch.cat([x_emb.unsqueeze(0).repeat(height, 1, 1), y_emb.unsqueeze(1).repeat(1, width, 1)], dim=-1) + pos = pos.permute(2, 0, 1) + pos = pos.unsqueeze(0) + pos = pos.repeat(pixel_values.shape[0], 1, 1, 1) + return pos + + +# Copied from transformers.models.detr.modeling_detr.build_position_encoding with Detr->TableTransformer +def build_position_encoding(config): + n_steps = config.d_model // 2 + if config.position_embedding_type == "sine": + # TODO find a better way of exposing other arguments + position_embedding = TableTransformerSinePositionEmbedding(n_steps, normalize=True) + elif config.position_embedding_type == "learned": + position_embedding = TableTransformerLearnedPositionEmbedding(n_steps) + else: + raise ValueError(f"Not supported {config.position_embedding_type}") + + return position_embedding + + +# Copied from transformers.models.detr.modeling_detr.DetrAttention with DETR->TABLE_TRANSFORMER,Detr->TableTransformer +class TableTransformerAttention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. + + Here, we add position embeddings to the queries and keys (as explained in the TABLE_TRANSFORMER paper). + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + bias: bool = True, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + if self.head_dim * num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int): + return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def with_pos_embed(self, tensor: torch.Tensor, object_queries: Optional[Tensor]): + return tensor if object_queries is None else tensor + object_queries + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + object_queries: Optional[torch.Tensor] = None, + key_value_states: Optional[torch.Tensor] = None, + spatial_position_embeddings: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + batch_size, target_len, embed_dim = hidden_states.size() + + # add position embeddings to the hidden states before projecting to queries and keys + if object_queries is not None: + hidden_states_original = hidden_states + hidden_states = self.with_pos_embed(hidden_states, object_queries) + + # add key-value position embeddings to the key value states + if spatial_position_embeddings is not None: + key_value_states_original = key_value_states + key_value_states = self.with_pos_embed(key_value_states, spatial_position_embeddings) + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + if is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, batch_size) + value_states = self._shape(self.v_proj(key_value_states_original), -1, batch_size) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, batch_size) + value_states = self._shape(self.v_proj(hidden_states_original), -1, batch_size) + + proj_shape = (batch_size * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, target_len, batch_size).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + source_len = key_states.size(1) + + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (batch_size * self.num_heads, target_len, source_len): + raise ValueError( + f"Attention weights should be of size {(batch_size * self.num_heads, target_len, source_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (batch_size, 1, target_len, source_len): + raise ValueError( + f"Attention mask should be of size {(batch_size, 1, target_len, source_len)}, but is" + f" {attention_mask.size()}" + ) + attn_weights = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attention_mask + attn_weights = attn_weights.view(batch_size * self.num_heads, target_len, source_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attn_weights = attn_weights_reshaped.view(batch_size * self.num_heads, target_len, source_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (batch_size * self.num_heads, target_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(batch_size, self.num_heads, target_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(batch_size, self.num_heads, target_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(batch_size, target_len, embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped + + +class TableTransformerEncoderLayer(nn.Module): + # Copied from transformers.models.detr.modeling_detr.DetrEncoderLayer.__init__ with Detr->TableTransformer + def __init__(self, config: TableTransformerConfig): + super().__init__() + self.embed_dim = config.d_model + self.self_attn = TableTransformerAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + ) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + object_queries: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ): + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative + values. + object_queries (`torch.FloatTensor`, *optional*): object queries, to be added to hidden_states. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + object_queries=object_queries, + output_attentions=output_attentions, + ) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + hidden_states = residual + hidden_states + + if self.training: + if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class TableTransformerDecoderLayer(nn.Module): + # Copied from transformers.models.detr.modeling_detr.DetrDecoderLayer.__init__ with Detr->TableTransformer + def __init__(self, config: TableTransformerConfig): + super().__init__() + self.embed_dim = config.d_model + + self.self_attn = TableTransformerAttention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + ) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.encoder_attn = TableTransformerAttention( + self.embed_dim, + config.decoder_attention_heads, + dropout=config.attention_dropout, + ) + self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) + self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + object_queries: Optional[torch.Tensor] = None, + query_position_embeddings: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ): + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative + values. + object_queries (`torch.FloatTensor`, *optional*): + object queries that are added to the queries and keys + in the cross-attention layer. + query_position_embeddings (`torch.FloatTensor`, *optional*): + object queries that are added to the queries and keys + in the self-attention layer. + encoder_hidden_states (`torch.FloatTensor`): + cross attention input to the layer of shape `(batch, seq_len, embed_dim)` + encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size + `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative + values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + object_queries=query_position_embeddings, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + # Cross-Attention Block + cross_attn_weights = None + if encoder_hidden_states is not None: + hidden_states, cross_attn_weights = self.encoder_attn( + hidden_states=hidden_states, + object_queries=query_position_embeddings, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + spatial_position_embeddings=object_queries, + output_attentions=output_attentions, + ) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + # Fully Connected + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + return outputs + + +class TableTransformerPreTrainedModel(PreTrainedModel): + config_class = TableTransformerConfig + base_model_prefix = "model" + main_input_name = "pixel_values" + _no_split_modules = [ + r"TableTransformerConvEncoder", + r"TableTransformerEncoderLayer", + r"TableTransformerDecoderLayer", + ] + + def _init_weights(self, module): + std = self.config.init_std + + if isinstance(module, TableTransformerLearnedPositionEmbedding): + nn.init.uniform_(module.row_embeddings.weight) + nn.init.uniform_(module.column_embeddings.weight) + if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +TABLE_TRANSFORMER_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`TableTransformerConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +TABLE_TRANSFORMER_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. + + Pixel values can be obtained using [`DetrImageProcessor`]. See [`DetrImageProcessor.__call__`] for details. + + pixel_mask (`torch.FloatTensor` of shape `(batch_size, height, width)`, *optional*): + Mask to avoid performing attention on padding pixel values. Mask values selected in `[0, 1]`: + + - 1 for pixels that are real (i.e. **not masked**), + - 0 for pixels that are padding (i.e. **masked**). + + [What are attention masks?](../glossary#attention-mask) + + decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*): + Not used by default. Can be used to mask object queries. + encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): + Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of + hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you + can choose to directly pass a flattened representation of an image. + decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*): + Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an + embedded representation. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class TableTransformerEncoder(TableTransformerPreTrainedModel): + """ + Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a + [`TableTransformerEncoderLayer`]. + + The encoder updates the flattened feature map through multiple self-attention layers. + + Small tweak for Table Transformer: + + - object_queries are added to the forward pass. + + Args: + config: TableTransformerConfig + """ + + def __init__(self, config: TableTransformerConfig): + super().__init__(config) + + self.dropout = config.dropout + self.layerdrop = config.encoder_layerdrop + + self.layers = nn.ModuleList([TableTransformerEncoderLayer(config) for _ in range(config.encoder_layers)]) + + self.layernorm = nn.LayerNorm(config.d_model) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + inputs_embeds=None, + attention_mask=None, + object_queries=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Flattened feature map (output of the backbone + projection layer) that is passed to the encoder. + + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding pixel features. Mask values selected in `[0, 1]`: + + - 1 for pixel features that are real (i.e. **not masked**), + - 0 for pixel features that are padding (i.e. **masked**). + + [What are attention masks?](../glossary#attention-mask) + + object_queries (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Position embeddings that are added to the queries and keys in each self-attention layer. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + hidden_states = inputs_embeds + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + # expand attention_mask + if attention_mask is not None: + # [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + for encoder_layer in self.layers: + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + to_drop = False + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: # skip the layer + to_drop = True + + if to_drop: + layer_outputs = (None, None) + else: + # we add object_queries as extra input to the encoder_layer + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + object_queries=object_queries, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + hidden_states = self.layernorm(hidden_states) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +# Copied from transformers.models.detr.modeling_detr.DetrDecoder with DETR->TABLE_TRANSFORMER,Detr->TableTransformer +class TableTransformerDecoder(TableTransformerPreTrainedModel): + """ + Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`TableTransformerDecoderLayer`]. + + The decoder updates the query embeddings through multiple self-attention and cross-attention layers. + + Some small tweaks for TABLE_TRANSFORMER: + + - object_queries and query_position_embeddings are added to the forward pass. + - if self.config.auxiliary_loss is set to True, also returns a stack of activations from all decoding layers. + + Args: + config: TableTransformerConfig + """ + + def __init__(self, config: TableTransformerConfig): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.decoder_layerdrop + + self.layers = nn.ModuleList([TableTransformerDecoderLayer(config) for _ in range(config.decoder_layers)]) + # in TABLE_TRANSFORMER, the decoder uses layernorm after the last decoder layer output + self.layernorm = nn.LayerNorm(config.d_model) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + inputs_embeds=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + object_queries=None, + query_position_embeddings=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + The query embeddings that are passed into the decoder. + + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on certain queries. Mask values selected in `[0, 1]`: + + - 1 for queries that are **not masked**, + - 0 for queries that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): + Mask to avoid performing cross-attention on padding pixel_values of the encoder. Mask values selected + in `[0, 1]`: + + - 1 for pixels that are real (i.e. **not masked**), + - 0 for pixels that are padding (i.e. **masked**). + + object_queries (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Object queries that are added to the queries and keys in each cross-attention layer. + query_position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`): + , *optional*): Position embeddings that are added to the values and keys in each self-attention layer. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if inputs_embeds is not None: + hidden_states = inputs_embeds + input_shape = inputs_embeds.size()[:-1] + + combined_attention_mask = None + + if attention_mask is not None and combined_attention_mask is not None: + # [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len] + combined_attention_mask = combined_attention_mask + _prepare_4d_attention_mask( + attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + # [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + # optional intermediate hidden states + intermediate = () if self.config.auxiliary_loss else None + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: + continue + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + combined_attention_mask, + encoder_hidden_states, + encoder_attention_mask, + None, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=combined_attention_mask, + object_queries=object_queries, + query_position_embeddings=query_position_embeddings, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if self.config.auxiliary_loss: + hidden_states = self.layernorm(hidden_states) + intermediate += (hidden_states,) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + # finally, apply layernorm + hidden_states = self.layernorm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + # stack intermediate decoder activations + if self.config.auxiliary_loss: + intermediate = torch.stack(intermediate) + + if not return_dict: + return tuple( + v + for v in [hidden_states, all_hidden_states, all_self_attns, all_cross_attentions, intermediate] + if v is not None + ) + return TableTransformerDecoderOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + intermediate_hidden_states=intermediate, + ) + + +@add_start_docstrings( + """ + The bare Table Transformer Model (consisting of a backbone and encoder-decoder Transformer) outputting raw + hidden-states without any specific head on top. + """, + TABLE_TRANSFORMER_START_DOCSTRING, +) +class TableTransformerModel(TableTransformerPreTrainedModel): + # Copied from transformers.models.detr.modeling_detr.DetrModel.__init__ with Detr->TableTransformer + def __init__(self, config: TableTransformerConfig): + super().__init__(config) + + # Create backbone + positional encoding + backbone = TableTransformerConvEncoder(config) + object_queries = build_position_encoding(config) + self.backbone = TableTransformerConvModel(backbone, object_queries) + + # Create projection layer + self.input_projection = nn.Conv2d(backbone.intermediate_channel_sizes[-1], config.d_model, kernel_size=1) + + self.query_position_embeddings = nn.Embedding(config.num_queries, config.d_model) + + self.encoder = TableTransformerEncoder(config) + self.decoder = TableTransformerDecoder(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + def freeze_backbone(self): + for name, param in self.backbone.conv_encoder.model.named_parameters(): + param.requires_grad_(False) + + def unfreeze_backbone(self): + for name, param in self.backbone.conv_encoder.model.named_parameters(): + param.requires_grad_(True) + + @add_start_docstrings_to_model_forward(TABLE_TRANSFORMER_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TableTransformerModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: torch.FloatTensor, + pixel_mask: Optional[torch.FloatTensor] = None, + decoder_attention_mask: Optional[torch.FloatTensor] = None, + encoder_outputs: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.FloatTensor], TableTransformerModelOutput]: + r""" + Returns: + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, TableTransformerModel + >>> from huggingface_hub import hf_hub_download + >>> from PIL import Image + + >>> file_path = hf_hub_download(repo_id="nielsr/example-pdf", repo_type="dataset", filename="example_pdf.png") + >>> image = Image.open(file_path).convert("RGB") + + >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/table-transformer-detection") + >>> model = TableTransformerModel.from_pretrained("microsoft/table-transformer-detection") + + >>> # prepare image for the model + >>> inputs = image_processor(images=image, return_tensors="pt") + + >>> # forward pass + >>> outputs = model(**inputs) + + >>> # the last hidden states are the final query embeddings of the Transformer decoder + >>> # these are of shape (batch_size, num_queries, hidden_size) + >>> last_hidden_states = outputs.last_hidden_state + >>> list(last_hidden_states.shape) + [1, 15, 256] + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size, num_channels, height, width = pixel_values.shape + device = pixel_values.device + + if pixel_mask is None: + pixel_mask = torch.ones(((batch_size, height, width)), device=device) + + # First, sent pixel_values + pixel_mask through Backbone to obtain the features + # pixel_values should be of shape (batch_size, num_channels, height, width) + # pixel_mask should be of shape (batch_size, height, width) + features, position_embeddings_list = self.backbone(pixel_values, pixel_mask) + + # get final feature map and downsampled mask + feature_map, mask = features[-1] + + if mask is None: + raise ValueError("Backbone does not return downsampled pixel mask") + + # Second, apply 1x1 convolution to reduce the channel dimension to d_model (256 by default) + projected_feature_map = self.input_projection(feature_map) + + # Third, flatten the feature map + object queries of shape NxCxHxW to NxCxHW, and permute it to NxHWxC + # In other words, turn their shape into (batch_size, sequence_length, hidden_size) + flattened_features = projected_feature_map.flatten(2).permute(0, 2, 1) + object_queries = position_embeddings_list[-1].flatten(2).permute(0, 2, 1) + + flattened_mask = mask.flatten(1) + + # Fourth, sent flattened_features + flattened_mask + object queries through encoder + # flattened_features is a Tensor of shape (batch_size, heigth*width, hidden_size) + # flattened_mask is a Tensor of shape (batch_size, heigth*width) + if encoder_outputs is None: + encoder_outputs = self.encoder( + inputs_embeds=flattened_features, + attention_mask=flattened_mask, + object_queries=object_queries, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + # Fifth, sent query embeddings + object queries through the decoder (which is conditioned on the encoder output) + query_position_embeddings = self.query_position_embeddings.weight.unsqueeze(0).repeat(batch_size, 1, 1) + queries = torch.zeros_like(query_position_embeddings) + + # decoder outputs consists of (dec_features, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + inputs_embeds=queries, + attention_mask=None, + object_queries=object_queries, + query_position_embeddings=query_position_embeddings, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=flattened_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return TableTransformerModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + intermediate_hidden_states=decoder_outputs.intermediate_hidden_states, + ) + + +@add_start_docstrings( + """ + Table Transformer Model (consisting of a backbone and encoder-decoder Transformer) with object detection heads on + top, for tasks such as COCO detection. + """, + TABLE_TRANSFORMER_START_DOCSTRING, +) +class TableTransformerForObjectDetection(TableTransformerPreTrainedModel): + # Copied from transformers.models.detr.modeling_detr.DetrForObjectDetection.__init__ with Detr->TableTransformer + def __init__(self, config: TableTransformerConfig): + super().__init__(config) + + # DETR encoder-decoder model + self.model = TableTransformerModel(config) + + # Object detection heads + self.class_labels_classifier = nn.Linear( + config.d_model, config.num_labels + 1 + ) # We add one for the "no object" class + self.bbox_predictor = TableTransformerMLPPredictionHead( + input_dim=config.d_model, hidden_dim=config.d_model, output_dim=4, num_layers=3 + ) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(TABLE_TRANSFORMER_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TableTransformerObjectDetectionOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: torch.FloatTensor, + pixel_mask: Optional[torch.FloatTensor] = None, + decoder_attention_mask: Optional[torch.FloatTensor] = None, + encoder_outputs: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[List[Dict]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.FloatTensor], TableTransformerObjectDetectionOutput]: + r""" + labels (`List[Dict]` of len `(batch_size,)`, *optional*): + Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the + following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch + respectively). The class labels themselves should be a `torch.LongTensor` of len `(number of bounding boxes + in the image,)` and the boxes a `torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)`. + + Returns: + + Examples: + + ```python + >>> from huggingface_hub import hf_hub_download + >>> from transformers import AutoImageProcessor, TableTransformerForObjectDetection + >>> import torch + >>> from PIL import Image + + >>> file_path = hf_hub_download(repo_id="nielsr/example-pdf", repo_type="dataset", filename="example_pdf.png") + >>> image = Image.open(file_path).convert("RGB") + + >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/table-transformer-detection") + >>> model = TableTransformerForObjectDetection.from_pretrained("microsoft/table-transformer-detection") + + >>> inputs = image_processor(images=image, return_tensors="pt") + >>> outputs = model(**inputs) + + >>> # convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax) + >>> target_sizes = torch.tensor([image.size[::-1]]) + >>> results = image_processor.post_process_object_detection(outputs, threshold=0.9, target_sizes=target_sizes)[ + ... 0 + ... ] + + >>> for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): + ... box = [round(i, 2) for i in box.tolist()] + ... print( + ... f"Detected {model.config.id2label[label.item()]} with confidence " + ... f"{round(score.item(), 3)} at location {box}" + ... ) + Detected table with confidence 1.0 at location [202.1, 210.59, 1119.22, 385.09] + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # First, sent images through TABLE_TRANSFORMER base model to obtain encoder + decoder outputs + outputs = self.model( + pixel_values, + pixel_mask=pixel_mask, + decoder_attention_mask=decoder_attention_mask, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + # class logits + predicted bounding boxes + logits = self.class_labels_classifier(sequence_output) + pred_boxes = self.bbox_predictor(sequence_output).sigmoid() + + loss, loss_dict, auxiliary_outputs = None, None, None + if labels is not None: + outputs_class, outputs_coord = None, None + if self.config.auxiliary_loss: + intermediate = outputs.intermediate_hidden_states if return_dict else outputs[4] + outputs_class = self.class_labels_classifier(intermediate) + outputs_coord = self.bbox_predictor(intermediate).sigmoid() + loss, loss_dict, auxiliary_outputs = self.loss_function( + logits, labels, self.device, pred_boxes, self.config, outputs_class, outputs_coord + ) + + if not return_dict: + if auxiliary_outputs is not None: + output = (logits, pred_boxes) + auxiliary_outputs + outputs + else: + output = (logits, pred_boxes) + outputs + return ((loss, loss_dict) + output) if loss is not None else output + + return TableTransformerObjectDetectionOutput( + loss=loss, + loss_dict=loss_dict, + logits=logits, + pred_boxes=pred_boxes, + auxiliary_outputs=auxiliary_outputs, + last_hidden_state=outputs.last_hidden_state, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + +# Copied from transformers.models.detr.modeling_detr.DetrMLPPredictionHead with Detr->TableTransformer,detr->table_transformer +class TableTransformerMLPPredictionHead(nn.Module): + """ + Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates, + height and width of a bounding box w.r.t. an image. + + Copied from https://github.com/facebookresearch/table_transformer/blob/master/models/table_transformer.py + + """ + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + + +__all__ = ["TableTransformerForObjectDetection", "TableTransformerModel", "TableTransformerPreTrainedModel"] diff --git a/docs/transformers/src/transformers/models/tapas/__init__.py b/docs/transformers/src/transformers/models/tapas/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7df7e765f60e84f181a7bc6ac3668f35e796b8c7 --- /dev/null +++ b/docs/transformers/src/transformers/models/tapas/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_tapas import * + from .modeling_tapas import * + from .modeling_tf_tapas import * + from .tokenization_tapas import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/tapas/configuration_tapas.py b/docs/transformers/src/transformers/models/tapas/configuration_tapas.py new file mode 100644 index 0000000000000000000000000000000000000000..58769e99a7221822dcd4a8309eb97862d0f6b022 --- /dev/null +++ b/docs/transformers/src/transformers/models/tapas/configuration_tapas.py @@ -0,0 +1,229 @@ +# coding=utf-8 +# Copyright 2020 Google Research and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +TAPAS configuration. Based on the BERT configuration with added parameters. + +Hyperparameters are taken from run_task_main.py and hparam_utils.py of the original implementation. URLS: + +- https://github.com/google-research/tapas/blob/master/tapas/run_task_main.py +- https://github.com/google-research/tapas/blob/master/tapas/utils/hparam_utils.py + +""" + +from ...configuration_utils import PretrainedConfig + + +class TapasConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`TapasModel`]. It is used to instantiate a TAPAS + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the TAPAS + [google/tapas-base-finetuned-sqa](https://huggingface.co/google/tapas-base-finetuned-sqa) architecture. + + Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Hyperparameters additional to BERT are taken from run_task_main.py and hparam_utils.py of the original + implementation. Original implementation available at https://github.com/google-research/tapas/tree/master. + + Args: + vocab_size (`int`, *optional*, defaults to 30522): + Vocabulary size of the TAPAS model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`TapasModel`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"swish"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (`int`, *optional*, defaults to 1024): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + type_vocab_sizes (`List[int]`, *optional*, defaults to `[3, 256, 256, 2, 256, 256, 10]`): + The vocabulary sizes of the `token_type_ids` passed when calling [`TapasModel`]. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + positive_label_weight (`float`, *optional*, defaults to 10.0): + Weight for positive labels. + num_aggregation_labels (`int`, *optional*, defaults to 0): + The number of aggregation operators to predict. + aggregation_loss_weight (`float`, *optional*, defaults to 1.0): + Importance weight for the aggregation loss. + use_answer_as_supervision (`bool`, *optional*): + Whether to use the answer as the only supervision for aggregation examples. + answer_loss_importance (`float`, *optional*, defaults to 1.0): + Importance weight for the regression loss. + use_normalized_answer_loss (`bool`, *optional*, defaults to `False`): + Whether to normalize the answer loss by the maximum of the predicted and expected value. + huber_loss_delta (`float`, *optional*): + Delta parameter used to calculate the regression loss. + temperature (`float`, *optional*, defaults to 1.0): + Value used to control (OR change) the skewness of cell logits probabilities. + aggregation_temperature (`float`, *optional*, defaults to 1.0): + Scales aggregation logits to control the skewness of probabilities. + use_gumbel_for_cells (`bool`, *optional*, defaults to `False`): + Whether to apply Gumbel-Softmax to cell selection. + use_gumbel_for_aggregation (`bool`, *optional*, defaults to `False`): + Whether to apply Gumbel-Softmax to aggregation selection. + average_approximation_function (`string`, *optional*, defaults to `"ratio"`): + Method to calculate the expected average of cells in the weak supervision case. One of `"ratio"`, + `"first_order"` or `"second_order"`. + cell_selection_preference (`float`, *optional*): + Preference for cell selection in ambiguous cases. Only applicable in case of weak supervision for + aggregation (WTQ, WikiSQL). If the total mass of the aggregation probabilities (excluding the "NONE" + operator) is higher than this hyperparameter, then aggregation is predicted for an example. + answer_loss_cutoff (`float`, *optional*): + Ignore examples with answer loss larger than cutoff. + max_num_rows (`int`, *optional*, defaults to 64): + Maximum number of rows. + max_num_columns (`int`, *optional*, defaults to 32): + Maximum number of columns. + average_logits_per_cell (`bool`, *optional*, defaults to `False`): + Whether to average logits per cell. + select_one_column (`bool`, *optional*, defaults to `True`): + Whether to constrain the model to only select cells from a single column. + allow_empty_column_selection (`bool`, *optional*, defaults to `False`): + Whether to allow not to select any column. + init_cell_selection_weights_to_zero (`bool`, *optional*, defaults to `False`): + Whether to initialize cell selection weights to 0 so that the initial probabilities are 50%. + reset_position_index_per_cell (`bool`, *optional*, defaults to `True`): + Whether to restart position indexes at every cell (i.e. use relative position embeddings). + disable_per_token_loss (`bool`, *optional*, defaults to `False`): + Whether to disable any (strong or weak) supervision on cells. + aggregation_labels (`Dict[int, label]`, *optional*): + The aggregation labels used to aggregate the results. For example, the WTQ models have the following + aggregation labels: `{0: "NONE", 1: "SUM", 2: "AVERAGE", 3: "COUNT"}` + no_aggregation_label_index (`int`, *optional*): + If the aggregation labels are defined and one of these labels represents "No aggregation", this should be + set to its index. For example, the WTQ models have the "NONE" aggregation label at index 0, so that value + should be set to 0 for these models. + + + Example: + + ```python + >>> from transformers import TapasModel, TapasConfig + + >>> # Initializing a default (SQA) Tapas configuration + >>> configuration = TapasConfig() + >>> # Initializing a model from the configuration + >>> model = TapasModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "tapas" + + def __init__( + self, + vocab_size=30522, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=1024, + type_vocab_sizes=[3, 256, 256, 2, 256, 256, 10], + initializer_range=0.02, + layer_norm_eps=1e-12, + pad_token_id=0, + positive_label_weight=10.0, + num_aggregation_labels=0, + aggregation_loss_weight=1.0, + use_answer_as_supervision=None, + answer_loss_importance=1.0, + use_normalized_answer_loss=False, + huber_loss_delta=None, + temperature=1.0, + aggregation_temperature=1.0, + use_gumbel_for_cells=False, + use_gumbel_for_aggregation=False, + average_approximation_function="ratio", + cell_selection_preference=None, + answer_loss_cutoff=None, + max_num_rows=64, + max_num_columns=32, + average_logits_per_cell=False, + select_one_column=True, + allow_empty_column_selection=False, + init_cell_selection_weights_to_zero=False, + reset_position_index_per_cell=True, + disable_per_token_loss=False, + aggregation_labels=None, + no_aggregation_label_index=None, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, **kwargs) + + # BERT hyperparameters (with updated max_position_embeddings and type_vocab_sizes) + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_sizes = type_vocab_sizes + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + + # Fine-tuning task hyperparameters + self.positive_label_weight = positive_label_weight + self.num_aggregation_labels = num_aggregation_labels + self.aggregation_loss_weight = aggregation_loss_weight + self.use_answer_as_supervision = use_answer_as_supervision + self.answer_loss_importance = answer_loss_importance + self.use_normalized_answer_loss = use_normalized_answer_loss + self.huber_loss_delta = huber_loss_delta + self.temperature = temperature + self.aggregation_temperature = aggregation_temperature + self.use_gumbel_for_cells = use_gumbel_for_cells + self.use_gumbel_for_aggregation = use_gumbel_for_aggregation + self.average_approximation_function = average_approximation_function + self.cell_selection_preference = cell_selection_preference + self.answer_loss_cutoff = answer_loss_cutoff + self.max_num_rows = max_num_rows + self.max_num_columns = max_num_columns + self.average_logits_per_cell = average_logits_per_cell + self.select_one_column = select_one_column + self.allow_empty_column_selection = allow_empty_column_selection + self.init_cell_selection_weights_to_zero = init_cell_selection_weights_to_zero + self.reset_position_index_per_cell = reset_position_index_per_cell + self.disable_per_token_loss = disable_per_token_loss + + # Aggregation hyperparameters + self.aggregation_labels = aggregation_labels + self.no_aggregation_label_index = no_aggregation_label_index + + if isinstance(self.aggregation_labels, dict): + self.aggregation_labels = {int(k): v for k, v in aggregation_labels.items()} + + +__all__ = ["TapasConfig"] diff --git a/docs/transformers/src/transformers/models/tapas/convert_tapas_original_tf_checkpoint_to_pytorch.py b/docs/transformers/src/transformers/models/tapas/convert_tapas_original_tf_checkpoint_to_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..34bf77cccd6bdd36853fe88488fd864e86167087 --- /dev/null +++ b/docs/transformers/src/transformers/models/tapas/convert_tapas_original_tf_checkpoint_to_pytorch.py @@ -0,0 +1,137 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert TAPAS checkpoint.""" + +import argparse + +from transformers import ( + TapasConfig, + TapasForMaskedLM, + TapasForQuestionAnswering, + TapasForSequenceClassification, + TapasModel, + TapasTokenizer, + load_tf_weights_in_tapas, +) +from transformers.utils import logging + + +logging.set_verbosity_info() + + +def convert_tf_checkpoint_to_pytorch( + task, reset_position_index_per_cell, tf_checkpoint_path, tapas_config_file, pytorch_dump_path +): + # Initialise PyTorch model. + # If you want to convert a checkpoint that uses absolute position embeddings, make sure to set reset_position_index_per_cell of + # TapasConfig to False. + + # initialize configuration from json file + config = TapasConfig.from_json_file(tapas_config_file) + # set absolute/relative position embeddings parameter + config.reset_position_index_per_cell = reset_position_index_per_cell + + # set remaining parameters of TapasConfig as well as the model based on the task + if task == "SQA": + model = TapasForQuestionAnswering(config=config) + elif task == "WTQ": + # run_task_main.py hparams + config.num_aggregation_labels = 4 + config.use_answer_as_supervision = True + # hparam_utils.py hparams + config.answer_loss_cutoff = 0.664694 + config.cell_selection_preference = 0.207951 + config.huber_loss_delta = 0.121194 + config.init_cell_selection_weights_to_zero = True + config.select_one_column = True + config.allow_empty_column_selection = False + config.temperature = 0.0352513 + + model = TapasForQuestionAnswering(config=config) + elif task == "WIKISQL_SUPERVISED": + # run_task_main.py hparams + config.num_aggregation_labels = 4 + config.use_answer_as_supervision = False + # hparam_utils.py hparams + config.answer_loss_cutoff = 36.4519 + config.cell_selection_preference = 0.903421 + config.huber_loss_delta = 222.088 + config.init_cell_selection_weights_to_zero = True + config.select_one_column = True + config.allow_empty_column_selection = True + config.temperature = 0.763141 + + model = TapasForQuestionAnswering(config=config) + elif task == "TABFACT": + model = TapasForSequenceClassification(config=config) + elif task == "MLM": + model = TapasForMaskedLM(config=config) + elif task == "INTERMEDIATE_PRETRAINING": + model = TapasModel(config=config) + else: + raise ValueError(f"Task {task} not supported.") + + print(f"Building PyTorch model from configuration: {config}") + # Load weights from tf checkpoint + load_tf_weights_in_tapas(model, config, tf_checkpoint_path) + + # Save pytorch-model (weights and configuration) + print(f"Save PyTorch model to {pytorch_dump_path}") + model.save_pretrained(pytorch_dump_path) + + # Save tokenizer files + print(f"Save tokenizer files to {pytorch_dump_path}") + tokenizer = TapasTokenizer(vocab_file=tf_checkpoint_path[:-10] + "vocab.txt", model_max_length=512) + tokenizer.save_pretrained(pytorch_dump_path) + + print("Used relative position embeddings:", model.config.reset_position_index_per_cell) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--task", default="SQA", type=str, help="Model task for which to convert a checkpoint. Defaults to SQA." + ) + parser.add_argument( + "--reset_position_index_per_cell", + default=False, + action="store_true", + help="Whether to use relative position embeddings or not. Defaults to True.", + ) + parser.add_argument( + "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." + ) + parser.add_argument( + "--tapas_config_file", + default=None, + type=str, + required=True, + help=( + "The config json file corresponding to the pre-trained TAPAS model. \n" + "This specifies the model architecture." + ), + ) + parser.add_argument( + "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." + ) + args = parser.parse_args() + convert_tf_checkpoint_to_pytorch( + args.task, + args.reset_position_index_per_cell, + args.tf_checkpoint_path, + args.tapas_config_file, + args.pytorch_dump_path, + ) diff --git a/docs/transformers/src/transformers/models/tapas/modeling_tapas.py b/docs/transformers/src/transformers/models/tapas/modeling_tapas.py new file mode 100644 index 0000000000000000000000000000000000000000..b54d5be1ef8a471bd466422b35030581bd947ca3 --- /dev/null +++ b/docs/transformers/src/transformers/models/tapas/modeling_tapas.py @@ -0,0 +1,2398 @@ +# coding=utf-8 +# Copyright 2020 Google Research and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch TAPAS model.""" + +import enum +import math +import os +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, MaskedLMOutput, SequenceClassifierOutput +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, +) +from ...utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from .configuration_tapas import TapasConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "TapasConfig" +_CHECKPOINT_FOR_DOC = "google/tapas-base" + + +EPSILON_ZERO_DIVISION = 1e-10 +CLOSE_ENOUGH_TO_LOG_ZERO = -10000.0 + + +@dataclass +class TableQuestionAnsweringOutput(ModelOutput): + """ + Output type of [`TapasForQuestionAnswering`]. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` (and possibly `answer`, `aggregation_labels`, `numeric_values` and `numeric_values_scale` are provided)): + Total loss as the sum of the hierarchical cell selection log-likelihood loss and (optionally) the + semi-supervised regression loss and (optionally) supervised loss for aggregations. + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Prediction scores of the cell selection head, for every token. + logits_aggregation (`torch.FloatTensor`, *optional*, of shape `(batch_size, num_aggregation_labels)`): + Prediction scores of the aggregation head, for every aggregation operator. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer + plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in + the self-attention heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + logits_aggregation: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +def load_tf_weights_in_tapas(model, config, tf_checkpoint_path): + """ + Load tf checkpoints in a PyTorch model. This is an adaptation from load_tf_weights_in_bert + + - add cell selection and aggregation heads + - take into account additional token type embedding layers + """ + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array) + + for name, array in zip(names, arrays): + name = name.split("/") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculate m and v + # which are not required for using pretrained model + if any( + n + in [ + "adam_v", + "adam_m", + "AdamWeightDecayOptimizer", + "AdamWeightDecayOptimizer_1", + "global_step", + "seq_relationship", + ] + for n in name + ): + logger.info(f"Skipping {'/'.join(name)}") + continue + # in case the model is TapasForSequenceClassification, we skip output_bias and output_weights + # since these are not used for classification + if isinstance(model, TapasForSequenceClassification): + if any(n in ["output_bias", "output_weights"] for n in name): + logger.info(f"Skipping {'/'.join(name)}") + continue + # in case the model is TapasModel, we skip output_bias, output_weights, output_bias_cls and output_weights_cls + # since this model does not have MLM and NSP heads + if isinstance(model, TapasModel): + if any(n in ["output_bias", "output_weights", "output_bias_cls", "output_weights_cls"] for n in name): + logger.info(f"Skipping {'/'.join(name)}") + continue + # in case the model is TapasForMaskedLM, we skip the pooler + if isinstance(model, TapasForMaskedLM): + if any(n in ["pooler"] for n in name): + logger.info(f"Skipping {'/'.join(name)}") + continue + # if first scope name starts with "bert", change it to "tapas" + if name[0] == "bert": + name[0] = "tapas" + pointer = model + for m_name in name: + if re.fullmatch(r"[A-Za-z]+_\d+", m_name): + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "kernel" or scope_names[0] == "gamma": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + # cell selection heads + elif scope_names[0] == "output_bias": + if not isinstance(model, TapasForMaskedLM): + pointer = getattr(pointer, "output_bias") + else: + pointer = getattr(pointer, "bias") + elif scope_names[0] == "output_weights": + pointer = getattr(pointer, "output_weights") + elif scope_names[0] == "column_output_bias": + pointer = getattr(pointer, "column_output_bias") + elif scope_names[0] == "column_output_weights": + pointer = getattr(pointer, "column_output_weights") + # aggregation head + elif scope_names[0] == "output_bias_agg": + pointer = getattr(pointer, "aggregation_classifier") + pointer = getattr(pointer, "bias") + elif scope_names[0] == "output_weights_agg": + pointer = getattr(pointer, "aggregation_classifier") + pointer = getattr(pointer, "weight") + # classification head + elif scope_names[0] == "output_bias_cls": + pointer = getattr(pointer, "classifier") + pointer = getattr(pointer, "bias") + elif scope_names[0] == "output_weights_cls": + pointer = getattr(pointer, "classifier") + pointer = getattr(pointer, "weight") + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info(f"Skipping {'/'.join(name)}") + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if m_name[-11:] == "_embeddings": + pointer = getattr(pointer, "weight") + elif m_name[-13:] in [f"_embeddings_{i}" for i in range(7)]: + pointer = getattr(pointer, "weight") + elif m_name == "kernel": + array = np.transpose(array) + try: + if pointer.shape != array.shape: + raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + # Added a check to see whether the array is a scalar (because bias terms in Tapas checkpoints can be + # scalar => should first be converted to numpy arrays) + if np.isscalar(array): + array = np.array(array) + pointer.data = torch.from_numpy(array) + return model + + +class TapasEmbeddings(nn.Module): + """ + Construct the embeddings from word, position and token_type embeddings. Same as BertEmbeddings but with a number of + additional token type embeddings to encode tabular structure. + """ + + def __init__(self, config): + super().__init__() + # we do not include config.disabled_features and config.disable_position_embeddings from the original implementation + # word embeddings + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + # position embeddings + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + # token type embeddings + for i, type_vocab_sizes in enumerate(config.type_vocab_sizes): + name = f"token_type_embeddings_{i}" + setattr(self, name, nn.Embedding(type_vocab_sizes, config.hidden_size)) + + self.number_of_token_type_embeddings = len(config.type_vocab_sizes) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + self.config = config + + def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if position_ids is None: + # create absolute position embeddings + position_ids = torch.arange(seq_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).expand(input_shape) + # when self.config.reset_position_index_per_cell is set to True, create relative position embeddings + if self.config.reset_position_index_per_cell: + # shape (batch_size, seq_len) + col_index = IndexMap(token_type_ids[:, :, 1], self.config.type_vocab_sizes[1], batch_dims=1) + # shape (batch_size, seq_len) + row_index = IndexMap(token_type_ids[:, :, 2], self.config.type_vocab_sizes[2], batch_dims=1) + # shape (batch_size, seq_len) + full_index = ProductIndexMap(col_index, row_index) + # shape (max_rows * max_columns,). First absolute position for every cell + first_position_per_segment = reduce_min(position_ids, full_index)[0] + # ? shape (batch_size, seq_len). First absolute position of the cell for every token + first_position = gather(first_position_per_segment, full_index) + # shape (1, seq_len) + position = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0) + position_ids = torch.min( + torch.as_tensor(self.config.max_position_embeddings - 1, device=device), position - first_position + ) + + if token_type_ids is None: + token_type_ids = torch.zeros( + (input_shape + self.number_of_token_type_embeddings), dtype=torch.long, device=device + ) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + position_embeddings = self.position_embeddings(position_ids) + + embeddings = inputs_embeds + position_embeddings + + for i in range(self.number_of_token_type_embeddings): + name = f"token_type_embeddings_{i}" + embeddings += getattr(self, name)(token_type_ids[:, :, i]) + + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class TapasSelfAttention(nn.Module): + def __init__(self, config): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size {config.hidden_size} is not a multiple of the number of attention " + f"heads {config.num_attention_heads}" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.is_decoder = config.is_decoder + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_layer = past_key_value[0] + value_layer = past_key_value[1] + attention_mask = encoder_attention_mask + elif is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + if self.is_decoder: + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in TapasModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + +# Copied from transformers.models.bert.modeling_bert.BertSelfOutput +class TapasSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class TapasAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.self = TapasSelfAttention(config) + self.output = TapasSelfOutput(config) + self.pruned_heads = set() + + # Copied from transformers.models.bert.modeling_bert.BertAttention.prune_heads + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + # Copied from transformers.models.bert.modeling_bert.BertAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.bert.modeling_bert.BertIntermediate +class TapasIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertOutput +class TapasOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class TapasLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = TapasAttention(config) + self.is_decoder = config.is_decoder + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + if not self.is_decoder: + raise ValueError(f"{self} should be used as a decoder model if cross attention is added") + self.crossattention = TapasAttention(config) + self.intermediate = TapasIntermediate(config) + self.output = TapasOutput(config) + + # Copied from transformers.models.bert.modeling_bert.BertLayer.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + + # if decoder, the last output is tuple of self-attn cache + if self.is_decoder: + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + else: + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + cross_attn_present_key_value = None + if self.is_decoder and encoder_hidden_states is not None: + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" + " by setting `config.add_cross_attention=True`" + ) + + # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + cross_attn_past_key_value, + output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + + # add cross-attn cache to positions 3,4 of present_key_value tuple + cross_attn_present_key_value = cross_attention_outputs[-1] + present_key_value = present_key_value + cross_attn_present_key_value + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output + ) + outputs = (layer_output,) + outputs + + # if decoder, return the attn key/values as the last output + if self.is_decoder: + outputs = outputs + (present_key_value,) + + return outputs + + # Copied from transformers.models.bert.modeling_bert.BertLayer.feed_forward_chunk + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class TapasEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList([TapasLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + ): + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.__call__, + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_values, + output_attentions, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_values, + output_attentions, + ) + hidden_states = layer_outputs[0] + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions + ) + + +# Copied from transformers.models.bert.modeling_bert.BertPooler +class TapasPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +# Copied from transformers.models.bert.modeling_bert.BertPredictionHeadTransform with Bert->Tapas +class TapasPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead with Bert->Tapas +class TapasLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = TapasPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def _tie_weights(self): + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertOnlyMLMHead with Bert->Tapas +class TapasOnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = TapasLMPredictionHead(config) + + def forward(self, sequence_output: torch.Tensor) -> torch.Tensor: + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class TapasPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = TapasConfig + base_model_prefix = "tapas" + supports_gradient_checkpointing = True + _supports_param_buffer_assignment = False + + # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights with Bert->Tapas + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, TapasLMPredictionHead): + module.bias.data.zero_() + + +TAPAS_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its models (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`TapasConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +TAPAS_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.LongTensor` of shape `({0}, 7)`, *optional*): + Token indices that encode tabular structure. Indices can be obtained using [`AutoTokenizer`]. See this + class for more info. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. If + `reset_position_index_per_cell` of [`TapasConfig`] is set to `True`, relative position embeddings will be + used. Selected in the range `[0, config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: - 1 + indicates the head is **not masked**, - 0 indicates the head is **masked**. + inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare Tapas Model transformer outputting raw hidden-states without any specific head on top.", + TAPAS_START_DOCSTRING, +) +class TapasModel(TapasPreTrainedModel): + """ + This class is a small change compared to [`BertModel`], taking into account the additional token type ids. + + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in [Attention is + all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + + """ + + def __init__(self, config, add_pooling_layer=True): + super().__init__(config) + self.config = config + + self.embeddings = TapasEmbeddings(config) + self.encoder = TapasEncoder(config) + + self.pooler = TapasPooler(config) if add_pooling_layer else None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(TAPAS_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + Examples: + + ```python + >>> from transformers import AutoTokenizer, TapasModel + >>> import pandas as pd + + >>> tokenizer = AutoTokenizer.from_pretrained("google/tapas-base") + >>> model = TapasModel.from_pretrained("google/tapas-base") + + >>> data = { + ... "Actors": ["Brad Pitt", "Leonardo Di Caprio", "George Clooney"], + ... "Age": ["56", "45", "59"], + ... "Number of movies": ["87", "53", "69"], + ... } + >>> table = pd.DataFrame.from_dict(data) + >>> queries = ["How many movies has George Clooney played in?", "How old is Brad Pitt?"] + + >>> inputs = tokenizer(table=table, queries=queries, padding="max_length", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> last_hidden_states = outputs.last_hidden_state + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if attention_mask is None: + attention_mask = torch.ones(input_shape, device=device) + if token_type_ids is None: + token_type_ids = torch.zeros( + (*input_shape, len(self.config.type_vocab_sizes)), dtype=torch.long, device=device + ) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + + # If a 2D ou 3D attention mask is provided for the cross-attention + # we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings( + input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds + ) + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings("""Tapas Model with a `language modeling` head on top.""", TAPAS_START_DOCSTRING) +class TapasForMaskedLM(TapasPreTrainedModel): + _tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"] + config_class = TapasConfig + base_model_prefix = "tapas" + + def __init__(self, config): + super().__init__(config) + + self.tapas = TapasModel(config, add_pooling_layer=False) + self.cls = TapasOnlyMLMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias + + @add_start_docstrings_to_model_forward(TAPAS_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple, MaskedLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + + Returns: + + Examples: + + ```python + >>> from transformers import AutoTokenizer, TapasForMaskedLM + >>> import pandas as pd + + >>> tokenizer = AutoTokenizer.from_pretrained("google/tapas-base") + >>> model = TapasForMaskedLM.from_pretrained("google/tapas-base") + + >>> data = { + ... "Actors": ["Brad Pitt", "Leonardo Di Caprio", "George Clooney"], + ... "Age": ["56", "45", "59"], + ... "Number of movies": ["87", "53", "69"], + ... } + >>> table = pd.DataFrame.from_dict(data) + + >>> inputs = tokenizer( + ... table=table, queries="How many [MASK] has George [MASK] played in?", return_tensors="pt" + ... ) + >>> labels = tokenizer( + ... table=table, queries="How many movies has George Clooney played in?", return_tensors="pt" + ... )["input_ids"] + + >>> outputs = model(**inputs, labels=labels) + >>> logits = outputs.logits + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.tapas( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Tapas Model with a cell selection head and optional aggregation head on top for question-answering tasks on tables + (linear layers on top of the hidden-states output to compute `logits` and optional `logits_aggregation`), e.g. for + SQA, WTQ or WikiSQL-supervised tasks. + """, + TAPAS_START_DOCSTRING, +) +class TapasForQuestionAnswering(TapasPreTrainedModel): + def __init__(self, config: TapasConfig): + super().__init__(config) + + # base model + self.tapas = TapasModel(config) + + # dropout (only used when training) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # cell selection heads + if config.init_cell_selection_weights_to_zero: + # init_cell_selection_weights_to_zero: Whether the initial weights should be + # set to 0. This ensures that all tokens have the same prior probability. + self.output_weights = nn.Parameter(torch.zeros(config.hidden_size)) + self.column_output_weights = nn.Parameter(torch.zeros(config.hidden_size)) + else: + self.output_weights = nn.Parameter(torch.empty(config.hidden_size)) + nn.init.normal_( + self.output_weights, std=config.initializer_range + ) # here, a truncated normal is used in the original implementation + self.column_output_weights = nn.Parameter(torch.empty(config.hidden_size)) + nn.init.normal_( + self.column_output_weights, std=config.initializer_range + ) # here, a truncated normal is used in the original implementation + self.output_bias = nn.Parameter(torch.zeros([])) + self.column_output_bias = nn.Parameter(torch.zeros([])) + + # aggregation head + if config.num_aggregation_labels > 0: + self.aggregation_classifier = nn.Linear(config.hidden_size, config.num_aggregation_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(TAPAS_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=TableQuestionAnsweringOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + table_mask: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + aggregation_labels: Optional[torch.LongTensor] = None, + float_answer: Optional[torch.FloatTensor] = None, + numeric_values: Optional[torch.FloatTensor] = None, + numeric_values_scale: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TableQuestionAnsweringOutput]: + r""" + table_mask (`torch.LongTensor` of shape `(batch_size, seq_length)`, *optional*): + Mask for the table. Indicates which tokens belong to the table (1). Question tokens, table headers and + padding are 0. + labels (`torch.LongTensor` of shape `(batch_size, seq_length)`, *optional*): + Labels per token for computing the hierarchical cell selection loss. This encodes the positions of the + answer appearing in the table. Can be obtained using [`AutoTokenizer`]. + + - 1 for tokens that are **part of the answer**, + - 0 for tokens that are **not part of the answer**. + + aggregation_labels (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + Aggregation function index for every example in the batch for computing the aggregation loss. Indices + should be in `[0, ..., config.num_aggregation_labels - 1]`. Only required in case of strong supervision for + aggregation (WikiSQL-supervised). + float_answer (`torch.FloatTensor` of shape `(batch_size, )`, *optional*): + Float answer for every example in the batch. Set to *float('nan')* for cell selection questions. Only + required in case of weak supervision (WTQ) to calculate the aggregate mask and regression loss. + numeric_values (`torch.FloatTensor` of shape `(batch_size, seq_length)`, *optional*): + Numeric values of every token, NaN for tokens which are not numeric values. Can be obtained using + [`AutoTokenizer`]. Only required in case of weak supervision for aggregation (WTQ) to calculate the + regression loss. + numeric_values_scale (`torch.FloatTensor` of shape `(batch_size, seq_length)`, *optional*): + Scale of the numeric values of every token. Can be obtained using [`AutoTokenizer`]. Only required in case + of weak supervision for aggregation (WTQ) to calculate the regression loss. + + Returns: + + Examples: + + ```python + >>> from transformers import AutoTokenizer, TapasForQuestionAnswering + >>> import pandas as pd + + >>> tokenizer = AutoTokenizer.from_pretrained("google/tapas-base-finetuned-wtq") + >>> model = TapasForQuestionAnswering.from_pretrained("google/tapas-base-finetuned-wtq") + + >>> data = { + ... "Actors": ["Brad Pitt", "Leonardo Di Caprio", "George Clooney"], + ... "Age": ["56", "45", "59"], + ... "Number of movies": ["87", "53", "69"], + ... } + >>> table = pd.DataFrame.from_dict(data) + >>> queries = ["How many movies has George Clooney played in?", "How old is Brad Pitt?"] + + >>> inputs = tokenizer(table=table, queries=queries, padding="max_length", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> logits = outputs.logits + >>> logits_aggregation = outputs.logits_aggregation + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.tapas( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + pooled_output = outputs[1] + + sequence_output = self.dropout(sequence_output) + + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + # Construct indices for the table. + if token_type_ids is None: + token_type_ids = torch.zeros( + (*input_shape, len(self.config.type_vocab_sizes)), dtype=torch.long, device=device + ) + + token_types = [ + "segment_ids", + "column_ids", + "row_ids", + "prev_labels", + "column_ranks", + "inv_column_ranks", + "numeric_relations", + ] + + row_ids = token_type_ids[:, :, token_types.index("row_ids")] + column_ids = token_type_ids[:, :, token_types.index("column_ids")] + + row_index = IndexMap( + indices=torch.min(row_ids, torch.as_tensor(self.config.max_num_rows - 1, device=row_ids.device)), + num_segments=self.config.max_num_rows, + batch_dims=1, + ) + col_index = IndexMap( + indices=torch.min(column_ids, torch.as_tensor(self.config.max_num_columns - 1, device=column_ids.device)), + num_segments=self.config.max_num_columns, + batch_dims=1, + ) + cell_index = ProductIndexMap(row_index, col_index) + + # Masks. + input_shape = input_ids.size() if input_ids is not None else inputs_embeds.size()[:-1] + device = input_ids.device if input_ids is not None else inputs_embeds.device + if attention_mask is None: + attention_mask = torch.ones(input_shape, device=device) + # Table cells only, without question tokens and table headers. + if table_mask is None: + table_mask = torch.where(row_ids > 0, torch.ones_like(row_ids), torch.zeros_like(row_ids)) + # torch.FloatTensor[batch_size, seq_length] + input_mask_float = attention_mask.to(device=device, dtype=torch.float) + table_mask_float = table_mask.to(device=device, dtype=torch.float) + # Mask for cells that exist in the table (i.e. that are not padding). + cell_mask, _ = reduce_mean(input_mask_float, cell_index) + + # Compute logits per token. These are used to select individual cells. + logits = compute_token_logits(sequence_output, self.config.temperature, self.output_weights, self.output_bias) + + # Compute logits per column. These are used to select a column. + column_logits = None + if self.config.select_one_column: + column_logits = compute_column_logits( + sequence_output, + self.column_output_weights, + self.column_output_bias, + cell_index, + cell_mask, + self.config.allow_empty_column_selection, + ) + + # Aggregation logits + logits_aggregation = None + if self.config.num_aggregation_labels > 0: + logits_aggregation = self.aggregation_classifier(pooled_output) + + # Total loss calculation + total_loss = 0.0 + calculate_loss = False + if labels is not None: + calculate_loss = True + is_supervised = not self.config.num_aggregation_labels > 0 or not self.config.use_answer_as_supervision + + # Semi-supervised cell selection in case of no aggregation: + # If the answer (the denotation) appears directly in the table we might + # select the answer without applying any aggregation function. There are + # some ambiguous cases, see utils._calculate_aggregate_mask for more info. + # `aggregate_mask` is 1 for examples where we chose to aggregate and 0 + # for examples where we chose to select the answer directly. + # `labels` encodes the positions of the answer appearing in the table. + if is_supervised: + aggregate_mask = None + else: + if float_answer is not None: + assert labels.shape[0] == float_answer.shape[0], ( + "Make sure the answers are a FloatTensor of shape (batch_size,)" + ) + # [batch_size] + aggregate_mask = _calculate_aggregate_mask( + float_answer, + pooled_output, + self.config.cell_selection_preference, + labels, + self.aggregation_classifier, + ) + else: + raise ValueError("You have to specify float answers in order to calculate the aggregate mask") + + # Cell selection log-likelihood + if self.config.average_logits_per_cell: + logits_per_cell, _ = reduce_mean(logits, cell_index) + logits = gather(logits_per_cell, cell_index) + dist_per_token = torch.distributions.Bernoulli(logits=logits) + + # Compute cell selection loss per example. + selection_loss_per_example = None + if not self.config.select_one_column: + weight = torch.where( + labels == 0, + torch.ones_like(labels, dtype=torch.float32), + self.config.positive_label_weight * torch.ones_like(labels, dtype=torch.float32), + ) + selection_loss_per_token = -dist_per_token.log_prob(labels) * weight + selection_loss_per_example = torch.sum(selection_loss_per_token * input_mask_float, dim=1) / ( + torch.sum(input_mask_float, dim=1) + EPSILON_ZERO_DIVISION + ) + else: + selection_loss_per_example, logits = _single_column_cell_selection_loss( + logits, column_logits, labels, cell_index, col_index, cell_mask + ) + dist_per_token = torch.distributions.Bernoulli(logits=logits) + + # Supervised cell selection + if self.config.disable_per_token_loss: + pass + elif is_supervised: + total_loss += torch.mean(selection_loss_per_example) + else: + # For the not supervised case, do not assign loss for cell selection + total_loss += torch.mean(selection_loss_per_example * (1.0 - aggregate_mask)) + + # Semi-supervised regression loss and supervised loss for aggregations + if self.config.num_aggregation_labels > 0: + if is_supervised: + # Note that `aggregate_mask` is None if the setting is supervised. + if aggregation_labels is not None: + assert labels.shape[0] == aggregation_labels.shape[0], ( + "Make sure the aggregation labels are a LongTensor of shape (batch_size,)" + ) + per_example_additional_loss = _calculate_aggregation_loss( + logits_aggregation, + aggregate_mask, + aggregation_labels, + self.config.use_answer_as_supervision, + self.config.num_aggregation_labels, + self.config.aggregation_loss_weight, + ) + else: + raise ValueError( + "You have to specify aggregation labels in order to calculate the aggregation loss" + ) + else: + # Set aggregation labels to zeros + aggregation_labels = torch.zeros(labels.shape[0], dtype=torch.long, device=labels.device) + per_example_additional_loss = _calculate_aggregation_loss( + logits_aggregation, + aggregate_mask, + aggregation_labels, + self.config.use_answer_as_supervision, + self.config.num_aggregation_labels, + self.config.aggregation_loss_weight, + ) + + if self.config.use_answer_as_supervision: + if numeric_values is not None and numeric_values_scale is not None: + assert numeric_values.shape == numeric_values_scale.shape + # Add regression loss for numeric answers which require aggregation. + answer_loss, large_answer_loss_mask = _calculate_regression_loss( + float_answer, + aggregate_mask, + dist_per_token, + numeric_values, + numeric_values_scale, + table_mask_float, + logits_aggregation, + self.config, + ) + per_example_additional_loss += answer_loss + # Zero loss for examples with answer_loss > cutoff. + per_example_additional_loss *= large_answer_loss_mask + else: + raise ValueError( + "You have to specify numeric values and numeric values scale in order to calculate the" + " regression loss" + ) + + total_loss += torch.mean(per_example_additional_loss) + + else: + # if no label ids are provided, set them to zeros in order to properly compute logits + labels = torch.zeros_like(logits) + _, logits = _single_column_cell_selection_loss( + logits, column_logits, labels, cell_index, col_index, cell_mask + ) + if not return_dict: + output = (logits, logits_aggregation) + outputs[2:] + return ((total_loss,) + output) if calculate_loss else output + + return TableQuestionAnsweringOutput( + loss=total_loss if calculate_loss else None, + logits=logits, + logits_aggregation=logits_aggregation, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Tapas Model with a sequence classification head on top (a linear layer on top of the pooled output), e.g. for table + entailment tasks, such as TabFact (Chen et al., 2020). + """, + TAPAS_START_DOCSTRING, +) +class TapasForSequenceClassification(TapasPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.tapas = TapasModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(TAPAS_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). Note: this is called + "classification_class_index" in the original implementation. + + Returns: + + Examples: + + ```python + >>> from transformers import AutoTokenizer, TapasForSequenceClassification + >>> import torch + >>> import pandas as pd + + >>> tokenizer = AutoTokenizer.from_pretrained("google/tapas-base-finetuned-tabfact") + >>> model = TapasForSequenceClassification.from_pretrained("google/tapas-base-finetuned-tabfact") + + >>> data = { + ... "Actors": ["Brad Pitt", "Leonardo Di Caprio", "George Clooney"], + ... "Age": ["56", "45", "59"], + ... "Number of movies": ["87", "53", "69"], + ... } + >>> table = pd.DataFrame.from_dict(data) + >>> queries = [ + ... "There is only one actor who is 45 years old", + ... "There are 3 actors which played in more than 60 movies", + ... ] + + >>> inputs = tokenizer(table=table, queries=queries, padding="max_length", return_tensors="pt") + >>> labels = torch.tensor([1, 0]) # 1 means entailed, 0 means refuted + + >>> outputs = model(**inputs, labels=labels) + >>> loss = outputs.loss + >>> logits = outputs.logits + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.tapas( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +""" TAPAS utilities.""" + + +class AverageApproximationFunction(str, enum.Enum): + RATIO = "ratio" + FIRST_ORDER = "first_order" + SECOND_ORDER = "second_order" + + +# Beginning of everything related to segmented tensors + + +class IndexMap: + """Index grouping entries within a tensor.""" + + def __init__(self, indices, num_segments, batch_dims=0): + """ + Creates an index + + Args: + indices (`torch.LongTensor`, same shape as a *values* Tensor to which the indices refer): + Tensor containing the indices. + num_segments (`torch.LongTensor`): + Scalar tensor, the number of segments. All elements in a batched segmented tensor must have the same + number of segments (although many segments can be empty). + batch_dims (`int`, *optional*, defaults to 0): + The number of batch dimensions. The first *batch_dims* dimensions of a SegmentedTensor are treated as + batch dimensions. Segments in different batch elements are always distinct even if they have the same + index. + """ + self.indices = torch.as_tensor(indices, device=indices.device) + self.num_segments = torch.as_tensor(num_segments, device=indices.device) + self.batch_dims = batch_dims + + def batch_shape(self): + return self.indices.size()[: self.batch_dims] # returns a torch.Size object + + +class ProductIndexMap(IndexMap): + """The product of two indices.""" + + def __init__(self, outer_index, inner_index): + """ + Combines indices i and j into pairs (i, j). The result is an index where each segment (i, j) is the + intersection of segments i and j. For example if the inputs represent table cells indexed by respectively rows + and columns the output will be a table indexed by (row, column) pairs, i.e. by cell. The implementation + combines indices {0, .., n - 1} and {0, .., m - 1} into {0, .., nm - 1}. The output has *num_segments* equal to + *outer_index.num_segments* * *inner_index.num_segments* + + Args: + outer_index (`IndexMap`): + IndexMap. + inner_index (`IndexMap`): + IndexMap, must have the same shape as *outer_index*. + """ + if outer_index.batch_dims != inner_index.batch_dims: + raise ValueError("outer_index.batch_dims and inner_index.batch_dims must be the same.") + + super().__init__( + indices=(inner_index.indices + outer_index.indices * inner_index.num_segments), + num_segments=inner_index.num_segments * outer_index.num_segments, + batch_dims=inner_index.batch_dims, + ) + self.outer_index = outer_index + self.inner_index = inner_index + + def project_outer(self, index): + """Projects an index with the same index set onto the outer components.""" + indices = torch.div(index.indices, self.inner_index.num_segments, rounding_mode="floor").type(torch.long) + return IndexMap(indices=indices, num_segments=self.outer_index.num_segments, batch_dims=index.batch_dims) + + def project_inner(self, index): + """Projects an index with the same index set onto the inner components.""" + return IndexMap( + indices=torch.fmod(index.indices, self.inner_index.num_segments) + .type(torch.float) + .floor() + .type(torch.long), + num_segments=self.inner_index.num_segments, + batch_dims=index.batch_dims, + ) + + +def gather(values, index, name="segmented_gather"): + """ + Gathers from *values* using the index map. For each element in the domain of the index map this operation looks up + a value for that index in *values*. Two elements from the same segment always get assigned the same value. + + Args: + values (`torch.Tensor` of shape (B1, ..., Bn, num_segments, V1, ...)): + Tensor with segment values. + index (`IndexMap` of shape (B1, ..., Bn, I1, ..., Ik)): + IndexMap. + name (`str`, *optional*, defaults to 'segmented_gather'): + Name for the operation. Currently not used + + Returns: + `tuple(torch.Tensor)`: Tensor of shape (B1, ..., Bn, I1, ..., Ik, V1, ...) with the gathered values. + """ + indices = index.indices + # first, check whether the indices of the index represent scalar values (i.e. not vectorized) + if len(values.shape[index.batch_dims :]) < 2: + return torch.gather( + values, + index.batch_dims, + indices.view( + values.size()[0], -1 + ), # torch.gather expects index to have the same number of dimensions as values + ).view(indices.size()) + else: + # this means we have a vectorized version + # we have to adjust the index + indices = indices.unsqueeze(-1).expand(values.shape) + return torch.gather(values, index.batch_dims, indices) + + +def flatten(index, name="segmented_flatten"): + """ + Flattens a batched index map (which is typically of shape batch_size, seq_length) to a 1d index map. This operation + relabels the segments to keep batch elements distinct. The k-th batch element will have indices shifted by + *num_segments* * (k - 1). The result is a tensor with *num_segments* multiplied by the number of elements in the + batch. + + Args: + index (`IndexMap`): + IndexMap to flatten. + name (`str`, *optional*, defaults to 'segmented_flatten'): + Name for the operation. Currently not used + + Returns: + (`IndexMap`): The flattened IndexMap. + """ + # first, get batch_size as scalar tensor + batch_size = torch.prod(torch.tensor(list(index.batch_shape()))) + # next, create offset as 1-D tensor of length batch_size, + # and multiply element-wise by num segments (to offset different elements in the batch) e.g. if batch size is 2: [0, 64] + offset = torch.arange(start=0, end=batch_size, device=index.num_segments.device) * index.num_segments + offset = offset.view(index.batch_shape()) + for _ in range(index.batch_dims, len(index.indices.size())): # typically range(1,2) + offset = offset.unsqueeze(-1) + + indices = offset + index.indices + return IndexMap(indices=indices.view(-1), num_segments=index.num_segments * batch_size, batch_dims=0) + + +def range_index_map(batch_shape, num_segments, name="range_index_map"): + """ + Constructs an index map equal to range(num_segments). + + Args: + batch_shape (`torch.Size`): + Batch shape + num_segments (`int`): + Number of segments + name (`str`, *optional*, defaults to 'range_index_map'): + Name for the operation. Currently not used + + Returns: + (`IndexMap`): IndexMap of shape batch_shape with elements equal to range(num_segments). + """ + device = num_segments.device if torch.is_tensor(num_segments) else "cpu" + batch_shape = torch.as_tensor( + batch_shape, dtype=torch.long, device=device + ) # create a rank 1 tensor vector containing batch_shape (e.g. [2]) + assert len(batch_shape.size()) == 1 + num_segments = torch.as_tensor( + num_segments, device=device + ) # create a rank 0 tensor (scalar) containing num_segments (e.g. 64) + assert len(num_segments.size()) == 0 + + indices = torch.arange( + start=0, end=num_segments, device=num_segments.device + ) # create a rank 1 vector with num_segments elements + new_tensor = torch.cat( + [torch.ones_like(batch_shape, dtype=torch.long, device=num_segments.device), num_segments.unsqueeze(dim=0)], + dim=0, + ) + # new_tensor is just a vector of [1 64] for example (assuming only 1 batch dimension) + new_shape = [int(x) for x in new_tensor.tolist()] + indices = indices.view(new_shape) + + multiples = torch.cat([batch_shape, torch.as_tensor([1], device=device)], dim=0) + indices = indices.repeat(multiples.tolist()) + # equivalent (in Numpy:) + # indices = torch.as_tensor(np.tile(indices.numpy(), multiples.tolist())) + + return IndexMap(indices=indices, num_segments=num_segments, batch_dims=list(batch_shape.size())[0]) + + +def _segment_reduce(values, index, segment_reduce_fn, name): + """ + Applies a segment reduction segment-wise. + + Args: + values (`torch.Tensor`): + Tensor with segment values. + index (`IndexMap`): + IndexMap. + segment_reduce_fn (`str`): + Name for the reduce operation. One of "sum", "mean", "max" or "min". + name (`str`): + Name for the operation. Currently not used + + Returns: + (`IndexMap`): IndexMap of shape batch_shape with elements equal to range(num_segments). + """ + # Flatten the batch dimensions, as segments ops (scatter) do not support batching. + # However if `values` has extra dimensions to the right keep them + # unflattened. Segmented ops support vector-valued operations. + flat_index = flatten(index) + vector_shape = values.size()[len(index.indices.size()) :] # torch.Size object + flattened_shape = torch.cat( + [torch.as_tensor([-1], dtype=torch.long), torch.as_tensor(vector_shape, dtype=torch.long)], dim=0 + ) + # changed "view" by "reshape" in the following line + flat_values = values.reshape(flattened_shape.tolist()) + + out = torch.zeros(int(flat_index.num_segments), dtype=torch.float, device=flat_values.device) + segment_means = out.scatter_reduce( + dim=0, index=flat_index.indices.long(), src=flat_values.float(), reduce=segment_reduce_fn, include_self=False + ) + + device = index.num_segments.device + # Unflatten the values. + new_shape = torch.cat( + [ + torch.as_tensor(index.batch_shape(), dtype=torch.long, device=device), + torch.as_tensor([index.num_segments], dtype=torch.long, device=device), + torch.as_tensor(vector_shape, dtype=torch.long, device=device), + ], + dim=0, + ) + + output_values = segment_means.clone().view(new_shape.tolist()).to(values.dtype) + output_index = range_index_map(index.batch_shape(), index.num_segments) + return output_values, output_index + + +def reduce_sum(values, index, name="segmented_reduce_sum"): + """ + Sums a tensor over its segments. + + Outputs 0 for empty segments. + + This operations computes the sum over segments, with support for: + + - Batching using the first dimensions [B1, B2, ..., Bn]. Each element in a batch can have different indices. + - Vectorization using the last dimension [V1, V2, ...]. If they are present, the output will be a sum of + vectors rather than scalars. Only the middle dimensions [I1, ..., Ik] are reduced by the operation. + + Args: + values (`torch.Tensor` of shape [B1, B2, ..., Bn, I1, .., Ik, V1, V2, ..]): + Tensor containing the values of which the sum must be taken segment-wise. + index (`IndexMap`, indices are of shape [B1, B2, ..., Bn, I1, .., Ik].): + Index defining the segments. + name (`str`, *optional*, defaults to 'segmented_reduce_sum'): + Name for the operation. Currently not used + + Returns: + output_values (`torch.Tensor`of shape [B1, B2, ..., Bn, num_segments, V1, V2, ..]): Tensor containing the + output values. output_index (`IndexMap`): IndexMap with shape [B1, B2, ..., Bn, num_segments]. . + """ + return _segment_reduce(values, index, "sum", name) + + +def reduce_mean(values, index, name="segmented_reduce_mean"): + """ + Averages a tensor over its segments. + + Outputs 0 for empty segments. + + This operations computes the mean over segments, with support for: + + - Batching using the first dimensions [B1, B2, ..., Bn]. Each element in a batch can have different indices. + - Vectorization using the last dimension [V1, V2, ...]. If they are present, the output will be a mean of + vectors rather than scalars. + + Only the middle dimensions [I1, ..., Ik] are reduced by the operation. + + Args: + values (`torch.Tensor` of shape [B1, B2, ..., Bn, I1, .., Ik, V1, V2, ..]): + Tensor containing the values of which the mean must be taken segment-wise. + index (`IndexMap`, indices are of shape [B1, B2, ..., Bn, I1, .., Ik].): + Index defining the segments. + name (`str`, *optional*, defaults to 'segmented_reduce_sum'): + Name for the operation. Currently not used + + Returns: + output_values (`torch.Tensor`of shape [B1, B2, ..., Bn, num_segments, V1, V2, ..]): Tensor containing the + output values. output_index (`IndexMap`): IndexMap with shape [B1, B2, ..., Bn, num_segments]. + """ + return _segment_reduce(values, index, "mean", name) + + +def reduce_max(values, index, name="segmented_reduce_max"): + """ + Computes the maximum over segments. + + This operation computes the maximum over segments, with support for: + + - Batching using the first dimensions [B1, B2, ..., Bn]. Each element in a batch can have different indices. + - Vectorization using the last dimension [V1, V2, ...]. If they are present, the output will be an element-wise + maximum of vectors rather than scalars. + + Only the middle dimensions [I1, ..., Ik] are reduced by the operation. + + Args: + values (`torch.Tensor` of shape [B1, B2, ..., Bn, I1, .., Ik, V1, V2, ..]): + Tensor containing the values of which the max must be taken segment-wise. + index (`IndexMap`, indices are of shape [B1, B2, ..., Bn, I1, .., Ik].): + Index defining the segments. + name (`str`, *optional*, defaults to 'segmented_reduce_sum'): + Name for the operation. Currently not used + + Returns: + output_values (`torch.Tensor`of shape [B1, B2, ..., Bn, num_segments, V1, V2, ..]): Tensor containing the + output values. output_index (`IndexMap`): IndexMap with shape [B1, B2, ..., Bn, num_segments]. + """ + return _segment_reduce(values, index, "amax", name) + + +def reduce_min(values, index, name="segmented_reduce_min"): + """ + Computes the minimum over segments. + + This operations computes the minimum over segments, with support for: + + - Batching using the first dimensions [B1, B2, ..., Bn]. Each element in a batch can have different indices. + - Vectorization using the last dimension [V1, V2, ...]. If they are present, the output will be an element-wise + minimum of vectors rather than scalars. + + Only the middle dimensions [I1, ..., Ik] are reduced by the operation. + + Args: + values (`torch.Tensor` of shape [B1, B2, ..., Bn, I1, .., Ik, V1, V2, ..]): + Tensor containing the values of which the min must be taken segment-wise. + index (`IndexMap`, indices are of shape [B1, B2, ..., Bn, I1, .., Ik].): + Index defining the segments. + name (`str`, *optional*, defaults to 'segmented_reduce_sum'): + Name for the operation. Currently not used + + Returns: + output_values (`torch.Tensor`of shape [B1, B2, ..., Bn, num_segments, V1, V2, ..]): Tensor containing the + output values. output_index (`IndexMap`): IndexMap with shape [B1, B2, ..., Bn, num_segments]. + """ + return _segment_reduce(values, index, "amin", name) + + +# End of everything related to segmented tensors + + +def compute_column_logits( + sequence_output, column_output_weights, column_output_bias, cell_index, cell_mask, allow_empty_column_selection +): + """ + Computes the column logits. + + Args: + sequence_output (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Also known as last_hidden_state. Sequence of hidden-states at the output of the last layer of the model. + column_output_weights (`torch.FloatTensor` of shape `(hidden_size)`): + Weights of the linear layer for column selection. + column_output_bias (`torch.FloatTensor` of shape `()`): + Bias of the linear layer for column selection. + cell_index (`ProductIndexMap`): + Index that groups tokens into cells. + cell_mask (`torch.FloatTensor` of shape `(batch_size, max_num_rows * max_num_cols)`): + Mask for cells that exist in the table (i.e. that are not padding). + allow_empty_column_selection (`bool`): + Whether to allow not to select any column + + Returns: + column_logits (`torch.FloatTensor`of shape `(batch_size, max_num_cols)`): Tensor containing the column logits + for every example in the batch. + """ + + # First, compute the token logits (batch_size, seq_len) - without temperature + token_logits = torch.einsum("bsj,j->bs", sequence_output, column_output_weights) + column_output_bias + + # Next, average the logits per cell (batch_size, max_num_cols*max_num_rows) + cell_logits, cell_logits_index = reduce_mean(token_logits, cell_index) + + # Finally, average the logits per column (batch_size, max_num_cols) + column_index = cell_index.project_inner(cell_logits_index) + column_logits, out_index = reduce_sum(cell_logits * cell_mask, column_index) + + cell_count, _ = reduce_sum(cell_mask, column_index) + column_logits /= cell_count + EPSILON_ZERO_DIVISION + + # Mask columns that do not appear in the example. + is_padding = torch.logical_and(cell_count < 0.5, ~torch.eq(out_index.indices, 0)) + column_logits += CLOSE_ENOUGH_TO_LOG_ZERO * torch.as_tensor( + is_padding, dtype=torch.float32, device=is_padding.device + ) + + if not allow_empty_column_selection: + column_logits += CLOSE_ENOUGH_TO_LOG_ZERO * torch.as_tensor( + torch.eq(out_index.indices, 0), dtype=torch.float32, device=out_index.indices.device + ) + + return column_logits + + +def _single_column_cell_selection_loss(token_logits, column_logits, labels, cell_index, col_index, cell_mask): + """ + Computes the loss for cell selection constrained to a single column. The loss is a hierarchical log-likelihood. The + model first predicts a column and then selects cells within that column (conditioned on the column). Cells outside + the selected column are never selected. + + Args: + token_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Tensor containing the logits per token. + column_logits (`torch.FloatTensor` of shape `(batch_size, max_num_cols)`): + Tensor containing the logits per column. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Labels per token. + cell_index (`ProductIndexMap`): + Index that groups tokens into cells. + col_index (`IndexMap`): + Index that groups tokens into columns. + cell_mask (`torch.FloatTensor` of shape `(batch_size, max_num_rows * max_num_cols)`): + Mask for cells that exist in the table (i.e. that are not padding). + + Returns: + selection_loss_per_example (`torch.FloatTensor` of shape `(batch_size,)`): Loss for each example. logits + (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): New logits which are only allowed to select + cells in a single column. Logits outside of the most likely column according to *column_logits* will be set to + a very low value (such that the probabilities are 0). + """ + # Part 1: column loss + + # First find the column we should select. We use the column with maximum number of selected cells. + labels_per_column, _ = reduce_sum(torch.as_tensor(labels, dtype=torch.float32, device=labels.device), col_index) + # shape of labels_per_column is (batch_size, max_num_cols). It contains the number of label ids for every column, for every example + column_label = torch.argmax(labels_per_column, dim=-1) # shape (batch_size,) + # Check if there are no selected cells in the column. In that case the model + # should predict the special column id 0, which means "select nothing". + no_cell_selected = torch.eq( + torch.max(labels_per_column, dim=-1)[0], 0 + ) # no_cell_selected is of shape (batch_size,) and equals True + # if an example of the batch has no cells selected (i.e. if there are no labels set to 1 for that example) + column_label = torch.where( + no_cell_selected.view(column_label.size()), torch.zeros_like(column_label), column_label + ) + + column_dist = torch.distributions.Categorical(logits=column_logits) # shape (batch_size, max_num_cols) + column_loss_per_example = -column_dist.log_prob(column_label) + + # Part 2: cell loss + + # Reduce the labels and logits to per-cell from per-token. + # logits_per_cell: shape (batch_size, max_num_rows*max_num_cols) i.e. (batch_size, 64*32) + logits_per_cell, _ = reduce_mean(token_logits, cell_index) + # labels_per_cell: shape (batch_size, 64*32), indicating whether each cell should be selected (1) or not (0) + labels_per_cell, labels_index = reduce_max( + torch.as_tensor(labels, dtype=torch.long, device=labels.device), cell_index + ) + + # Mask for the selected column. + # column_id_for_cells: shape (batch_size, 64*32), indicating to which column each cell belongs + column_id_for_cells = cell_index.project_inner(labels_index).indices + # column_mask: shape (batch_size, 64*32), equal to 1 if cell belongs to column to be selected + column_mask = torch.as_tensor( + torch.eq(column_id_for_cells, torch.unsqueeze(column_label, dim=-1)), + dtype=torch.float32, + device=cell_mask.device, + ) + + # Compute the log-likelihood for cells, but only for the selected column. + cell_dist = torch.distributions.Bernoulli(logits=logits_per_cell) # shape (batch_size, 64*32) + cell_log_prob = cell_dist.log_prob(labels_per_cell.type(torch.float32)) # shape(batch_size, 64*32) + + cell_loss = -torch.sum(cell_log_prob * column_mask * cell_mask, dim=1) + + # We need to normalize the loss by the number of cells in the column. + cell_loss /= torch.sum(column_mask * cell_mask, dim=1) + EPSILON_ZERO_DIVISION + + selection_loss_per_example = column_loss_per_example + selection_loss_per_example += torch.where( + no_cell_selected.view(selection_loss_per_example.size()), + torch.zeros_like(selection_loss_per_example), + cell_loss, + ) + + # Set the probs outside the selected column (selected by the *model*) + # to 0. This ensures backwards compatibility with models that select + # cells from multiple columns. + selected_column_id = torch.as_tensor( + torch.argmax(column_logits, dim=-1), dtype=torch.long, device=column_logits.device + ) # shape (batch_size,) + + # selected_column_mask: shape (batch_size, 64*32), equal to 1 if cell belongs to column selected by the model + selected_column_mask = torch.as_tensor( + torch.eq(column_id_for_cells, torch.unsqueeze(selected_column_id, dim=-1)), + dtype=torch.float32, + device=selected_column_id.device, + ) + + # Never select cells with the special column id 0. + selected_column_mask = torch.where( + torch.eq(column_id_for_cells, 0).view(selected_column_mask.size()), + torch.zeros_like(selected_column_mask), + selected_column_mask, + ) + new_logits_per_cell = logits_per_cell + CLOSE_ENOUGH_TO_LOG_ZERO * (1.0 - cell_mask * selected_column_mask) + logits = gather(new_logits_per_cell, cell_index) + + return selection_loss_per_example, logits + + +def compute_token_logits(sequence_output, temperature, output_weights, output_bias): + """ + Computes logits per token + + Args: + sequence_output (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Also known as last_hidden_state. Sequence of hidden-states at the output of the last layer of the model. + temperature (`float`): + Temperature for the Bernoulli distribution. + output_weights (`torch.FloatTensor` of shape `(hidden_size,)`): + Weights of the linear layer for cell selection. + output_bias (`torch.FloatTensor` of shape `()`): + Bias of the linear layer for cell selection + + Returns: + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): Logits per token. + """ + logits = (torch.einsum("bsj,j->bs", sequence_output, output_weights) + output_bias) / temperature + + return logits + + +def _calculate_aggregate_mask(answer, pooled_output, cell_selection_preference, labels, aggregation_classifier): + """ + Finds examples where the model should select cells with no aggregation. + + Returns a mask that determines for which examples should the model select answers directly from the table, without + any aggregation function. If the answer is a piece of text the case is unambiguous as aggregation functions only + apply to numbers. If the answer is a number but does not appear in the table then we must use some aggregation + case. The ambiguous case is when the answer is a number that also appears in the table. In this case we use the + aggregation function probabilities predicted by the model to decide whether to select or aggregate. The threshold + for this is a hyperparameter *cell_selection_preference* + + Args: + answer (`torch.FloatTensor` of shape `(batch_size, )`): + Answer for every example in the batch. Nan if there is no scalar answer. + pooled_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`): + Output of the pooler (BertPooler) on top of the encoder layer. + cell_selection_preference (`float`): + Preference for cell selection in ambiguous cases. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Labels per token. aggregation_classifier (`torch.nn.Linear`): Aggregation head + + Returns: + aggregate_mask (`torch.FloatTensor` of shape `(batch_size,)`): A mask set to 1 for examples that should use + aggregation functions. + """ + # torch.FloatTensor(batch_size,) + aggregate_mask_init = torch.logical_not(torch.isnan(answer)).type(torch.FloatTensor).to(answer.device) + logits_aggregation = aggregation_classifier(pooled_output) + dist_aggregation = torch.distributions.categorical.Categorical(logits=logits_aggregation) + # Index 0 corresponds to "no aggregation". + aggregation_ops_total_mass = torch.sum(dist_aggregation.probs[:, 1:], dim=1) + + # Cell selection examples according to current model. + is_pred_cell_selection = aggregation_ops_total_mass <= cell_selection_preference + + # Examples with non-empty cell selection supervision. + is_cell_supervision_available = torch.sum(labels, dim=1) > 0 + + # torch.where is not equivalent to tf.where (in tensorflow 1) + # hence the added .view on the condition to match the shape of the first tensor + aggregate_mask = torch.where( + torch.logical_and(is_pred_cell_selection, is_cell_supervision_available).view(aggregate_mask_init.size()), + torch.zeros_like(aggregate_mask_init, dtype=torch.float32), + aggregate_mask_init, + ) + + aggregate_mask = aggregate_mask.detach() + + return aggregate_mask + + +def _calculate_aggregation_loss_known( + logits_aggregation, aggregate_mask, aggregation_labels, use_answer_as_supervision, num_aggregation_labels +): + """ + Calculates aggregation loss when its type is known during training. + + In the weakly supervised setting, the only known information is that for cell selection examples, "no aggregation" + should be predicted. For other examples (those that require aggregation), no loss is accumulated. In the setting + where aggregation type is always known, standard cross entropy loss is accumulated for all examples + + Args: + logits_aggregation (`torch.FloatTensor` of shape `(batch_size, num_aggregation_labels)`): + Logits per aggregation operation. + aggregate_mask (`torch.FloatTensor` of shape `(batch_size, )`): + A mask set to 1 for examples that should use aggregation functions. + aggregation_labels (`torch.LongTensor` of shape `(batch_size, )`): + Aggregation function id for every example in the batch. + use_answer_as_supervision (`bool`, *optional*): + Whether to use the answer as the only supervision for aggregation examples. + num_aggregation_labels (`int`, *optional*, defaults to 0): + The number of aggregation operators to predict. + + Returns: + aggregation_loss_known (`torch.FloatTensor` of shape `(batch_size,)`): Aggregation loss (when its type is known + during training) per example. + """ + if use_answer_as_supervision: + # Prepare "no aggregation" targets for cell selection examples. + target_aggregation = torch.zeros_like(aggregate_mask, dtype=torch.long) + else: + # Use aggregation supervision as the target. + target_aggregation = aggregation_labels + + one_hot_labels = nn.functional.one_hot(target_aggregation, num_classes=num_aggregation_labels).type(torch.float32) + log_probs = nn.functional.log_softmax(logits_aggregation, dim=-1) + + # torch.FloatTensor[batch_size] + per_example_aggregation_intermediate = -torch.sum(one_hot_labels * log_probs, dim=-1) + if use_answer_as_supervision: + # Accumulate loss only for examples requiring cell selection + # (no aggregation). + return per_example_aggregation_intermediate * (1 - aggregate_mask) + else: + return per_example_aggregation_intermediate + + +def _calculate_aggregation_loss_unknown(logits_aggregation, aggregate_mask): + """ + Calculates aggregation loss in the case of answer supervision. + + Args: + logits_aggregation (`torch.FloatTensor` of shape `(batch_size, num_aggregation_labels)`): + Logits per aggregation operation. + aggregate_mask (`torch.FloatTensor` of shape `(batch_size, )`): + A mask set to 1 for examples that should use aggregation functions + + Returns: + aggregation_loss_unknown (`torch.FloatTensor` of shape `(batch_size,)`): Aggregation loss (in case of answer + supervision) per example. + """ + dist_aggregation = torch.distributions.categorical.Categorical(logits=logits_aggregation) + # Index 0 corresponds to "no aggregation". + aggregation_ops_total_mass = torch.sum(dist_aggregation.probs[:, 1:], dim=1) + # Predict some aggregation in case of an answer that needs aggregation. + # This increases the probability of all aggregation functions, in a way + # similar to MML, but without considering whether the function gives the + # correct answer. + return -torch.log(aggregation_ops_total_mass) * aggregate_mask + + +def _calculate_aggregation_loss( + logits_aggregation, + aggregate_mask, + aggregation_labels, + use_answer_as_supervision, + num_aggregation_labels, + aggregation_loss_weight, +): + """ + Calculates the aggregation loss per example. + + Args: + logits_aggregation (`torch.FloatTensor` of shape `(batch_size, num_aggregation_labels)`): + Logits per aggregation operation. + aggregate_mask (`torch.FloatTensor` of shape `(batch_size, )`): + A mask set to 1 for examples that should use aggregation functions. + aggregation_labels (`torch.LongTensor` of shape `(batch_size, )`): + Aggregation function id for every example in the batch. + use_answer_as_supervision (`bool`, *optional*): + Whether to use the answer as the only supervision for aggregation examples. + num_aggregation_labels (`int`, *optional*, defaults to 0): + The number of aggregation operators to predict. + aggregation_loss_weight (`float`, *optional*, defaults to 1.0): + Importance weight for the aggregation loss. + + Returns: + aggregation_loss (`torch.FloatTensor` of shape `(batch_size,)`): Aggregation loss per example. + """ + per_example_aggregation_loss = _calculate_aggregation_loss_known( + logits_aggregation, aggregate_mask, aggregation_labels, use_answer_as_supervision, num_aggregation_labels + ) + + if use_answer_as_supervision: + # Add aggregation loss for numeric answers that need aggregation. + per_example_aggregation_loss += _calculate_aggregation_loss_unknown(logits_aggregation, aggregate_mask) + return aggregation_loss_weight * per_example_aggregation_loss + + +def _calculate_expected_result( + dist_per_cell, numeric_values, numeric_values_scale, input_mask_float, logits_aggregation, config +): + """ + Calculates the expected result given cell and aggregation probabilities. + + Args: + dist_per_cell (`torch.distributions.Bernoulli`): + Cell selection distribution for each cell. + numeric_values (`torch.FloatTensor` of shape `(batch_size, seq_length)`): + Numeric values of every token. Nan for tokens which are not numeric values. + numeric_values_scale (`torch.FloatTensor` of shape `(batch_size, seq_length)`): + Scale of the numeric values of every token. + input_mask_float (`torch.FloatTensor` of shape `(batch_size, seq_length)`): + Mask for the table, without question tokens and table headers. + logits_aggregation (`torch.FloatTensor` of shape `(batch_size, num_aggregation_labels)`): + Logits per aggregation operation. + config ([`TapasConfig`]): + Model configuration class with all the hyperparameters of the model + + Returns: + expected_result (`torch.FloatTensor` of shape `(batch_size,)`): The expected result per example. + """ + if config.use_gumbel_for_cells: + gumbel_dist = torch.distributions.RelaxedBernoulli( + # The token logits where already divided by the temperature and used for + # computing cell selection errors so we need to multiply it again here + temperature=config.temperature, + logits=dist_per_cell.logits * config.temperature, + ) + scaled_probability_per_cell = gumbel_dist.sample() + else: + scaled_probability_per_cell = dist_per_cell.probs + + # [batch_size, seq_length] + scaled_probability_per_cell = (scaled_probability_per_cell / numeric_values_scale) * input_mask_float + count_result = torch.sum(scaled_probability_per_cell, dim=1) + numeric_values_masked = torch.where( + torch.isnan(numeric_values), torch.zeros_like(numeric_values), numeric_values + ) # Mask non-numeric table values to zero. + sum_result = torch.sum(scaled_probability_per_cell * numeric_values_masked, dim=1) + avg_approximation = config.average_approximation_function + if avg_approximation == AverageApproximationFunction.RATIO: + average_result = sum_result / (count_result + EPSILON_ZERO_DIVISION) + elif avg_approximation == AverageApproximationFunction.FIRST_ORDER: + # The sum of all probabilities except that correspond to other cells + # Ex here stands for expectation, more explicitly the expectation of the sum of N-1 Bernoulli random variables plus + # the constant 1, which is computed as adding all N expected values and subtracting the extra one. It corresponds to X_c + # in Appendix D of the original TAPAS paper which is trying to approximate the average of a random set. + ex = torch.sum(scaled_probability_per_cell, dim=1, keepdim=True) - scaled_probability_per_cell + 1 + average_result = torch.sum(numeric_values_masked * scaled_probability_per_cell / ex, dim=1) + elif avg_approximation == AverageApproximationFunction.SECOND_ORDER: + # The sum of all probabilities except that correspond to other cells + ex = torch.sum(scaled_probability_per_cell, dim=1, keepdim=True) - scaled_probability_per_cell + 1 + pointwise_var = scaled_probability_per_cell * (1 - scaled_probability_per_cell) + var = torch.sum(pointwise_var, dim=1, keepdim=True) - pointwise_var + + multiplier = (var / torch.square(ex) + 1) / ex + average_result = torch.sum(numeric_values_masked * scaled_probability_per_cell * multiplier, dim=1) + else: + raise ValueError(f"Invalid average_approximation_function: {config.average_approximation_function}") + + if config.use_gumbel_for_aggregation: + gumbel_dist = torch.distributions.RelaxedOneHotCategorical( + config.aggregation_temperature, logits=logits_aggregation[:, 1:] + ) + # [batch_size, num_aggregation_labels - 1] + aggregation_op_only_probs = gumbel_dist.sample() + else: + # [batch_size, num_aggregation_labels - 1] + aggregation_op_only_probs = nn.functional.softmax( + logits_aggregation[:, 1:] / config.aggregation_temperature, dim=-1 + ) + + all_results = torch.cat( + [ + torch.unsqueeze(sum_result, dim=1), + torch.unsqueeze(average_result, dim=1), + torch.unsqueeze(count_result, dim=1), + ], + dim=1, + ) + + expected_result = torch.sum(all_results * aggregation_op_only_probs, dim=1) + return expected_result + + +# PyTorch does not currently support Huber loss with custom delta so we define it ourself +def huber_loss(input, target, delta: float = 1.0): + errors = torch.abs(input - target) # shape (batch_size,) + return torch.where(errors < delta, 0.5 * errors**2, errors * delta - (0.5 * delta**2)) + + +def _calculate_regression_loss( + answer, + aggregate_mask, + dist_per_cell, + numeric_values, + numeric_values_scale, + input_mask_float, + logits_aggregation, + config, +): + """ + Calculates the regression loss per example. + + Args: + answer (`torch.FloatTensor` of shape `(batch_size,)`): + Answer for every example in the batch. Nan if there is no scalar answer. + aggregate_mask (`torch.FloatTensor` of shape `(batch_size,)`): + A mask set to 1 for examples that should use aggregation functions. + dist_per_cell (`torch.distributions.Bernoulli`): + Cell selection distribution for each cell. + numeric_values (`torch.FloatTensor` of shape `(batch_size, seq_length)`): + Numeric values of every token. Nan for tokens which are not numeric values. + numeric_values_scale (`torch.FloatTensor` of shape `(batch_size, seq_length)`): + Scale of the numeric values of every token. + input_mask_float (`torch.FloatTensor` of shape `(batch_size, seq_length)`): + Mask for the table, without question tokens and table headers. + logits_aggregation (`torch.FloatTensor` of shape `(batch_size, num_aggregation_labels)`): + Logits per aggregation operation. + config ([`TapasConfig`]): + Model configuration class with all the parameters of the model + + Returns: + per_example_answer_loss_scaled (`torch.FloatTensor` of shape `(batch_size,)`): Scales answer loss for each + example in the batch. large_answer_loss_mask (`torch.FloatTensor` of shape `(batch_size,)`): A mask which is 1 + for examples for which their answer loss is larger than the answer_loss_cutoff. + """ + # float32 (batch_size,) + expected_result = _calculate_expected_result( + dist_per_cell, numeric_values, numeric_values_scale, input_mask_float, logits_aggregation, config + ) + + # float32 (batch_size,) + answer_masked = torch.where(torch.isnan(answer), torch.zeros_like(answer), answer) + + if config.use_normalized_answer_loss: + normalizer = (torch.max(torch.abs(expected_result), torch.abs(answer_masked)) + EPSILON_ZERO_DIVISION).detach() + + normalized_answer_masked = answer_masked / normalizer + normalized_expected_result = expected_result / normalizer + per_example_answer_loss = huber_loss( + normalized_expected_result * aggregate_mask, normalized_answer_masked * aggregate_mask + ) + else: + per_example_answer_loss = huber_loss( + expected_result * aggregate_mask, answer_masked * aggregate_mask, delta=config.huber_loss_delta + ) + + if config.answer_loss_cutoff is None: + large_answer_loss_mask = torch.ones_like(per_example_answer_loss, dtype=torch.float32) + + else: + large_answer_loss_mask = torch.where( + per_example_answer_loss > config.answer_loss_cutoff, + torch.zeros_like(per_example_answer_loss, dtype=torch.float32), + torch.ones_like(per_example_answer_loss, dtype=torch.float32), + ) + per_example_answer_loss_scaled = config.answer_loss_importance * (per_example_answer_loss * aggregate_mask) + + return per_example_answer_loss_scaled, large_answer_loss_mask + + +__all__ = [ + "TapasForMaskedLM", + "TapasForQuestionAnswering", + "TapasForSequenceClassification", + "TapasModel", + "TapasPreTrainedModel", + "load_tf_weights_in_tapas", +] diff --git a/docs/transformers/src/transformers/models/tapas/modeling_tf_tapas.py b/docs/transformers/src/transformers/models/tapas/modeling_tf_tapas.py new file mode 100644 index 0000000000000000000000000000000000000000..7e5abdd7fabcc6be96b7f8bc5f5f2f7b50459ea9 --- /dev/null +++ b/docs/transformers/src/transformers/models/tapas/modeling_tf_tapas.py @@ -0,0 +1,2462 @@ +# coding=utf-8 +# Copyright 2021 Google Research and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TF 2.0 TAPAS model.""" + +from __future__ import annotations + +import enum +import math +from dataclasses import dataclass +from typing import Dict, Optional, Tuple, Union + +import numpy as np +import tensorflow as tf + +from ...activations_tf import get_tf_activation +from ...modeling_tf_outputs import ( + TFBaseModelOutputWithPastAndCrossAttentions, + TFBaseModelOutputWithPooling, + TFMaskedLMOutput, + TFSequenceClassifierOutput, +) +from ...modeling_tf_utils import ( + TFMaskedLanguageModelingLoss, + TFModelInputType, + TFPreTrainedModel, + TFSequenceClassificationLoss, + get_initializer, + keras, + keras_serializable, + unpack_inputs, +) +from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax +from ...utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_tensorflow_probability_available, + logging, + replace_return_docstrings, +) +from .configuration_tapas import TapasConfig + + +logger = logging.get_logger(__name__) + +# soft dependency +if is_tensorflow_probability_available(): + try: + import tensorflow_probability as tfp + + # On the first call, check whether a compatible version of TensorFlow is installed + # TensorFlow Probability depends on a recent stable release of TensorFlow + n = tfp.distributions.Normal(loc=0.0, scale=1.0) + except ImportError: + logger.error( + "TAPAS models are not usable since `tensorflow_probability` can't be loaded. " + "It seems you have `tensorflow_probability` installed with the wrong tensorflow version. " + "Please try to reinstall it following the instructions here: https://github.com/tensorflow/probability." + ) +else: + try: + import tensorflow_probability as tfp + + # On the first call, check whether a compatible version of TensorFlow is installed + # TensorFlow Probability depends on a recent stable release of TensorFlow + _ = tfp.distributions.Normal(loc=0.0, scale=1.0) + except ImportError: + pass + +_CONFIG_FOR_DOC = "TapasConfig" +_CHECKPOINT_FOR_DOC = "google/tapas-base" + + +EPSILON_ZERO_DIVISION = 1e-10 +CLOSE_ENOUGH_TO_LOG_ZERO = -10000.0 + + +@dataclass +class TFTableQuestionAnsweringOutput(ModelOutput): + """ + Output type of [`TFTapasForQuestionAnswering`]. + + Args: + loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` (and possibly `answer`, `aggregation_labels`, `numeric_values` and `numeric_values_scale` are provided)): + Total loss as the sum of the hierarchical cell selection log-likelihood loss and (optionally) the + semi-supervised regression loss and (optionally) supervised loss for aggregations. + logits (`tf.Tensor` of shape `(batch_size, sequence_length)`): + Prediction scores of the cell selection head, for every token. + logits_aggregation (`tf.Tensor`, *optional*, of shape `(batch_size, num_aggregation_labels)`): + Prediction scores of the aggregation head, for every aggregation operator. + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus + the initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in + the self-attention heads. + """ + + loss: tf.Tensor | None = None + logits: Optional[tf.Tensor] = None + logits_aggregation: tf.Tensor | None = None + hidden_states: Tuple[tf.Tensor] | None = None + attentions: Tuple[tf.Tensor] | None = None + + +class TFTapasEmbeddings(keras.layers.Layer): + """ + Construct the embeddings from word, position and token_type embeddings. Same as BertEmbeddings but with a number of + additional token type embeddings to encode tabular structure. + """ + + def __init__(self, config: TapasConfig, **kwargs): + super().__init__(**kwargs) + + self.config = config + self.number_of_token_type_embeddings = len(config.type_vocab_sizes) + self.reset_position_index_per_cell = config.reset_position_index_per_cell + self.hidden_size = config.hidden_size + self.max_position_embeddings = config.max_position_embeddings + self.initializer_range = config.initializer_range + self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") + self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) + + def build(self, input_shape=None): + with tf.name_scope("word_embeddings"): + self.weight = self.add_weight( + name="weight", + shape=[self.config.vocab_size, self.hidden_size], + initializer=get_initializer(self.initializer_range), + ) + + with tf.name_scope("position_embeddings"): + self.position_embeddings = self.add_weight( + name="embeddings", + shape=[self.max_position_embeddings, self.hidden_size], + initializer=get_initializer(self.initializer_range), + ) + for i, type_vocab_size in enumerate(self.config.type_vocab_sizes): + with tf.name_scope(f"token_type_embeddings_{i}"): + setattr( + self, + f"token_type_embeddings_{i}", + self.add_weight( + name="embeddings", + shape=[type_vocab_size, self.hidden_size], + initializer=get_initializer(self.initializer_range), + ), + ) + + if self.built: + return + self.built = True + if getattr(self, "LayerNorm", None) is not None: + with tf.name_scope(self.LayerNorm.name): + self.LayerNorm.build([None, None, self.config.hidden_size]) + + def call( + self, + input_ids: Optional[tf.Tensor] = None, + position_ids: Optional[tf.Tensor] = None, + token_type_ids: Optional[tf.Tensor] = None, + inputs_embeds: Optional[tf.Tensor] = None, + training: bool = False, + ) -> tf.Tensor: + """ + Applies embedding based on inputs tensor. + + Returns: + final_embeddings (`tf.Tensor`): output embedding tensor. + """ + assert not (input_ids is None and inputs_embeds is None) + if input_ids is not None: + input_shape = shape_list(input_ids) + else: + input_shape = shape_list(inputs_embeds)[:-1] + + seq_length = input_shape[1] + + if token_type_ids is None: + token_type_ids = tf.fill(dims=input_shape + [self.number_of_token_type_embeddings], value=0) + + if position_ids is None: + # create absolute position embeddings + position_ids = tf.expand_dims(tf.range(start=0, limit=seq_length), axis=0) + position_ids = tf.broadcast_to(position_ids, shape=input_shape) + # when self.config.reset_position_index_per_cell is set to True, create relative position embeddings + if self.reset_position_index_per_cell: + # shape (batch_size, seq_len) + col_index = IndexMap(token_type_ids[:, :, 1], self.config.type_vocab_sizes[1], batch_dims=1) + # shape (batch_size, seq_len) + row_index = IndexMap(token_type_ids[:, :, 2], self.config.type_vocab_sizes[2], batch_dims=1) + # shape (batch_size, seq_len) + full_index = ProductIndexMap(col_index, row_index) + # shape (max_rows * max_columns,). First absolute position for every cell + first_position_per_segment = reduce_min(position_ids, full_index)[0] + # ? shape (batch_size, seq_len). First absolute position of the cell for every token + first_position = gather(first_position_per_segment, full_index) + # shape (1, seq_len) + position = tf.expand_dims(tf.range(start=0, limit=seq_length), axis=0) + position_ids = tf.math.minimum(self.max_position_embeddings - 1, position - first_position) + + if input_ids is not None: + check_embeddings_within_bounds(input_ids, self.config.vocab_size) + inputs_embeds = tf.gather(params=self.weight, indices=input_ids) + + position_embeddings = tf.gather(self.position_embeddings, indices=position_ids) + + final_embeddings = inputs_embeds + position_embeddings + + for i in range(self.number_of_token_type_embeddings): + name = f"token_type_embeddings_{i}" + final_embeddings += tf.gather(params=getattr(self, name), indices=token_type_ids[:, :, i]) + + final_embeddings = self.LayerNorm(inputs=final_embeddings) + final_embeddings = self.dropout(inputs=final_embeddings, training=training) + + return final_embeddings + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfAttention with Bert->Tapas +class TFTapasSelfAttention(keras.layers.Layer): + def __init__(self, config: TapasConfig, **kwargs): + super().__init__(**kwargs) + + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number " + f"of attention heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.sqrt_att_head_size = math.sqrt(self.attention_head_size) + + self.query = keras.layers.Dense( + units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query" + ) + self.key = keras.layers.Dense( + units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key" + ) + self.value = keras.layers.Dense( + units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value" + ) + self.dropout = keras.layers.Dropout(rate=config.attention_probs_dropout_prob) + + self.is_decoder = config.is_decoder + self.config = config + + def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor: + # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size] + tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size)) + + # Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size] + return tf.transpose(tensor, perm=[0, 2, 1, 3]) + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: tf.Tensor, + head_mask: tf.Tensor, + encoder_hidden_states: tf.Tensor, + encoder_attention_mask: tf.Tensor, + past_key_value: Tuple[tf.Tensor], + output_attentions: bool, + training: bool = False, + ) -> Tuple[tf.Tensor]: + batch_size = shape_list(hidden_states)[0] + mixed_query_layer = self.query(inputs=hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_layer = past_key_value[0] + value_layer = past_key_value[1] + attention_mask = encoder_attention_mask + elif is_cross_attention: + key_layer = self.transpose_for_scores(self.key(inputs=encoder_hidden_states), batch_size) + value_layer = self.transpose_for_scores(self.value(inputs=encoder_hidden_states), batch_size) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size) + value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size) + key_layer = tf.concat([past_key_value[0], key_layer], axis=2) + value_layer = tf.concat([past_key_value[1], value_layer], axis=2) + else: + key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size) + value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size) + + query_layer = self.transpose_for_scores(mixed_query_layer, batch_size) + + if self.is_decoder: + # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + # (batch size, num_heads, seq_len_q, seq_len_k) + attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) + dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype) + attention_scores = tf.divide(attention_scores, dk) + + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in TFTapasModel call() function) + attention_scores = tf.add(attention_scores, attention_mask) + + # Normalize the attention scores to probabilities. + attention_probs = stable_softmax(logits=attention_scores, axis=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(inputs=attention_probs, training=training) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = tf.multiply(attention_probs, head_mask) + + attention_output = tf.matmul(attention_probs, value_layer) + attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3]) + + # (batch_size, seq_len_q, all_head_size) + attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size)) + outputs = (attention_output, attention_probs) if output_attentions else (attention_output,) + + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "query", None) is not None: + with tf.name_scope(self.query.name): + self.query.build([None, None, self.config.hidden_size]) + if getattr(self, "key", None) is not None: + with tf.name_scope(self.key.name): + self.key.build([None, None, self.config.hidden_size]) + if getattr(self, "value", None) is not None: + with tf.name_scope(self.value.name): + self.value.build([None, None, self.config.hidden_size]) + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfOutput with Bert->Tapas +class TFTapasSelfOutput(keras.layers.Layer): + def __init__(self, config: TapasConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = keras.layers.Dense( + units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") + self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) + self.config = config + + def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: + hidden_states = self.dense(inputs=hidden_states) + hidden_states = self.dropout(inputs=hidden_states, training=training) + hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor) + + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + if getattr(self, "LayerNorm", None) is not None: + with tf.name_scope(self.LayerNorm.name): + self.LayerNorm.build([None, None, self.config.hidden_size]) + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertAttention with Bert->Tapas +class TFTapasAttention(keras.layers.Layer): + def __init__(self, config: TapasConfig, **kwargs): + super().__init__(**kwargs) + + self.self_attention = TFTapasSelfAttention(config, name="self") + self.dense_output = TFTapasSelfOutput(config, name="output") + + def prune_heads(self, heads): + raise NotImplementedError + + def call( + self, + input_tensor: tf.Tensor, + attention_mask: tf.Tensor, + head_mask: tf.Tensor, + encoder_hidden_states: tf.Tensor, + encoder_attention_mask: tf.Tensor, + past_key_value: Tuple[tf.Tensor], + output_attentions: bool, + training: bool = False, + ) -> Tuple[tf.Tensor]: + self_outputs = self.self_attention( + hidden_states=input_tensor, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + training=training, + ) + attention_output = self.dense_output( + hidden_states=self_outputs[0], input_tensor=input_tensor, training=training + ) + # add attentions (possibly with past_key_value) if we output them + outputs = (attention_output,) + self_outputs[1:] + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "self_attention", None) is not None: + with tf.name_scope(self.self_attention.name): + self.self_attention.build(None) + if getattr(self, "dense_output", None) is not None: + with tf.name_scope(self.dense_output.name): + self.dense_output.build(None) + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertIntermediate with Bert->Tapas +class TFTapasIntermediate(keras.layers.Layer): + def __init__(self, config: TapasConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = keras.layers.Dense( + units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = get_tf_activation(config.hidden_act) + else: + self.intermediate_act_fn = config.hidden_act + self.config = config + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + hidden_states = self.dense(inputs=hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertOutput with Bert->Tapas +class TFTapasOutput(keras.layers.Layer): + def __init__(self, config: TapasConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = keras.layers.Dense( + units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") + self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) + self.config = config + + def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: + hidden_states = self.dense(inputs=hidden_states) + hidden_states = self.dropout(inputs=hidden_states, training=training) + hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor) + + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.intermediate_size]) + if getattr(self, "LayerNorm", None) is not None: + with tf.name_scope(self.LayerNorm.name): + self.LayerNorm.build([None, None, self.config.hidden_size]) + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertLayer with Bert->Tapas +class TFTapasLayer(keras.layers.Layer): + def __init__(self, config: TapasConfig, **kwargs): + super().__init__(**kwargs) + + self.attention = TFTapasAttention(config, name="attention") + self.is_decoder = config.is_decoder + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + if not self.is_decoder: + raise ValueError(f"{self} should be used as a decoder model if cross attention is added") + self.crossattention = TFTapasAttention(config, name="crossattention") + self.intermediate = TFTapasIntermediate(config, name="intermediate") + self.bert_output = TFTapasOutput(config, name="output") + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: tf.Tensor, + head_mask: tf.Tensor, + encoder_hidden_states: tf.Tensor | None, + encoder_attention_mask: tf.Tensor | None, + past_key_value: Tuple[tf.Tensor] | None, + output_attentions: bool, + training: bool = False, + ) -> Tuple[tf.Tensor]: + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attention_outputs = self.attention( + input_tensor=hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=self_attn_past_key_value, + output_attentions=output_attentions, + training=training, + ) + attention_output = self_attention_outputs[0] + + # if decoder, the last output is tuple of self-attn cache + if self.is_decoder: + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + else: + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + cross_attn_present_key_value = None + if self.is_decoder and encoder_hidden_states is not None: + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" + " by setting `config.add_cross_attention=True`" + ) + + # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + cross_attention_outputs = self.crossattention( + input_tensor=attention_output, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=cross_attn_past_key_value, + output_attentions=output_attentions, + training=training, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + + # add cross-attn cache to positions 3,4 of present_key_value tuple + cross_attn_present_key_value = cross_attention_outputs[-1] + present_key_value = present_key_value + cross_attn_present_key_value + + intermediate_output = self.intermediate(hidden_states=attention_output) + layer_output = self.bert_output( + hidden_states=intermediate_output, input_tensor=attention_output, training=training + ) + outputs = (layer_output,) + outputs # add attentions if we output them + + # if decoder, return the attn key/values as the last output + if self.is_decoder: + outputs = outputs + (present_key_value,) + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "attention", None) is not None: + with tf.name_scope(self.attention.name): + self.attention.build(None) + if getattr(self, "intermediate", None) is not None: + with tf.name_scope(self.intermediate.name): + self.intermediate.build(None) + if getattr(self, "bert_output", None) is not None: + with tf.name_scope(self.bert_output.name): + self.bert_output.build(None) + if getattr(self, "crossattention", None) is not None: + with tf.name_scope(self.crossattention.name): + self.crossattention.build(None) + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertEncoder with Bert->Tapas +class TFTapasEncoder(keras.layers.Layer): + def __init__(self, config: TapasConfig, **kwargs): + super().__init__(**kwargs) + self.config = config + self.layer = [TFTapasLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)] + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: tf.Tensor, + head_mask: tf.Tensor, + encoder_hidden_states: tf.Tensor | None, + encoder_attention_mask: tf.Tensor | None, + past_key_values: Tuple[Tuple[tf.Tensor]] | None, + use_cache: Optional[bool], + output_attentions: bool, + output_hidden_states: bool, + return_dict: bool, + training: bool = False, + ) -> Union[TFBaseModelOutputWithPastAndCrossAttentions, Tuple[tf.Tensor]]: + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + next_decoder_cache = () if use_cache else None + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + past_key_value = past_key_values[i] if past_key_values is not None else None + + layer_outputs = layer_module( + hidden_states=hidden_states, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + training=training, + ) + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + if self.config.add_cross_attention and encoder_hidden_states is not None: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v for v in [hidden_states, all_hidden_states, all_attentions, all_cross_attentions] if v is not None + ) + + return TFBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "layer", None) is not None: + for layer in self.layer: + with tf.name_scope(layer.name): + layer.build(None) + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertPooler with Bert->Tapas +class TFTapasPooler(keras.layers.Layer): + def __init__(self, config: TapasConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = keras.layers.Dense( + units=config.hidden_size, + kernel_initializer=get_initializer(config.initializer_range), + activation="tanh", + name="dense", + ) + self.config = config + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(inputs=first_token_tensor) + + return pooled_output + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertPredictionHeadTransform with Bert->Tapas +class TFTapasPredictionHeadTransform(keras.layers.Layer): + def __init__(self, config: TapasConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = keras.layers.Dense( + units=config.hidden_size, + kernel_initializer=get_initializer(config.initializer_range), + name="dense", + ) + + if isinstance(config.hidden_act, str): + self.transform_act_fn = get_tf_activation(config.hidden_act) + else: + self.transform_act_fn = config.hidden_act + + self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") + self.config = config + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + hidden_states = self.dense(inputs=hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(inputs=hidden_states) + + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + if getattr(self, "LayerNorm", None) is not None: + with tf.name_scope(self.LayerNorm.name): + self.LayerNorm.build([None, None, self.config.hidden_size]) + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertLMPredictionHead with Bert->Tapas +class TFTapasLMPredictionHead(keras.layers.Layer): + def __init__(self, config: TapasConfig, input_embeddings: keras.layers.Layer, **kwargs): + super().__init__(**kwargs) + + self.config = config + self.hidden_size = config.hidden_size + + self.transform = TFTapasPredictionHeadTransform(config, name="transform") + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.input_embeddings = input_embeddings + + def build(self, input_shape=None): + self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias") + + if self.built: + return + self.built = True + if getattr(self, "transform", None) is not None: + with tf.name_scope(self.transform.name): + self.transform.build(None) + + def get_output_embeddings(self) -> keras.layers.Layer: + return self.input_embeddings + + def set_output_embeddings(self, value: tf.Variable): + self.input_embeddings.weight = value + self.input_embeddings.vocab_size = shape_list(value)[0] + + def get_bias(self) -> Dict[str, tf.Variable]: + return {"bias": self.bias} + + def set_bias(self, value: tf.Variable): + self.bias = value["bias"] + self.config.vocab_size = shape_list(value["bias"])[0] + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + hidden_states = self.transform(hidden_states=hidden_states) + seq_length = shape_list(hidden_states)[1] + hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.hidden_size]) + hidden_states = tf.matmul(a=hidden_states, b=self.input_embeddings.weight, transpose_b=True) + hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.config.vocab_size]) + hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.bias) + + return hidden_states + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertMLMHead with Bert->Tapas +class TFTapasMLMHead(keras.layers.Layer): + def __init__(self, config: TapasConfig, input_embeddings: keras.layers.Layer, **kwargs): + super().__init__(**kwargs) + + self.predictions = TFTapasLMPredictionHead(config, input_embeddings, name="predictions") + + def call(self, sequence_output: tf.Tensor) -> tf.Tensor: + prediction_scores = self.predictions(hidden_states=sequence_output) + + return prediction_scores + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "predictions", None) is not None: + with tf.name_scope(self.predictions.name): + self.predictions.build(None) + + +@keras_serializable +class TFTapasMainLayer(keras.layers.Layer): + config_class = TapasConfig + + def __init__(self, config: TapasConfig, add_pooling_layer: bool = True, **kwargs): + super().__init__(**kwargs) + + self.config = config + + self.embeddings = TFTapasEmbeddings(config, name="embeddings") + self.encoder = TFTapasEncoder(config, name="encoder") + self.pooler = TFTapasPooler(config, name="pooler") if add_pooling_layer else None + + def get_input_embeddings(self) -> keras.layers.Layer: + return self.embeddings + + def set_input_embeddings(self, value: tf.Variable): + self.embeddings.weight = value + self.embeddings.vocab_size = shape_list(value)[0] + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + raise NotImplementedError + + @unpack_inputs + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: bool = False, + ) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]: + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = shape_list(input_ids) + elif inputs_embeds is not None: + input_shape = shape_list(inputs_embeds)[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if attention_mask is None: + attention_mask = tf.fill(dims=input_shape, value=1) + + if token_type_ids is None: + token_type_ids = tf.fill(dims=input_shape + [len(self.config.type_vocab_sizes)], value=0) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + training=training, + ) + + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + extended_attention_mask = tf.reshape(attention_mask, (input_shape[0], 1, 1, input_shape[1])) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = tf.cast(extended_attention_mask, dtype=embedding_output.dtype) + one_cst = tf.constant(1.0, dtype=embedding_output.dtype) + ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype) + extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + if head_mask is not None: + raise NotImplementedError + else: + head_mask = [None] * self.config.num_hidden_layers + + encoder_outputs = self.encoder( + hidden_states=embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(hidden_states=sequence_output) if self.pooler is not None else None + + if not return_dict: + return ( + sequence_output, + pooled_output, + ) + encoder_outputs[1:] + + return TFBaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "embeddings", None) is not None: + with tf.name_scope(self.embeddings.name): + self.embeddings.build(None) + if getattr(self, "encoder", None) is not None: + with tf.name_scope(self.encoder.name): + self.encoder.build(None) + if getattr(self, "pooler", None) is not None: + with tf.name_scope(self.pooler.name): + self.pooler.build(None) + + +class TFTapasPreTrainedModel(TFPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = TapasConfig + base_model_prefix = "tapas" + + @property + def input_signature(self): + return { + "input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"), + "attention_mask": tf.TensorSpec((None, None), tf.float32, name="attention_mask"), + "token_type_ids": tf.TensorSpec((None, None, 7), tf.int32, name="token_type_ids"), + } + + +TAPAS_START_DOCSTRING = r""" + + This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it + as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and + behavior. + + + + TensorFlow models and layers in `transformers` accept two formats as input: + + - having all inputs as keyword arguments (like PyTorch models), or + - having all inputs as a list, tuple or dict in the first positional argument. + + The reason the second format is supported is that Keras methods prefer this format when passing inputs to models + and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just + pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second + format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with + the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first + positional argument: + + - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` + - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: + `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` + - a dictionary with one or several input Tensors associated to the input names given in the docstring: + `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` + + Note that when creating models and layers with + [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry + about any of this, as you can just pass inputs like you would to any other Python function! + + + + Parameters: + config ([`TapasConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +TAPAS_INPUTS_DOCSTRING = r""" + Args: + input_ids (`np.ndarray`, `tf.Tensor`, `List[tf.Tensor]` ``Dict[str, tf.Tensor]` or `Dict[str, np.ndarray]` and each example must have the shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and + [`PreTrainedTokenizer.encode`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`np.ndarray` or `tf.Tensor` of shape `({0}, 7)`, *optional*): + Token indices that encode tabular structure. Indices can be obtained using [`AutoTokenizer`]. See this + class for more info. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. If + `reset_position_index_per_cell` of [`TapasConfig`] is set to `True`, relative position embeddings will be + used. Selected in the range `[0, config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`np.ndarray` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`np.ndarray` or `tf.Tensor` of shape `({0}, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the + config will be used instead. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. This argument can be used only in eager mode, in graph mode the value in the config will be + used instead. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in + eager mode, in graph mode the value will always be set to True. + training (`bool`, *optional*, defaults to `False``): + Whether or not to use the model in training mode (some modules like dropout modules have different + behaviors between training and evaluation). +""" + + +@add_start_docstrings( + "The bare Tapas Model transformer outputting raw hidden-states without any specific head on top.", + TAPAS_START_DOCSTRING, +) +class TFTapasModel(TFTapasPreTrainedModel): + def __init__(self, config: TapasConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.tapas = TFTapasMainLayer(config, name="tapas") + + @unpack_inputs + @add_start_docstrings_to_model_forward(TAPAS_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=TFBaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: Optional[bool] = False, + ) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]: + r""" + Returns: + + Examples: + + ```python + >>> from transformers import AutoTokenizer, TapasModel + >>> import pandas as pd + + >>> tokenizer = AutoTokenizer.from_pretrained("google/tapas-base") + >>> model = TapasModel.from_pretrained("google/tapas-base") + + >>> data = { + ... "Actors": ["Brad Pitt", "Leonardo Di Caprio", "George Clooney"], + ... "Age": ["56", "45", "59"], + ... "Number of movies": ["87", "53", "69"], + ... } + >>> table = pd.DataFrame.from_dict(data) + >>> queries = ["How many movies has George Clooney played in?", "How old is Brad Pitt?"] + + >>> inputs = tokenizer(table=table, queries=queries, padding="max_length", return_tensors="tf") + >>> outputs = model(**inputs) + + >>> last_hidden_states = outputs.last_hidden_state + ```""" + outputs = self.tapas( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "tapas", None) is not None: + with tf.name_scope(self.tapas.name): + self.tapas.build(None) + + +@add_start_docstrings("""Tapas Model with a `language modeling` head on top.""", TAPAS_START_DOCSTRING) +class TFTapasForMaskedLM(TFTapasPreTrainedModel, TFMaskedLanguageModelingLoss): + def __init__(self, config: TapasConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + if config.is_decoder: + logger.warning( + "If you want to use `TFTapasForMaskedLM` make sure `config.is_decoder=False` for " + "bi-directional self-attention." + ) + + self.tapas = TFTapasMainLayer(config, add_pooling_layer=False, name="tapas") + self.lm_head = TFTapasMLMHead(config, input_embeddings=self.tapas.embeddings, name="cls") + + def get_lm_head(self) -> keras.layers.Layer: + return self.lm_head.predictions + + @unpack_inputs + @add_start_docstrings_to_model_forward(TAPAS_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=TFMaskedLMOutput, config_class=_CONFIG_FOR_DOC) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: np.ndarray | tf.Tensor | None = None, + training: Optional[bool] = False, + ) -> Union[TFMaskedLMOutput, Tuple[tf.Tensor]]: + r""" + labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + + Returns: + + Examples: + + ```python + >>> from transformers import AutoTokenizer, TapasForMaskedLM + >>> import pandas as pd + + >>> tokenizer = AutoTokenizer.from_pretrained("google/tapas-base") + >>> model = TapasForMaskedLM.from_pretrained("google/tapas-base") + + >>> data = { + ... "Actors": ["Brad Pitt", "Leonardo Di Caprio", "George Clooney"], + ... "Age": ["56", "45", "59"], + ... "Number of movies": ["87", "53", "69"], + ... } + >>> table = pd.DataFrame.from_dict(data) + + >>> inputs = tokenizer( + ... table=table, queries="How many [MASK] has George [MASK] played in?", return_tensors="tf" + ... ) + >>> labels = tokenizer( + ... table=table, queries="How many movies has George Clooney played in?", return_tensors="tf" + ... )["input_ids"] + + >>> outputs = model(**inputs, labels=labels) + >>> logits = outputs.logits + ```""" + outputs = self.tapas( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + sequence_output = outputs[0] + prediction_scores = self.lm_head(sequence_output) + loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=prediction_scores) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFMaskedLMOutput( + loss=loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "tapas", None) is not None: + with tf.name_scope(self.tapas.name): + self.tapas.build(None) + if getattr(self, "lm_head", None) is not None: + with tf.name_scope(self.lm_head.name): + self.lm_head.build(None) + + +class TFTapasComputeTokenLogits(keras.layers.Layer): + def __init__(self, config: TapasConfig, **kwargs): + super().__init__(**kwargs) + + self.temperature = config.temperature + # cell selection heads + with tf.name_scope("output"): + self.output_weights = self.add_weight( + name="output_weights", + shape=(config.hidden_size,), + dtype=tf.float32, + trainable=True, + initializer=tf.zeros_initializer() + if config.init_cell_selection_weights_to_zero + else keras.initializers.TruncatedNormal(stddev=config.initializer_range), + ) + self.output_bias = self.add_weight( + name="output_bias", shape=(), trainable=True, initializer=tf.zeros_initializer() + ) + + def call(self, sequence_output: tf.Tensor) -> tf.Tensor: + """ + Computes logits per token + + Args: + sequence_output (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): + Also known as last_hidden_state. Sequence of hidden-states at the output of the last layer of the + model. + + Returns: + logits (`tf.Tensor` of shape `(batch_size, sequence_length)`): Logits per token. + """ + logits = (tf.einsum("bsj,j->bs", sequence_output, self.output_weights) + self.output_bias) / self.temperature + return logits + + +class TFTapasComputeColumnLogits(keras.layers.Layer): + def __init__(self, config: TapasConfig, **kwargs): + super().__init__(**kwargs) + + with tf.name_scope("column_output"): + self.column_output_weights = self.add_weight( + name="column_output_weights", + shape=[config.hidden_size], + dtype=tf.float32, + trainable=True, + initializer=tf.zeros_initializer() + if config.init_cell_selection_weights_to_zero + else keras.initializers.TruncatedNormal(stddev=config.initializer_range), + ) + self.column_output_bias = self.add_weight( + name="column_output_bias", shape=(), trainable=True, initializer=tf.zeros_initializer() + ) + + def call(self, sequence_output, cell_index, cell_mask, allow_empty_column_selection) -> tf.Tensor: + """ + Computes the column logits. + + Args: + sequence_output (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): + Also known as last_hidden_state. Sequence of hidden-states at the output of the last layer of the + model. + cell_index (`ProductIndexMap`): + Index that groups tokens into cells. + cell_mask (`tf.Tensor` of shape `(batch_size, max_num_rows * max_num_cols)`): + Mask for cells that exist in the table (i.e. that are not padding). + allow_empty_column_selection (`bool`): + Whether to allow not to select any column + + Returns: + column_logits (`tf.Tensor`of shape `(batch_size, max_num_cols)`): Tensor containing the column logits for + every example in the batch. + """ + + # First, compute the token logits (batch_size, seq_len) - without temperature + token_logits = tf.einsum("bsj,j->bs", sequence_output, self.column_output_weights) + self.column_output_bias + + # Next, average the logits per cell (batch_size, max_num_cols*max_num_rows) + cell_logits, cell_logits_index = reduce_mean(token_logits, cell_index) + + # Finally, average the logits per column (batch_size, max_num_cols) + column_index = cell_index.project_inner(cell_logits_index) + column_logits, out_index = reduce_sum(cell_logits * cell_mask, column_index) + + cell_count, _ = reduce_sum(cell_mask, column_index) + column_logits /= cell_count + EPSILON_ZERO_DIVISION + + # Mask columns that do not appear in the example. + is_padding = tf.logical_and(cell_count < 0.5, tf.not_equal(out_index.indices, 0)) + column_logits += CLOSE_ENOUGH_TO_LOG_ZERO * tf.cast(is_padding, tf.float32) + + if not allow_empty_column_selection: + column_logits += CLOSE_ENOUGH_TO_LOG_ZERO * tf.cast(tf.equal(out_index.indices, 0), tf.float32) + + return column_logits + + +@add_start_docstrings( + """ + Tapas Model with a cell selection head and optional aggregation head on top for question-answering tasks on tables + (linear layers on top of the hidden-states output to compute `logits` and optional `logits_aggregation`), e.g. for + SQA, WTQ or WikiSQL-supervised tasks. + """, + TAPAS_START_DOCSTRING, +) +class TFTapasForQuestionAnswering(TFTapasPreTrainedModel): + def __init__(self, config: TapasConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + # base model + self.tapas = TFTapasMainLayer(config, name="tapas") + + # dropout + self.dropout = keras.layers.Dropout(config.hidden_dropout_prob) + + self.compute_token_logits = TFTapasComputeTokenLogits(config, name="compute_token_logits") + + self.compute_column_logits = TFTapasComputeColumnLogits(config, name="compute_column_logits") + + if config.num_aggregation_labels > 0: + self.aggregation_classifier = keras.layers.Dense( + config.num_aggregation_labels, + kernel_initializer=get_initializer(config.initializer_range), + name="aggregation_classifier", + ) + self.config = config + + @unpack_inputs + @add_start_docstrings_to_model_forward(TAPAS_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=TFTableQuestionAnsweringOutput, config_class=_CONFIG_FOR_DOC) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + table_mask: np.ndarray | tf.Tensor | None = None, + aggregation_labels: np.ndarray | tf.Tensor | None = None, + float_answer: np.ndarray | tf.Tensor | None = None, + numeric_values: np.ndarray | tf.Tensor | None = None, + numeric_values_scale: np.ndarray | tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: np.ndarray | tf.Tensor | None = None, + training: Optional[bool] = False, + ) -> Union[TFTableQuestionAnsweringOutput, Tuple[tf.Tensor]]: + r""" + table_mask (`tf.Tensor` of shape `(batch_size, seq_length)`, *optional*): + Mask for the table. Indicates which tokens belong to the table (1). Question tokens, table headers and + padding are 0. + labels (`tf.Tensor` of shape `(batch_size, seq_length)`, *optional*): + Labels per token for computing the hierarchical cell selection loss. This encodes the positions of the + answer appearing in the table. Can be obtained using [`AutoTokenizer`]. + + - 1 for tokens that are **part of the answer**, + - 0 for tokens that are **not part of the answer**. + + aggregation_labels (`tf.Tensor` of shape `(batch_size, )`, *optional*): + Aggregation function index for every example in the batch for computing the aggregation loss. Indices + should be in `[0, ..., config.num_aggregation_labels - 1]`. Only required in case of strong supervision for + aggregation (WikiSQL-supervised). + float_answer (`tf.Tensor` of shape `(batch_size, )`, *optional*): + Float answer for every example in the batch. Set to *float('nan')* for cell selection questions. Only + required in case of weak supervision (WTQ) to calculate the aggregate mask and regression loss. + numeric_values (`tf.Tensor` of shape `(batch_size, seq_length)`, *optional*): + Numeric values of every token, NaN for tokens which are not numeric values. Can be obtained using + [`AutoTokenizer`]. Only required in case of weak supervision for aggregation (WTQ) to calculate the + regression loss. + numeric_values_scale (`tf.Tensor` of shape `(batch_size, seq_length)`, *optional*): + Scale of the numeric values of every token. Can be obtained using [`AutoTokenizer`]. Only required in case + of weak supervision for aggregation (WTQ) to calculate the regression loss. + + Returns: + + Examples: + + ```python + >>> from transformers import AutoTokenizer, TapasForQuestionAnswering + >>> import pandas as pd + + >>> tokenizer = AutoTokenizer.from_pretrained("google/tapas-base-finetuned-wtq") + >>> model = TapasForQuestionAnswering.from_pretrained("google/tapas-base-finetuned-wtq") + + >>> data = { + ... "Actors": ["Brad Pitt", "Leonardo Di Caprio", "George Clooney"], + ... "Age": ["56", "45", "59"], + ... "Number of movies": ["87", "53", "69"], + ... } + >>> table = pd.DataFrame.from_dict(data) + >>> queries = ["How many movies has George Clooney played in?", "How old is Brad Pitt?"] + + >>> inputs = tokenizer(table=table, queries=queries, padding="max_length", return_tensors="tf") + >>> outputs = model(**inputs) + + >>> logits = outputs.logits + >>> logits_aggregation = outputs.logits_aggregation + ```""" + + outputs = self.tapas( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + sequence_output = outputs[0] + pooled_output = outputs[1] + + sequence_output = self.dropout(sequence_output) + + if input_ids is not None: + input_shape = shape_list(input_ids) + else: + input_shape = shape_list(inputs_embeds)[:-1] + + # Construct indices for the table. + if token_type_ids is None: + token_type_ids = tf.fill(input_shape + [len(self.config.type_vocab_sizes)], 0) + + token_types = [ + "segment_ids", + "column_ids", + "row_ids", + "prev_labels", + "column_ranks", + "inv_column_ranks", + "numeric_relations", + ] + + row_ids = token_type_ids[:, :, token_types.index("row_ids")] + column_ids = token_type_ids[:, :, token_types.index("column_ids")] + + # Construct indices for the table. + row_index = IndexMap( + indices=tf.minimum(tf.cast(row_ids, tf.int32), self.config.max_num_rows - 1), + num_segments=self.config.max_num_rows, + batch_dims=1, + ) + col_index = IndexMap( + indices=tf.minimum(tf.cast(column_ids, tf.int32), self.config.max_num_columns - 1), + num_segments=self.config.max_num_columns, + batch_dims=1, + ) + cell_index = ProductIndexMap(row_index, col_index) + + # Masks. + input_shape = shape_list(input_ids) if input_ids is not None else shape_list(inputs_embeds)[:-1] + if attention_mask is None: + attention_mask = tf.ones(input_shape) + # Table cells only, without question tokens and table headers. + if table_mask is None: + table_mask = tf.where(row_ids > 0, tf.ones_like(row_ids), tf.zeros_like(row_ids)) + # [batch_size, seq_length] + input_mask_float = tf.cast(attention_mask, tf.float32) + table_mask_float = tf.cast(table_mask, tf.float32) + + # Mask for cells that exist in the table (i.e. that are not padding). + cell_mask, _ = reduce_mean(input_mask_float, cell_index) + + # Compute logits per token. These are used to select individual cells. + logits = self.compute_token_logits(sequence_output) + + # Compute logits per column. These are used to select a column. + column_logits = None + if self.config.select_one_column: + column_logits = self.compute_column_logits( + sequence_output, cell_index, cell_mask, self.config.allow_empty_column_selection + ) + + # Aggregate logits. + logits_aggregation = None + if self.config.num_aggregation_labels > 0: + logits_aggregation = self.aggregation_classifier(pooled_output) + + # Total loss calculation + total_loss = tf.zeros(shape=(1,), dtype=tf.float32) + calculate_loss = False + if labels is not None: + calculate_loss = True + is_supervised = not self.config.num_aggregation_labels > 0 or not self.config.use_answer_as_supervision + + # Semi-supervised cell selection in case of no aggregation: + # If the answer (the denotation) appears directly in the table we might + # select the answer without applying any aggregation function. There are + # some ambiguous cases, see utils._calculate_aggregate_mask for more info. + # `aggregate_mask` is 1 for examples where we chose to aggregate and 0 + # for examples where we chose to select the answer directly. + # `labels` encodes the positions of the answer appearing in the table. + if is_supervised: + aggregate_mask = None + else: + if float_answer is not None: + assert shape_list(labels)[0] == shape_list(float_answer)[0], ( + "Make sure the answers are a FloatTensor of shape (batch_size,)" + ) + # [batch_size] + aggregate_mask = _calculate_aggregate_mask( + float_answer, + pooled_output, + self.config.cell_selection_preference, + labels, + self.aggregation_classifier, + ) + else: + aggregate_mask = None + raise ValueError("You have to specify float answers in order to calculate the aggregate mask") + + # Cell selection log-likelihood + if self.config.average_logits_per_cell: + logits_per_cell, _ = reduce_mean(logits, cell_index) + logits = gather(logits_per_cell, cell_index) + dist_per_token = tfp.distributions.Bernoulli(logits=logits) + + # Compute cell selection loss per example. + selection_loss_per_example = None + if not self.config.select_one_column: + weight = tf.where( + labels == 0, + tf.ones_like(labels, dtype=tf.float32), + self.config.positive_label_weight * tf.ones_like(labels, dtype=tf.float32), + ) + selection_loss_per_token = -dist_per_token.log_prob(labels) * weight + selection_loss_per_example = tf.reduce_sum(selection_loss_per_token * input_mask_float, axis=1) / ( + tf.reduce_sum(input_mask_float, axis=1) + EPSILON_ZERO_DIVISION + ) + else: + selection_loss_per_example, logits = _single_column_cell_selection_loss( + logits, column_logits, labels, cell_index, col_index, cell_mask + ) + dist_per_token = tfp.distributions.Bernoulli(logits=logits) + + # Supervised cell selection + if self.config.disable_per_token_loss: + pass + elif is_supervised: + total_loss += tf.reduce_mean(selection_loss_per_example) + else: + # For the not supervised case, do not assign loss for cell selection + total_loss += tf.reduce_mean(selection_loss_per_example * (1.0 - aggregate_mask)) + + # Semi-supervised regression loss and supervised loss for aggregations + if self.config.num_aggregation_labels > 0: + if is_supervised: + # Note that `aggregate_mask` is None if the setting is supervised. + if aggregation_labels is not None: + assert shape_list(labels)[0] == shape_list(aggregation_labels)[0], ( + "Make sure the aggregation labels are a LongTensor of shape (batch_size,)" + ) + per_example_additional_loss = _calculate_aggregation_loss( + logits_aggregation, + aggregate_mask, + aggregation_labels, + self.config.use_answer_as_supervision, + self.config.num_aggregation_labels, + self.config.aggregation_loss_weight, + ) + else: + raise ValueError( + "You have to specify aggregation labels in order to calculate the aggregation loss" + ) + else: + aggregation_labels = tf.zeros(shape_list(labels)[0], dtype=tf.int32) + per_example_additional_loss = _calculate_aggregation_loss( + logits_aggregation, + aggregate_mask, + aggregation_labels, + self.config.use_answer_as_supervision, + self.config.num_aggregation_labels, + self.config.aggregation_loss_weight, + ) + + if self.config.use_answer_as_supervision: + if numeric_values is not None and numeric_values_scale is not None: + assert shape_list(numeric_values) == shape_list(numeric_values_scale) + # Add regression loss for numeric answers which require aggregation. + answer_loss, large_answer_loss_mask = _calculate_regression_loss( + float_answer, + aggregate_mask, + dist_per_token, + numeric_values, + numeric_values_scale, + table_mask_float, + logits_aggregation, + self.config, + ) + per_example_additional_loss += answer_loss + # Zero loss for examples with answer_loss > cutoff. + per_example_additional_loss *= large_answer_loss_mask + else: + raise ValueError( + "You have to specify numeric values and numeric values scale in order to calculate the" + " regression loss" + ) + total_loss += tf.reduce_mean(per_example_additional_loss) + + else: + # if no label ids are provided, set them to zeros in order to properly compute logits + labels = tf.zeros_like(logits) + _, logits = _single_column_cell_selection_loss( + logits, column_logits, labels, cell_index, col_index, cell_mask + ) + if not return_dict: + output = (logits, logits_aggregation) + outputs[2:] + return ((total_loss,) + output) if calculate_loss else output + + return TFTableQuestionAnsweringOutput( + loss=total_loss if calculate_loss else None, + logits=logits, + logits_aggregation=logits_aggregation, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "tapas", None) is not None: + with tf.name_scope(self.tapas.name): + self.tapas.build(None) + if getattr(self, "compute_token_logits", None) is not None: + with tf.name_scope(self.compute_token_logits.name): + self.compute_token_logits.build(None) + if getattr(self, "compute_column_logits", None) is not None: + with tf.name_scope(self.compute_column_logits.name): + self.compute_column_logits.build(None) + if getattr(self, "aggregation_classifier", None) is not None: + with tf.name_scope(self.aggregation_classifier.name): + self.aggregation_classifier.build([None, None, self.config.hidden_size]) + + +@add_start_docstrings( + """ + Tapas Model with a sequence classification head on top (a linear layer on top of the pooled output), e.g. for table + entailment tasks, such as TabFact (Chen et al., 2020). + """, + TAPAS_START_DOCSTRING, +) +class TFTapasForSequenceClassification(TFTapasPreTrainedModel, TFSequenceClassificationLoss): + def __init__(self, config: TapasConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.num_labels = config.num_labels + + self.tapas = TFTapasMainLayer(config, name="tapas") + self.dropout = keras.layers.Dropout(config.hidden_dropout_prob, name="dropout") + self.classifier = keras.layers.Dense( + config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" + ) + self.config = config + + @unpack_inputs + @add_start_docstrings_to_model_forward(TAPAS_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) + @replace_return_docstrings(output_type=TFSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: np.ndarray | tf.Tensor | None = None, + training: Optional[bool] = False, + ) -> Union[TFSequenceClassifierOutput, Tuple[tf.Tensor]]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). Note: this is called + "classification_class_index" in the original implementation. + + Returns: + + Examples: + + ```python + >>> from transformers import AutoTokenizer, TapasForSequenceClassification + >>> import tensorflow as tf + >>> import pandas as pd + + >>> tokenizer = AutoTokenizer.from_pretrained("google/tapas-base-finetuned-tabfact") + >>> model = TapasForSequenceClassification.from_pretrained("google/tapas-base-finetuned-tabfact") + + >>> data = { + ... "Actors": ["Brad Pitt", "Leonardo Di Caprio", "George Clooney"], + ... "Age": ["56", "45", "59"], + ... "Number of movies": ["87", "53", "69"], + ... } + >>> table = pd.DataFrame.from_dict(data) + >>> queries = [ + ... "There is only one actor who is 45 years old", + ... "There are 3 actors which played in more than 60 movies", + ... ] + + >>> inputs = tokenizer(table=table, queries=queries, padding="max_length", return_tensors="tf") + >>> labels = tf.convert_to_tensor([1, 0]) # 1 means entailed, 0 means refuted + + >>> outputs = model(**inputs, labels=labels) + >>> loss = outputs.loss + >>> logits = outputs.logits + ```""" + + outputs = self.tapas( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + pooled_output = outputs[1] + pooled_output = self.dropout(inputs=pooled_output, training=training) + logits = self.classifier(inputs=pooled_output) + loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFSequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "tapas", None) is not None: + with tf.name_scope(self.tapas.name): + self.tapas.build(None) + if getattr(self, "dropout", None) is not None: + with tf.name_scope(self.dropout.name): + self.dropout.build(None) + if getattr(self, "classifier", None) is not None: + with tf.name_scope(self.classifier.name): + self.classifier.build([None, None, self.config.hidden_size]) + + +""" TAPAS utilities.""" + + +class AverageApproximationFunction(str, enum.Enum): + RATIO = "ratio" + FIRST_ORDER = "first_order" + SECOND_ORDER = "second_order" + + +# Beginning of everything related to segmented tensors + + +class IndexMap: + """Index grouping entries within a tensor.""" + + def __init__(self, indices, num_segments, batch_dims=0): + """ + Creates an index. + + Args: + indices: Tensor of indices, same shape as `values`. + num_segments: Scalar tensor, the number of segments. All elements + in a batched segmented tensor must have the same number of segments (although many segments can be empty). + batch_dims: Python integer, the number of batch dimensions. The first + `batch_dims` dimensions of a SegmentedTensor are treated as batch dimensions. Segments in different batch + elements are always distinct even if they have the same index. + """ + self.indices = tf.convert_to_tensor(indices) + self.num_segments = tf.convert_to_tensor(num_segments) + self.batch_dims = batch_dims + + def batch_shape(self): + return tf.shape(self.indices)[: self.batch_dims] + + +class ProductIndexMap(IndexMap): + """The product of two indices.""" + + def __init__(self, outer_index, inner_index): + """ + Combines indices i and j into pairs (i, j). The result is an index where each segment (i, j) is the + intersection of segments i and j. For example if the inputs represent table cells indexed by respectively rows + and columns the output will be a table indexed by (row, column) pairs, i.e. by cell. The implementation + combines indices {0, .., n - 1} and {0, .., m - 1} into {0, .., nm - 1}. The output has `num_segments` equal to + `outer_index.num_segements` * `inner_index.num_segments`. + + Args: + outer_index: IndexMap. + inner_index: IndexMap, must have the same shape as `outer_index`. + """ + if outer_index.batch_dims != inner_index.batch_dims: + raise ValueError("outer_index.batch_dims and inner_index.batch_dims must be the same.") + + super(ProductIndexMap, self).__init__( + indices=( + inner_index.indices + + outer_index.indices * tf.cast(inner_index.num_segments, inner_index.indices.dtype) + ), + num_segments=inner_index.num_segments * outer_index.num_segments, + batch_dims=inner_index.batch_dims, + ) + self.outer_index = outer_index + self.inner_index = inner_index + + def project_outer(self, index): + """Projects an index with the same index set onto the outer components.""" + return IndexMap( + indices=tf.math.floordiv(index.indices, self.inner_index.num_segments), + num_segments=self.outer_index.num_segments, + batch_dims=index.batch_dims, + ) + + def project_inner(self, index): + """Projects an index with the same index set onto the inner components.""" + return IndexMap( + indices=tf.math.floormod(index.indices, self.inner_index.num_segments), + num_segments=self.inner_index.num_segments, + batch_dims=index.batch_dims, + ) + + +def gather(values, index, name="segmented_gather"): + """ + Gathers from `values` using the index map. For each element in the domain of the index map this operation looks up + a value for that index in `values`. Two elements from the same segment always get assigned the same value. + + Args: + values: [B1, ..., Bn, num_segments, V1, ...] Tensor with segment values. + index: [B1, ..., Bn, I1, ..., Ik] IndexMap. + name: Name for the TensorFlow operation. + + Returns: + [B1, ..., Bn, I1, ..., Ik, V1, ...] Tensor with the gathered values. + """ + return tf.gather(values, index.indices, batch_dims=index.batch_dims, name=name) + + +def flatten(index, name="segmented_flatten"): + """ + Flattens a batched index map to a 1d index map. This operation relabels the segments to keep batch elements + distinct. The k-th batch element will have indices shifted by `num_segments` * (k - 1). The result is a tensor with + `num_segments` multiplied by the number of elements in the batch. + + Args: + index: IndexMap to flatten. + name: Name for the TensorFlow operation. + + Returns: + The flattened IndexMap. + """ + batch_size = tf.reduce_prod(index.batch_shape()) + offset = tf.range(batch_size) * index.num_segments + offset = tf.reshape(offset, index.batch_shape()) + for _ in range(index.batch_dims, index.indices.shape.rank): + offset = tf.expand_dims(offset, -1) + + indices = tf.cast(offset, index.indices.dtype) + index.indices + return IndexMap(indices=tf.reshape(indices, [-1]), num_segments=index.num_segments * batch_size, batch_dims=0) + + +def range_index_map(batch_shape, num_segments, name="range_index_map"): + """ + Constructs an index map equal to range(num_segments). + + Args: + batch_shape (`tf.Tensor`): + Batch shape + num_segments (`int`): + Number of segments + name (`str`, *optional*, defaults to 'range_index_map'): + Name for the operation. Currently not used + + Returns: + (`IndexMap`): IndexMap of shape batch_shape with elements equal to range(num_segments). + """ + batch_shape = tf.convert_to_tensor(batch_shape) + batch_shape.shape.assert_has_rank(1) + num_segments = tf.convert_to_tensor(num_segments) + num_segments.shape.assert_has_rank(0) + + indices = tf.range(num_segments) + shape = tf.concat([tf.ones_like(batch_shape, dtype=tf.int32), tf.expand_dims(num_segments, axis=0)], axis=0) + indices = tf.reshape(indices, shape) + multiples = tf.concat([batch_shape, [1]], axis=0) + indices = tf.tile(indices, multiples) + return IndexMap(indices=indices, num_segments=num_segments, batch_dims=batch_shape.shape.as_list()[0]) + + +def _segment_reduce(values, index, segment_reduce_fn, name): + """ + Applies a segment reduction segment-wise. + + Args: + values (`tf.Tensor`): + Tensor with segment values. + index (`IndexMap`): + IndexMap. + segment_reduce_fn (`str`): + Name for the reduce operation. One of "sum", "mean", "max" or "min". + name (`str`): + Name for the operation. Currently not used + + Returns: + (`IndexMap`): IndexMap of shape batch_shape with elements equal to range(num_segments). + """ + # Flatten the batch dimensions, as segments ops do not support batching. + # However if `values` has extra dimensions to the right keep them + # unflattened. Segmented ops support vector-valued operations. + flat_index = flatten(index) + vector_shape = tf.shape(values)[index.indices.shape.rank :] + flattened_shape = tf.concat([[-1], vector_shape], axis=0) + flat_values = tf.reshape(values, flattened_shape) + segment_means = segment_reduce_fn( + data=flat_values, segment_ids=flat_index.indices, num_segments=flat_index.num_segments + ) + + # Unflatten the values. + new_shape = tf.concat([index.batch_shape(), [index.num_segments], vector_shape], axis=0) + output_values = tf.reshape(segment_means, new_shape) + output_index = range_index_map(index.batch_shape(), index.num_segments) + return output_values, output_index + + +def reduce_mean(values, index, name="segmented_reduce_mean"): + """ + Averages a tensor over its segments. Outputs 0 for empty segments. This operations computes the mean over segments, + with support for: + + - Batching using the first dimensions [B1, B2, ..., Bn]. Each element in a batch can have different indices. + - Vectorization using the last dimension [V1, V2, ...]. If they are present the output will be a mean of vectors + rather than scalars. + Only the middle dimensions [I1, ..., Ik] are reduced by the operation. + + Args: + values: [B1, B2, ..., Bn, I1, .., Ik, V1, V2, ..] tensor of values to be + averaged. + index: IndexMap [B1, B2, ..., Bn, I1, .., Ik] index defining the segments. + name: Name for the TensorFlow ops. + + Returns: + A pair (output_values, output_index) where `output_values` is a tensor of shape [B1, B2, ..., Bn, num_segments, + V1, V2, ..] and `index` is an IndexMap with shape [B1, B2, ..., Bn, num_segments]. + """ + return _segment_reduce(values, index, tf.math.unsorted_segment_mean, name) + + +def reduce_sum(values, index, name="segmented_reduce_sum"): + """ + Sums a tensor over its segments. Outputs 0 for empty segments. This operations computes the sum over segments, with + support for: + + - Batching using the first dimensions [B1, B2, ..., Bn]. Each element in a batch can have different indices. + - Vectorization using the last dimension [V1, V2, ...]. If they are present the output will be a sum of vectors + rather than scalars. + Only the middle dimensions [I1, ..., Ik] are reduced by the operation. + + Args: + values: [B1, B2, ..., Bn, I1, .., Ik, V1, V2, ..] tensor of values to be + averaged. + index: IndexMap [B1, B2, ..., Bn, I1, .., Ik] index defining the segments. + name: Name for the TensorFlow ops. + + Returns: + A pair (output_values, output_index) where `output_values` is a tensor of shape [B1, B2, ..., Bn, num_segments, + V1, V2, ..] and `index` is an IndexMap with shape [B1, B2, ..., Bn, num_segments]. + """ + return _segment_reduce(values, index, tf.math.unsorted_segment_sum, name) + + +def reduce_max(values, index, name="segmented_reduce_max"): + """ + Computes the maximum over segments. This operations computes the maximum over segments, with support for: + + - Batching using the first dimensions [B1, B2, ..., Bn]. Each element in a batch can have different indices. + - Vectorization using the last dimension [V1, V2, ...]. If they are present the output will be an element-wise + maximum of vectors rather than scalars. + Only the middle dimensions [I1, ..., Ik] are reduced by the operation. + + Args: + values: [B1, B2, ..., Bn, I1, .., Ik, V1, V2, ..] tensor of values to be + averaged. + index: IndexMap [B1, B2, ..., Bn, I1, .., Ik] index defining the segments. + name: Name for the TensorFlow ops. + + Returns: + A pair (output_values, output_index) where `output_values` is a tensor of shape [B1, B2, ..., Bn, num_segments, + V1, V2, ..] and `index` is an IndexMap with shape [B1, B2, ..., Bn, num_segments]. + """ + return _segment_reduce(values, index, tf.math.unsorted_segment_max, name) + + +def reduce_min(values, index, name="segmented_reduce_min"): + """Computes the minimum over segments.""" + return _segment_reduce(values, index, tf.math.unsorted_segment_min, name) + + +def _single_column_cell_selection_loss(token_logits, column_logits, labels, cell_index, col_index, cell_mask): + """ + Computes the loss for cell selection constrained to a single column. The loss is a hierarchical log-likelihood. The + model first predicts a column and then selects cells within that column (conditioned on the column). Cells outside + the selected column are never selected. + + Args: + token_logits (`tf.Tensor` of shape `(batch_size, sequence_length)`): + Tensor containing the logits per token. + column_logits (`tf.Tensor` of shape `(batch_size, max_num_cols)`): + Tensor containing the logits per column. + labels (`tf.Tensor` of shape `(batch_size, sequence_length)`): + Labels per token. + cell_index (`ProductIndexMap`): + Index that groups tokens into cells. + col_index (`IndexMap`): + Index that groups tokens into columns. + cell_mask (`tf.Tensor` of shape `(batch_size, max_num_rows * max_num_cols)`): + Mask for cells that exist in the table (i.e. that are not padding). + + Returns: + selection_loss_per_example (`tf.Tensor` of shape `(batch_size,)`): Loss for each example. logits (`tf.Tensor` + of shape `(batch_size, sequence_length)`): New logits which are only allowed to select cells in a single + column. Logits outside of the most likely column according to *column_logits* will be set to a very low value + (such that the probabilities are 0). + """ + # First find the column we should select. We use the column with maximum + # number of selected cells. + labels_per_column, _ = reduce_sum(tf.cast(labels, tf.float32), col_index) + column_label = tf.argmax(labels_per_column, axis=-1, output_type=tf.int32) + # Check if there are no selected cells in the column. In that case the model + # should predict the special column id 0, which means "select nothing". + no_cell_selected = tf.equal(tf.reduce_max(labels_per_column, axis=-1), 0) + column_label = tf.where(no_cell_selected, tf.zeros_like(column_label), column_label) + + column_dist = tfp.distributions.Categorical(logits=column_logits) + column_loss_per_example = -column_dist.log_prob(column_label) + + # Reduce the labels and logits to per-cell from per-token. + logits_per_cell, _ = reduce_mean(token_logits, cell_index) + labels_per_cell, labels_index = reduce_max(tf.cast(labels, tf.int32), cell_index) + + # Mask for the selected column. + column_id_for_cells = cell_index.project_inner(labels_index).indices + column_mask = tf.cast(tf.equal(column_id_for_cells, tf.expand_dims(column_label, axis=1)), tf.float32) + + # Compute the log-likelihood for cells, but only for the selected column. + cell_dist = tfp.distributions.Bernoulli(logits=logits_per_cell) + cell_log_prob = cell_dist.log_prob(labels_per_cell) + cell_loss = -tf.reduce_sum(cell_log_prob * column_mask * cell_mask, axis=1) + # We need to normalize the loss by the number of cells in the column. + cell_loss /= tf.reduce_sum(column_mask * cell_mask, axis=1) + EPSILON_ZERO_DIVISION + + selection_loss_per_example = column_loss_per_example + selection_loss_per_example += tf.where(no_cell_selected, tf.zeros_like(selection_loss_per_example), cell_loss) + + # Set the probs outside the selected column (selected by the *model*) + # to 0. This ensures backwards compatibility with models that select + # cells from multiple columns. + selected_column_id = tf.argmax(column_logits, axis=-1, output_type=tf.int32) + selected_column_mask = tf.cast( + tf.equal(column_id_for_cells, tf.expand_dims(selected_column_id, axis=-1)), tf.float32 + ) + # Never select cells with the special column id 0. + selected_column_mask = tf.where( + tf.equal(column_id_for_cells, 0), tf.zeros_like(selected_column_mask), selected_column_mask + ) + logits_per_cell += CLOSE_ENOUGH_TO_LOG_ZERO * (1.0 - cell_mask * selected_column_mask) + logits = gather(logits_per_cell, cell_index) + + return selection_loss_per_example, logits + + +def _calculate_aggregate_mask(answer, pooled_output, cell_selection_preference, labels, aggregation_classifier): + """ + Finds examples where the model should select cells with no aggregation. + + Returns a mask that determines for which examples should the model select answers directly from the table, without + any aggregation function. If the answer is a piece of text the case is unambiguous as aggregation functions only + apply to numbers. If the answer is a number but does not appear in the table then we must use some aggregation + case. The ambiguous case is when the answer is a number that also appears in the table. In this case we use the + aggregation function probabilities predicted by the model to decide whether to select or aggregate. The threshold + for this is a hyperparameter *cell_selection_preference* + + Args: + answer (`tf.Tensor` of shape `(batch_size, )`): + Answer for every example in the batch. Nan if there is no scalar answer. + pooled_output (`tf.Tensor` of shape `(batch_size, hidden_size)`): + Output of the pooler (BertPooler) on top of the encoder layer. + cell_selection_preference (`float`): + Preference for cell selection in ambiguous cases. + labels (`tf.Tensor` of shape `(batch_size, sequence_length)`): + Labels per token. aggregation_classifier (`torch.nn.Linear`): Aggregation head + + Returns: + aggregate_mask (`tf.Tensor` of shape `(batch_size,)`): A mask set to 1 for examples that should use aggregation + functions. + """ + # tf.Tensor(batch_size,) + aggregate_mask_init = tf.cast(tf.logical_not(tf.math.is_nan(answer)), tf.float32) + logits_aggregation = aggregation_classifier(pooled_output) + dist_aggregation = tfp.distributions.Categorical(logits=logits_aggregation) + # Index 0 corresponds to "no aggregation". + aggregation_ops_total_mass = tf.reduce_sum(dist_aggregation.probs_parameter()[:, 1:], axis=1) + # Cell selection examples according to current model. + is_pred_cell_selection = aggregation_ops_total_mass <= cell_selection_preference + # Examples with non-empty cell selection supervision. + is_cell_supervision_available = tf.reduce_sum(labels, axis=1) > 0 + aggregate_mask = tf.where( + tf.logical_and(is_pred_cell_selection, is_cell_supervision_available), + tf.zeros_like(aggregate_mask_init, dtype=tf.float32), + aggregate_mask_init, + ) + aggregate_mask = tf.stop_gradient(aggregate_mask) + return aggregate_mask + + +def _calculate_aggregation_loss_known( + logits_aggregation, aggregate_mask, aggregation_labels, use_answer_as_supervision, num_aggregation_labels +): + """ + Calculates aggregation loss when its type is known during training. + + In the weakly supervised setting, the only known information is that for cell selection examples, "no aggregation" + should be predicted. For other examples (those that require aggregation), no loss is accumulated. In the setting + where aggregation type is always known, standard cross entropy loss is accumulated for all examples + + Args: + logits_aggregation (`tf.Tensor` of shape `(batch_size, num_aggregation_labels)`): + Logits per aggregation operation. + aggregate_mask (`tf.Tensor` of shape `(batch_size, )`): + A mask set to 1 for examples that should use aggregation functions. + aggregation_labels (`tf.Tensor` of shape `(batch_size, )`): + Aggregation function id for every example in the batch. + use_answer_as_supervision (`bool`, *optional*): + Whether to use the answer as the only supervision for aggregation examples. + num_aggregation_labels (`int`, *optional*, defaults to 0): + The number of aggregation operators to predict. + + Returns: + aggregation_loss_known (`tf.Tensor` of shape `(batch_size,)`): Aggregation loss (when its type is known during + training) per example. + """ + if use_answer_as_supervision: + # Prepare "no aggregation" targets for cell selection examples. + target_aggregation = tf.zeros_like(aggregate_mask, dtype=tf.int32) + else: + # Use aggregation supervision as the target. + target_aggregation = aggregation_labels + + one_hot_labels = tf.one_hot(target_aggregation, depth=num_aggregation_labels, dtype=tf.float32) + log_probs = tf.nn.log_softmax(logits_aggregation, axis=-1) + + # [batch_size] + per_example_aggregation_intermediate = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) + if use_answer_as_supervision: + # Accumulate loss only for examples requiring cell selection + # (no aggregation). + return per_example_aggregation_intermediate * (1 - aggregate_mask) + else: + return per_example_aggregation_intermediate + + +def _calculate_aggregation_loss_unknown(logits_aggregation, aggregate_mask): + """ + Calculates aggregation loss in the case of answer supervision. + + Args: + logits_aggregation (`tf.Tensor` of shape `(batch_size, num_aggregation_labels)`): + Logits per aggregation operation. + aggregate_mask (`tf.Tensor` of shape `(batch_size, )`): + A mask set to 1 for examples that should use aggregation functions + + Returns: + aggregation_loss_unknown (`tf.Tensor` of shape `(batch_size,)`): Aggregation loss (in case of answer + supervision) per example. + """ + dist_aggregation = tfp.distributions.Categorical(logits=logits_aggregation) + # Index 0 corresponds to "no aggregation". + aggregation_ops_total_mass = tf.reduce_sum(dist_aggregation.probs_parameter()[:, 1:], axis=1) + # Predict some aggregation in case of an answer that needs aggregation. + # This increases the probability of all aggregation functions, in a way + # similar to MML, but without considering whether the function gives the + # correct answer. + return -tf.math.log(aggregation_ops_total_mass) * aggregate_mask + + +def _calculate_aggregation_loss( + logits_aggregation, + aggregate_mask, + aggregation_labels, + use_answer_as_supervision, + num_aggregation_labels, + aggregation_loss_weight, +): + """ + Calculates the aggregation loss per example. + + Args: + logits_aggregation (`tf.Tensor` of shape `(batch_size, num_aggregation_labels)`): + Logits per aggregation operation. + aggregate_mask (`tf.Tensor` of shape `(batch_size, )`): + A mask set to 1 for examples that should use aggregation functions. + aggregation_labels (`tf.Tensor` of shape `(batch_size, )`): + Aggregation function id for every example in the batch. + use_answer_as_supervision (`bool`, *optional*): + Whether to use the answer as the only supervision for aggregation examples. + num_aggregation_labels (`int`, *optional*, defaults to 0): + The number of aggregation operators to predict. + aggregation_loss_weight (`float`, *optional*, defaults to 1.0): + Importance weight for the aggregation loss. + + Returns: + aggregation_loss (`tf.Tensor` of shape `(batch_size,)`): Aggregation loss per example. + """ + per_example_aggregation_loss = _calculate_aggregation_loss_known( + logits_aggregation, aggregate_mask, aggregation_labels, use_answer_as_supervision, num_aggregation_labels + ) + + if use_answer_as_supervision: + # Add aggregation loss for numeric answers that need aggregation. + per_example_aggregation_loss += _calculate_aggregation_loss_unknown(logits_aggregation, aggregate_mask) + return aggregation_loss_weight * per_example_aggregation_loss + + +def _calculate_expected_result( + dist_per_cell, numeric_values, numeric_values_scale, input_mask_float, logits_aggregation, config +): + """ + Calculates the expected result given cell and aggregation probabilities. + + Args: + dist_per_cell (`tfp.distributions.Bernoulli`): + Cell selection distribution for each cell. + numeric_values (`tf.Tensor` of shape `(batch_size, seq_length)`): + Numeric values of every token. Nan for tokens which are not numeric values. + numeric_values_scale (`tf.Tensor` of shape `(batch_size, seq_length)`): + Scale of the numeric values of every token. + input_mask_float (`tf.Tensor` of shape `(batch_size, seq_length)`): + Mask for the table, without question tokens and table headers. + logits_aggregation (`tf.Tensor` of shape `(batch_size, num_aggregation_labels)`): + Logits per aggregation operation. + config ([`TapasConfig`]): + Model configuration class with all the hyperparameters of the model + + Returns: + expected_result (`tf.Tensor` of shape `(batch_size,)`): The expected result per example. + """ + if config.use_gumbel_for_cells: + gumbel_dist = tfp.distributions.RelaxedBernoulli( + # The token logits where already divided by the temperature and used for + # computing cell selection errors so we need to multiply it again here + config.temperature, + logits=dist_per_cell.logits_parameter() * config.temperature, + ) + scaled_probability_per_cell = gumbel_dist.sample() + else: + scaled_probability_per_cell = dist_per_cell.probs_parameter() + + # [batch_size, seq_length] + scaled_probability_per_cell = (scaled_probability_per_cell / numeric_values_scale) * input_mask_float + count_result = tf.reduce_sum(scaled_probability_per_cell, axis=1) + numeric_values_masked = tf.where( + tf.math.is_nan(numeric_values), tf.zeros_like(numeric_values), numeric_values + ) # Mask non-numeric table values to zero. + sum_result = tf.reduce_sum(scaled_probability_per_cell * numeric_values_masked, axis=1) + avg_approximation = config.average_approximation_function + if avg_approximation == AverageApproximationFunction.RATIO: + average_result = sum_result / (count_result + EPSILON_ZERO_DIVISION) + elif avg_approximation == AverageApproximationFunction.FIRST_ORDER: + # The sum of all probabilities exept that correspond to other cells + ex = tf.reduce_sum(scaled_probability_per_cell, axis=1, keepdims=True) - scaled_probability_per_cell + 1 + average_result = tf.reduce_sum(numeric_values_masked * scaled_probability_per_cell / ex, axis=1) + elif avg_approximation == AverageApproximationFunction.SECOND_ORDER: + # The sum of all probabilities exept that correspond to other cells + ex = tf.reduce_sum(scaled_probability_per_cell, axis=1, keepdims=True) - scaled_probability_per_cell + 1 + pointwise_var = scaled_probability_per_cell * (1 - scaled_probability_per_cell) + var = tf.reduce_sum(pointwise_var, axis=1, keepdims=True) - pointwise_var + multiplier = (var / tf.math.square(ex) + 1) / ex + average_result = tf.reduce_sum(numeric_values_masked * scaled_probability_per_cell * multiplier, axis=1) + else: + raise ValueError("Invalid average_approximation_function: %s", config.average_approximation_function) + + if config.use_gumbel_for_aggregation: + gumbel_dist = tfp.distributions.RelaxedOneHotCategorical( + config.aggregation_temperature, logits=logits_aggregation[:, 1:] + ) + # [batch_size, num_aggregation_labels - 1] + aggregation_op_only_probs = gumbel_dist.sample() + else: + # [batch_size, num_aggregation_labels - 1] + aggregation_op_only_probs = stable_softmax(logits_aggregation[:, 1:] / config.aggregation_temperature, axis=-1) + all_results = tf.concat( + [ + tf.expand_dims(sum_result, axis=1), + tf.expand_dims(average_result, axis=1), + tf.expand_dims(count_result, axis=1), + ], + axis=1, + ) + expected_result = tf.reduce_sum(all_results * aggregation_op_only_probs, axis=1) + return expected_result + + +def _calculate_regression_loss( + answer, + aggregate_mask, + dist_per_cell, + numeric_values, + numeric_values_scale, + input_mask_float, + logits_aggregation, + config, +): + """ + Calculates the regression loss per example. + + Args: + answer (`tf.Tensor` of shape `(batch_size,)`): + Answer for every example in the batch. Nan if there is no scalar answer. + aggregate_mask (`tf.Tensor` of shape `(batch_size,)`): + A mask set to 1 for examples that should use aggregation functions. + dist_per_cell (`torch.distributions.Bernoulli`): + Cell selection distribution for each cell. + numeric_values (`tf.Tensor` of shape `(batch_size, seq_length)`): + Numeric values of every token. Nan for tokens which are not numeric values. + numeric_values_scale (`tf.Tensor` of shape `(batch_size, seq_length)`): + Scale of the numeric values of every token. + input_mask_float (`tf.Tensor` of shape `(batch_size, seq_length)`): + Mask for the table, without question tokens and table headers. + logits_aggregation (`tf.Tensor` of shape `(batch_size, num_aggregation_labels)`): + Logits per aggregation operation. + config ([`TapasConfig`]): + Model configuration class with all the parameters of the model + + Returns: + per_example_answer_loss_scaled (`tf.Tensor` of shape `(batch_size,)`): Scales answer loss for each example in + the batch. large_answer_loss_mask (`tf.Tensor` of shape `(batch_size,)`): A mask which is 1 for examples for + which their answer loss is larger than the answer_loss_cutoff. + """ + # float32 (batch_size,) + expected_result = _calculate_expected_result( + dist_per_cell, numeric_values, numeric_values_scale, input_mask_float, logits_aggregation, config + ) + + # [batch_size] + answer_masked = tf.where(tf.math.is_nan(answer), tf.zeros_like(answer), answer) + + if config.use_normalized_answer_loss: + normalizer = tf.stop_gradient( + tf.math.maximum(tf.math.abs(expected_result), tf.math.abs(answer_masked)) + EPSILON_ZERO_DIVISION + ) + normalized_answer_masked = answer_masked / normalizer + normalized_expected_result = expected_result / normalizer + per_example_answer_loss = tf.compat.v1.losses.huber_loss( + normalized_answer_masked * aggregate_mask, + normalized_expected_result * aggregate_mask, + delta=tf.cast(1.0, tf.float32), + reduction=tf.losses.Reduction.NONE, + ) + else: + per_example_answer_loss = tf.compat.v1.losses.huber_loss( + answer_masked * aggregate_mask, + expected_result * aggregate_mask, + delta=tf.cast(config.huber_loss_delta, tf.float32), + reduction=tf.losses.Reduction.NONE, + ) + if config.answer_loss_cutoff is None: + large_answer_loss_mask = tf.ones_like(per_example_answer_loss, dtype=tf.float32) + else: + large_answer_loss_mask = tf.where( + per_example_answer_loss > config.answer_loss_cutoff, + tf.zeros_like(per_example_answer_loss, dtype=tf.float32), + tf.ones_like(per_example_answer_loss, dtype=tf.float32), + ) + per_example_answer_loss_scaled = config.answer_loss_importance * (per_example_answer_loss * aggregate_mask) + return per_example_answer_loss_scaled, large_answer_loss_mask + + +__all__ = [ + "TFTapasForMaskedLM", + "TFTapasForQuestionAnswering", + "TFTapasForSequenceClassification", + "TFTapasModel", + "TFTapasPreTrainedModel", +] diff --git a/docs/transformers/src/transformers/models/tapas/tokenization_tapas.py b/docs/transformers/src/transformers/models/tapas/tokenization_tapas.py new file mode 100644 index 0000000000000000000000000000000000000000..b290b4990d1e815c81b7291eaa912dfc2c3d933c --- /dev/null +++ b/docs/transformers/src/transformers/models/tapas/tokenization_tapas.py @@ -0,0 +1,2792 @@ +# coding=utf-8 +# Copyright 2020 Google Research and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization class for TAPAS model.""" + +import collections +import datetime +import enum +import itertools +import math +import os +import re +import unicodedata +from dataclasses import dataclass +from typing import Callable, Dict, Generator, List, Optional, Tuple, Union + +import numpy as np + +from ...tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace +from ...tokenization_utils_base import ( + ENCODE_KWARGS_DOCSTRING, + VERY_LARGE_INTEGER, + BatchEncoding, + EncodedInput, + PreTokenizedInput, + TextInput, +) +from ...utils import ExplicitEnum, PaddingStrategy, TensorType, add_end_docstrings, is_pandas_available, logging + + +if is_pandas_available(): + import pandas as pd + +logger = logging.get_logger(__name__) + + +VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} + + +class TapasTruncationStrategy(ExplicitEnum): + """ + Possible values for the `truncation` argument in [`~TapasTokenizer.__call__`]. Useful for tab-completion in an IDE. + """ + + DROP_ROWS_TO_FIT = "drop_rows_to_fit" + DO_NOT_TRUNCATE = "do_not_truncate" + + +TableValue = collections.namedtuple("TokenValue", ["token", "column_id", "row_id"]) + + +@dataclass(frozen=True) +class TokenCoordinates: + column_index: int + row_index: int + token_index: int + + +@dataclass +class TokenizedTable: + rows: List[List[List[str]]] + selected_tokens: List[TokenCoordinates] + + +@dataclass(frozen=True) +class SerializedExample: + tokens: List[str] + column_ids: List[int] + row_ids: List[int] + segment_ids: List[int] + + +def _is_inner_wordpiece(token: str): + return token.startswith("##") + + +def load_vocab(vocab_file): + """Loads a vocabulary file into a dictionary.""" + vocab = collections.OrderedDict() + with open(vocab_file, "r", encoding="utf-8") as reader: + tokens = reader.readlines() + for index, token in enumerate(tokens): + token = token.rstrip("\n") + vocab[token] = index + return vocab + + +def whitespace_tokenize(text): + """Runs basic whitespace cleaning and splitting on a piece of text.""" + text = text.strip() + if not text: + return [] + tokens = text.split() + return tokens + + +TAPAS_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING = r""" + add_special_tokens (`bool`, *optional*, defaults to `True`): + Whether or not to encode the sequences with the special tokens relative to their model. + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`): + Activates and controls padding. Accepts the following values: + + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + truncation (`bool`, `str` or [`TapasTruncationStrategy`], *optional*, defaults to `False`): + Activates and controls truncation. Accepts the following values: + + - `True` or `'drop_rows_to_fit'`: Truncate to a maximum length specified with the argument `max_length` + or to the maximum acceptable input length for the model if that argument is not provided. This will + truncate row by row, removing rows from the table. + - `False` or `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths + greater than the model maximum admissible input size). + max_length (`int`, *optional*): + Controls the maximum length to use by one of the truncation/padding parameters. + + If left unset or set to `None`, this will use the predefined model maximum length if a maximum length + is required by one of the truncation/padding parameters. If the model has no specific maximum input + length (like XLNet) truncation/padding to a maximum length will be deactivated. + is_split_into_words (`bool`, *optional*, defaults to `False`): + Whether or not the input is already pre-tokenized (e.g., split into words). If set to `True`, the + tokenizer assumes the input is already split into words (for instance, by splitting it on whitespace) + which it will tokenize. This is useful for NER or token classification. + pad_to_multiple_of (`int`, *optional*): + If set will pad the sequence to a multiple of the provided value. This is especially useful to enable + the use of Tensor Cores on NVIDIA hardware with compute capability `>= 7.5` (Volta). + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors instead of list of python integers. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return Numpy `np.ndarray` objects. +""" + + +class TapasTokenizer(PreTrainedTokenizer): + r""" + Construct a TAPAS tokenizer. Based on WordPiece. Flattens a table and one or more related sentences to be used by + TAPAS models. + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. [`TapasTokenizer`] creates several token type ids to + encode tabular structure. To be more precise, it adds 7 token type ids, in the following order: `segment_ids`, + `column_ids`, `row_ids`, `prev_labels`, `column_ranks`, `inv_column_ranks` and `numeric_relations`: + + - segment_ids: indicate whether a token belongs to the question (0) or the table (1). 0 for special tokens and + padding. + - column_ids: indicate to which column of the table a token belongs (starting from 1). Is 0 for all question + tokens, special tokens and padding. + - row_ids: indicate to which row of the table a token belongs (starting from 1). Is 0 for all question tokens, + special tokens and padding. Tokens of column headers are also 0. + - prev_labels: indicate whether a token was (part of) an answer to the previous question (1) or not (0). Useful in + a conversational setup (such as SQA). + - column_ranks: indicate the rank of a table token relative to a column, if applicable. For example, if you have a + column "number of movies" with values 87, 53 and 69, then the column ranks of these tokens are 3, 1 and 2 + respectively. 0 for all question tokens, special tokens and padding. + - inv_column_ranks: indicate the inverse rank of a table token relative to a column, if applicable. For example, if + you have a column "number of movies" with values 87, 53 and 69, then the inverse column ranks of these tokens are + 1, 3 and 2 respectively. 0 for all question tokens, special tokens and padding. + - numeric_relations: indicate numeric relations between the question and the tokens of the table. 0 for all + question tokens, special tokens and padding. + + [`TapasTokenizer`] runs end-to-end tokenization on a table and associated sentences: punctuation splitting and + wordpiece. + + Args: + vocab_file (`str`): + File containing the vocabulary. + do_lower_case (`bool`, *optional*, defaults to `True`): + Whether or not to lowercase the input when tokenizing. + do_basic_tokenize (`bool`, *optional*, defaults to `True`): + Whether or not to do basic tokenization before WordPiece. + never_split (`Iterable`, *optional*): + Collection of tokens which will never be split during tokenization. Only has an effect when + `do_basic_tokenize=True` + unk_token (`str`, *optional*, defaults to `"[UNK]"`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + sep_token (`str`, *optional*, defaults to `"[SEP]"`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + pad_token (`str`, *optional*, defaults to `"[PAD]"`): + The token used for padding, for example when batching sequences of different lengths. + cls_token (`str`, *optional*, defaults to `"[CLS]"`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + mask_token (`str`, *optional*, defaults to `"[MASK]"`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + empty_token (`str`, *optional*, defaults to `"[EMPTY]"`): + The token used for empty cell values in a table. Empty cell values include "", "n/a", "nan" and "?". + tokenize_chinese_chars (`bool`, *optional*, defaults to `True`): + Whether or not to tokenize Chinese characters. This should likely be deactivated for Japanese (see this + [issue](https://github.com/huggingface/transformers/issues/328)). + strip_accents (`bool`, *optional*): + Whether or not to strip all accents. If this option is not specified, then it will be determined by the + value for `lowercase` (as in the original BERT). + cell_trim_length (`int`, *optional*, defaults to -1): + If > 0: Trim cells so that the length is <= this value. Also disables further cell trimming, should thus be + used with `truncation` set to `True`. + max_column_id (`int`, *optional*): + Max column id to extract. + max_row_id (`int`, *optional*): + Max row id to extract. + strip_column_names (`bool`, *optional*, defaults to `False`): + Whether to add empty strings instead of column names. + update_answer_coordinates (`bool`, *optional*, defaults to `False`): + Whether to recompute the answer coordinates from the answer text. + min_question_length (`int`, *optional*): + Minimum length of each question in terms of tokens (will be skipped otherwise). + max_question_length (`int`, *optional*): + Maximum length of each question in terms of tokens (will be skipped otherwise). + clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`): + Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like + extra spaces. + """ + + vocab_files_names = VOCAB_FILES_NAMES + + def __init__( + self, + vocab_file, + do_lower_case=True, + do_basic_tokenize=True, + never_split=None, + unk_token="[UNK]", + sep_token="[SEP]", + pad_token="[PAD]", + cls_token="[CLS]", + mask_token="[MASK]", + empty_token="[EMPTY]", + tokenize_chinese_chars=True, + strip_accents=None, + cell_trim_length: int = -1, + max_column_id: Optional[int] = None, + max_row_id: Optional[int] = None, + strip_column_names: bool = False, + update_answer_coordinates: bool = False, + min_question_length=None, + max_question_length=None, + model_max_length: int = 512, + additional_special_tokens: Optional[List[str]] = None, + clean_up_tokenization_spaces=True, + **kwargs, + ): + if not is_pandas_available(): + raise ImportError("Pandas is required for the TAPAS tokenizer.") + + if additional_special_tokens is not None: + if empty_token not in additional_special_tokens: + additional_special_tokens.append(empty_token) + else: + additional_special_tokens = [empty_token] + + if not os.path.isfile(vocab_file): + raise ValueError( + f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained" + " model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`" + ) + self.vocab = load_vocab(vocab_file) + self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()]) + self.do_basic_tokenize = do_basic_tokenize + if do_basic_tokenize: + self.basic_tokenizer = BasicTokenizer( + do_lower_case=do_lower_case, + never_split=never_split, + tokenize_chinese_chars=tokenize_chinese_chars, + strip_accents=strip_accents, + ) + self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=str(unk_token)) + + # Additional properties + self.cell_trim_length = cell_trim_length + self.max_column_id = ( + max_column_id + if max_column_id is not None + else model_max_length + if model_max_length is not None + else VERY_LARGE_INTEGER + ) + self.max_row_id = ( + max_row_id + if max_row_id is not None + else model_max_length + if model_max_length is not None + else VERY_LARGE_INTEGER + ) + self.strip_column_names = strip_column_names + self.update_answer_coordinates = update_answer_coordinates + self.min_question_length = min_question_length + self.max_question_length = max_question_length + + super().__init__( + do_lower_case=do_lower_case, + do_basic_tokenize=do_basic_tokenize, + never_split=never_split, + unk_token=unk_token, + sep_token=sep_token, + pad_token=pad_token, + cls_token=cls_token, + mask_token=mask_token, + empty_token=empty_token, + tokenize_chinese_chars=tokenize_chinese_chars, + strip_accents=strip_accents, + cell_trim_length=cell_trim_length, + max_column_id=max_column_id, + max_row_id=max_row_id, + strip_column_names=strip_column_names, + update_answer_coordinates=update_answer_coordinates, + min_question_length=min_question_length, + max_question_length=max_question_length, + model_max_length=model_max_length, + additional_special_tokens=additional_special_tokens, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + **kwargs, + ) + + @property + def do_lower_case(self): + return self.basic_tokenizer.do_lower_case + + @property + def vocab_size(self): + return len(self.vocab) + + def get_vocab(self): + return dict(self.vocab, **self.added_tokens_encoder) + + def _tokenize(self, text): + if format_text(text) == EMPTY_TEXT: + return [self.additional_special_tokens[0]] + split_tokens = [] + if self.do_basic_tokenize: + for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens): + # If the token is part of the never_split set + if token in self.basic_tokenizer.never_split: + split_tokens.append(token) + else: + split_tokens += self.wordpiece_tokenizer.tokenize(token) + else: + split_tokens = self.wordpiece_tokenizer.tokenize(text) + return split_tokens + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.vocab.get(token, self.vocab.get(self.unk_token)) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.ids_to_tokens.get(index, self.unk_token) + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + out_string = " ".join(tokens).replace(" ##", "").strip() + return out_string + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + index = 0 + if os.path.isdir(save_directory): + vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + else: + vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory + with open(vocab_file, "w", encoding="utf-8") as writer: + for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning( + f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive." + " Please check that the vocabulary is not corrupted!" + ) + index = token_index + writer.write(token + "\n") + index += 1 + return (vocab_file,) + + def create_attention_mask_from_sequences(self, query_ids: List[int], table_values: List[TableValue]) -> List[int]: + """ + Creates the attention mask according to the query token IDs and a list of table values. + + Args: + query_ids (`List[int]`): list of token IDs corresponding to the ID. + table_values (`List[TableValue]`): lift of table values, which are named tuples containing the + token value, the column ID and the row ID of said token. + + Returns: + `List[int]`: List of ints containing the attention mask values. + """ + return [1] * (1 + len(query_ids) + 1 + len(table_values)) + + def create_segment_token_type_ids_from_sequences( + self, query_ids: List[int], table_values: List[TableValue] + ) -> List[int]: + """ + Creates the segment token type IDs according to the query token IDs and a list of table values. + + Args: + query_ids (`List[int]`): list of token IDs corresponding to the ID. + table_values (`List[TableValue]`): lift of table values, which are named tuples containing the + token value, the column ID and the row ID of said token. + + Returns: + `List[int]`: List of ints containing the segment token type IDs values. + """ + table_ids = list(zip(*table_values))[0] if table_values else [] + return [0] * (1 + len(query_ids) + 1) + [1] * len(table_ids) + + def create_column_token_type_ids_from_sequences( + self, query_ids: List[int], table_values: List[TableValue] + ) -> List[int]: + """ + Creates the column token type IDs according to the query token IDs and a list of table values. + + Args: + query_ids (`List[int]`): list of token IDs corresponding to the ID. + table_values (`List[TableValue]`): lift of table values, which are named tuples containing the + token value, the column ID and the row ID of said token. + + Returns: + `List[int]`: List of ints containing the column token type IDs values. + """ + table_column_ids = list(zip(*table_values))[1] if table_values else [] + return [0] * (1 + len(query_ids) + 1) + list(table_column_ids) + + def create_row_token_type_ids_from_sequences( + self, query_ids: List[int], table_values: List[TableValue] + ) -> List[int]: + """ + Creates the row token type IDs according to the query token IDs and a list of table values. + + Args: + query_ids (`List[int]`): list of token IDs corresponding to the ID. + table_values (`List[TableValue]`): lift of table values, which are named tuples containing the + token value, the column ID and the row ID of said token. + + Returns: + `List[int]`: List of ints containing the row token type IDs values. + """ + table_row_ids = list(zip(*table_values))[2] if table_values else [] + return [0] * (1 + len(query_ids) + 1) + list(table_row_ids) + + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a question and flattened table for question answering or sequence classification tasks + by concatenating and adding special tokens. + + Args: + token_ids_0 (`List[int]`): The ids of the question. + token_ids_1 (`List[int]`, *optional*): The ids of the flattened table. + + Returns: + `List[int]`: The model input with special tokens. + """ + if token_ids_1 is None: + raise ValueError("With TAPAS, you must provide both question IDs and table IDs.") + + return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + token_ids_1 + + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of question IDs. + token_ids_1 (`List[int]`, *optional*): + List of flattened table IDs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + if token_ids_1 is not None: + return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + return [1] + ([0] * len(token_ids_0)) + [1] + + @add_end_docstrings(TAPAS_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) + def __call__( + self, + table: "pd.DataFrame", + queries: Optional[ + Union[ + TextInput, + PreTokenizedInput, + EncodedInput, + List[TextInput], + List[PreTokenizedInput], + List[EncodedInput], + ] + ] = None, + answer_coordinates: Optional[Union[List[Tuple], List[List[Tuple]]]] = None, + answer_text: Optional[Union[List[TextInput], List[List[TextInput]]]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TapasTruncationStrategy] = False, + max_length: Optional[int] = None, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + """ + Main method to tokenize and prepare for the model one or several sequence(s) related to a table. + + Args: + table (`pd.DataFrame`): + Table containing tabular data. Note that all cell values must be text. Use *.astype(str)* on a Pandas + dataframe to convert it to string. + queries (`str` or `List[str]`): + Question or batch of questions related to a table to be encoded. Note that in case of a batch, all + questions must refer to the **same** table. + answer_coordinates (`List[Tuple]` or `List[List[Tuple]]`, *optional*): + Answer coordinates of each table-question pair in the batch. In case only a single table-question pair + is provided, then the answer_coordinates must be a single list of one or more tuples. Each tuple must + be a (row_index, column_index) pair. The first data row (not the column header row) has index 0. The + first column has index 0. In case a batch of table-question pairs is provided, then the + answer_coordinates must be a list of lists of tuples (each list corresponding to a single + table-question pair). + answer_text (`List[str]` or `List[List[str]]`, *optional*): + Answer text of each table-question pair in the batch. In case only a single table-question pair is + provided, then the answer_text must be a single list of one or more strings. Each string must be the + answer text of a corresponding answer coordinate. In case a batch of table-question pairs is provided, + then the answer_coordinates must be a list of lists of strings (each list corresponding to a single + table-question pair). + """ + assert isinstance(table, pd.DataFrame), "Table must be of type pd.DataFrame" + + # Input type checking for clearer error + valid_query = False + + # Check that query has a valid type + if queries is None or isinstance(queries, str): + valid_query = True + elif isinstance(queries, (list, tuple)): + if len(queries) == 0 or isinstance(queries[0], str): + valid_query = True + + if not valid_query: + raise ValueError( + "queries input must of type `str` (single example), `List[str]` (batch or single pretokenized" + " example). " + ) + is_batched = isinstance(queries, (list, tuple)) + + if is_batched: + return self.batch_encode_plus( + table=table, + queries=queries, + answer_coordinates=answer_coordinates, + answer_text=answer_text, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + else: + return self.encode_plus( + table=table, + query=queries, + answer_coordinates=answer_coordinates, + answer_text=answer_text, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + + @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, TAPAS_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) + def batch_encode_plus( + self, + table: "pd.DataFrame", + queries: Optional[ + Union[ + List[TextInput], + List[PreTokenizedInput], + List[EncodedInput], + ] + ] = None, + answer_coordinates: Optional[List[List[Tuple]]] = None, + answer_text: Optional[List[List[TextInput]]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TapasTruncationStrategy] = False, + max_length: Optional[int] = None, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + """ + Prepare a table and a list of strings for the model. + + + + This method is deprecated, `__call__` should be used instead. + + + + Args: + table (`pd.DataFrame`): + Table containing tabular data. Note that all cell values must be text. Use *.astype(str)* on a Pandas + dataframe to convert it to string. + queries (`List[str]`): + Batch of questions related to a table to be encoded. Note that all questions must refer to the **same** + table. + answer_coordinates (`List[Tuple]` or `List[List[Tuple]]`, *optional*): + Answer coordinates of each table-question pair in the batch. Each tuple must be a (row_index, + column_index) pair. The first data row (not the column header row) has index 0. The first column has + index 0. The answer_coordinates must be a list of lists of tuples (each list corresponding to a single + table-question pair). + answer_text (`List[str]` or `List[List[str]]`, *optional*): + Answer text of each table-question pair in the batch. In case a batch of table-question pairs is + provided, then the answer_coordinates must be a list of lists of strings (each list corresponding to a + single table-question pair). Each string must be the answer text of a corresponding answer coordinate. + """ + if return_token_type_ids is not None and not add_special_tokens: + raise ValueError( + "Asking to return token_type_ids while setting add_special_tokens to False " + "results in an undefined behavior. Please set add_special_tokens to True or " + "set return_token_type_ids to None." + ) + + if (answer_coordinates and not answer_text) or (not answer_coordinates and answer_text): + raise ValueError("In case you provide answers, both answer_coordinates and answer_text should be provided") + elif answer_coordinates is None and answer_text is None: + answer_coordinates = answer_text = [None] * len(queries) + + if "is_split_into_words" in kwargs: + raise NotImplementedError("Currently TapasTokenizer only supports questions as strings.") + + if return_offsets_mapping: + raise NotImplementedError( + "return_offset_mapping is not available when using Python tokenizers. " + "To use this feature, change your tokenizer to one deriving from " + "transformers.PreTrainedTokenizerFast." + ) + + return self._batch_encode_plus( + table=table, + queries=queries, + answer_coordinates=answer_coordinates, + answer_text=answer_text, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + + def _get_question_tokens(self, query): + """Tokenizes the query, taking into account the max and min question length.""" + + query_tokens = self.tokenize(query) + if self.max_question_length is not None and len(query_tokens) > self.max_question_length: + logger.warning("Skipping query as its tokens are longer than the max question length") + return "", [] + if self.min_question_length is not None and len(query_tokens) < self.min_question_length: + logger.warning("Skipping query as its tokens are shorter than the min question length") + return "", [] + + return query, query_tokens + + def _batch_encode_plus( + self, + table, + queries: Union[ + List[TextInput], + List[PreTokenizedInput], + List[EncodedInput], + ], + answer_coordinates: Optional[List[List[Tuple]]] = None, + answer_text: Optional[List[List[TextInput]]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TapasTruncationStrategy] = False, + max_length: Optional[int] = None, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = True, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + table_tokens = self._tokenize_table(table) + + queries_tokens = [] + for idx, query in enumerate(queries): + query, query_tokens = self._get_question_tokens(query) + queries[idx] = query + queries_tokens.append(query_tokens) + + batch_outputs = self._batch_prepare_for_model( + table, + queries, + tokenized_table=table_tokens, + queries_tokens=queries_tokens, + answer_coordinates=answer_coordinates, + padding=padding, + truncation=truncation, + answer_text=answer_text, + add_special_tokens=add_special_tokens, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_tensors=return_tensors, + prepend_batch_axis=True, + return_attention_mask=return_attention_mask, + return_token_type_ids=return_token_type_ids, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_length=return_length, + verbose=verbose, + ) + + return BatchEncoding(batch_outputs) + + def _batch_prepare_for_model( + self, + raw_table: "pd.DataFrame", + raw_queries: Union[ + List[TextInput], + List[PreTokenizedInput], + List[EncodedInput], + ], + tokenized_table: Optional[TokenizedTable] = None, + queries_tokens: Optional[List[List[str]]] = None, + answer_coordinates: Optional[List[List[Tuple]]] = None, + answer_text: Optional[List[List[TextInput]]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TapasTruncationStrategy] = False, + max_length: Optional[int] = None, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = True, + return_attention_mask: Optional[bool] = True, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + prepend_batch_axis: bool = False, + **kwargs, + ) -> BatchEncoding: + batch_outputs = {} + + for index, example in enumerate(zip(raw_queries, queries_tokens, answer_coordinates, answer_text)): + raw_query, query_tokens, answer_coords, answer_txt = example + outputs = self.prepare_for_model( + raw_table, + raw_query, + tokenized_table=tokenized_table, + query_tokens=query_tokens, + answer_coordinates=answer_coords, + answer_text=answer_txt, + add_special_tokens=add_special_tokens, + padding=PaddingStrategy.DO_NOT_PAD.value, # we pad in batch afterwards + truncation=truncation, + max_length=max_length, + pad_to_multiple_of=None, # we pad in batch afterwards + padding_side=None, # we pad in batch afterward + return_attention_mask=False, # we pad in batch afterwards + return_token_type_ids=return_token_type_ids, + return_special_tokens_mask=return_special_tokens_mask, + return_length=return_length, + return_tensors=None, # We convert the whole batch to tensors at the end + prepend_batch_axis=False, + verbose=verbose, + prev_answer_coordinates=answer_coordinates[index - 1] if index != 0 else None, + prev_answer_text=answer_text[index - 1] if index != 0 else None, + ) + + for key, value in outputs.items(): + if key not in batch_outputs: + batch_outputs[key] = [] + batch_outputs[key].append(value) + + batch_outputs = self.pad( + batch_outputs, + padding=padding, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_attention_mask=return_attention_mask, + ) + + batch_outputs = BatchEncoding(batch_outputs, tensor_type=return_tensors) + + return batch_outputs + + @add_end_docstrings(ENCODE_KWARGS_DOCSTRING) + def encode( + self, + table: "pd.DataFrame", + query: Optional[ + Union[ + TextInput, + PreTokenizedInput, + EncodedInput, + ] + ] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TapasTruncationStrategy] = False, + max_length: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + **kwargs, + ) -> List[int]: + """ + Prepare a table and a string for the model. This method does not return token type IDs, attention masks, etc. + which are necessary for the model to work correctly. Use that method if you want to build your processing on + your own, otherwise refer to `__call__`. + + Args: + table (`pd.DataFrame`): + Table containing tabular data. Note that all cell values must be text. Use *.astype(str)* on a Pandas + dataframe to convert it to string. + query (`str` or `List[str]`): + Question related to a table to be encoded. + """ + encoded_inputs = self.encode_plus( + table, + query=query, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + return_tensors=return_tensors, + **kwargs, + ) + + return encoded_inputs["input_ids"] + + @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, TAPAS_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) + def encode_plus( + self, + table: "pd.DataFrame", + query: Optional[ + Union[ + TextInput, + PreTokenizedInput, + EncodedInput, + ] + ] = None, + answer_coordinates: Optional[List[Tuple]] = None, + answer_text: Optional[List[TextInput]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TapasTruncationStrategy] = False, + max_length: Optional[int] = None, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + """ + Prepare a table and a string for the model. + + Args: + table (`pd.DataFrame`): + Table containing tabular data. Note that all cell values must be text. Use *.astype(str)* on a Pandas + dataframe to convert it to string. + query (`str` or `List[str]`): + Question related to a table to be encoded. + answer_coordinates (`List[Tuple]` or `List[List[Tuple]]`, *optional*): + Answer coordinates of each table-question pair in the batch. The answer_coordinates must be a single + list of one or more tuples. Each tuple must be a (row_index, column_index) pair. The first data row + (not the column header row) has index 0. The first column has index 0. + answer_text (`List[str]` or `List[List[str]]`, *optional*): + Answer text of each table-question pair in the batch. The answer_text must be a single list of one or + more strings. Each string must be the answer text of a corresponding answer coordinate. + """ + if return_token_type_ids is not None and not add_special_tokens: + raise ValueError( + "Asking to return token_type_ids while setting add_special_tokens to False " + "results in an undefined behavior. Please set add_special_tokens to True or " + "set return_token_type_ids to None." + ) + + if (answer_coordinates and not answer_text) or (not answer_coordinates and answer_text): + raise ValueError("In case you provide answers, both answer_coordinates and answer_text should be provided") + + if "is_split_into_words" in kwargs: + raise NotImplementedError("Currently TapasTokenizer only supports questions as strings.") + + if return_offsets_mapping: + raise NotImplementedError( + "return_offset_mapping is not available when using Python tokenizers. " + "To use this feature, change your tokenizer to one deriving from " + "transformers.PreTrainedTokenizerFast." + ) + + return self._encode_plus( + table=table, + query=query, + answer_coordinates=answer_coordinates, + answer_text=answer_text, + add_special_tokens=add_special_tokens, + truncation=truncation, + padding=padding, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + + def _encode_plus( + self, + table: "pd.DataFrame", + query: Union[ + TextInput, + PreTokenizedInput, + EncodedInput, + ], + answer_coordinates: Optional[List[Tuple]] = None, + answer_text: Optional[List[TextInput]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TapasTruncationStrategy] = False, + max_length: Optional[int] = None, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = True, + return_attention_mask: Optional[bool] = True, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ): + if query is None: + query = "" + logger.warning( + "TAPAS is a question answering model but you have not passed a query. Please be aware that the " + "model will probably not behave correctly." + ) + + table_tokens = self._tokenize_table(table) + query, query_tokens = self._get_question_tokens(query) + + return self.prepare_for_model( + table, + query, + tokenized_table=table_tokens, + query_tokens=query_tokens, + answer_coordinates=answer_coordinates, + answer_text=answer_text, + add_special_tokens=add_special_tokens, + truncation=truncation, + padding=padding, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_tensors=return_tensors, + prepend_batch_axis=True, + return_attention_mask=return_attention_mask, + return_token_type_ids=return_token_type_ids, + return_special_tokens_mask=return_special_tokens_mask, + return_length=return_length, + verbose=verbose, + ) + + @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, TAPAS_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) + def prepare_for_model( + self, + raw_table: "pd.DataFrame", + raw_query: Union[ + TextInput, + PreTokenizedInput, + EncodedInput, + ], + tokenized_table: Optional[TokenizedTable] = None, + query_tokens: Optional[TokenizedTable] = None, + answer_coordinates: Optional[List[Tuple]] = None, + answer_text: Optional[List[TextInput]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TapasTruncationStrategy] = False, + max_length: Optional[int] = None, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = True, + return_attention_mask: Optional[bool] = True, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + prepend_batch_axis: bool = False, + **kwargs, + ) -> BatchEncoding: + """ + Prepares a sequence of input id so that it can be used by the model. It adds special tokens, truncates + sequences if overflowing while taking into account the special tokens. + + Args: + raw_table (`pd.DataFrame`): + The original table before any transformation (like tokenization) was applied to it. + raw_query (`TextInput` or `PreTokenizedInput` or `EncodedInput`): + The original query before any transformation (like tokenization) was applied to it. + tokenized_table (`TokenizedTable`): + The table after tokenization. + query_tokens (`List[str]`): + The query after tokenization. + answer_coordinates (`List[Tuple]` or `List[List[Tuple]]`, *optional*): + Answer coordinates of each table-question pair in the batch. The answer_coordinates must be a single + list of one or more tuples. Each tuple must be a (row_index, column_index) pair. The first data row + (not the column header row) has index 0. The first column has index 0. + answer_text (`List[str]` or `List[List[str]]`, *optional*): + Answer text of each table-question pair in the batch. The answer_text must be a single list of one or + more strings. Each string must be the answer text of a corresponding answer coordinate. + """ + if isinstance(padding, bool): + if padding and (max_length is not None or pad_to_multiple_of is not None): + padding = PaddingStrategy.MAX_LENGTH + else: + padding = PaddingStrategy.DO_NOT_PAD + elif not isinstance(padding, PaddingStrategy): + padding = PaddingStrategy(padding) + + if isinstance(truncation, bool): + if truncation: + truncation = TapasTruncationStrategy.DROP_ROWS_TO_FIT + else: + truncation = TapasTruncationStrategy.DO_NOT_TRUNCATE + elif not isinstance(truncation, TapasTruncationStrategy): + truncation = TapasTruncationStrategy(truncation) + + encoded_inputs = {} + + is_part_of_batch = False + prev_answer_coordinates, prev_answer_text = None, None + if "prev_answer_coordinates" in kwargs and "prev_answer_text" in kwargs: + is_part_of_batch = True + prev_answer_coordinates = kwargs["prev_answer_coordinates"] + prev_answer_text = kwargs["prev_answer_text"] + + num_rows = self._get_num_rows(raw_table, truncation != TapasTruncationStrategy.DO_NOT_TRUNCATE) + num_columns = self._get_num_columns(raw_table) + _, _, num_tokens = self._get_table_boundaries(tokenized_table) + + if truncation != TapasTruncationStrategy.DO_NOT_TRUNCATE: + num_rows, num_tokens = self._get_truncated_table_rows( + query_tokens, tokenized_table, num_rows, num_columns, max_length, truncation_strategy=truncation + ) + table_data = list(self._get_table_values(tokenized_table, num_columns, num_rows, num_tokens)) + + query_ids = self.convert_tokens_to_ids(query_tokens) + table_ids = list(zip(*table_data))[0] if len(table_data) > 0 else list(zip(*table_data)) + table_ids = self.convert_tokens_to_ids(list(table_ids)) + + if "return_overflowing_tokens" in kwargs and kwargs["return_overflowing_tokens"]: + raise ValueError("TAPAS does not return overflowing tokens as it works on tables.") + + if add_special_tokens: + input_ids = self.build_inputs_with_special_tokens(query_ids, table_ids) + else: + input_ids = query_ids + table_ids + + if max_length is not None and len(input_ids) > max_length: + raise ValueError( + "Could not encode the query and table header given the maximum length. Encoding the query and table " + f"header results in a length of {len(input_ids)} which is higher than the max_length of {max_length}" + ) + + encoded_inputs["input_ids"] = input_ids + + segment_ids = self.create_segment_token_type_ids_from_sequences(query_ids, table_data) + column_ids = self.create_column_token_type_ids_from_sequences(query_ids, table_data) + row_ids = self.create_row_token_type_ids_from_sequences(query_ids, table_data) + if not is_part_of_batch or (prev_answer_coordinates is None and prev_answer_text is None): + # simply set the prev_labels to zeros + prev_labels = [0] * len(row_ids) + else: + prev_labels = self.get_answer_ids( + column_ids, row_ids, table_data, prev_answer_text, prev_answer_coordinates + ) + + # FIRST: parse both the table and question in terms of numeric values + + raw_table = add_numeric_table_values(raw_table) + raw_query = add_numeric_values_to_question(raw_query) + + # SECOND: add numeric-related features (and not parse them in these functions): + + column_ranks, inv_column_ranks = self._get_numeric_column_ranks(column_ids, row_ids, raw_table) + numeric_relations = self._get_numeric_relations(raw_query, column_ids, row_ids, raw_table) + + # Load from model defaults + if return_token_type_ids is None: + return_token_type_ids = "token_type_ids" in self.model_input_names + if return_attention_mask is None: + return_attention_mask = "attention_mask" in self.model_input_names + + if return_attention_mask: + attention_mask = self.create_attention_mask_from_sequences(query_ids, table_data) + encoded_inputs["attention_mask"] = attention_mask + + if answer_coordinates is not None and answer_text is not None: + labels = self.get_answer_ids(column_ids, row_ids, table_data, answer_text, answer_coordinates) + numeric_values = self._get_numeric_values(raw_table, column_ids, row_ids) + numeric_values_scale = self._get_numeric_values_scale(raw_table, column_ids, row_ids) + + encoded_inputs["labels"] = labels + encoded_inputs["numeric_values"] = numeric_values + encoded_inputs["numeric_values_scale"] = numeric_values_scale + + if return_token_type_ids: + token_type_ids = [ + segment_ids, + column_ids, + row_ids, + prev_labels, + column_ranks, + inv_column_ranks, + numeric_relations, + ] + + token_type_ids = [list(ids) for ids in list(zip(*token_type_ids))] + encoded_inputs["token_type_ids"] = token_type_ids + + if return_special_tokens_mask: + if add_special_tokens: + encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask(query_ids, table_ids) + else: + encoded_inputs["special_tokens_mask"] = [0] * len(input_ids) + + # Check lengths + if max_length is None and len(encoded_inputs["input_ids"]) > self.model_max_length and verbose: + if not self.deprecation_warnings.get("sequence-length-is-longer-than-the-specified-maximum", False): + logger.warning( + "Token indices sequence length is longer than the specified maximum sequence length " + f"for this model ({len(encoded_inputs['input_ids'])} > {self.model_max_length}). Running this " + "sequence through the model will result in indexing errors." + ) + self.deprecation_warnings["sequence-length-is-longer-than-the-specified-maximum"] = True + + # Padding + if padding != PaddingStrategy.DO_NOT_PAD or return_attention_mask: + encoded_inputs = self.pad( + encoded_inputs, + max_length=max_length, + padding=padding.value, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_attention_mask=return_attention_mask, + ) + + if return_length: + encoded_inputs["length"] = len(encoded_inputs["input_ids"]) + + batch_outputs = BatchEncoding( + encoded_inputs, tensor_type=return_tensors, prepend_batch_axis=prepend_batch_axis + ) + + return batch_outputs + + def _get_truncated_table_rows( + self, + query_tokens: List[str], + tokenized_table: TokenizedTable, + num_rows: int, + num_columns: int, + max_length: int, + truncation_strategy: Union[str, TapasTruncationStrategy], + ) -> Tuple[int, int]: + """ + Truncates a sequence pair in-place following the strategy. + + Args: + query_tokens (`List[str]`): + List of strings corresponding to the tokenized query. + tokenized_table (`TokenizedTable`): + Tokenized table + num_rows (`int`): + Total number of table rows + num_columns (`int`): + Total number of table columns + max_length (`int`): + Total maximum length. + truncation_strategy (`str` or [`TapasTruncationStrategy]`): + Truncation strategy to use. Seeing as this method should only be called when truncating, the only + available strategy is the `"drop_rows_to_fit"` strategy. + + Returns: + `Tuple(int, int)`: tuple containing the number of rows after truncation, and the number of tokens available + for each table element. + """ + if not isinstance(truncation_strategy, TapasTruncationStrategy): + truncation_strategy = TapasTruncationStrategy(truncation_strategy) + + if max_length is None: + max_length = self.model_max_length + + if truncation_strategy == TapasTruncationStrategy.DROP_ROWS_TO_FIT: + while True: + num_tokens = self._get_max_num_tokens( + query_tokens, tokenized_table, num_rows=num_rows, num_columns=num_columns, max_length=max_length + ) + + if num_tokens is not None: + # We could fit the table. + break + + # Try to drop a row to fit the table. + num_rows -= 1 + + if num_rows < 1: + break + elif truncation_strategy != TapasTruncationStrategy.DO_NOT_TRUNCATE: + raise ValueError(f"Unknown truncation strategy {truncation_strategy}.") + + return num_rows, num_tokens or 1 + + def _tokenize_table( + self, + table=None, + ): + """ + Tokenizes column headers and cell texts of a table. + + Args: + table (`pd.Dataframe`): + Table. Returns: `TokenizedTable`: TokenizedTable object. + """ + tokenized_rows = [] + tokenized_row = [] + # tokenize column headers + for column in table: + if self.strip_column_names: + tokenized_row.append(self.tokenize("")) + else: + tokenized_row.append(self.tokenize(column)) + tokenized_rows.append(tokenized_row) + + # tokenize cell values + for idx, row in table.iterrows(): + tokenized_row = [] + for cell in row: + tokenized_row.append(self.tokenize(cell)) + tokenized_rows.append(tokenized_row) + + token_coordinates = [] + for row_index, row in enumerate(tokenized_rows): + for column_index, cell in enumerate(row): + for token_index, _ in enumerate(cell): + token_coordinates.append( + TokenCoordinates( + row_index=row_index, + column_index=column_index, + token_index=token_index, + ) + ) + + return TokenizedTable( + rows=tokenized_rows, + selected_tokens=token_coordinates, + ) + + def _question_encoding_cost(self, question_tokens): + # Two extra spots of SEP and CLS. + return len(question_tokens) + 2 + + def _get_token_budget(self, question_tokens, max_length=None): + """ + Computes the number of tokens left for the table after tokenizing a question, taking into account the max + sequence length of the model. + + Args: + question_tokens (`List[String]`): + List of question tokens. Returns: `int`: the number of tokens left for the table, given the model max + length. + """ + return (max_length if max_length is not None else self.model_max_length) - self._question_encoding_cost( + question_tokens + ) + + def _get_table_values(self, table, num_columns, num_rows, num_tokens) -> Generator[TableValue, None, None]: + """Iterates over partial table and returns token, column and row indexes.""" + for tc in table.selected_tokens: + # First row is header row. + if tc.row_index >= num_rows + 1: + continue + if tc.column_index >= num_columns: + continue + cell = table.rows[tc.row_index][tc.column_index] + token = cell[tc.token_index] + word_begin_index = tc.token_index + # Don't add partial words. Find the starting word piece and check if it + # fits in the token budget. + while word_begin_index >= 0 and _is_inner_wordpiece(cell[word_begin_index]): + word_begin_index -= 1 + if word_begin_index >= num_tokens: + continue + yield TableValue(token, tc.column_index + 1, tc.row_index) + + def _get_table_boundaries(self, table): + """Return maximal number of rows, columns and tokens.""" + max_num_tokens = 0 + max_num_columns = 0 + max_num_rows = 0 + for tc in table.selected_tokens: + max_num_columns = max(max_num_columns, tc.column_index + 1) + max_num_rows = max(max_num_rows, tc.row_index + 1) + max_num_tokens = max(max_num_tokens, tc.token_index + 1) + max_num_columns = min(self.max_column_id, max_num_columns) + max_num_rows = min(self.max_row_id, max_num_rows) + return max_num_rows, max_num_columns, max_num_tokens + + def _get_table_cost(self, table, num_columns, num_rows, num_tokens): + return sum(1 for _ in self._get_table_values(table, num_columns, num_rows, num_tokens)) + + def _get_max_num_tokens(self, question_tokens, tokenized_table, num_columns, num_rows, max_length): + """Computes max number of tokens that can be squeezed into the budget.""" + token_budget = self._get_token_budget(question_tokens, max_length) + _, _, max_num_tokens = self._get_table_boundaries(tokenized_table) + if self.cell_trim_length >= 0 and max_num_tokens > self.cell_trim_length: + max_num_tokens = self.cell_trim_length + num_tokens = 0 + for num_tokens in range(max_num_tokens + 1): + cost = self._get_table_cost(tokenized_table, num_columns, num_rows, num_tokens + 1) + if cost > token_budget: + break + if num_tokens < max_num_tokens: + if self.cell_trim_length >= 0: + # We don't allow dynamic trimming if a cell_trim_length is set. + return None + if num_tokens == 0: + return None + return num_tokens + + def _get_num_columns(self, table): + num_columns = table.shape[1] + if num_columns >= self.max_column_id: + raise ValueError("Too many columns") + return num_columns + + def _get_num_rows(self, table, drop_rows_to_fit): + num_rows = table.shape[0] + if num_rows >= self.max_row_id: + if drop_rows_to_fit: + num_rows = self.max_row_id - 1 + else: + raise ValueError("Too many rows") + return num_rows + + def _serialize_text(self, question_tokens): + """Serializes texts in index arrays.""" + tokens = [] + segment_ids = [] + column_ids = [] + row_ids = [] + + # add [CLS] token at the beginning + tokens.append(self.cls_token) + segment_ids.append(0) + column_ids.append(0) + row_ids.append(0) + + for token in question_tokens: + tokens.append(token) + segment_ids.append(0) + column_ids.append(0) + row_ids.append(0) + + return tokens, segment_ids, column_ids, row_ids + + def _serialize( + self, + question_tokens, + table, + num_columns, + num_rows, + num_tokens, + ): + """Serializes table and text.""" + tokens, segment_ids, column_ids, row_ids = self._serialize_text(question_tokens) + + # add [SEP] token between question and table tokens + tokens.append(self.sep_token) + segment_ids.append(0) + column_ids.append(0) + row_ids.append(0) + + for token, column_id, row_id in self._get_table_values(table, num_columns, num_rows, num_tokens): + tokens.append(token) + segment_ids.append(1) + column_ids.append(column_id) + row_ids.append(row_id) + + return SerializedExample( + tokens=tokens, + segment_ids=segment_ids, + column_ids=column_ids, + row_ids=row_ids, + ) + + def _get_column_values(self, table, col_index): + table_numeric_values = {} + for row_index, row in table.iterrows(): + cell = row[col_index] + if cell.numeric_value is not None: + table_numeric_values[row_index] = cell.numeric_value + return table_numeric_values + + def _get_cell_token_indexes(self, column_ids, row_ids, column_id, row_id): + for index in range(len(column_ids)): + if column_ids[index] - 1 == column_id and row_ids[index] - 1 == row_id: + yield index + + def _get_numeric_column_ranks(self, column_ids, row_ids, table): + """Returns column ranks for all numeric columns.""" + + ranks = [0] * len(column_ids) + inv_ranks = [0] * len(column_ids) + + # original code from tf_example_utils.py of the original implementation + if table is not None: + for col_index in range(len(table.columns)): + table_numeric_values = self._get_column_values(table, col_index) + + if not table_numeric_values: + continue + + try: + key_fn = get_numeric_sort_key_fn(table_numeric_values.values()) + except ValueError: + continue + + table_numeric_values = {row_index: key_fn(value) for row_index, value in table_numeric_values.items()} + + table_numeric_values_inv = collections.defaultdict(list) + for row_index, value in table_numeric_values.items(): + table_numeric_values_inv[value].append(row_index) + + unique_values = sorted(table_numeric_values_inv.keys()) + + for rank, value in enumerate(unique_values): + for row_index in table_numeric_values_inv[value]: + for index in self._get_cell_token_indexes(column_ids, row_ids, col_index, row_index): + ranks[index] = rank + 1 + inv_ranks[index] = len(unique_values) - rank + + return ranks, inv_ranks + + def _get_numeric_sort_key_fn(self, table_numeric_values, value): + """ + Returns the sort key function for comparing value to table values. The function returned will be a suitable + input for the key param of the sort(). See number_annotation_utils._get_numeric_sort_key_fn for details + + Args: + table_numeric_values: Numeric values of a column + value: Numeric value in the question + + Returns: + A function key function to compare column and question values. + """ + if not table_numeric_values: + return None + all_values = list(table_numeric_values.values()) + all_values.append(value) + try: + return get_numeric_sort_key_fn(all_values) + except ValueError: + return None + + def _get_numeric_relations(self, question, column_ids, row_ids, table): + """ + Returns numeric relations embeddings + + Args: + question: Question object. + column_ids: Maps word piece position to column id. + row_ids: Maps word piece position to row id. + table: The table containing the numeric cell values. + """ + + numeric_relations = [0] * len(column_ids) + + # first, we add any numeric value spans to the question: + # Create a dictionary that maps a table cell to the set of all relations + # this cell has with any value in the question. + cell_indices_to_relations = collections.defaultdict(set) + if question is not None and table is not None: + for numeric_value_span in question.numeric_spans: + for value in numeric_value_span.values: + for column_index in range(len(table.columns)): + table_numeric_values = self._get_column_values(table, column_index) + sort_key_fn = self._get_numeric_sort_key_fn(table_numeric_values, value) + if sort_key_fn is None: + continue + for row_index, cell_value in table_numeric_values.items(): + relation = get_numeric_relation(value, cell_value, sort_key_fn) + if relation is not None: + cell_indices_to_relations[column_index, row_index].add(relation) + + # For each cell add a special feature for all its word pieces. + for (column_index, row_index), relations in cell_indices_to_relations.items(): + relation_set_index = 0 + for relation in relations: + assert relation.value >= Relation.EQ.value + relation_set_index += 2 ** (relation.value - Relation.EQ.value) + for cell_token_index in self._get_cell_token_indexes(column_ids, row_ids, column_index, row_index): + numeric_relations[cell_token_index] = relation_set_index + + return numeric_relations + + def _get_numeric_values(self, table, column_ids, row_ids): + """Returns numeric values for computation of answer loss.""" + + numeric_values = [float("nan")] * len(column_ids) + + if table is not None: + num_rows = table.shape[0] + num_columns = table.shape[1] + + for col_index in range(num_columns): + for row_index in range(num_rows): + numeric_value = table.iloc[row_index, col_index].numeric_value + if numeric_value is not None: + if numeric_value.float_value is None: + continue + float_value = numeric_value.float_value + if float_value == float("inf"): + continue + for index in self._get_cell_token_indexes(column_ids, row_ids, col_index, row_index): + numeric_values[index] = float_value + + return numeric_values + + def _get_numeric_values_scale(self, table, column_ids, row_ids): + """Returns a scale to each token to down weigh the value of long words.""" + + numeric_values_scale = [1.0] * len(column_ids) + + if table is None: + return numeric_values_scale + + num_rows = table.shape[0] + num_columns = table.shape[1] + + for col_index in range(num_columns): + for row_index in range(num_rows): + indices = list(self._get_cell_token_indexes(column_ids, row_ids, col_index, row_index)) + num_indices = len(indices) + if num_indices > 1: + for index in indices: + numeric_values_scale[index] = float(num_indices) + + return numeric_values_scale + + def _pad_to_seq_length(self, inputs): + while len(inputs) > self.model_max_length: + inputs.pop() + while len(inputs) < self.model_max_length: + inputs.append(0) + + def _get_all_answer_ids_from_coordinates( + self, + column_ids, + row_ids, + answers_list, + ): + """Maps lists of answer coordinates to token indexes.""" + answer_ids = [0] * len(column_ids) + found_answers = set() + all_answers = set() + for answers in answers_list: + column_index, row_index = answers + all_answers.add((column_index, row_index)) + for index in self._get_cell_token_indexes(column_ids, row_ids, column_index, row_index): + found_answers.add((column_index, row_index)) + answer_ids[index] = 1 + + missing_count = len(all_answers) - len(found_answers) + return answer_ids, missing_count + + def _get_all_answer_ids(self, column_ids, row_ids, answer_coordinates): + """ + Maps answer coordinates of a question to token indexes. + + In the SQA format (TSV), the coordinates are given as (row, column) tuples. Here, we first swap them to + (column, row) format before calling _get_all_answer_ids_from_coordinates. + """ + + def _to_coordinates(answer_coordinates_question): + return [(coords[1], coords[0]) for coords in answer_coordinates_question] + + return self._get_all_answer_ids_from_coordinates( + column_ids, row_ids, answers_list=(_to_coordinates(answer_coordinates)) + ) + + def _find_tokens(self, text, segment): + """Return start index of segment in text or None.""" + logging.info(f"text: {text} {segment}") + for index in range(1 + len(text) - len(segment)): + for seg_index, seg_token in enumerate(segment): + if text[index + seg_index].piece != seg_token.piece: + break + else: + return index + return None + + def _find_answer_coordinates_from_answer_text( + self, + tokenized_table, + answer_text, + ): + """Returns all occurrences of answer_text in the table.""" + logging.info(f"answer text: {answer_text}") + for row_index, row in enumerate(tokenized_table.rows): + if row_index == 0: + # We don't search for answers in the header. + continue + for col_index, cell in enumerate(row): + token_index = self._find_tokens(cell, answer_text) + if token_index is not None: + yield TokenCoordinates( + row_index=row_index, + column_index=col_index, + token_index=token_index, + ) + + def _find_answer_ids_from_answer_texts( + self, + column_ids, + row_ids, + tokenized_table, + answer_texts, + ): + """Maps question with answer texts to the first matching token indexes.""" + answer_ids = [0] * len(column_ids) + for answer_text in answer_texts: + for coordinates in self._find_answer_coordinates_from_answer_text( + tokenized_table, + answer_text, + ): + # Maps answer coordinates to indexes this can fail if tokens / rows have + # been pruned. + indexes = list( + self._get_cell_token_indexes( + column_ids, + row_ids, + column_id=coordinates.column_index, + row_id=coordinates.row_index - 1, + ) + ) + indexes.sort() + coordinate_answer_ids = [] + if indexes: + begin_index = coordinates.token_index + indexes[0] + end_index = begin_index + len(answer_text) + for index in indexes: + if index >= begin_index and index < end_index: + coordinate_answer_ids.append(index) + if len(coordinate_answer_ids) == len(answer_text): + for index in coordinate_answer_ids: + answer_ids[index] = 1 + break + return answer_ids + + def _get_answer_ids(self, column_ids, row_ids, answer_coordinates): + """Maps answer coordinates of a question to token indexes.""" + answer_ids, missing_count = self._get_all_answer_ids(column_ids, row_ids, answer_coordinates) + + if missing_count: + raise ValueError("Couldn't find all answers") + return answer_ids + + def get_answer_ids(self, column_ids, row_ids, tokenized_table, answer_texts_question, answer_coordinates_question): + if self.update_answer_coordinates: + return self._find_answer_ids_from_answer_texts( + column_ids, + row_ids, + tokenized_table, + answer_texts=[self.tokenize(at) for at in answer_texts_question], + ) + return self._get_answer_ids(column_ids, row_ids, answer_coordinates_question) + + def _pad( + self, + encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding], + max_length: Optional[int] = None, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_attention_mask: Optional[bool] = None, + ) -> dict: + """ + Pad encoded inputs (on left/right and up to predefined length or max length in the batch) + + Args: + encoded_inputs: + Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`). + max_length: maximum length of the returned list and optionally padding length (see below). + Will truncate by taking into account the special tokens. + padding_strategy: PaddingStrategy to use for padding. + + - PaddingStrategy.LONGEST Pad to the longest sequence in the batch + - PaddingStrategy.MAX_LENGTH: Pad to the max length (default) + - PaddingStrategy.DO_NOT_PAD: Do not pad + The tokenizer padding sides are defined in self.padding_side: + + - 'left': pads on the left of the sequences + - 'right': pads on the right of the sequences + pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value. + This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability + `>= 7.5` (Volta). + padding_side: + The side on which the model should have padding applied. Should be selected between ['right', 'left']. + Default value is picked from the class attribute of the same name. + return_attention_mask: + (optional) Set to False to avoid returning attention mask (default: set to model specifics) + """ + # Load from model defaults + if return_attention_mask is None: + return_attention_mask = "attention_mask" in self.model_input_names + + if padding_strategy == PaddingStrategy.LONGEST: + max_length = len(encoded_inputs["input_ids"]) + + if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): + max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of + + needs_to_be_padded = ( + padding_strategy != PaddingStrategy.DO_NOT_PAD and len(encoded_inputs["input_ids"]) != max_length + ) + + # Initialize attention mask if not present. + if return_attention_mask and "attention_mask" not in encoded_inputs: + encoded_inputs["attention_mask"] = [1] * len(encoded_inputs["input_ids"]) + + if needs_to_be_padded: + difference = max_length - len(encoded_inputs["input_ids"]) + padding_side = padding_side if padding_side is not None else self.padding_side + if padding_side == "right": + if return_attention_mask: + encoded_inputs["attention_mask"] = encoded_inputs["attention_mask"] + [0] * difference + if "token_type_ids" in encoded_inputs: + encoded_inputs["token_type_ids"] = ( + encoded_inputs["token_type_ids"] + [[self.pad_token_type_id] * 7] * difference + ) + if "labels" in encoded_inputs: + encoded_inputs["labels"] = encoded_inputs["labels"] + [0] * difference + if "numeric_values" in encoded_inputs: + encoded_inputs["numeric_values"] = encoded_inputs["numeric_values"] + [float("nan")] * difference + if "numeric_values_scale" in encoded_inputs: + encoded_inputs["numeric_values_scale"] = ( + encoded_inputs["numeric_values_scale"] + [1.0] * difference + ) + if "special_tokens_mask" in encoded_inputs: + encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference + encoded_inputs["input_ids"] = encoded_inputs["input_ids"] + [self.pad_token_id] * difference + elif padding_side == "left": + if return_attention_mask: + encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"] + if "token_type_ids" in encoded_inputs: + encoded_inputs["token_type_ids"] = [[self.pad_token_type_id] * 7] * difference + encoded_inputs[ + "token_type_ids" + ] + if "labels" in encoded_inputs: + encoded_inputs["labels"] = [0] * difference + encoded_inputs["labels"] + if "numeric_values" in encoded_inputs: + encoded_inputs["numeric_values"] = [float("nan")] * difference + encoded_inputs["numeric_values"] + if "numeric_values_scale" in encoded_inputs: + encoded_inputs["numeric_values_scale"] = [1.0] * difference + encoded_inputs[ + "numeric_values_scale" + ] + if "special_tokens_mask" in encoded_inputs: + encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"] + encoded_inputs["input_ids"] = [self.pad_token_id] * difference + encoded_inputs["input_ids"] + else: + raise ValueError("Invalid padding strategy:" + str(padding_side)) + + return encoded_inputs + + # Everything related to converting logits to predictions + + def _get_cell_token_probs(self, probabilities, segment_ids, row_ids, column_ids): + for i, p in enumerate(probabilities): + segment_id = segment_ids[i] + col = column_ids[i] - 1 + row = row_ids[i] - 1 + if col >= 0 and row >= 0 and segment_id == 1: + yield i, p + + def _get_mean_cell_probs(self, probabilities, segment_ids, row_ids, column_ids): + """Computes average probability per cell, aggregating over tokens.""" + coords_to_probs = collections.defaultdict(list) + for i, prob in self._get_cell_token_probs(probabilities, segment_ids, row_ids, column_ids): + col = column_ids[i] - 1 + row = row_ids[i] - 1 + coords_to_probs[(col, row)].append(prob) + return {coords: np.array(cell_probs).mean() for coords, cell_probs in coords_to_probs.items()} + + def convert_logits_to_predictions(self, data, logits, logits_agg=None, cell_classification_threshold=0.5): + """ + Converts logits of [`TapasForQuestionAnswering`] to actual predicted answer coordinates and optional + aggregation indices. + + The original implementation, on which this function is based, can be found + [here](https://github.com/google-research/tapas/blob/4908213eb4df7aa988573350278b44c4dbe3f71b/tapas/experiments/prediction_utils.py#L288). + + Args: + data (`dict`): + Dictionary mapping features to actual values. Should be created using [`TapasTokenizer`]. + logits (`torch.Tensor` or `tf.Tensor` of shape `(batch_size, sequence_length)`): + Tensor containing the logits at the token level. + logits_agg (`torch.Tensor` or `tf.Tensor` of shape `(batch_size, num_aggregation_labels)`, *optional*): + Tensor containing the aggregation logits. + cell_classification_threshold (`float`, *optional*, defaults to 0.5): + Threshold to be used for cell selection. All table cells for which their probability is larger than + this threshold will be selected. + + Returns: + `tuple` comprising various elements depending on the inputs: + + - predicted_answer_coordinates (`List[List[[tuple]]` of length `batch_size`): Predicted answer coordinates + as a list of lists of tuples. Each element in the list contains the predicted answer coordinates of a + single example in the batch, as a list of tuples. Each tuple is a cell, i.e. (row index, column index). + - predicted_aggregation_indices (`List[int]`of length `batch_size`, *optional*, returned when + `logits_aggregation` is provided): Predicted aggregation operator indices of the aggregation head. + """ + # converting to numpy arrays to work with PT/TF + logits = logits.numpy() + if logits_agg is not None: + logits_agg = logits_agg.numpy() + data = {key: value.numpy() for key, value in data.items() if key != "training"} + # input data is of type float32 + # np.log(np.finfo(np.float32).max) = 88.72284 + # Any value over 88.72284 will overflow when passed through the exponential, sending a warning + # We disable this warning by truncating the logits. + logits[logits < -88.7] = -88.7 + + # Compute probabilities from token logits + probabilities = 1 / (1 + np.exp(-logits)) * data["attention_mask"] + token_types = [ + "segment_ids", + "column_ids", + "row_ids", + "prev_labels", + "column_ranks", + "inv_column_ranks", + "numeric_relations", + ] + + # collect input_ids, segment ids, row ids and column ids of batch. Shape (batch_size, seq_len) + input_ids = data["input_ids"] + segment_ids = data["token_type_ids"][:, :, token_types.index("segment_ids")] + row_ids = data["token_type_ids"][:, :, token_types.index("row_ids")] + column_ids = data["token_type_ids"][:, :, token_types.index("column_ids")] + + # next, get answer coordinates for every example in the batch + num_batch = input_ids.shape[0] + predicted_answer_coordinates = [] + for i in range(num_batch): + probabilities_example = probabilities[i].tolist() + segment_ids_example = segment_ids[i] + row_ids_example = row_ids[i] + column_ids_example = column_ids[i] + + max_width = column_ids_example.max() + max_height = row_ids_example.max() + + if max_width == 0 and max_height == 0: + continue + + cell_coords_to_prob = self._get_mean_cell_probs( + probabilities_example, + segment_ids_example.tolist(), + row_ids_example.tolist(), + column_ids_example.tolist(), + ) + + # Select the answers above the classification threshold. + answer_coordinates = [] + for col in range(max_width): + for row in range(max_height): + cell_prob = cell_coords_to_prob.get((col, row), None) + if cell_prob is not None: + if cell_prob > cell_classification_threshold: + answer_coordinates.append((row, col)) + answer_coordinates = sorted(answer_coordinates) + predicted_answer_coordinates.append(answer_coordinates) + + output = (predicted_answer_coordinates,) + + if logits_agg is not None: + predicted_aggregation_indices = logits_agg.argmax(axis=-1) + output = (predicted_answer_coordinates, predicted_aggregation_indices.tolist()) + + return output + + # End of everything related to converting logits to predictions + + +# Copied from transformers.models.bert.tokenization_bert.BasicTokenizer +class BasicTokenizer: + """ + Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.). + + Args: + do_lower_case (`bool`, *optional*, defaults to `True`): + Whether or not to lowercase the input when tokenizing. + never_split (`Iterable`, *optional*): + Collection of tokens which will never be split during tokenization. Only has an effect when + `do_basic_tokenize=True` + tokenize_chinese_chars (`bool`, *optional*, defaults to `True`): + Whether or not to tokenize Chinese characters. + + This should likely be deactivated for Japanese (see this + [issue](https://github.com/huggingface/transformers/issues/328)). + strip_accents (`bool`, *optional*): + Whether or not to strip all accents. If this option is not specified, then it will be determined by the + value for `lowercase` (as in the original BERT). + do_split_on_punc (`bool`, *optional*, defaults to `True`): + In some instances we want to skip the basic punctuation splitting so that later tokenization can capture + the full context of the words, such as contractions. + """ + + def __init__( + self, + do_lower_case=True, + never_split=None, + tokenize_chinese_chars=True, + strip_accents=None, + do_split_on_punc=True, + ): + if never_split is None: + never_split = [] + self.do_lower_case = do_lower_case + self.never_split = set(never_split) + self.tokenize_chinese_chars = tokenize_chinese_chars + self.strip_accents = strip_accents + self.do_split_on_punc = do_split_on_punc + + def tokenize(self, text, never_split=None): + """ + Basic Tokenization of a piece of text. For sub-word tokenization, see WordPieceTokenizer. + + Args: + never_split (`List[str]`, *optional*) + Kept for backward compatibility purposes. Now implemented directly at the base class level (see + [`PreTrainedTokenizer.tokenize`]) List of token not to split. + """ + # union() returns a new set by concatenating the two sets. + never_split = self.never_split.union(set(never_split)) if never_split else self.never_split + text = self._clean_text(text) + + # This was added on November 1st, 2018 for the multilingual and Chinese + # models. This is also applied to the English models now, but it doesn't + # matter since the English models were not trained on any Chinese data + # and generally don't have any Chinese data in them (there are Chinese + # characters in the vocabulary because Wikipedia does have some Chinese + # words in the English Wikipedia.). + if self.tokenize_chinese_chars: + text = self._tokenize_chinese_chars(text) + # prevents treating the same character with different unicode codepoints as different characters + unicode_normalized_text = unicodedata.normalize("NFC", text) + orig_tokens = whitespace_tokenize(unicode_normalized_text) + split_tokens = [] + for token in orig_tokens: + if token not in never_split: + if self.do_lower_case: + token = token.lower() + if self.strip_accents is not False: + token = self._run_strip_accents(token) + elif self.strip_accents: + token = self._run_strip_accents(token) + split_tokens.extend(self._run_split_on_punc(token, never_split)) + + output_tokens = whitespace_tokenize(" ".join(split_tokens)) + return output_tokens + + def _run_strip_accents(self, text): + """Strips accents from a piece of text.""" + text = unicodedata.normalize("NFD", text) + output = [] + for char in text: + cat = unicodedata.category(char) + if cat == "Mn": + continue + output.append(char) + return "".join(output) + + def _run_split_on_punc(self, text, never_split=None): + """Splits punctuation on a piece of text.""" + if not self.do_split_on_punc or (never_split is not None and text in never_split): + return [text] + chars = list(text) + i = 0 + start_new_word = True + output = [] + while i < len(chars): + char = chars[i] + if _is_punctuation(char): + output.append([char]) + start_new_word = True + else: + if start_new_word: + output.append([]) + start_new_word = False + output[-1].append(char) + i += 1 + + return ["".join(x) for x in output] + + def _tokenize_chinese_chars(self, text): + """Adds whitespace around any CJK character.""" + output = [] + for char in text: + cp = ord(char) + if self._is_chinese_char(cp): + output.append(" ") + output.append(char) + output.append(" ") + else: + output.append(char) + return "".join(output) + + def _is_chinese_char(self, cp): + """Checks whether CP is the codepoint of a CJK character.""" + # This defines a "chinese character" as anything in the CJK Unicode block: + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + # + # Note that the CJK Unicode block is NOT all Japanese and Korean characters, + # despite its name. The modern Korean Hangul alphabet is a different block, + # as is Japanese Hiragana and Katakana. Those alphabets are used to write + # space-separated words, so they are not treated specially and handled + # like the all of the other languages. + if ( + (cp >= 0x4E00 and cp <= 0x9FFF) + or (cp >= 0x3400 and cp <= 0x4DBF) # + or (cp >= 0x20000 and cp <= 0x2A6DF) # + or (cp >= 0x2A700 and cp <= 0x2B73F) # + or (cp >= 0x2B740 and cp <= 0x2B81F) # + or (cp >= 0x2B820 and cp <= 0x2CEAF) # + or (cp >= 0xF900 and cp <= 0xFAFF) + or (cp >= 0x2F800 and cp <= 0x2FA1F) # + ): # + return True + + return False + + def _clean_text(self, text): + """Performs invalid character removal and whitespace cleanup on text.""" + output = [] + for char in text: + cp = ord(char) + if cp == 0 or cp == 0xFFFD or _is_control(char): + continue + if _is_whitespace(char): + output.append(" ") + else: + output.append(char) + return "".join(output) + + +# Copied from transformers.models.bert.tokenization_bert.WordpieceTokenizer +class WordpieceTokenizer: + """Runs WordPiece tokenization.""" + + def __init__(self, vocab, unk_token, max_input_chars_per_word=100): + self.vocab = vocab + self.unk_token = unk_token + self.max_input_chars_per_word = max_input_chars_per_word + + def tokenize(self, text): + """ + Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform + tokenization using the given vocabulary. + + For example, `input = "unaffable"` wil return as output `["un", "##aff", "##able"]`. + + Args: + text: A single token or whitespace separated tokens. This should have + already been passed through *BasicTokenizer*. + + Returns: + A list of wordpiece tokens. + """ + + output_tokens = [] + for token in whitespace_tokenize(text): + chars = list(token) + if len(chars) > self.max_input_chars_per_word: + output_tokens.append(self.unk_token) + continue + + is_bad = False + start = 0 + sub_tokens = [] + while start < len(chars): + end = len(chars) + cur_substr = None + while start < end: + substr = "".join(chars[start:end]) + if start > 0: + substr = "##" + substr + if substr in self.vocab: + cur_substr = substr + break + end -= 1 + if cur_substr is None: + is_bad = True + break + sub_tokens.append(cur_substr) + start = end + + if is_bad: + output_tokens.append(self.unk_token) + else: + output_tokens.extend(sub_tokens) + return output_tokens + + +# Below: utilities for TAPAS tokenizer (independent from PyTorch/Tensorflow). +# This includes functions to parse numeric values (dates and numbers) from both the table and questions in order +# to create the column_ranks, inv_column_ranks, numeric_values, numeric values_scale and numeric_relations in +# prepare_for_model of TapasTokenizer. +# These are meant to be used in an academic setup, for production use cases Gold mine or Aqua should be used. + + +# taken from constants.py of the original implementation +# URL: https://github.com/google-research/tapas/blob/master/tapas/utils/constants.py +class Relation(enum.Enum): + HEADER_TO_CELL = 1 # Connects header to cell. + CELL_TO_HEADER = 2 # Connects cell to header. + QUERY_TO_HEADER = 3 # Connects query to headers. + QUERY_TO_CELL = 4 # Connects query to cells. + ROW_TO_CELL = 5 # Connects row to cells. + CELL_TO_ROW = 6 # Connects cells to row. + EQ = 7 # Annotation value is same as cell value + LT = 8 # Annotation value is less than cell value + GT = 9 # Annotation value is greater than cell value + + +@dataclass +class Date: + year: Optional[int] = None + month: Optional[int] = None + day: Optional[int] = None + + +@dataclass +class NumericValue: + float_value: Optional[float] = None + date: Optional[Date] = None + + +@dataclass +class NumericValueSpan: + begin_index: Optional[int] = None + end_index: Optional[int] = None + values: List[NumericValue] = None + + +@dataclass +class Cell: + text: str + numeric_value: Optional[NumericValue] = None + + +@dataclass +class Question: + original_text: str # The original raw question string. + text: str # The question string after normalization. + numeric_spans: Optional[List[NumericValueSpan]] = None + + +# Below: all functions from number_utils.py as well as 2 functions (namely get_all_spans and normalize_for_match) +# from text_utils.py of the original implementation. URL's: +# - https://github.com/google-research/tapas/blob/master/tapas/utils/number_utils.py +# - https://github.com/google-research/tapas/blob/master/tapas/utils/text_utils.py + + +# Constants for parsing date expressions. +# Masks that specify (by a bool) which of (year, month, day) will be populated. +_DateMask = collections.namedtuple("_DateMask", ["year", "month", "day"]) + +_YEAR = _DateMask(True, False, False) +_YEAR_MONTH = _DateMask(True, True, False) +_YEAR_MONTH_DAY = _DateMask(True, True, True) +_MONTH = _DateMask(False, True, False) +_MONTH_DAY = _DateMask(False, True, True) + +# Pairs of patterns to pass to 'datetime.strptime' and masks specifying which +# fields will be set by the corresponding pattern. +_DATE_PATTERNS = ( + ("%B", _MONTH), + ("%Y", _YEAR), + ("%Ys", _YEAR), + ("%b %Y", _YEAR_MONTH), + ("%B %Y", _YEAR_MONTH), + ("%B %d", _MONTH_DAY), + ("%b %d", _MONTH_DAY), + ("%d %b", _MONTH_DAY), + ("%d %B", _MONTH_DAY), + ("%B %d, %Y", _YEAR_MONTH_DAY), + ("%d %B %Y", _YEAR_MONTH_DAY), + ("%m-%d-%Y", _YEAR_MONTH_DAY), + ("%Y-%m-%d", _YEAR_MONTH_DAY), + ("%Y-%m", _YEAR_MONTH), + ("%B %Y", _YEAR_MONTH), + ("%d %b %Y", _YEAR_MONTH_DAY), + ("%Y-%m-%d", _YEAR_MONTH_DAY), + ("%b %d, %Y", _YEAR_MONTH_DAY), + ("%d.%m.%Y", _YEAR_MONTH_DAY), + ("%A, %b %d", _MONTH_DAY), + ("%A, %B %d", _MONTH_DAY), +) + +# This mapping is used to convert date patterns to regex patterns. +_FIELD_TO_REGEX = ( + ("%A", r"\w+"), # Weekday as locale’s full name. + ("%B", r"\w+"), # Month as locale’s full name. + ("%Y", r"\d{4}"), # Year with century as a decimal number. + ("%b", r"\w{3}"), # Month as locale’s abbreviated name. + ("%d", r"\d{1,2}"), # Day of the month as a zero-padded decimal number. + ("%m", r"\d{1,2}"), # Month as a zero-padded decimal number. +) + + +def _process_date_pattern(dp): + """Compute a regex for each date pattern to use as a prefilter.""" + pattern, mask = dp + regex = pattern + regex = regex.replace(".", re.escape(".")) + regex = regex.replace("-", re.escape("-")) + regex = regex.replace(" ", r"\s+") + for field, field_regex in _FIELD_TO_REGEX: + regex = regex.replace(field, field_regex) + # Make sure we didn't miss any of the fields. + assert "%" not in regex, regex + return pattern, mask, re.compile("^" + regex + "$") + + +def _process_date_patterns(): + return tuple(_process_date_pattern(dp) for dp in _DATE_PATTERNS) + + +_PROCESSED_DATE_PATTERNS = _process_date_patterns() + +_MAX_DATE_NGRAM_SIZE = 5 + +# Following DynSp: +# https://github.com/Microsoft/DynSP/blob/master/util.py#L414. +_NUMBER_WORDS = [ + "zero", + "one", + "two", + "three", + "four", + "five", + "six", + "seven", + "eight", + "nine", + "ten", + "eleven", + "twelve", +] + +_ORDINAL_WORDS = [ + "zeroth", + "first", + "second", + "third", + "fourth", + "fith", + "sixth", + "seventh", + "eighth", + "ninth", + "tenth", + "eleventh", + "twelfth", +] + +_ORDINAL_SUFFIXES = ["st", "nd", "rd", "th"] + +_NUMBER_PATTERN = re.compile(r"((^|\s)[+-])?((\.\d+)|(\d+(,\d\d\d)*(\.\d*)?))") + +# Following DynSp: +# https://github.com/Microsoft/DynSP/blob/master/util.py#L293. +_MIN_YEAR = 1700 +_MAX_YEAR = 2016 + +_INF = float("INF") + + +def _get_numeric_value_from_date(date, mask): + """Converts date (datetime Python object) to a NumericValue object with a Date object value.""" + if date.year < _MIN_YEAR or date.year > _MAX_YEAR: + raise ValueError(f"Invalid year: {date.year}") + + new_date = Date() + if mask.year: + new_date.year = date.year + if mask.month: + new_date.month = date.month + if mask.day: + new_date.day = date.day + return NumericValue(date=new_date) + + +def _get_span_length_key(span): + """Sorts span by decreasing length first and increasing first index second.""" + return span[1] - span[0], -span[0] + + +def _get_numeric_value_from_float(value): + """Converts float (Python) to a NumericValue object with a float value.""" + return NumericValue(float_value=value) + + +# Doesn't parse ordinal expressions such as '18th of february 1655'. +def _parse_date(text): + """Attempts to format a text as a standard date string (yyyy-mm-dd).""" + text = re.sub(r"Sept\b", "Sep", text) + for in_pattern, mask, regex in _PROCESSED_DATE_PATTERNS: + if not regex.match(text): + continue + try: + date = datetime.datetime.strptime(text, in_pattern).date() + except ValueError: + continue + try: + return _get_numeric_value_from_date(date, mask) + except ValueError: + continue + return None + + +def _parse_number(text): + """Parses simple cardinal and ordinals numbers.""" + for suffix in _ORDINAL_SUFFIXES: + if text.endswith(suffix): + text = text[: -len(suffix)] + break + text = text.replace(",", "") + try: + value = float(text) + except ValueError: + return None + if math.isnan(value): + return None + if value == _INF: + return None + return value + + +def get_all_spans(text, max_ngram_length): + """ + Split a text into all possible ngrams up to 'max_ngram_length'. Split points are white space and punctuation. + + Args: + text: Text to split. + max_ngram_length: maximal ngram length. + Yields: + Spans, tuples of begin-end index. + """ + start_indexes = [] + for index, char in enumerate(text): + if not char.isalnum(): + continue + if index == 0 or not text[index - 1].isalnum(): + start_indexes.append(index) + if index + 1 == len(text) or not text[index + 1].isalnum(): + for start_index in start_indexes[-max_ngram_length:]: + yield start_index, index + 1 + + +def normalize_for_match(text): + return " ".join(text.lower().split()) + + +def format_text(text): + """Lowercases and strips punctuation.""" + text = text.lower().strip() + if text == "n/a" or text == "?" or text == "nan": + text = EMPTY_TEXT + + text = re.sub(r"[^\w\d]+", " ", text).replace("_", " ") + text = " ".join(text.split()) + text = text.strip() + if text: + return text + return EMPTY_TEXT + + +def parse_text(text): + """ + Extracts longest number and date spans. + + Args: + text: text to annotate + + Returns: + List of longest numeric value spans. + """ + span_dict = collections.defaultdict(list) + for match in _NUMBER_PATTERN.finditer(text): + span_text = text[match.start() : match.end()] + number = _parse_number(span_text) + if number is not None: + span_dict[match.span()].append(_get_numeric_value_from_float(number)) + + for begin_index, end_index in get_all_spans(text, max_ngram_length=1): + if (begin_index, end_index) in span_dict: + continue + span_text = text[begin_index:end_index] + + number = _parse_number(span_text) + if number is not None: + span_dict[begin_index, end_index].append(_get_numeric_value_from_float(number)) + for number, word in enumerate(_NUMBER_WORDS): + if span_text == word: + span_dict[begin_index, end_index].append(_get_numeric_value_from_float(float(number))) + break + for number, word in enumerate(_ORDINAL_WORDS): + if span_text == word: + span_dict[begin_index, end_index].append(_get_numeric_value_from_float(float(number))) + break + + for begin_index, end_index in get_all_spans(text, max_ngram_length=_MAX_DATE_NGRAM_SIZE): + span_text = text[begin_index:end_index] + date = _parse_date(span_text) + if date is not None: + span_dict[begin_index, end_index].append(date) + + spans = sorted(span_dict.items(), key=lambda span_value: _get_span_length_key(span_value[0]), reverse=True) + selected_spans = [] + for span, value in spans: + for selected_span, _ in selected_spans: + if selected_span[0] <= span[0] and span[1] <= selected_span[1]: + break + else: + selected_spans.append((span, value)) + + selected_spans.sort(key=lambda span_value: span_value[0][0]) + + numeric_value_spans = [] + for span, values in selected_spans: + numeric_value_spans.append(NumericValueSpan(begin_index=span[0], end_index=span[1], values=values)) + return numeric_value_spans + + +# Below: all functions from number_annotation_utils.py and 2 functions (namely filter_invalid_unicode +# and filter_invalid_unicode_from_table) from text_utils.py of the original implementation. URL's: +# - https://github.com/google-research/tapas/blob/master/tapas/utils/number_annotation_utils.py +# - https://github.com/google-research/tapas/blob/master/tapas/utils/text_utils.py + + +_PrimitiveNumericValue = Union[float, Tuple[Optional[float], Optional[float], Optional[float]]] +_SortKeyFn = Callable[[NumericValue], Tuple[float, Ellipsis]] + +_DATE_TUPLE_SIZE = 3 + +EMPTY_TEXT = "EMPTY" + +NUMBER_TYPE = "number" +DATE_TYPE = "date" + + +def _get_value_type(numeric_value): + if numeric_value.float_value is not None: + return NUMBER_TYPE + elif numeric_value.date is not None: + return DATE_TYPE + raise ValueError(f"Unknown type: {numeric_value}") + + +def _get_value_as_primitive_value(numeric_value): + """Maps a NumericValue proto to a float or tuple of float.""" + if numeric_value.float_value is not None: + return numeric_value.float_value + if numeric_value.date is not None: + date = numeric_value.date + value_tuple = [None, None, None] + # All dates fields are cased to float to produce a simple primitive value. + if date.year is not None: + value_tuple[0] = float(date.year) + if date.month is not None: + value_tuple[1] = float(date.month) + if date.day is not None: + value_tuple[2] = float(date.day) + return tuple(value_tuple) + raise ValueError(f"Unknown type: {numeric_value}") + + +def _get_all_types(numeric_values): + return {_get_value_type(value) for value in numeric_values} + + +def get_numeric_sort_key_fn(numeric_values): + """ + Creates a function that can be used as a sort key or to compare the values. Maps to primitive types and finds the + biggest common subset. Consider the values "05/05/2010" and "August 2007". With the corresponding primitive values + (2010.,5.,5.) and (2007.,8., None). These values can be compared by year and date so we map to the sequence (2010., + 5.), (2007., 8.). If we added a third value "2006" with primitive value (2006., None, None), we could only compare + by the year so we would map to (2010.,), (2007.,) and (2006.,). + + Args: + numeric_values: Values to compare + + Returns: + A function that can be used as a sort key function (mapping numeric values to a comparable tuple) + + Raises: + ValueError if values don't have a common type or are not comparable. + """ + value_types = _get_all_types(numeric_values) + if len(value_types) != 1: + raise ValueError(f"No common value type in {numeric_values}") + + value_type = next(iter(value_types)) + if value_type == NUMBER_TYPE: + # Primitive values are simple floats, nothing to do here. + return _get_value_as_primitive_value + + # The type can only be Date at this point which means the primitive type + # is a float triple. + valid_indexes = set(range(_DATE_TUPLE_SIZE)) + + for numeric_value in numeric_values: + value = _get_value_as_primitive_value(numeric_value) + assert isinstance(value, tuple) + for tuple_index, inner_value in enumerate(value): + if inner_value is None: + valid_indexes.discard(tuple_index) + + if not valid_indexes: + raise ValueError(f"No common value in {numeric_values}") + + def _sort_key_fn(numeric_value): + value = _get_value_as_primitive_value(numeric_value) + return tuple(value[index] for index in valid_indexes) + + return _sort_key_fn + + +def _consolidate_numeric_values(row_index_to_values, min_consolidation_fraction, debug_info): + """ + Finds the most common numeric values in a column and returns them + + Args: + row_index_to_values: + For each row index all the values in that cell. + min_consolidation_fraction: + Fraction of cells that need to have consolidated value. + debug_info: + Additional information only used for logging + + Returns: + For each row index the first value that matches the most common value. Rows that don't have a matching value + are dropped. Empty list if values can't be consolidated. + """ + type_counts = collections.Counter() + for numeric_values in row_index_to_values.values(): + type_counts.update(_get_all_types(numeric_values)) + if not type_counts: + return {} + max_count = max(type_counts.values()) + if max_count < len(row_index_to_values) * min_consolidation_fraction: + # logging.log_every_n(logging.INFO, f'Can\'t consolidate types: {debug_info} {row_index_to_values} {max_count}', 100) + return {} + + valid_types = set() + for value_type, count in type_counts.items(): + if count == max_count: + valid_types.add(value_type) + if len(valid_types) > 1: + assert DATE_TYPE in valid_types + max_type = DATE_TYPE + else: + max_type = next(iter(valid_types)) + + new_row_index_to_value = {} + for index, values in row_index_to_values.items(): + # Extract the first matching value. + for value in values: + if _get_value_type(value) == max_type: + new_row_index_to_value[index] = value + break + + return new_row_index_to_value + + +def _get_numeric_values(text): + """Parses text and returns numeric values.""" + numeric_spans = parse_text(text) + return itertools.chain(*(span.values for span in numeric_spans)) + + +def _get_column_values(table, col_index): + """ + Parses text in column and returns a dict mapping row_index to values. This is the _get_column_values function from + number_annotation_utils.py of the original implementation + + Args: + table: Pandas dataframe + col_index: integer, indicating the index of the column to get the numeric values of + """ + index_to_values = {} + for row_index, row in table.iterrows(): + text = normalize_for_match(row[col_index].text) + index_to_values[row_index] = list(_get_numeric_values(text)) + return index_to_values + + +def get_numeric_relation(value, other_value, sort_key_fn): + """Compares two values and returns their relation or None.""" + value = sort_key_fn(value) + other_value = sort_key_fn(other_value) + if value == other_value: + return Relation.EQ + if value < other_value: + return Relation.LT + if value > other_value: + return Relation.GT + return None + + +def add_numeric_values_to_question(question): + """Adds numeric value spans to a question.""" + original_text = question + question = normalize_for_match(question) + numeric_spans = parse_text(question) + return Question(original_text=original_text, text=question, numeric_spans=numeric_spans) + + +def filter_invalid_unicode(text): + """Return an empty string and True if 'text' is in invalid unicode.""" + return ("", True) if isinstance(text, bytes) else (text, False) + + +def filter_invalid_unicode_from_table(table): + """ + Removes invalid unicode from table. Checks whether a table cell text contains an invalid unicode encoding. If yes, + reset the table cell text to an empty str and log a warning for each invalid cell + + Args: + table: table to clean. + """ + # to do: add table id support + if not hasattr(table, "table_id"): + table.table_id = 0 + + for row_index, row in table.iterrows(): + for col_index, cell in enumerate(row): + cell, is_invalid = filter_invalid_unicode(cell) + if is_invalid: + logging.warning( + f"Scrub an invalid table body @ table_id: {table.table_id}, row_index: {row_index}, " + f"col_index: {col_index}", + ) + for col_index, column in enumerate(table.columns): + column, is_invalid = filter_invalid_unicode(column) + if is_invalid: + logging.warning(f"Scrub an invalid table header @ table_id: {table.table_id}, col_index: {col_index}") + + +def add_numeric_table_values(table, min_consolidation_fraction=0.7, debug_info=None): + """ + Parses text in table column-wise and adds the consolidated values. Consolidation refers to finding values with a + common types (date or number) + + Args: + table: + Table to annotate. + min_consolidation_fraction: + Fraction of cells in a column that need to have consolidated value. + debug_info: + Additional information used for logging. + """ + table = table.copy() + # First, filter table on invalid unicode + filter_invalid_unicode_from_table(table) + + # Second, replace cell values by Cell objects + for row_index, row in table.iterrows(): + for col_index, cell in enumerate(row): + table.iloc[row_index, col_index] = Cell(text=cell) + + # Third, add numeric_value attributes to these Cell objects + for col_index, column in enumerate(table.columns): + column_values = _consolidate_numeric_values( + _get_column_values(table, col_index), + min_consolidation_fraction=min_consolidation_fraction, + debug_info=(debug_info, column), + ) + + for row_index, numeric_value in column_values.items(): + table.iloc[row_index, col_index].numeric_value = numeric_value + + return table + + +__all__ = ["TapasTokenizer"] diff --git a/docs/transformers/src/transformers/models/textnet/__init__.py b/docs/transformers/src/transformers/models/textnet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8f04a680b21f4895675b2001ef5972b11d7ab1d0 --- /dev/null +++ b/docs/transformers/src/transformers/models/textnet/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_textnet import * + from .image_processing_textnet import * + from .modeling_textnet import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/textnet/configuration_textnet.py b/docs/transformers/src/transformers/models/textnet/configuration_textnet.py new file mode 100644 index 0000000000000000000000000000000000000000..61ecaaeba8e554c526510e8169d0628e59c5422b --- /dev/null +++ b/docs/transformers/src/transformers/models/textnet/configuration_textnet.py @@ -0,0 +1,135 @@ +# coding=utf-8 +# Copyright 2024 the Fast authors and HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TextNet model configuration""" + +from transformers import PretrainedConfig +from transformers.utils import logging +from transformers.utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices + + +logger = logging.get_logger(__name__) + + +class TextNetConfig(BackboneConfigMixin, PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`TextNextModel`]. It is used to instantiate a + TextNext model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the + [czczup/textnet-base](https://huggingface.co/czczup/textnet-base). Configuration objects inherit from + [`PretrainedConfig`] and can be used to control the model outputs.Read the documentation from [`PretrainedConfig`] + for more information. + + Args: + stem_kernel_size (`int`, *optional*, defaults to 3): + The kernel size for the initial convolution layer. + stem_stride (`int`, *optional*, defaults to 2): + The stride for the initial convolution layer. + stem_num_channels (`int`, *optional*, defaults to 3): + The num of channels in input for the initial convolution layer. + stem_out_channels (`int`, *optional*, defaults to 64): + The num of channels in out for the initial convolution layer. + stem_act_func (`str`, *optional*, defaults to `"relu"`): + The activation function for the initial convolution layer. + image_size (`Tuple[int, int]`, *optional*, defaults to `[640, 640]`): + The size (resolution) of each image. + conv_layer_kernel_sizes (`List[List[List[int]]]`, *optional*): + A list of stage-wise kernel sizes. If `None`, defaults to: + `[[[3, 3], [3, 3], [3, 3]], [[3, 3], [1, 3], [3, 3], [3, 1]], [[3, 3], [3, 3], [3, 1], [1, 3]], [[3, 3], [3, 1], [1, 3], [3, 3]]]`. + conv_layer_strides (`List[List[int]]`, *optional*): + A list of stage-wise strides. If `None`, defaults to: + `[[1, 2, 1], [2, 1, 1, 1], [2, 1, 1, 1], [2, 1, 1, 1]]`. + hidden_sizes (`List[int]`, *optional*, defaults to `[64, 64, 128, 256, 512]`): + Dimensionality (hidden size) at each stage. + batch_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the batch normalization layers. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + out_features (`List[str]`, *optional*): + If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc. + (depending on how many stages the model has). If unset and `out_indices` is set, will default to the + corresponding stages. If unset and `out_indices` is unset, will default to the last stage. + out_indices (`List[int]`, *optional*): + If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how + many stages the model has). If unset and `out_features` is set, will default to the corresponding stages. + If unset and `out_features` is unset, will default to the last stage. + + Examples: + + ```python + >>> from transformers import TextNetConfig, TextNetBackbone + + >>> # Initializing a TextNetConfig + >>> configuration = TextNetConfig() + + >>> # Initializing a model (with random weights) + >>> model = TextNetBackbone(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "textnet" + + def __init__( + self, + stem_kernel_size=3, + stem_stride=2, + stem_num_channels=3, + stem_out_channels=64, + stem_act_func="relu", + image_size=[640, 640], + conv_layer_kernel_sizes=None, + conv_layer_strides=None, + hidden_sizes=[64, 64, 128, 256, 512], + batch_norm_eps=1e-5, + initializer_range=0.02, + out_features=None, + out_indices=None, + **kwargs, + ): + super().__init__(**kwargs) + + if conv_layer_kernel_sizes is None: + conv_layer_kernel_sizes = [ + [[3, 3], [3, 3], [3, 3]], + [[3, 3], [1, 3], [3, 3], [3, 1]], + [[3, 3], [3, 3], [3, 1], [1, 3]], + [[3, 3], [3, 1], [1, 3], [3, 3]], + ] + if conv_layer_strides is None: + conv_layer_strides = [[1, 2, 1], [2, 1, 1, 1], [2, 1, 1, 1], [2, 1, 1, 1]] + + self.stem_kernel_size = stem_kernel_size + self.stem_stride = stem_stride + self.stem_num_channels = stem_num_channels + self.stem_out_channels = stem_out_channels + self.stem_act_func = stem_act_func + + self.image_size = image_size + self.conv_layer_kernel_sizes = conv_layer_kernel_sizes + self.conv_layer_strides = conv_layer_strides + + self.initializer_range = initializer_range + self.hidden_sizes = hidden_sizes + self.batch_norm_eps = batch_norm_eps + + self.depths = [len(layer) for layer in self.conv_layer_kernel_sizes] + self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, 5)] + self._out_features, self._out_indices = get_aligned_output_features_output_indices( + out_features=out_features, out_indices=out_indices, stage_names=self.stage_names + ) + + +__all__ = ["TextNetConfig"] diff --git a/docs/transformers/src/transformers/models/textnet/convert_textnet_to_hf.py b/docs/transformers/src/transformers/models/textnet/convert_textnet_to_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..a8a004d18a35e10bb83cc6a62e68b07fd0309100 --- /dev/null +++ b/docs/transformers/src/transformers/models/textnet/convert_textnet_to_hf.py @@ -0,0 +1,208 @@ +# coding=utf-8 +# Copyright 2024 the Fast authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import logging +import re +from collections import OrderedDict + +import requests +import torch +from huggingface_hub import hf_hub_download +from PIL import Image + +from transformers import TextNetBackbone, TextNetConfig, TextNetImageProcessor + + +tiny_config_url = "https://raw.githubusercontent.com/czczup/FAST/main/config/fast/nas-configs/fast_tiny.config" +small_config_url = "https://raw.githubusercontent.com/czczup/FAST/main/config/fast/nas-configs/fast_small.config" +base_config_url = "https://raw.githubusercontent.com/czczup/FAST/main/config/fast/nas-configs/fast_base.config" + +rename_key_mappings = { + "module.backbone": "textnet", + "first_conv": "stem", + "bn": "batch_norm", + "ver": "vertical", + "hor": "horizontal", +} + + +def prepare_config(size_config_url, size): + config_dict = json.loads(requests.get(size_config_url).text) + + backbone_config = {} + for stage_ix in range(1, 5): + stage_config = config_dict[f"stage{stage_ix}"] + + merged_dict = {} + + # Iterate through the list of dictionaries + for layer in stage_config: + for key, value in layer.items(): + if key != "name": + # Check if the key is already in the merged_dict + if key in merged_dict: + merged_dict[key].append(value) + else: + # If the key is not in merged_dict, create a new list with the value + merged_dict[key] = [value] + backbone_config[f"stage{stage_ix}"] = merged_dict + + neck_in_channels = [] + neck_out_channels = [] + neck_kernel_size = [] + neck_stride = [] + neck_dilation = [] + neck_groups = [] + + for i in range(1, 5): + layer_key = f"reduce_layer{i}" + layer_dict = config_dict["neck"].get(layer_key) + + if layer_dict: + # Append values to the corresponding lists + neck_in_channels.append(layer_dict["in_channels"]) + neck_out_channels.append(layer_dict["out_channels"]) + neck_kernel_size.append(layer_dict["kernel_size"]) + neck_stride.append(layer_dict["stride"]) + neck_dilation.append(layer_dict["dilation"]) + neck_groups.append(layer_dict["groups"]) + + textnet_config = TextNetConfig( + stem_kernel_size=config_dict["first_conv"]["kernel_size"], + stem_stride=config_dict["first_conv"]["stride"], + stem_num_channels=config_dict["first_conv"]["in_channels"], + stem_out_channels=config_dict["first_conv"]["out_channels"], + stem_act_func=config_dict["first_conv"]["act_func"], + conv_layer_kernel_sizes=[ + backbone_config["stage1"]["kernel_size"], + backbone_config["stage2"]["kernel_size"], + backbone_config["stage3"]["kernel_size"], + backbone_config["stage4"]["kernel_size"], + ], + conv_layer_strides=[ + backbone_config["stage1"]["stride"], + backbone_config["stage2"]["stride"], + backbone_config["stage3"]["stride"], + backbone_config["stage4"]["stride"], + ], + hidden_sizes=[ + config_dict["first_conv"]["out_channels"], + backbone_config["stage1"]["out_channels"][-1], + backbone_config["stage2"]["out_channels"][-1], + backbone_config["stage3"]["out_channels"][-1], + backbone_config["stage4"]["out_channels"][-1], + ], + out_features=["stage1", "stage2", "stage3", "stage4"], + out_indices=[1, 2, 3, 4], + ) + + return textnet_config + + +def convert_textnet_checkpoint(checkpoint_url, checkpoint_config_filename, pytorch_dump_folder_path): + config_filepath = hf_hub_download(repo_id="Raghavan/fast_model_config_files", filename="fast_model_configs.json") + + with open(config_filepath) as f: + content = json.loads(f.read()) + + size = content[checkpoint_config_filename]["short_size"] + + if "tiny" in content[checkpoint_config_filename]["config"]: + config = prepare_config(tiny_config_url, size) + expected_slice_backbone = torch.tensor( + [0.0000, 0.0000, 0.0000, 0.0000, 0.5300, 0.0000, 0.0000, 0.0000, 0.0000, 1.1221] + ) + elif "small" in content[checkpoint_config_filename]["config"]: + config = prepare_config(small_config_url, size) + expected_slice_backbone = torch.tensor( + [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1394] + ) + else: + config = prepare_config(base_config_url, size) + expected_slice_backbone = torch.tensor( + [0.9210, 0.6099, 0.0000, 0.0000, 0.0000, 0.0000, 3.2207, 2.6602, 1.8925, 0.0000] + ) + + model = TextNetBackbone(config) + textnet_image_processor = TextNetImageProcessor(size={"shortest_edge": size}) + state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu", check_hash=True)["ema"] + state_dict_changed = OrderedDict() + for key in state_dict: + if "backbone" in key: + val = state_dict[key] + new_key = key + for search, replacement in rename_key_mappings.items(): + if search in new_key: + new_key = new_key.replace(search, replacement) + + pattern = r"textnet\.stage(\d)" + + def adjust_stage(match): + stage_number = int(match.group(1)) - 1 + return f"textnet.encoder.stages.{stage_number}.stage" + + # Using regex to find and replace the pattern in the string + new_key = re.sub(pattern, adjust_stage, new_key) + state_dict_changed[new_key] = val + model.load_state_dict(state_dict_changed) + model.eval() + + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + image = Image.open(requests.get(url, stream=True).raw).convert("RGB") + + original_pixel_values = torch.tensor( + [0.1939, 0.3481, 0.4166, 0.3309, 0.4508, 0.4679, 0.4851, 0.4851, 0.3309, 0.4337] + ) + pixel_values = textnet_image_processor(image, return_tensors="pt").pixel_values + + assert torch.allclose(original_pixel_values, pixel_values[0][0][3][:10], atol=1e-4) + + with torch.no_grad(): + output = model(pixel_values) + + assert torch.allclose(output["feature_maps"][-1][0][10][12][:10].detach(), expected_slice_backbone, atol=1e-3) + + model.save_pretrained(pytorch_dump_folder_path) + textnet_image_processor.save_pretrained(pytorch_dump_folder_path) + logging.info("The converted weights are saved here : " + pytorch_dump_folder_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--checkpoint_url", + default="https://github.com/czczup/FAST/releases/download/release/fast_base_ic17mlt_640.pth", + type=str, + help="URL to the original PyTorch checkpoint (.pth file).", + ) + parser.add_argument( + "--checkpoint_config_filename", + default="fast_base_ic17mlt_640.py", + type=str, + help="URL to the original PyTorch checkpoint (.pth file).", + ) + parser.add_argument( + "--pytorch_dump_folder_path", default=None, type=str, help="Path to the folder to output PyTorch model." + ) + args = parser.parse_args() + + convert_textnet_checkpoint( + args.checkpoint_url, + args.checkpoint_config_filename, + args.pytorch_dump_folder_path, + ) diff --git a/docs/transformers/src/transformers/models/textnet/image_processing_textnet.py b/docs/transformers/src/transformers/models/textnet/image_processing_textnet.py new file mode 100644 index 0000000000000000000000000000000000000000..74806a05566f2139a89acaceb3c57b34861b72cb --- /dev/null +++ b/docs/transformers/src/transformers/models/textnet/image_processing_textnet.py @@ -0,0 +1,355 @@ +# coding=utf-8 +# Copyright 2024 the Fast authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for TextNet.""" + +from typing import Dict, List, Optional, Union + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import ( + convert_to_rgb, + get_resize_output_image_size, + resize, + to_channel_dimension_format, +) +from ...image_utils import ( + IMAGENET_DEFAULT_MEAN, + IMAGENET_DEFAULT_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + infer_channel_dimension_format, + is_scaled_image, + make_list_of_images, + to_numpy_array, + valid_images, + validate_kwargs, + validate_preprocess_arguments, +) +from ...utils import TensorType, is_vision_available, logging + + +logger = logging.get_logger(__name__) + +if is_vision_available(): + import PIL + + +class TextNetImageProcessor(BaseImageProcessor): + r""" + Constructs a TextNet image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by + `do_resize` in the `preprocess` method. + size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 640}`): + Size of the image after resizing. The shortest edge of the image is resized to size["shortest_edge"], with + the longest edge resized to keep the input aspect ratio. Can be overridden by `size` in the `preprocess` + method. + size_divisor (`int`, *optional*, defaults to 32): + Ensures height and width are rounded to a multiple of this value after resizing. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`): + Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method. + do_center_crop (`bool`, *optional*, defaults to `False`): + Whether to center crop the image to the specified `crop_size`. Can be overridden by `do_center_crop` in the + `preprocess` method. + crop_size (`Dict[str, int]` *optional*, defaults to 224): + Size of the output image after applying `center_crop`. Can be overridden by `crop_size` in the `preprocess` + method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in + the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess` + method. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method. + image_mean (`float` or `List[float]`, *optional*, defaults to `[0.485, 0.456, 0.406]`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*, defaults to `[0.229, 0.224, 0.225]`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + Can be overridden by the `image_std` parameter in the `preprocess` method. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize: bool = True, + size: Dict[str, int] = None, + size_divisor: int = 32, + resample: PILImageResampling = PILImageResampling.BILINEAR, + do_center_crop: bool = False, + crop_size: Dict[str, int] = None, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, List[float]]] = IMAGENET_DEFAULT_MEAN, + image_std: Optional[Union[float, List[float]]] = IMAGENET_DEFAULT_STD, + do_convert_rgb: bool = True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"shortest_edge": 640} + size = get_size_dict(size, default_to_square=False) + crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224} + crop_size = get_size_dict(crop_size, param_name="crop_size") + + self.do_resize = do_resize + self.size = size + self.size_divisor = size_divisor + self.resample = resample + self.do_center_crop = do_center_crop + self.crop_size = crop_size + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN + self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD + self.do_convert_rgb = do_convert_rgb + + self._valid_processor_keys = [ + "images", + "do_resize", + "size", + "size_divisor", + "resample", + "do_center_crop", + "crop_size", + "do_rescale", + "rescale_factor", + "do_normalize", + "image_mean", + "image_std", + "do_convert_rgb", + "return_tensors", + "data_format", + "input_data_format", + ] + + def resize( + self, + image: np.ndarray, + size: Dict[str, int], + resample: PILImageResampling = PILImageResampling.BILINEAR, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize an image. The shortest edge of the image is resized to size["shortest_edge"] , with the longest edge + resized to keep the input aspect ratio. Both the height and width are resized to be divisible by 32. + + Args: + image (`np.ndarray`): + Image to resize. + size (`Dict[str, int]`): + Size of the output image. + size_divisor (`int`, *optional*, defaults to `32`): + Ensures height and width are rounded to a multiple of this value after resizing. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`): + Resampling filter to use when resiizing the image. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be the same as the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + default_to_square (`bool`, *optional*, defaults to `False`): + The value to be passed to `get_size_dict` as `default_to_square` when computing the image size. If the + `size` argument in `get_size_dict` is an `int`, it determines whether to default to a square image or + not.Note that this attribute is not used in computing `crop_size` via calling `get_size_dict`. + """ + if "shortest_edge" in size: + size = size["shortest_edge"] + elif "height" in size and "width" in size: + size = (size["height"], size["width"]) + else: + raise ValueError("Size must contain either 'shortest_edge' or 'height' and 'width'.") + + height, width = get_resize_output_image_size( + image, size=size, input_data_format=input_data_format, default_to_square=False + ) + if height % self.size_divisor != 0: + height += self.size_divisor - (height % self.size_divisor) + if width % self.size_divisor != 0: + width += self.size_divisor - (width % self.size_divisor) + + return resize( + image, + size=(height, width), + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + def preprocess( + self, + images: ImageInput, + do_resize: Optional[bool] = None, + size: Dict[str, int] = None, + size_divisor: Optional[int] = None, + resample: PILImageResampling = None, + do_center_crop: Optional[bool] = None, + crop_size: Optional[int] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_convert_rgb: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> PIL.Image.Image: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with + the longest edge resized to keep the input aspect ratio. + size_divisor (`int`, *optional*, defaults to `32`): + Ensures height and width are rounded to a multiple of this value after resizing. + resample (`int`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only + has an effect if `do_resize` is set to `True`. + do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`): + Whether to center crop the image. + crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`): + Size of the center crop. Only has an effect if `do_center_crop` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to + `True`. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + do_resize = do_resize if do_resize is not None else self.do_resize + size = size if size is not None else self.size + size = get_size_dict(size, param_name="size", default_to_square=False) + size_divisor = size_divisor if size_divisor is not None else self.size_divisor + resample = resample if resample is not None else self.resample + do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop + crop_size = crop_size if crop_size is not None else self.crop_size + crop_size = get_size_dict(crop_size, param_name="crop_size", default_to_square=True) + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + + validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys) + + images = make_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_resize=do_resize, + size=size, + resample=resample, + ) + + if do_convert_rgb: + images = [convert_to_rgb(image) for image in images] + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if is_scaled_image(images[0]) and do_rescale: + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + all_images = [] + for image in images: + if do_resize: + image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format) + + if do_center_crop: + image = self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) + + if do_rescale: + image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + + if do_normalize: + image = self.normalize( + image=image, mean=image_mean, std=image_std, input_data_format=input_data_format + ) + + all_images.append(image) + images = [ + to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + for image in all_images + ] + + data = {"pixel_values": images} + return BatchFeature(data=data, tensor_type=return_tensors) + + +__all__ = ["TextNetImageProcessor"] diff --git a/docs/transformers/src/transformers/models/textnet/modeling_textnet.py b/docs/transformers/src/transformers/models/textnet/modeling_textnet.py new file mode 100644 index 0000000000000000000000000000000000000000..d32a7351af893fe90d4f36264390186ac7658975 --- /dev/null +++ b/docs/transformers/src/transformers/models/textnet/modeling_textnet.py @@ -0,0 +1,487 @@ +# coding=utf-8 +# Copyright 2024 the Fast authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch TextNet model.""" + +from typing import Any, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from torch import Tensor +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers import PreTrainedModel, add_start_docstrings +from transformers.activations import ACT2CLS +from transformers.modeling_outputs import ( + BackboneOutput, + BaseModelOutputWithNoAttention, + BaseModelOutputWithPoolingAndNoAttention, + ImageClassifierOutputWithNoAttention, +) +from transformers.models.textnet.configuration_textnet import TextNetConfig +from transformers.utils import ( + add_code_sample_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from transformers.utils.backbone_utils import BackboneMixin + + +logger = logging.get_logger(__name__) + +# General docstring +_CONFIG_FOR_DOC = "TextNetConfig" +_CHECKPOINT_FOR_DOC = "czczup/textnet-base" +_EXPECTED_OUTPUT_SHAPE = [1, 512, 20, 27] + + +class TextNetConvLayer(nn.Module): + def __init__(self, config: TextNetConfig): + super().__init__() + + self.kernel_size = config.stem_kernel_size + self.stride = config.stem_stride + self.activation_function = config.stem_act_func + + padding = ( + (config.kernel_size[0] // 2, config.kernel_size[1] // 2) + if isinstance(config.stem_kernel_size, tuple) + else config.stem_kernel_size // 2 + ) + + self.conv = nn.Conv2d( + config.stem_num_channels, + config.stem_out_channels, + kernel_size=config.stem_kernel_size, + stride=config.stem_stride, + padding=padding, + bias=False, + ) + self.batch_norm = nn.BatchNorm2d(config.stem_out_channels, config.batch_norm_eps) + + self.activation = nn.Identity() + if self.activation_function is not None: + self.activation = ACT2CLS[self.activation_function]() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.conv(hidden_states) + hidden_states = self.batch_norm(hidden_states) + return self.activation(hidden_states) + + +class TextNetRepConvLayer(nn.Module): + r""" + This layer supports re-parameterization by combining multiple convolutional branches + (e.g., main convolution, vertical, horizontal, and identity branches) during training. + At inference time, these branches can be collapsed into a single convolution for + efficiency, as per the re-parameterization paradigm. + + The "Rep" in the name stands for "re-parameterization" (introduced by RepVGG). + """ + + def __init__(self, config: TextNetConfig, in_channels: int, out_channels: int, kernel_size: int, stride: int): + super().__init__() + + self.num_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + + padding = ((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2) + + self.activation_function = nn.ReLU() + + self.main_conv = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + bias=False, + ) + self.main_batch_norm = nn.BatchNorm2d(num_features=out_channels, eps=config.batch_norm_eps) + + vertical_padding = ((kernel_size[0] - 1) // 2, 0) + horizontal_padding = (0, (kernel_size[1] - 1) // 2) + + if kernel_size[1] != 1: + self.vertical_conv = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(kernel_size[0], 1), + stride=stride, + padding=vertical_padding, + bias=False, + ) + self.vertical_batch_norm = nn.BatchNorm2d(num_features=out_channels, eps=config.batch_norm_eps) + else: + self.vertical_conv, self.vertical_batch_norm = None, None + + if kernel_size[0] != 1: + self.horizontal_conv = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(1, kernel_size[1]), + stride=stride, + padding=horizontal_padding, + bias=False, + ) + self.horizontal_batch_norm = nn.BatchNorm2d(num_features=out_channels, eps=config.batch_norm_eps) + else: + self.horizontal_conv, self.horizontal_batch_norm = None, None + + self.rbr_identity = ( + nn.BatchNorm2d(num_features=in_channels, eps=config.batch_norm_eps) + if out_channels == in_channels and stride == 1 + else None + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + main_outputs = self.main_conv(hidden_states) + main_outputs = self.main_batch_norm(main_outputs) + + # applies a convolution with a vertical kernel + if self.vertical_conv is not None: + vertical_outputs = self.vertical_conv(hidden_states) + vertical_outputs = self.vertical_batch_norm(vertical_outputs) + main_outputs = main_outputs + vertical_outputs + + # applies a convolution with a horizontal kernel + if self.horizontal_conv is not None: + horizontal_outputs = self.horizontal_conv(hidden_states) + horizontal_outputs = self.horizontal_batch_norm(horizontal_outputs) + main_outputs = main_outputs + horizontal_outputs + + if self.rbr_identity is not None: + id_out = self.rbr_identity(hidden_states) + main_outputs = main_outputs + id_out + + return self.activation_function(main_outputs) + + +class TextNetStage(nn.Module): + def __init__(self, config: TextNetConfig, depth: int): + super().__init__() + kernel_size = config.conv_layer_kernel_sizes[depth] + stride = config.conv_layer_strides[depth] + + num_layers = len(kernel_size) + stage_in_channel_size = config.hidden_sizes[depth] + stage_out_channel_size = config.hidden_sizes[depth + 1] + + in_channels = [stage_in_channel_size] + [stage_out_channel_size] * (num_layers - 1) + out_channels = [stage_out_channel_size] * num_layers + + stage = [] + for stage_config in zip(in_channels, out_channels, kernel_size, stride): + stage.append(TextNetRepConvLayer(config, *stage_config)) + self.stage = nn.ModuleList(stage) + + def forward(self, hidden_state): + for block in self.stage: + hidden_state = block(hidden_state) + return hidden_state + + +class TextNetEncoder(nn.Module): + def __init__(self, config: TextNetConfig): + super().__init__() + + stages = [] + num_stages = len(config.conv_layer_kernel_sizes) + for stage_ix in range(num_stages): + stages.append(TextNetStage(config, stage_ix)) + + self.stages = nn.ModuleList(stages) + + def forward( + self, + hidden_state: torch.Tensor, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> BaseModelOutputWithNoAttention: + hidden_states = [hidden_state] + for stage in self.stages: + hidden_state = stage(hidden_state) + hidden_states.append(hidden_state) + + if not return_dict: + output = (hidden_state,) + return output + (hidden_states,) if output_hidden_states else output + + return BaseModelOutputWithNoAttention(last_hidden_state=hidden_state, hidden_states=hidden_states) + + +TEXTNET_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it + as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`TextNetConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +TEXTNET_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See + [`TextNetImageProcessor.__call__`] for details. + + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class TextNetPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = TextNetConfig + base_model_prefix = "textnet" + main_input_name = "pixel_values" + + def _init_weights(self, module): + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.BatchNorm2d): + module.weight.data.fill_(1.0) + if module.bias is not None: + module.bias.data.zero_() + + +@add_start_docstrings( + "The bare Textnet model outputting raw features without any specific head on top.", + TEXTNET_START_DOCSTRING, +) +class TextNetModel(TextNetPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.stem = TextNetConvLayer(config) + self.encoder = TextNetEncoder(config) + self.pooler = nn.AdaptiveAvgPool2d((2, 2)) + self.post_init() + + @add_start_docstrings_to_model_forward(TEXTNET_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPoolingAndNoAttention, + config_class=_CONFIG_FOR_DOC, + modality="vision", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward( + self, pixel_values: Tensor, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None + ) -> Union[Tuple[Any, List[Any]], Tuple[Any], BaseModelOutputWithPoolingAndNoAttention]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + hidden_state = self.stem(pixel_values) + + encoder_outputs = self.encoder( + hidden_state, output_hidden_states=output_hidden_states, return_dict=return_dict + ) + + last_hidden_state = encoder_outputs[0] + pooled_output = self.pooler(last_hidden_state) + + if not return_dict: + output = (last_hidden_state, pooled_output) + return output + (encoder_outputs[1],) if output_hidden_states else output + + return BaseModelOutputWithPoolingAndNoAttention( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs[1] if output_hidden_states else None, + ) + + +@add_start_docstrings( + """ + TextNet Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for + ImageNet. + """, + TEXTNET_START_DOCSTRING, +) +class TextNetForImageClassification(TextNetPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.textnet = TextNetModel(config) + self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) + self.flatten = nn.Flatten() + self.fc = nn.Linear(config.hidden_sizes[-1], config.num_labels) if config.num_labels > 0 else nn.Identity() + + # classification head + self.classifier = nn.ModuleList([self.avg_pool, self.flatten]) + + # initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(TEXTNET_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=ImageClassifierOutputWithNoAttention, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> ImageClassifierOutputWithNoAttention: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + Returns: + + Examples: + ```python + >>> import torch + >>> import requests + >>> from transformers import TextNetForImageClassification, TextNetImageProcessor + >>> from PIL import Image + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> processor = TextNetImageProcessor.from_pretrained("czczup/textnet-base") + >>> model = TextNetForImageClassification.from_pretrained("czczup/textnet-base") + + >>> inputs = processor(images=image, return_tensors="pt") + >>> with torch.no_grad(): + ... outputs = model(**inputs) + >>> outputs.logits.shape + torch.Size([1, 2]) + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.textnet(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict) + last_hidden_state = outputs[0] + for layer in self.classifier: + last_hidden_state = layer(last_hidden_state) + logits = self.fc(last_hidden_state) + loss = None + + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[2:] + return (loss,) + output if loss is not None else output + + return ImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states) + + +@add_start_docstrings( + """ + TextNet backbone, to be used with frameworks like DETR and MaskFormer. + """, + TEXTNET_START_DOCSTRING, +) +class TextNetBackbone(TextNetPreTrainedModel, BackboneMixin): + def __init__(self, config): + super().__init__(config) + super()._init_backbone(config) + + self.textnet = TextNetModel(config) + self.num_features = config.hidden_sizes + + # initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(TEXTNET_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, pixel_values: Tensor, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None + ) -> Union[Tuple[Tuple], BackboneOutput]: + """ + Returns: + + Examples: + + ```python + >>> import torch + >>> import requests + >>> from PIL import Image + >>> from transformers import AutoImageProcessor, AutoBackbone + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> processor = AutoImageProcessor.from_pretrained("czczup/textnet-base") + >>> model = AutoBackbone.from_pretrained("czczup/textnet-base") + + >>> inputs = processor(image, return_tensors="pt") + >>> with torch.no_grad(): + >>> outputs = model(**inputs) + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + outputs = self.textnet(pixel_values, output_hidden_states=True, return_dict=return_dict) + + hidden_states = outputs.hidden_states if return_dict else outputs[2] + + feature_maps = () + for idx, stage in enumerate(self.stage_names): + if stage in self.out_features: + feature_maps += (hidden_states[idx],) + + if not return_dict: + output = (feature_maps,) + if output_hidden_states: + hidden_states = outputs.hidden_states if return_dict else outputs[2] + output += (hidden_states,) + return output + + return BackboneOutput( + feature_maps=feature_maps, + hidden_states=outputs.hidden_states if output_hidden_states else None, + attentions=None, + ) + + +__all__ = ["TextNetBackbone", "TextNetModel", "TextNetPreTrainedModel", "TextNetForImageClassification"] diff --git a/docs/transformers/src/transformers/models/time_series_transformer/__init__.py b/docs/transformers/src/transformers/models/time_series_transformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e36c36bcfccb32d5299b30ab6c48283f93e35e7a --- /dev/null +++ b/docs/transformers/src/transformers/models/time_series_transformer/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_time_series_transformer import * + from .modeling_time_series_transformer import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/time_series_transformer/configuration_time_series_transformer.py b/docs/transformers/src/transformers/models/time_series_transformer/configuration_time_series_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..bc063774389cbad89b0ff7fd3e496674f2942b6c --- /dev/null +++ b/docs/transformers/src/transformers/models/time_series_transformer/configuration_time_series_transformer.py @@ -0,0 +1,229 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Time Series Transformer model configuration""" + +from typing import List, Optional, Union + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class TimeSeriesTransformerConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`TimeSeriesTransformerModel`]. It is used to + instantiate a Time Series Transformer model according to the specified arguments, defining the model architecture. + Instantiating a configuration with the defaults will yield a similar configuration to that of the Time Series + Transformer + [huggingface/time-series-transformer-tourism-monthly](https://huggingface.co/huggingface/time-series-transformer-tourism-monthly) + architecture. + + Configuration objects inherit from [`PretrainedConfig`] can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + prediction_length (`int`): + The prediction length for the decoder. In other words, the prediction horizon of the model. This value is + typically dictated by the dataset and we recommend to set it appropriately. + context_length (`int`, *optional*, defaults to `prediction_length`): + The context length for the encoder. If `None`, the context length will be the same as the + `prediction_length`. + distribution_output (`string`, *optional*, defaults to `"student_t"`): + The distribution emission head for the model. Could be either "student_t", "normal" or "negative_binomial". + loss (`string`, *optional*, defaults to `"nll"`): + The loss function for the model corresponding to the `distribution_output` head. For parametric + distributions it is the negative log likelihood (nll) - which currently is the only supported one. + input_size (`int`, *optional*, defaults to 1): + The size of the target variable which by default is 1 for univariate targets. Would be > 1 in case of + multivariate targets. + scaling (`string` or `bool`, *optional* defaults to `"mean"`): + Whether to scale the input targets via "mean" scaler, "std" scaler or no scaler if `None`. If `True`, the + scaler is set to "mean". + lags_sequence (`list[int]`, *optional*, defaults to `[1, 2, 3, 4, 5, 6, 7]`): + The lags of the input time series as covariates often dictated by the frequency of the data. Default is + `[1, 2, 3, 4, 5, 6, 7]` but we recommend to change it based on the dataset appropriately. + num_time_features (`int`, *optional*, defaults to 0): + The number of time features in the input time series. + num_dynamic_real_features (`int`, *optional*, defaults to 0): + The number of dynamic real valued features. + num_static_categorical_features (`int`, *optional*, defaults to 0): + The number of static categorical features. + num_static_real_features (`int`, *optional*, defaults to 0): + The number of static real valued features. + cardinality (`list[int]`, *optional*): + The cardinality (number of different values) for each of the static categorical features. Should be a list + of integers, having the same length as `num_static_categorical_features`. Cannot be `None` if + `num_static_categorical_features` is > 0. + embedding_dimension (`list[int]`, *optional*): + The dimension of the embedding for each of the static categorical features. Should be a list of integers, + having the same length as `num_static_categorical_features`. Cannot be `None` if + `num_static_categorical_features` is > 0. + d_model (`int`, *optional*, defaults to 64): + Dimensionality of the transformer layers. + encoder_layers (`int`, *optional*, defaults to 2): + Number of encoder layers. + decoder_layers (`int`, *optional*, defaults to 2): + Number of decoder layers. + encoder_attention_heads (`int`, *optional*, defaults to 2): + Number of attention heads for each attention layer in the Transformer encoder. + decoder_attention_heads (`int`, *optional*, defaults to 2): + Number of attention heads for each attention layer in the Transformer decoder. + encoder_ffn_dim (`int`, *optional*, defaults to 32): + Dimension of the "intermediate" (often named feed-forward) layer in encoder. + decoder_ffn_dim (`int`, *optional*, defaults to 32): + Dimension of the "intermediate" (often named feed-forward) layer in decoder. + activation_function (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and decoder. If string, `"gelu"` and + `"relu"` are supported. + dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the encoder, and decoder. + encoder_layerdrop (`float`, *optional*, defaults to 0.1): + The dropout probability for the attention and fully connected layers for each encoder layer. + decoder_layerdrop (`float`, *optional*, defaults to 0.1): + The dropout probability for the attention and fully connected layers for each decoder layer. + attention_dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for the attention probabilities. + activation_dropout (`float`, *optional*, defaults to 0.1): + The dropout probability used between the two layers of the feed-forward networks. + num_parallel_samples (`int`, *optional*, defaults to 100): + The number of samples to generate in parallel for each time step of inference. + init_std (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated normal weight initialization distribution. + use_cache (`bool`, *optional*, defaults to `True`): + Whether to use the past key/values attentions (if applicable to the model) to speed up decoding. + + Example: + + ```python + >>> from transformers import TimeSeriesTransformerConfig, TimeSeriesTransformerModel + + >>> # Initializing a Time Series Transformer configuration with 12 time steps for prediction + >>> configuration = TimeSeriesTransformerConfig(prediction_length=12) + + >>> # Randomly initializing a model (with random weights) from the configuration + >>> model = TimeSeriesTransformerModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "time_series_transformer" + attribute_map = { + "hidden_size": "d_model", + "num_attention_heads": "encoder_attention_heads", + "num_hidden_layers": "encoder_layers", + } + + def __init__( + self, + prediction_length: Optional[int] = None, + context_length: Optional[int] = None, + distribution_output: str = "student_t", + loss: str = "nll", + input_size: int = 1, + lags_sequence: List[int] = [1, 2, 3, 4, 5, 6, 7], + scaling: Optional[Union[str, bool]] = "mean", + num_dynamic_real_features: int = 0, + num_static_categorical_features: int = 0, + num_static_real_features: int = 0, + num_time_features: int = 0, + cardinality: Optional[List[int]] = None, + embedding_dimension: Optional[List[int]] = None, + encoder_ffn_dim: int = 32, + decoder_ffn_dim: int = 32, + encoder_attention_heads: int = 2, + decoder_attention_heads: int = 2, + encoder_layers: int = 2, + decoder_layers: int = 2, + is_encoder_decoder: bool = True, + activation_function: str = "gelu", + d_model: int = 64, + dropout: float = 0.1, + encoder_layerdrop: float = 0.1, + decoder_layerdrop: float = 0.1, + attention_dropout: float = 0.1, + activation_dropout: float = 0.1, + num_parallel_samples: int = 100, + init_std: float = 0.02, + use_cache=True, + **kwargs, + ): + # time series specific configuration + self.prediction_length = prediction_length + self.context_length = context_length or prediction_length + self.distribution_output = distribution_output + self.loss = loss + self.input_size = input_size + self.num_time_features = num_time_features + self.lags_sequence = lags_sequence + self.scaling = scaling + self.num_dynamic_real_features = num_dynamic_real_features + self.num_static_real_features = num_static_real_features + self.num_static_categorical_features = num_static_categorical_features + if cardinality and num_static_categorical_features > 0: + if len(cardinality) != num_static_categorical_features: + raise ValueError( + "The cardinality should be a list of the same length as `num_static_categorical_features`" + ) + self.cardinality = cardinality + else: + self.cardinality = [0] + if embedding_dimension and num_static_categorical_features > 0: + if len(embedding_dimension) != num_static_categorical_features: + raise ValueError( + "The embedding dimension should be a list of the same length as `num_static_categorical_features`" + ) + self.embedding_dimension = embedding_dimension + else: + self.embedding_dimension = [min(50, (cat + 1) // 2) for cat in self.cardinality] + self.num_parallel_samples = num_parallel_samples + + # Transformer architecture configuration + self.feature_size = input_size * len(lags_sequence) + self._number_of_features + self.d_model = d_model + self.encoder_attention_heads = encoder_attention_heads + self.decoder_attention_heads = decoder_attention_heads + self.encoder_ffn_dim = encoder_ffn_dim + self.decoder_ffn_dim = decoder_ffn_dim + self.encoder_layers = encoder_layers + self.decoder_layers = decoder_layers + + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.encoder_layerdrop = encoder_layerdrop + self.decoder_layerdrop = decoder_layerdrop + + self.activation_function = activation_function + self.init_std = init_std + + self.use_cache = use_cache + + super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs) + + @property + def _number_of_features(self) -> int: + return ( + sum(self.embedding_dimension) + + self.num_dynamic_real_features + + self.num_time_features + + self.num_static_real_features + + self.input_size * 2 # the log1p(abs(loc)) and log(scale) features + ) + + +__all__ = ["TimeSeriesTransformerConfig"] diff --git a/docs/transformers/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py b/docs/transformers/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..d63be6d8d7567d5080b002600c767111d2f976e2 --- /dev/null +++ b/docs/transformers/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py @@ -0,0 +1,1781 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Time Series Transformer model.""" + +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +from torch import nn + +from ...activations import ACT2FN +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + SampleTSPredictionOutput, + Seq2SeqTSModelOutput, + Seq2SeqTSPredictionOutput, +) +from ...modeling_utils import PreTrainedModel +from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from .configuration_time_series_transformer import TimeSeriesTransformerConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "TimeSeriesTransformerConfig" + + +class TimeSeriesFeatureEmbedder(nn.Module): + """ + Embed a sequence of categorical features. + + Args: + cardinalities (`list[int]`): + List of cardinalities of the categorical features. + embedding_dims (`list[int]`): + List of embedding dimensions of the categorical features. + """ + + def __init__(self, cardinalities: List[int], embedding_dims: List[int]) -> None: + super().__init__() + + self.num_features = len(cardinalities) + self.embedders = nn.ModuleList([nn.Embedding(c, d) for c, d in zip(cardinalities, embedding_dims)]) + + def forward(self, features: torch.Tensor) -> torch.Tensor: + if self.num_features > 1: + # we slice the last dimension, giving an array of length + # self.num_features with shape (N,T) or (N) + cat_feature_slices = torch.chunk(features, self.num_features, dim=-1) + else: + cat_feature_slices = [features] + + return torch.cat( + [ + embed(cat_feature_slice.squeeze(-1)) + for embed, cat_feature_slice in zip(self.embedders, cat_feature_slices) + ], + dim=-1, + ) + + +class TimeSeriesStdScaler(nn.Module): + """ + Standardize features by calculating the mean and scaling along the first dimension, and then normalizes it by + subtracting from the mean and dividing by the standard deviation. + """ + + def __init__(self, config: TimeSeriesTransformerConfig): + super().__init__() + self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1 + self.keepdim = config.keepdim if hasattr(config, "keepdim") else True + self.minimum_scale = config.minimum_scale if hasattr(config, "minimum_scale") else 1e-5 + + def forward( + self, data: torch.Tensor, observed_indicator: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Parameters: + data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`): + input for Batch norm calculation + observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`): + Calculating the scale on the observed indicator. + Returns: + tuple of `torch.Tensor` of shapes + (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`, + `(batch_size, 1, num_input_channels)`) + """ + denominator = observed_indicator.sum(self.dim, keepdim=self.keepdim) + denominator = denominator.clamp_min(1.0) + loc = (data * observed_indicator).sum(self.dim, keepdim=self.keepdim) / denominator + + variance = (((data - loc) * observed_indicator) ** 2).sum(self.dim, keepdim=self.keepdim) / denominator + scale = torch.sqrt(variance + self.minimum_scale) + return (data - loc) / scale, loc, scale + + +class TimeSeriesMeanScaler(nn.Module): + """ + Computes a scaling factor as the weighted average absolute value along the first dimension, and scales the data + accordingly. + """ + + def __init__(self, config: TimeSeriesTransformerConfig): + super().__init__() + self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1 + self.keepdim = config.keepdim if hasattr(config, "keepdim") else True + self.minimum_scale = config.minimum_scale if hasattr(config, "minimum_scale") else 1e-10 + self.default_scale = config.default_scale if hasattr(config, "default_scale") else None + + def forward( + self, data: torch.Tensor, observed_indicator: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Parameters: + data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`): + input for Batch norm calculation + observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`): + Calculating the scale on the observed indicator. + Returns: + tuple of `torch.Tensor` of shapes + (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`, + `(batch_size, 1, num_input_channels)`) + """ + ts_sum = (data * observed_indicator).abs().sum(self.dim, keepdim=True) + num_observed = observed_indicator.sum(self.dim, keepdim=True) + + scale = ts_sum / torch.clamp(num_observed, min=1) + + # If `default_scale` is provided, we use it, otherwise we use the scale + # of the batch. + if self.default_scale is None: + batch_sum = ts_sum.sum(dim=0) + batch_observations = torch.clamp(num_observed.sum(0), min=1) + default_scale = torch.squeeze(batch_sum / batch_observations) + else: + default_scale = self.default_scale * torch.ones_like(scale) + + # apply default scale where there are no observations + scale = torch.where(num_observed > 0, scale, default_scale) + + # ensure the scale is at least `self.minimum_scale` + scale = torch.clamp(scale, min=self.minimum_scale) + scaled_data = data / scale + + if not self.keepdim: + scale = scale.squeeze(dim=self.dim) + + return scaled_data, torch.zeros_like(scale), scale + + +class TimeSeriesNOPScaler(nn.Module): + """ + Assigns a scaling factor equal to 1 along the first dimension, and therefore applies no scaling to the input data. + """ + + def __init__(self, config: TimeSeriesTransformerConfig): + super().__init__() + self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1 + self.keepdim = config.keepdim if hasattr(config, "keepdim") else True + + def forward( + self, data: torch.Tensor, observed_indicator: Optional[torch.Tensor] = None + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Parameters: + data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`): + input for Batch norm calculation + Returns: + tuple of `torch.Tensor` of shapes + (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`, + `(batch_size, 1, num_input_channels)`) + """ + scale = torch.ones_like(data, requires_grad=False).mean(dim=self.dim, keepdim=self.keepdim) + loc = torch.zeros_like(data, requires_grad=False).mean(dim=self.dim, keepdim=self.keepdim) + return data, loc, scale + + +def nll(input: torch.distributions.Distribution, target: torch.Tensor) -> torch.Tensor: + """ + Computes the negative log likelihood loss from input distribution with respect to target. + """ + return -input.log_prob(target) + + +def weighted_average(input_tensor: torch.Tensor, weights: Optional[torch.Tensor] = None, dim=None) -> torch.Tensor: + """ + Computes the weighted average of a given tensor across a given `dim`, masking values associated with weight zero, + meaning instead of `nan * 0 = nan` you will get `0 * 0 = 0`. + + Args: + input_tensor (`torch.FloatTensor`): + Input tensor, of which the average must be computed. + weights (`torch.FloatTensor`, *optional*): + Weights tensor, of the same shape as `input_tensor`. + dim (`int`, *optional*): + The dim along which to average `input_tensor`. + + Returns: + `torch.FloatTensor`: The tensor with values averaged along the specified `dim`. + """ + if weights is not None: + weighted_tensor = torch.where(weights != 0, input_tensor * weights, torch.zeros_like(input_tensor)) + sum_weights = torch.clamp(weights.sum(dim=dim) if dim else weights.sum(), min=1.0) + return (weighted_tensor.sum(dim=dim) if dim else weighted_tensor.sum()) / sum_weights + else: + return input_tensor.mean(dim=dim) + + +# Copied from transformers.models.marian.modeling_marian.MarianSinusoidalPositionalEmbedding with Marian->TimeSeries +class TimeSeriesSinusoidalPositionalEmbedding(nn.Embedding): + """This module produces sinusoidal positional embeddings of any length.""" + + def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None) -> None: + super().__init__(num_positions, embedding_dim) + + def _init_weight(self): + """ + Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in + the 2nd half of the vector. [dim // 2:] + """ + n_pos, dim = self.weight.shape + position_enc = np.array( + [[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)] + ) + out = torch.empty(n_pos, dim, dtype=self.weight.dtype, requires_grad=False) + sentinel = dim // 2 if dim % 2 == 0 else (dim // 2) + 1 + out[:, 0:sentinel] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) + out[:, sentinel:] = torch.FloatTensor(np.cos(position_enc[:, 1::2])) + self.weight = nn.Parameter(out, requires_grad=False) + + @torch.no_grad() + def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0) -> torch.Tensor: + """`input_ids_shape` is expected to be [bsz x seqlen].""" + bsz, seq_len = input_ids_shape[:2] + positions = torch.arange( + past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device + ) + return super().forward(positions) + + +class TimeSeriesValueEmbedding(nn.Module): + def __init__(self, feature_size, d_model): + super().__init__() + self.value_projection = nn.Linear(in_features=feature_size, out_features=d_model, bias=False) + + def forward(self, x): + return self.value_projection(x) + + +# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->TimeSeriesTransformer +class TimeSeriesTransformerAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + is_causal: bool = False, + config: Optional[TimeSeriesTransformerConfig] = None, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + self.config = config + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + self.is_causal = is_causal + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.reshape(*proj_shape) + value_states = value_states.reshape(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value + + +# Copied from transformers.models.bart.modeling_bart.BartEncoderLayer with Bart->TimeSeriesTransformer, BART->TIME_SERIES_TRANSFORMER +class TimeSeriesTransformerEncoderLayer(nn.Module): + def __init__(self, config: TimeSeriesTransformerConfig): + super().__init__() + self.embed_dim = config.d_model + + self.self_attn = TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES[config._attn_implementation]( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + config=config, + ) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: torch.FloatTensor, + layer_head_mask: torch.FloatTensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states, attn_weights, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + if hidden_states.dtype == torch.float16 and ( + torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() + ): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +# TODO: Implement attention with SDPA for TimeSeriesTransformer. +TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES = { + "eager": TimeSeriesTransformerAttention, +} + + +# Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->TimeSeriesTransformer, with BART->TIME_SERIES_TRANSFORMER +class TimeSeriesTransformerDecoderLayer(nn.Module): + def __init__(self, config: TimeSeriesTransformerConfig): + super().__init__() + self.embed_dim = config.d_model + + self.self_attn = TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES[config._attn_implementation]( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + is_causal=True, + config=config, + ) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.encoder_attn = TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES[config._attn_implementation]( + self.embed_dim, + config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + config=config, + ) + self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) + self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + cross_attn_layer_head_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = True, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + encoder_hidden_states (`torch.FloatTensor`): + cross attention input to the layer of shape `(batch, seq_len, embed_dim)` + encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of + size `(decoder_attention_heads,)`. + past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + # Self Attention + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + # add present self-attn cache to positions 1,2 of present_key_value tuple + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=self_attn_past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Cross-Attention Block + cross_attn_present_key_value = None + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + + # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=cross_attn_past_key_value, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + # add cross-attn to positions 3,4 of present_key_value tuple + present_key_value = present_key_value + cross_attn_present_key_value + + # Fully Connected + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class TimeSeriesTransformerPreTrainedModel(PreTrainedModel): + config_class = TimeSeriesTransformerConfig + base_model_prefix = "model" + main_input_name = "past_values" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + std = self.config.init_std + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, TimeSeriesSinusoidalPositionalEmbedding): + module._init_weight() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +TIME_SERIES_TRANSFORMER_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`TimeSeriesTransformerConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +TIME_SERIES_TRANSFORMER_INPUTS_DOCSTRING = r""" + Args: + past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`): + Past values of the time series, that serve as context in order to predict the future. The sequence size of + this tensor must be larger than the `context_length` of the model, since the model will use the larger size + to construct lag features, i.e. additional values from the past which are added in order to serve as "extra + context". + + The `sequence_length` here is equal to `config.context_length` + `max(config.lags_sequence)`, which if no + `lags_sequence` is configured, is equal to `config.context_length` + 7 (as by default, the largest + look-back index in `config.lags_sequence` is 7). The property `_past_length` returns the actual length of + the past. + + The `past_values` is what the Transformer encoder gets as input (with optional additional features, such as + `static_categorical_features`, `static_real_features`, `past_time_features` and lags). + + Optionally, missing values need to be replaced with zeros and indicated via the `past_observed_mask`. + + For multivariate time series, the `input_size` > 1 dimension is required and corresponds to the number of + variates in the time series per time step. + past_time_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_features)`): + Required time features, which the model internally will add to `past_values`. These could be things like + "month of year", "day of the month", etc. encoded as vectors (for instance as Fourier features). These + could also be so-called "age" features, which basically help the model know "at which point in life" a + time-series is. Age features have small values for distant past time steps and increase monotonically the + more we approach the current time step. Holiday features are also a good example of time features. + + These features serve as the "positional encodings" of the inputs. So contrary to a model like BERT, where + the position encodings are learned from scratch internally as parameters of the model, the Time Series + Transformer requires to provide additional time features. The Time Series Transformer only learns + additional embeddings for `static_categorical_features`. + + Additional dynamic real covariates can be concatenated to this tensor, with the caveat that these features + must but known at prediction time. + + The `num_features` here is equal to `config.`num_time_features` + `config.num_dynamic_real_features`. + past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`, *optional*): + Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected in + `[0, 1]`: + + - 1 for values that are **observed**, + - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros). + + static_categorical_features (`torch.LongTensor` of shape `(batch_size, number of static categorical features)`, *optional*): + Optional static categorical features for which the model will learn an embedding, which it will add to the + values of the time series. + + Static categorical features are features which have the same value for all time steps (static over time). + + A typical example of a static categorical feature is a time series ID. + static_real_features (`torch.FloatTensor` of shape `(batch_size, number of static real features)`, *optional*): + Optional static real features which the model will add to the values of the time series. + + Static real features are features which have the same value for all time steps (static over time). + + A typical example of a static real feature is promotion information. + future_values (`torch.FloatTensor` of shape `(batch_size, prediction_length)` or `(batch_size, prediction_length, input_size)`, *optional*): + Future values of the time series, that serve as labels for the model. The `future_values` is what the + Transformer needs during training to learn to output, given the `past_values`. + + The sequence length here is equal to `prediction_length`. + + See the demo notebook and code snippets for details. + + Optionally, during training any missing values need to be replaced with zeros and indicated via the + `future_observed_mask`. + + For multivariate time series, the `input_size` > 1 dimension is required and corresponds to the number of + variates in the time series per time step. + future_time_features (`torch.FloatTensor` of shape `(batch_size, prediction_length, num_features)`): + Required time features for the prediction window, which the model internally will add to `future_values`. + These could be things like "month of year", "day of the month", etc. encoded as vectors (for instance as + Fourier features). These could also be so-called "age" features, which basically help the model know "at + which point in life" a time-series is. Age features have small values for distant past time steps and + increase monotonically the more we approach the current time step. Holiday features are also a good example + of time features. + + These features serve as the "positional encodings" of the inputs. So contrary to a model like BERT, where + the position encodings are learned from scratch internally as parameters of the model, the Time Series + Transformer requires to provide additional time features. The Time Series Transformer only learns + additional embeddings for `static_categorical_features`. + + Additional dynamic real covariates can be concatenated to this tensor, with the caveat that these features + must but known at prediction time. + + The `num_features` here is equal to `config.`num_time_features` + `config.num_dynamic_real_features`. + future_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`, *optional*): + Boolean mask to indicate which `future_values` were observed and which were missing. Mask values selected + in `[0, 1]`: + + - 1 for values that are **observed**, + - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros). + + This mask is used to filter out missing values for the final loss calculation. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on certain token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Mask to avoid performing attention on certain token indices. By default, a causal mask will be used, to + make sure the model can only look at previous inputs in order to predict the future. + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + decoder_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): + Tuple consists of `last_hidden_state`, `hidden_states` (*optional*) and `attentions` (*optional*) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` (*optional*) is a sequence of + hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class TimeSeriesTransformerEncoder(TimeSeriesTransformerPreTrainedModel): + """ + Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a + [`TimeSeriesTransformerEncoderLayer`]. + + Args: + config: TimeSeriesTransformerConfig + """ + + def __init__(self, config: TimeSeriesTransformerConfig): + super().__init__(config) + + self.dropout = config.dropout + self.layerdrop = config.encoder_layerdrop + if config.prediction_length is None: + raise ValueError("The `prediction_length` config needs to be specified.") + + self.value_embedding = TimeSeriesValueEmbedding(feature_size=config.feature_size, d_model=config.d_model) + self.embed_positions = TimeSeriesSinusoidalPositionalEmbedding( + config.context_length + config.prediction_length, config.d_model + ) + self.layers = nn.ModuleList([TimeSeriesTransformerEncoderLayer(config) for _ in range(config.encoder_layers)]) + self.layernorm_embedding = nn.LayerNorm(config.d_model) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + hidden_states = self.value_embedding(inputs_embeds) + embed_pos = self.embed_positions(inputs_embeds.size()) + + hidden_states = self.layernorm_embedding(hidden_states + embed_pos) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + # expand attention_mask + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + if head_mask.size()[0] != (len(self.layers)): + raise ValueError( + f"The head_mask should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + to_drop = False + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: # skip the layer + to_drop = True + + if to_drop: + layer_outputs = (None, None) + else: + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + attention_mask, + (head_mask[idx] if head_mask is not None else None), + output_attentions, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +class TimeSeriesTransformerDecoder(TimeSeriesTransformerPreTrainedModel): + """ + Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a + [`TimeSeriesTransformerDecoderLayer`] + + Args: + config: TimeSeriesTransformerConfig + """ + + def __init__(self, config: TimeSeriesTransformerConfig): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.decoder_layerdrop + if config.prediction_length is None: + raise ValueError("The `prediction_length` config needs to be specified.") + + self.value_embedding = TimeSeriesValueEmbedding(feature_size=config.feature_size, d_model=config.d_model) + self.embed_positions = TimeSeriesSinusoidalPositionalEmbedding( + config.context_length + config.prediction_length, config.d_model + ) + self.layers = nn.ModuleList([TimeSeriesTransformerDecoderLayer(config) for _ in range(config.decoder_layers)]) + self.layernorm_embedding = nn.LayerNorm(config.d_model) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + r""" + Args: + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): + Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values + selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing + cross-attention on hidden heads. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + input_shape = inputs_embeds.size()[:-1] + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + hidden_states = self.value_embedding(inputs_embeds) + embed_pos = self.embed_positions(inputs_embeds.size(), past_key_values_length=self.config.context_length) + hidden_states = self.layernorm_embedding(hidden_states + embed_pos) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + next_decoder_cache = () if use_cache else None + + # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): + if attn_mask is not None: + if attn_mask.size()[0] != (len(self.layers)): + raise ValueError( + f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: + continue + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + head_mask[idx] if head_mask is not None else None, + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, + None, + output_attentions, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=( + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None + ), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + +@add_start_docstrings( + "The bare Time Series Transformer Model outputting raw hidden-states without any specific head on top.", + TIME_SERIES_TRANSFORMER_START_DOCSTRING, +) +class TimeSeriesTransformerModel(TimeSeriesTransformerPreTrainedModel): + def __init__(self, config: TimeSeriesTransformerConfig): + super().__init__(config) + + if config.scaling == "mean" or config.scaling is True: + self.scaler = TimeSeriesMeanScaler(config) + elif config.scaling == "std": + self.scaler = TimeSeriesStdScaler(config) + else: + self.scaler = TimeSeriesNOPScaler(config) + + if config.num_static_categorical_features > 0: + self.embedder = TimeSeriesFeatureEmbedder( + cardinalities=config.cardinality, + embedding_dims=config.embedding_dimension, + ) + + # transformer encoder-decoder and mask initializer + self.encoder = TimeSeriesTransformerEncoder(config) + self.decoder = TimeSeriesTransformerDecoder(config) + + # Initialize weights and apply final processing + self.post_init() + + @property + def _past_length(self) -> int: + return self.config.context_length + max(self.config.lags_sequence) + + def get_lagged_subsequences( + self, sequence: torch.Tensor, subsequences_length: int, shift: int = 0 + ) -> torch.Tensor: + """ + Returns lagged subsequences of a given sequence. Returns a tensor of shape (N, S, C, I), + where S = subsequences_length and I = len(indices), containing lagged subsequences. Specifically, lagged[i, + j, :, k] = sequence[i, -indices[k]-S+j, :]. + + Args: + sequence: Tensor + The sequence from which lagged subsequences should be extracted. Shape: (N, T, C). + subsequences_length : int + Length of the subsequences to be extracted. + shift: int + Shift the lags by this amount back. + """ + sequence_length = sequence.shape[1] + indices = [lag - shift for lag in self.config.lags_sequence] + + if max(indices) + subsequences_length > sequence_length: + raise ValueError( + f"lags cannot go further than history length, found lag {max(indices)} " + f"while history length is only {sequence_length}" + ) + + lagged_values = [] + for lag_index in indices: + begin_index = -lag_index - subsequences_length + end_index = -lag_index if lag_index > 0 else None + lagged_values.append(sequence[:, begin_index:end_index, ...]) + return torch.stack(lagged_values, dim=-1) + + def create_network_inputs( + self, + past_values: torch.Tensor, + past_time_features: torch.Tensor, + static_categorical_features: Optional[torch.Tensor] = None, + static_real_features: Optional[torch.Tensor] = None, + past_observed_mask: Optional[torch.Tensor] = None, + future_values: Optional[torch.Tensor] = None, + future_time_features: Optional[torch.Tensor] = None, + ): + # time feature + time_feat = ( + torch.cat( + ( + past_time_features[:, self._past_length - self.config.context_length :, ...], + future_time_features, + ), + dim=1, + ) + if future_values is not None + else past_time_features[:, self._past_length - self.config.context_length :, ...] + ) + + # target + if past_observed_mask is None: + past_observed_mask = torch.ones_like(past_values) + + context = past_values[:, -self.config.context_length :] + observed_context = past_observed_mask[:, -self.config.context_length :] + _, loc, scale = self.scaler(context, observed_context) + + inputs = ( + (torch.cat((past_values, future_values), dim=1) - loc) / scale + if future_values is not None + else (past_values - loc) / scale + ) + + # static features + log_abs_loc = loc.abs().log1p() if self.config.input_size == 1 else loc.squeeze(1).abs().log1p() + log_scale = scale.log() if self.config.input_size == 1 else scale.squeeze(1).log() + static_feat = torch.cat((log_abs_loc, log_scale), dim=1) + + if static_real_features is not None: + static_feat = torch.cat((static_real_features, static_feat), dim=1) + if static_categorical_features is not None: + embedded_cat = self.embedder(static_categorical_features) + static_feat = torch.cat((embedded_cat, static_feat), dim=1) + expanded_static_feat = static_feat.unsqueeze(1).expand(-1, time_feat.shape[1], -1) + + # all features + features = torch.cat((expanded_static_feat, time_feat), dim=-1) + + # lagged features + subsequences_length = ( + self.config.context_length + self.config.prediction_length + if future_values is not None + else self.config.context_length + ) + lagged_sequence = self.get_lagged_subsequences(sequence=inputs, subsequences_length=subsequences_length) + lags_shape = lagged_sequence.shape + reshaped_lagged_sequence = lagged_sequence.reshape(lags_shape[0], lags_shape[1], -1) + + if reshaped_lagged_sequence.shape[1] != time_feat.shape[1]: + raise ValueError( + f"input length {reshaped_lagged_sequence.shape[1]} and time feature lengths {time_feat.shape[1]} does not match" + ) + + # transformer inputs + transformer_inputs = torch.cat((reshaped_lagged_sequence, features), dim=-1) + + return transformer_inputs, loc, scale, static_feat + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + @add_start_docstrings_to_model_forward(TIME_SERIES_TRANSFORMER_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqTSModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + past_values: torch.Tensor, + past_time_features: torch.Tensor, + past_observed_mask: torch.Tensor, + static_categorical_features: Optional[torch.Tensor] = None, + static_real_features: Optional[torch.Tensor] = None, + future_values: Optional[torch.Tensor] = None, + future_time_features: Optional[torch.Tensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + use_cache: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Seq2SeqTSModelOutput, Tuple]: + r""" + Returns: + + Examples: + + ```python + >>> from huggingface_hub import hf_hub_download + >>> import torch + >>> from transformers import TimeSeriesTransformerModel + + >>> file = hf_hub_download( + ... repo_id="hf-internal-testing/tourism-monthly-batch", filename="train-batch.pt", repo_type="dataset" + ... ) + >>> batch = torch.load(file) + + >>> model = TimeSeriesTransformerModel.from_pretrained("huggingface/time-series-transformer-tourism-monthly") + + >>> # during training, one provides both past and future values + >>> # as well as possible additional features + >>> outputs = model( + ... past_values=batch["past_values"], + ... past_time_features=batch["past_time_features"], + ... past_observed_mask=batch["past_observed_mask"], + ... static_categorical_features=batch["static_categorical_features"], + ... static_real_features=batch["static_real_features"], + ... future_values=batch["future_values"], + ... future_time_features=batch["future_time_features"], + ... ) + + >>> last_hidden_state = outputs.last_hidden_state + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_inputs, loc, scale, static_feat = self.create_network_inputs( + past_values=past_values, + past_time_features=past_time_features, + past_observed_mask=past_observed_mask, + static_categorical_features=static_categorical_features, + static_real_features=static_real_features, + future_values=future_values, + future_time_features=future_time_features, + ) + + if encoder_outputs is None: + enc_input = transformer_inputs[:, : self.config.context_length, ...] + encoder_outputs = self.encoder( + inputs_embeds=enc_input, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + dec_input = transformer_inputs[:, self.config.context_length :, ...] + decoder_outputs = self.decoder( + inputs_embeds=dec_input, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + (loc, scale, static_feat) + + return Seq2SeqTSModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + loc=loc, + scale=scale, + static_features=static_feat, + ) + + +@add_start_docstrings( + "The Time Series Transformer Model with a distribution head on top for time-series forecasting.", + TIME_SERIES_TRANSFORMER_START_DOCSTRING, +) +class TimeSeriesTransformerForPrediction(TimeSeriesTransformerPreTrainedModel): + def __init__(self, config: TimeSeriesTransformerConfig): + super().__init__(config) + self.model = TimeSeriesTransformerModel(config) + if config.distribution_output == "student_t": + self.distribution_output = StudentTOutput(dim=config.input_size) + elif config.distribution_output == "normal": + self.distribution_output = NormalOutput(dim=config.input_size) + elif config.distribution_output == "negative_binomial": + self.distribution_output = NegativeBinomialOutput(dim=config.input_size) + else: + raise ValueError(f"Unknown distribution output {config.distribution_output}") + + self.parameter_projection = self.distribution_output.get_parameter_projection(self.model.config.d_model) + self.target_shape = self.distribution_output.event_shape + + if config.loss == "nll": + self.loss = nll + else: + raise ValueError(f"Unknown loss function {config.loss}") + + # Initialize weights of distribution_output and apply final processing + self.post_init() + + def output_params(self, dec_output): + return self.parameter_projection(dec_output) + + def get_encoder(self): + return self.model.get_encoder() + + def get_decoder(self): + return self.model.get_decoder() + + @torch.jit.ignore + def output_distribution(self, params, loc=None, scale=None, trailing_n=None) -> torch.distributions.Distribution: + sliced_params = params + if trailing_n is not None: + sliced_params = [p[:, -trailing_n:] for p in params] + return self.distribution_output.distribution(sliced_params, loc=loc, scale=scale) + + @add_start_docstrings_to_model_forward(TIME_SERIES_TRANSFORMER_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqTSModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + past_values: torch.Tensor, + past_time_features: torch.Tensor, + past_observed_mask: torch.Tensor, + static_categorical_features: Optional[torch.Tensor] = None, + static_real_features: Optional[torch.Tensor] = None, + future_values: Optional[torch.Tensor] = None, + future_time_features: Optional[torch.Tensor] = None, + future_observed_mask: Optional[torch.Tensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + use_cache: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Seq2SeqTSModelOutput, Tuple]: + r""" + Returns: + + Examples: + + ```python + >>> from huggingface_hub import hf_hub_download + >>> import torch + >>> from transformers import TimeSeriesTransformerForPrediction + + >>> file = hf_hub_download( + ... repo_id="hf-internal-testing/tourism-monthly-batch", filename="train-batch.pt", repo_type="dataset" + ... ) + >>> batch = torch.load(file) + + >>> model = TimeSeriesTransformerForPrediction.from_pretrained( + ... "huggingface/time-series-transformer-tourism-monthly" + ... ) + + >>> # during training, one provides both past and future values + >>> # as well as possible additional features + >>> outputs = model( + ... past_values=batch["past_values"], + ... past_time_features=batch["past_time_features"], + ... past_observed_mask=batch["past_observed_mask"], + ... static_categorical_features=batch["static_categorical_features"], + ... static_real_features=batch["static_real_features"], + ... future_values=batch["future_values"], + ... future_time_features=batch["future_time_features"], + ... ) + + >>> loss = outputs.loss + >>> loss.backward() + + >>> # during inference, one only provides past values + >>> # as well as possible additional features + >>> # the model autoregressively generates future values + >>> outputs = model.generate( + ... past_values=batch["past_values"], + ... past_time_features=batch["past_time_features"], + ... past_observed_mask=batch["past_observed_mask"], + ... static_categorical_features=batch["static_categorical_features"], + ... static_real_features=batch["static_real_features"], + ... future_time_features=batch["future_time_features"], + ... ) + + >>> mean_prediction = outputs.sequences.mean(dim=1) + ```""" + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if future_values is not None: + use_cache = False + + outputs = self.model( + past_values=past_values, + past_time_features=past_time_features, + past_observed_mask=past_observed_mask, + static_categorical_features=static_categorical_features, + static_real_features=static_real_features, + future_values=future_values, + future_time_features=future_time_features, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + encoder_outputs=encoder_outputs, + past_key_values=past_key_values, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + use_cache=use_cache, + return_dict=return_dict, + ) + + prediction_loss = None + params = None + if future_values is not None: + params = self.output_params(outputs[0]) # outputs.last_hidden_state + # loc is 3rd last and scale is 2nd last output + distribution = self.output_distribution(params, loc=outputs[-3], scale=outputs[-2]) + + loss = self.loss(distribution, future_values) + + if future_observed_mask is None: + future_observed_mask = torch.ones_like(future_values) + + if len(self.target_shape) == 0: + loss_weights = future_observed_mask + else: + loss_weights, _ = future_observed_mask.min(dim=-1, keepdim=False) + + prediction_loss = weighted_average(loss, weights=loss_weights) + + if not return_dict: + outputs = ((params,) + outputs[1:]) if params is not None else outputs[1:] + return ((prediction_loss,) + outputs) if prediction_loss is not None else outputs + + return Seq2SeqTSPredictionOutput( + loss=prediction_loss, + params=params, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + loc=outputs.loc, + scale=outputs.scale, + static_features=outputs.static_features, + ) + + @torch.no_grad() + def generate( + self, + past_values: torch.Tensor, + past_time_features: torch.Tensor, + future_time_features: torch.Tensor, + past_observed_mask: Optional[torch.Tensor] = None, + static_categorical_features: Optional[torch.Tensor] = None, + static_real_features: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> SampleTSPredictionOutput: + r""" + Greedily generate sequences of sample predictions from a model with a probability distribution head. + + Parameters: + past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`): + Past values of the time series, that serve as context in order to predict the future. The sequence size + of this tensor must be larger than the `context_length` of the model, since the model will use the + larger size to construct lag features, i.e. additional values from the past which are added in order to + serve as "extra context". + + The `sequence_length` here is equal to `config.context_length` + `max(config.lags_sequence)`, which if + no `lags_sequence` is configured, is equal to `config.context_length` + 7 (as by default, the largest + look-back index in `config.lags_sequence` is 7). The property `_past_length` returns the actual length + of the past. + + The `past_values` is what the Transformer encoder gets as input (with optional additional features, + such as `static_categorical_features`, `static_real_features`, `past_time_features` and lags). + + Optionally, missing values need to be replaced with zeros and indicated via the `past_observed_mask`. + + For multivariate time series, the `input_size` > 1 dimension is required and corresponds to the number + of variates in the time series per time step. + past_time_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_features)`): + Required time features, which the model internally will add to `past_values`. These could be things + like "month of year", "day of the month", etc. encoded as vectors (for instance as Fourier features). + These could also be so-called "age" features, which basically help the model know "at which point in + life" a time-series is. Age features have small values for distant past time steps and increase + monotonically the more we approach the current time step. Holiday features are also a good example of + time features. + + These features serve as the "positional encodings" of the inputs. So contrary to a model like BERT, + where the position encodings are learned from scratch internally as parameters of the model, the Time + Series Transformer requires to provide additional time features. The Time Series Transformer only + learns additional embeddings for `static_categorical_features`. + + Additional dynamic real covariates can be concatenated to this tensor, with the caveat that these + features must but known at prediction time. + + The `num_features` here is equal to `config.`num_time_features` + `config.num_dynamic_real_features`. + future_time_features (`torch.FloatTensor` of shape `(batch_size, prediction_length, num_features)`): + Required time features for the prediction window, which the model internally will add to sampled + predictions. These could be things like "month of year", "day of the month", etc. encoded as vectors + (for instance as Fourier features). These could also be so-called "age" features, which basically help + the model know "at which point in life" a time-series is. Age features have small values for distant + past time steps and increase monotonically the more we approach the current time step. Holiday features + are also a good example of time features. + + These features serve as the "positional encodings" of the inputs. So contrary to a model like BERT, + where the position encodings are learned from scratch internally as parameters of the model, the Time + Series Transformer requires to provide additional time features. The Time Series Transformer only + learns additional embeddings for `static_categorical_features`. + + Additional dynamic real covariates can be concatenated to this tensor, with the caveat that these + features must but known at prediction time. + + The `num_features` here is equal to `config.`num_time_features` + `config.num_dynamic_real_features`. + past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`, *optional*): + Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected + in `[0, 1]`: + + - 1 for values that are **observed**, + - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros). + + static_categorical_features (`torch.LongTensor` of shape `(batch_size, number of static categorical features)`, *optional*): + Optional static categorical features for which the model will learn an embedding, which it will add to + the values of the time series. + + Static categorical features are features which have the same value for all time steps (static over + time). + + A typical example of a static categorical feature is a time series ID. + static_real_features (`torch.FloatTensor` of shape `(batch_size, number of static real features)`, *optional*): + Optional static real features which the model will add to the values of the time series. + + Static real features are features which have the same value for all time steps (static over time). + + A typical example of a static real feature is promotion information. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. + + Return: + [`SampleTSPredictionOutput`] where the outputs `sequences` tensor will have shape `(batch_size, number of + samples, prediction_length)` or `(batch_size, number of samples, prediction_length, input_size)` for + multivariate predictions. + """ + outputs = self( + static_categorical_features=static_categorical_features, + static_real_features=static_real_features, + past_time_features=past_time_features, + past_values=past_values, + past_observed_mask=past_observed_mask, + future_time_features=future_time_features, + future_values=None, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + use_cache=True, + ) + + decoder = self.model.get_decoder() + enc_last_hidden = outputs.encoder_last_hidden_state + loc = outputs.loc + scale = outputs.scale + static_feat = outputs.static_features + + num_parallel_samples = self.config.num_parallel_samples + repeated_loc = loc.repeat_interleave(repeats=num_parallel_samples, dim=0) + repeated_scale = scale.repeat_interleave(repeats=num_parallel_samples, dim=0) + + repeated_past_values = ( + past_values.repeat_interleave(repeats=num_parallel_samples, dim=0) - repeated_loc + ) / repeated_scale + + expanded_static_feat = static_feat.unsqueeze(1).expand(-1, future_time_features.shape[1], -1) + features = torch.cat((expanded_static_feat, future_time_features), dim=-1) + repeated_features = features.repeat_interleave(repeats=num_parallel_samples, dim=0) + + repeated_enc_last_hidden = enc_last_hidden.repeat_interleave(repeats=num_parallel_samples, dim=0) + + future_samples = [] + + # greedy decoding + for k in range(self.config.prediction_length): + lagged_sequence = self.model.get_lagged_subsequences( + sequence=repeated_past_values, + subsequences_length=1 + k, + shift=1, + ) + + lags_shape = lagged_sequence.shape + reshaped_lagged_sequence = lagged_sequence.reshape(lags_shape[0], lags_shape[1], -1) + + decoder_input = torch.cat((reshaped_lagged_sequence, repeated_features[:, : k + 1]), dim=-1) + + dec_output = decoder(inputs_embeds=decoder_input, encoder_hidden_states=repeated_enc_last_hidden) + dec_last_hidden = dec_output.last_hidden_state + + params = self.parameter_projection(dec_last_hidden[:, -1:]) + distr = self.output_distribution(params, loc=repeated_loc, scale=repeated_scale) + next_sample = distr.sample() + + repeated_past_values = torch.cat( + (repeated_past_values, (next_sample - repeated_loc) / repeated_scale), dim=1 + ) + future_samples.append(next_sample) + + concat_future_samples = torch.cat(future_samples, dim=1) + + return SampleTSPredictionOutput( + sequences=concat_future_samples.reshape( + (-1, num_parallel_samples, self.config.prediction_length) + self.target_shape, + ) + ) + + +__all__ = ["TimeSeriesTransformerForPrediction", "TimeSeriesTransformerModel", "TimeSeriesTransformerPreTrainedModel"] diff --git a/docs/transformers/src/transformers/models/timesfm/__init__.py b/docs/transformers/src/transformers/models/timesfm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..12f1541b9c549256b67b3a304cc482091842e94e --- /dev/null +++ b/docs/transformers/src/transformers/models/timesfm/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_timesfm import * + from .modeling_timesfm import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/timesfm/configuration_timesfm.py b/docs/transformers/src/transformers/models/timesfm/configuration_timesfm.py new file mode 100644 index 0000000000000000000000000000000000000000..bd371cf1b21879ada9d8923a6d01372476d3efc9 --- /dev/null +++ b/docs/transformers/src/transformers/models/timesfm/configuration_timesfm.py @@ -0,0 +1,129 @@ +# coding=utf-8 +# Copyright 2025 Google LLC and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TimesFM model configuration""" + +from typing import List + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class TimesFmConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`TimesFmModelForPrediction`] or a [`TFTimesFmModel`]. It is used to + instantiate a TimesFM model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the TimesFM + [google/timesfm-2.0-500m-pytorch](https://huggingface.co/google/timesfm-2.0-500m-pytorch) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Arguments: + patch_length (`int`, *optional*, defaults to 32): + The length of one patch in the input sequence. + context_length (`int`, *optional*, defaults to 512): + The length of the input context. + horizon_length (`int`, *optional*, defaults to 128): + The length of the prediction horizon. + freq_size (`int`, *optional*, defaults to 3): + The number of frequency embeddings. + num_hidden_layers (`int`, *optional*, defaults to 50): + Number of Transformer layers. + hidden_size (`int`, *optional*, defaults to 1280): + Size of the hidden layers in the feed-forward networks. + intermediate_size (`int`, *optional*, defaults to 1280): + Dimension of the MLP representations. + head_dim (`int`, *optional*, defaults to 80): + Size of the key, query, value projections per attention head. The `inner_dim` of the projection layer will + be defined as `num_attention_heads * head_dim`. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + tolerance (`float`, *optional*, defaults to 1e-06): + The tolerance for the quantile loss. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the RMS normalization layers. + quantiles (`List[float]`, *optional*, defaults to `[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]`): + The quantiles to predict. + pad_val (`float`, *optional*, defaults to 1123581321.0): + The value used to pad the predictions. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout probability for the attention scores. + use_positional_embedding (`bool`, *optional*, defaults to `False`): + Whether to add positional embeddings. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + min_timescale (`int`, *optional*, defaults to 1): + The start of the geometric positional index. Determines the periodicity of + the added signal. + max_timescale (`int`, *optional*, defaults to 10000): + The end of the geometric positional index. Determines the frequency of the + added signal. + """ + + model_type = "timesfm" + keys_to_ignore_at_inference = [] + is_encoder_decoder = False + + def __init__( + self, + patch_length: int = 32, + context_length: int = 512, + horizon_length: int = 128, + freq_size: int = 3, + num_hidden_layers: int = 50, + hidden_size: int = 1280, + intermediate_size: int = 1280, + head_dim: int = 80, + num_attention_heads: int = 16, + tolerance: float = 1e-6, + rms_norm_eps: float = 1e-6, + quantiles: List[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], + pad_val: float = 1123581321.0, + attention_dropout: float = 0.0, + use_positional_embedding: bool = False, + initializer_range: float = 0.02, + min_timescale: int = 1, + max_timescale: int = 10_000, + **kwargs, + ): + self.patch_length = patch_length + self.context_length = context_length + self.horizon_length = horizon_length + self.quantiles = quantiles + self.pad_val = pad_val + self.freq_size = freq_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.head_dim = head_dim + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.tolerance = tolerance + self.rms_norm_eps = rms_norm_eps + self.attention_dropout = attention_dropout + self.use_positional_embedding = use_positional_embedding + self.initializer_range = initializer_range + self.min_timescale = min_timescale + self.max_timescale = max_timescale + + super().__init__( + is_encoder_decoder=self.is_encoder_decoder, + **kwargs, + ) + + +__all__ = ["TimesFmConfig"] diff --git a/docs/transformers/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py b/docs/transformers/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..25ed42fc2064bbfaf4b652f807296d8f8d1801e9 --- /dev/null +++ b/docs/transformers/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py @@ -0,0 +1,277 @@ +import argparse +import os +import re +import shutil + +import numpy as np +import timesfm +import torch + +from transformers import TimesFmConfig, TimesFmModelForPrediction + + +""" +Sample usage: + +``` +python src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py \ + --output_dir /output/path +``` +""" + + +def get_nested_attr(obj, key): + """Recursively retrieves an attribute from an object, handling list/tuple indexing if present.""" + parts = key.split(".") + for part in parts: + match = re.match(r"(.*)\[(\d+)\]", part) # Handle list indexing like `layers[0]` + if match: + attr_name, index = match.groups() + obj = getattr(obj, attr_name)[int(index)] # Access list/tuple element + else: + obj = getattr(obj, part) # Regular attribute access + return obj + + +def write_model(model_path, safe_serialization=True, huggingface_repo_id="google/timesfm-2.0-500m-pytorch"): + os.makedirs(model_path, exist_ok=True) + tmp_model_path = os.path.join(model_path, "tmp") + os.makedirs(tmp_model_path, exist_ok=True) + + tfm = timesfm.TimesFm( + hparams=timesfm.TimesFmHparams( + backend="cuda" if torch.cuda.is_available() else "cpu", + per_core_batch_size=32, + horizon_len=128, + input_patch_len=32, + output_patch_len=128, + num_layers=50, + model_dims=1280, + use_positional_embedding=False, + context_len=2048, + ), + checkpoint=timesfm.TimesFmCheckpoint(huggingface_repo_id=huggingface_repo_id), + ) + + timesfm_config = TimesFmConfig( + patch_length=tfm.hparams.input_patch_len, + context_length=tfm.hparams.context_len, + horizon_length=tfm.hparams.horizon_len, + num_hidden_layers=tfm.hparams.num_layers, + hidden_size=tfm.hparams.model_dims, + intermediate_size=tfm.hparams.model_dims, + head_dim=tfm.hparams.model_dims // tfm.hparams.num_heads, + num_attention_heads=tfm.hparams.num_heads, + use_positional_embedding=tfm.hparams.use_positional_embedding, + ) + timesfm_config.save_pretrained(tmp_model_path) + timesfm_model = TimesFmModelForPrediction(timesfm_config) + + # copy the weights from the original model to the new model making + original_model = tfm._model + + # mapping of the layers from the original model to the transformer model + MODEL_LAYER_MAPPING = { + "input_ff_layer.hidden_layer[0].weight": "decoder.input_ff_layer.input_layer.weight", + "input_ff_layer.hidden_layer[0].bias": "decoder.input_ff_layer.input_layer.bias", + "input_ff_layer.output_layer.weight": "decoder.input_ff_layer.output_layer.weight", + "input_ff_layer.output_layer.bias": "decoder.input_ff_layer.output_layer.bias", + "input_ff_layer.residual_layer.weight": "decoder.input_ff_layer.residual_layer.weight", + "input_ff_layer.residual_layer.bias": "decoder.input_ff_layer.residual_layer.bias", + "freq_emb.weight": "decoder.freq_emb.weight", + "horizon_ff_layer.hidden_layer[0].weight": "horizon_ff_layer.input_layer.weight", + "horizon_ff_layer.hidden_layer[0].bias": "horizon_ff_layer.input_layer.bias", + "horizon_ff_layer.output_layer.weight": "horizon_ff_layer.output_layer.weight", + "horizon_ff_layer.output_layer.bias": "horizon_ff_layer.output_layer.bias", + "horizon_ff_layer.residual_layer.weight": "horizon_ff_layer.residual_layer.weight", + "horizon_ff_layer.residual_layer.bias": "horizon_ff_layer.residual_layer.bias", + } + + TRANSFORMER_LAYER_MAPPING = { + "stacked_transformer.layers[{i}].self_attn.qkv_proj.weight": "decoder.layers[{i}].self_attn.qkv_proj.weight", + "stacked_transformer.layers[{i}].self_attn.qkv_proj.bias": "decoder.layers[{i}].self_attn.qkv_proj.bias", + "stacked_transformer.layers[{i}].self_attn.o_proj.weight": "decoder.layers[{i}].self_attn.o_proj.weight", + "stacked_transformer.layers[{i}].self_attn.o_proj.bias": "decoder.layers[{i}].self_attn.o_proj.bias", + "stacked_transformer.layers[{i}].self_attn.scaling": "decoder.layers[{i}].self_attn.scaling", + "stacked_transformer.layers[{i}].mlp.gate_proj.weight": "decoder.layers[{i}].mlp.gate_proj.weight", + "stacked_transformer.layers[{i}].mlp.gate_proj.bias": "decoder.layers[{i}].mlp.gate_proj.bias", + "stacked_transformer.layers[{i}].mlp.down_proj.weight": "decoder.layers[{i}].mlp.down_proj.weight", + "stacked_transformer.layers[{i}].mlp.down_proj.bias": "decoder.layers[{i}].mlp.down_proj.bias", + "stacked_transformer.layers[{i}].mlp.layer_norm.weight": "decoder.layers[{i}].mlp.layer_norm.weight", + "stacked_transformer.layers[{i}].mlp.layer_norm.bias": "decoder.layers[{i}].mlp.layer_norm.bias", + "stacked_transformer.layers[{i}].input_layernorm.weight": "decoder.layers[{i}].input_layernorm.weight", + } + + for old_key, new_key in MODEL_LAYER_MAPPING.items(): + try: + old_attr = get_nested_attr(original_model, old_key) # Get tensor from original model + new_attr = get_nested_attr(timesfm_model, new_key) # Get corresponding attribute in new model + new_attr.data.copy_(old_attr.data) # Copy data + except AttributeError: + print(f"Skipping {old_key} (not found in original model).") + + num_layers = len(timesfm_model.decoder.layers) + for i in range(num_layers): + for old_template, new_template in TRANSFORMER_LAYER_MAPPING.items(): + old_key = old_template.format(i=i) + new_key = new_template.format(i=i) + + try: + # Get tensor from original model + old_attr = get_nested_attr(original_model, old_key) + if "qkv_proj" in old_key: + # Split the tensor into q, k, v projections + q_proj, k_proj, v_proj = ( + old_attr[: tfm.hparams.model_dims, ...], + old_attr[tfm.hparams.model_dims : tfm.hparams.model_dims * 2, ...], + old_attr[tfm.hparams.model_dims * 2 :, ...], + ) + # Get corresponding attribute in new model + q_key = new_key.replace("qkv_proj", "q_proj") + q_attr = get_nested_attr(timesfm_model, q_key) + q_attr.data.copy_(q_proj.data) # Copy data + k_key = new_key.replace("qkv_proj", "k_proj") + k_attr = get_nested_attr(timesfm_model, k_key) + k_attr.data.copy_(k_proj.data) # Copy data + v_key = new_key.replace("qkv_proj", "v_proj") + v_attr = get_nested_attr(timesfm_model, v_key) + v_attr.data.copy_(v_proj.data) # Copy data + else: + # Get corresponding attribute in new model + new_attr = get_nested_attr(timesfm_model, new_key) + new_attr.data.copy_(old_attr.data) # Copy data + except AttributeError: + print(f"Skipping {old_key} (not found in original model).") + + timesfm_model.save_pretrained(model_path, safe_serialization=safe_serialization) + shutil.rmtree(tmp_model_path) + + +def check_outputs(model_path, huggingface_repo_id): + """Compares outputs between original and converted models.""" + print("\nChecking model outputs...") + + # Load original model + tfm = timesfm.TimesFm( + hparams=timesfm.TimesFmHparams( + backend="cuda" if torch.cuda.is_available() else "cpu", + per_core_batch_size=32, + horizon_len=128, + input_patch_len=32, + output_patch_len=128, + num_layers=50, + context_len=2048, + model_dims=1280, + use_positional_embedding=False, + point_forecast_mode="mean", + ), + checkpoint=timesfm.TimesFmCheckpoint(huggingface_repo_id=huggingface_repo_id), + ) + + # Load converted model + converted_model = TimesFmModelForPrediction.from_pretrained( + model_path, + torch_dtype=torch.bfloat16, + attn_implementation="sdpa", + ).to("cuda" if torch.cuda.is_available() else "cpu") + converted_model.eval() # Set to evaluation mode + + # Create test inputs + forecast_input = [ + np.sin(np.linspace(0, 20, 100)), + np.sin(np.linspace(0, 20, 200)), + np.sin(np.linspace(0, 20, 400)), + ] + frequency_input = [0, 1, 2] + + # Get predictions from original model + point_forecast_orig, quantile_forecast_orig = tfm.forecast( + forecast_input, + freq=frequency_input, + ) + + # Convert inputs to sequence of tensors + forecast_input_tensor = [ + torch.tensor(ts, dtype=torch.bfloat16).to("cuda" if torch.cuda.is_available() else "cpu") + for ts in forecast_input + ] + frequency_input_tensor = torch.tensor(frequency_input, dtype=torch.long).to( + "cuda" if torch.cuda.is_available() else "cpu" + ) + + # Get predictions from converted model + with torch.no_grad(): + outputs = converted_model(past_values=forecast_input_tensor, freq=frequency_input_tensor, return_dict=True) + point_forecast_conv = outputs.mean_predictions.float().cpu().numpy() + quantile_forecast_conv = outputs.full_predictions.float().cpu().numpy() + + # Compare outputs + point_forecast_diff = np.abs(point_forecast_orig - point_forecast_conv) + quantile_forecast_diff = np.abs(quantile_forecast_orig - quantile_forecast_conv) + + max_point_diff = point_forecast_diff.max() + mean_point_diff = point_forecast_diff.mean() + max_quantile_diff = quantile_forecast_diff.max() + mean_quantile_diff = quantile_forecast_diff.mean() + + print("\nOutput comparison:") + print(f"Point forecast - Max difference: {max_point_diff:.6f}") + print(f"Point forecast - Mean difference: {mean_point_diff:.6f}") + print(f"Quantile forecast - Max difference: {max_quantile_diff:.6f}") + print(f"Quantile forecast - Mean difference: {mean_quantile_diff:.6f}") + + # Define acceptable thresholds + POINT_THRESHOLD = 1e-5 + QUANTILE_THRESHOLD = 1e-5 + + if max_point_diff > POINT_THRESHOLD or max_quantile_diff > QUANTILE_THRESHOLD: + raise ValueError( + f"Output mismatch detected!\n" + f"Point forecast max diff: {max_point_diff} (threshold: {POINT_THRESHOLD})\n" + f"Quantile forecast max diff: {max_quantile_diff} (threshold: {QUANTILE_THRESHOLD})" + ) + + print("\n✓ All outputs match within acceptable tolerance!") + + # Optional: Print shapes for verification + print("\nOutput shapes:") + print(f"Original point forecast: {point_forecast_orig.shape}") + print(f"Converted point forecast: {point_forecast_conv.shape}") + print(f"Original quantile forecast: {quantile_forecast_orig.shape}") + print(f"Converted quantile forecast: {quantile_forecast_conv.shape}") + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--output_dir", + required=True, + help="Location to write HF model and tokenizer", + ) + parser.add_argument( + "--safe_serialization", type=bool, default=True, help="Whether or not to save using `safetensors`." + ) + parser.add_argument( + "--huggingface_repo_id", + type=str, + default="google/timesfm-2.0-500m-pytorch", + help="The Hugging Face repository ID to use for the model.", + ) + args = parser.parse_args() + + # if the saved model file exists, skip the conversion + if os.path.exists( + os.path.join(args.output_dir, "model.safetensors" if args.safe_serialization else "pytorch_model.bin") + ): + print(f"Model already exists in {args.output_dir}, skipping conversion.") + else: + write_model( + model_path=args.output_dir, + safe_serialization=args.safe_serialization, + huggingface_repo_id=args.huggingface_repo_id, + ) + check_outputs(args.output_dir, args.huggingface_repo_id) + + +if __name__ == "__main__": + main() diff --git a/docs/transformers/src/transformers/models/timesfm/modeling_timesfm.py b/docs/transformers/src/transformers/models/timesfm/modeling_timesfm.py new file mode 100644 index 0000000000000000000000000000000000000000..e34f823500da1af5eb049fe593fd4f19cc44d6cc --- /dev/null +++ b/docs/transformers/src/transformers/models/timesfm/modeling_timesfm.py @@ -0,0 +1,908 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/timesfm/modular_timesfm.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_timesfm.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 Google LLC and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import Callable, Optional, Sequence, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...integrations import use_kernel_forward_from_hub +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import BaseModelOutput +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + can_return_tuple, + logging, + replace_return_docstrings, +) +from .configuration_timesfm import TimesFmConfig + + +logger = logging.get_logger(__name__) +_CONFIG_FOR_DOC = "TimesFmConfig" + + +@dataclass +class TimesFmOutput(BaseModelOutput): + """ + Args: + loc (`torch.Tensor` of shape `(batch_size, )`): + The mean of the time series inputs. + scale (`torch.Tensor` of shape `(batch_size,)`): + The scale of the time series inputs. + """ + + loc: Optional[torch.Tensor] = None + scale: Optional[torch.Tensor] = None + + +@dataclass +class TimesFmOutputForPrediction(BaseModelOutput): + """ + Args: + mean_predictions (`torch.Tensor` of shape `(batch_size, sequence_length)`): + The mean predictions of the time series. + full_predictions (`torch.Tensor` of shape `(batch_size, sequence_length)`): + The full predictions of the time series including the mean and the quantiles. + loss (`torch.Tensor` of shape `(1,)`, *optional*, returned when `future_values` is provided): + The loss of the TimesFM model. + """ + + mean_predictions: Optional[torch.Tensor] = None + full_predictions: Optional[torch.Tensor] = None + loss: Optional[Union[torch.Tensor, float]] = None + + +class TimesFmMLP(nn.Module): + """Pax MLP in pytorch.""" + + def __init__(self, config: TimesFmConfig): + super().__init__() + hidden_size = config.hidden_size + intermediate_size = config.intermediate_size + + self.gate_proj = nn.Linear(hidden_size, intermediate_size) + self.down_proj = nn.Linear(intermediate_size, hidden_size) + self.layer_norm = nn.LayerNorm(normalized_shape=hidden_size, eps=1e-6) + + def forward(self, x, paddings=None): + gate_inp = self.layer_norm(x) + gate = self.gate_proj(gate_inp) + gate = F.relu(gate) + outputs = self.down_proj(gate) + if paddings is not None: + outputs = outputs * (1.0 - paddings[:, :, None]) + return outputs + x + + +class TimesFmResidualBlock(nn.Module): + """TimesFM residual block.""" + + def __init__(self, input_dims, hidden_dims, output_dims): + super().__init__() + self.input_dims = input_dims + self.hidden_dims = hidden_dims + self.output_dims = output_dims + + self.input_layer = nn.Linear(input_dims, hidden_dims) + self.activation = nn.SiLU() + self.output_layer = nn.Linear(hidden_dims, output_dims) + self.residual_layer = nn.Linear(input_dims, output_dims) + + def forward(self, x): + hidden = self.input_layer(x) + hidden = self.activation(hidden) + output = self.output_layer(hidden) + residual = self.residual_layer(x) + return output + residual + + +@use_kernel_forward_from_hub("RMSNorm") +class TimesFmRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + TimesFmRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class TimesFmPositionalEmbedding(nn.Module): + """Generates position embedding for a given 1-d sequence.""" + + def __init__(self, config: TimesFmConfig): + super().__init__() + min_timescale = config.min_timescale + max_timescale = config.max_timescale + self.embedding_dims = config.hidden_size + + num_timescales = self.embedding_dims // 2 + log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / max(num_timescales - 1, 1) + self.register_buffer( + "inv_timescales", + min_timescale * torch.exp(torch.arange(num_timescales, dtype=torch.float32) * -log_timescale_increment), + ) + + def forward(self, seq_length=None, position=None): + """Generates a Tensor of sinusoids with different frequencies. + + Args: + seq_length: an optional Python int defining the output sequence length. + if the `position` argument is specified. + position: [B, seq_length], optional position for each token in the + sequence, only required when the sequence is packed. + + Returns: + [B, seqlen, D] if `position` is specified, else [1, seqlen, D] + """ + if position is None and seq_length is None: + raise ValueError("Either position or seq_length must be provided") + + if position is None: + # [1, seqlen] + position = torch.arange(seq_length, dtype=torch.float32, device=self.inv_timescales.device).unsqueeze(0) + elif position.ndim != 2: + raise ValueError(f"position must be 2-dimensional, got shape {position.shape}") + + scaled_time = position.view(*position.shape, 1) * self.inv_timescales.view(1, 1, -1) + signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2) + + # Padding to ensure correct embedding dimension + signal = F.pad(signal, (0, 0, 0, self.embedding_dims % 2)) + return signal + + +def simple_eager_attention_forward( + module: nn.Module, + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class TimesFmAttention(nn.Module): + """Implements the attention used in TimesFM. One key difference is that there is _per_dim_scaling of the query.""" + + def __init__(self, config: TimesFmConfig, layer_idx: int): + super().__init__() + self.config = config + self.is_causal = True + self.attention_dropout = config.attention_dropout + self.layer_idx = layer_idx + + self.num_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.head_dim = config.head_dim + + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_heads * self.head_dim + self.scaling = nn.Parameter(torch.empty((self.head_dim,))) + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim) + self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim) + self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size) + + def _scale_query(self, query: torch.Tensor) -> torch.Tensor: + scale = F.softplus(self.scaling).mul(1.442695041 / math.sqrt(self.head_dim)) + return query * scale[None, None, None, :] + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + query_states = self._scale_query(query_states) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + attention_interface: Callable = simple_eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=1.0, + **kwargs, + ) + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class TimesFmDecoderLayer(nn.Module): + """Transformer layer.""" + + def __init__(self, config: TimesFmConfig, layer_idx: int): + super().__init__() + + self.self_attn = TimesFmAttention(config, layer_idx=layer_idx) + self.mlp = TimesFmMLP(config) + self.input_layernorm = TimesFmRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + paddings: torch.Tensor, + output_attentions: bool = False, + ) -> tuple[Optional[torch.Tensor], torch.Tensor]: + # Self Attention + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states, scores = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + # MLP + hidden_states = self.mlp(hidden_states, paddings=paddings) + + return scores, hidden_states + + +TIMESFM_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`TimesFmConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare TimesFM Model outputting raw hidden-states without any specific head on top.", + TIMESFM_START_DOCSTRING, +) +class TimesFmPreTrainedModel(PreTrainedModel): + """handles the loading for all models.""" + + config_class = TimesFmConfig + base_model_prefix = "timesfm" + _no_split_modules = ["TimesFmDecoderLayer"] + main_input_name = "past_values" + _supports_sdpa = True + + def _init_weights(self, module): + if isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0, std=self.config.initializer_range) + + elif isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0, std=self.config.initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + + elif isinstance(module, nn.LayerNorm): + nn.init.ones_(module.weight) + nn.init.zeros_(module.bias) + + elif isinstance(module, TimesFmRMSNorm): + nn.init.zeros_(module.weight) + + elif isinstance(module, TimesFmAttention): + # Initialize scaling parameter + nn.init.ones_(module.scaling) + + +TIMESFM_INPUTS_DOCSTRING = r""" + Args: + past_values: list of time series forecast contexts. Each context time series + can be a torch Tensor of potentially different context lengths. + freq: frequency of each context time series in the inputs. 0 for high frequency + (default), 1 for medium, and 2 for low. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. +""" + + +@add_start_docstrings( + "The bare TimesFM Model outputting raw hidden-states without any specific head on top.", + TIMESFM_START_DOCSTRING, +) +class TimesFmModel(TimesFmPreTrainedModel): + """Patched time-series decoder without any specific output layer.""" + + def __init__(self, config: TimesFmConfig): + super().__init__(config) + + self.config = config + self.input_ff_layer = TimesFmResidualBlock( + input_dims=2 * config.patch_length, + output_dims=config.hidden_size, + hidden_dims=config.intermediate_size, + ) + self.freq_emb = nn.Embedding(num_embeddings=config.freq_size, embedding_dim=config.hidden_size) + self.layers = nn.ModuleList( + [TimesFmDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + if self.config.use_positional_embedding: + self.position_emb = TimesFmPositionalEmbedding(config=config) + + # Initialize weights and apply final processing + self.post_init() + + def _forward_transform( + self, inputs: torch.Tensor, patched_pads: torch.Tensor + ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + """Input is of shape [B, N, P].""" + mu, sigma = self._timesfm_masked_mean_std(inputs, patched_pads) + sigma = torch.where( + sigma < self.config.tolerance, + torch.tensor(1.0, dtype=sigma.dtype, device=sigma.device), + sigma, + ) + + # Normalize each patch + outputs = (inputs - mu[:, None, None]) / sigma[:, None, None] + outputs = torch.where( + torch.abs(inputs - self.config.pad_val) < self.config.tolerance, + torch.tensor(self.config.pad_val, dtype=outputs.dtype, device=outputs.device), + outputs, + ) + return outputs, (mu, sigma) + + @can_return_tuple + @add_start_docstrings_to_model_forward(TIMESFM_INPUTS_DOCSTRING) + def forward( + self, + past_values: torch.Tensor, + past_values_padding: torch.LongTensor, + freq: torch.Tensor, + output_attentions: bool = False, + output_hidden_states: bool = False, + ) -> TimesFmOutput: + """ + past_values_padding (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The padding indicator of the time series. + """ + # Reshape into patches (using view for efficiency) + bsize = past_values.shape[0] + patched_inputs = past_values.view(bsize, -1, self.config.patch_length) + patched_pads = past_values_padding.view(bsize, -1, self.config.patch_length) + + patched_inputs = torch.where( + torch.abs(patched_pads - 1.0) < self.config.tolerance, + torch.tensor(0.0, dtype=patched_inputs.dtype, device=patched_inputs.device), + patched_inputs, + ) + patched_pads = torch.where( + torch.abs(patched_inputs - self.config.pad_val) < self.config.tolerance, + torch.tensor(1.0, dtype=patched_pads.dtype, device=patched_pads.device), + patched_pads, + ) + patched_inputs, stats = self._forward_transform(patched_inputs, patched_pads) + + # B x N x D + patched_inputs = patched_inputs * (1.0 - patched_pads) + concat_inputs = torch.cat([patched_inputs, patched_pads], dim=-1) + model_input = self.input_ff_layer(concat_inputs) + + # A patch should not be padded even if there is at least one zero. + patched_padding = torch.min(patched_pads, dim=-1)[0] # Get the values from the min result + if self.config.use_positional_embedding: + pos_emb = self.position_emb(model_input.shape[1]) + pos_emb = torch.concat([pos_emb] * model_input.shape[0], dim=0) + pos_emb = self._timesfm_shift_padded_seq(patched_padding, pos_emb) + model_input += pos_emb + + f_emb = self.freq_emb(freq) # B x 1 x D + model_input += f_emb + + # Convert paddings to attention mask and combine with causal mask + hidden_states = model_input + attention_mask = self._prepare_4d_attention_mask( + attention_mask=patched_padding, + sequence_length=hidden_states.shape[1], + dtype=hidden_states.dtype, + device=hidden_states.device, + is_causal=True, + ) + + all_attentions = [] + all_hidden_states = [] + + for layer in self.layers[: self.config.num_hidden_layers]: + scores, hidden_states = layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + paddings=patched_padding, + output_attentions=output_attentions, + ) + if output_attentions: + all_attentions.append(scores) + if output_hidden_states: + all_hidden_states.append(hidden_states) + + if output_hidden_states: + all_hidden_states = [model_input] + all_hidden_states + else: + all_hidden_states = None + + return TimesFmOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_attentions if output_attentions else None, + loc=stats[0], + scale=stats[1], + ) + + @staticmethod + def _prepare_4d_attention_mask( + attention_mask: Optional[torch.Tensor], + sequence_length: int, + dtype: torch.dtype, + device: torch.device, + is_causal: bool = True, + ) -> Optional[torch.Tensor]: + """ + Creates 4D attention mask and combines causal and padding masks if needed. + + Args: + attention_mask: Optional tensor of shape (batch_size, seq_length) containing padding mask + sequence_length: Length of the sequence + dtype: Data type of the mask + device: Device of the mask + is_causal: Whether to apply causal masking + + Returns: + 4D attention mask of shape (batch_size, 1, seq_length, seq_length) + """ + # Get minimum value for the dtype + min_value = torch.finfo(dtype).min if dtype.is_floating_point else torch.iinfo(dtype).min + + # Handle padding mask + if attention_mask is not None: + # Convert 2D padding mask to 4D attention mask + attention_mask = attention_mask.view(attention_mask.shape[0], 1, 1, -1) + attention_mask = attention_mask * min_value + + # Create causal mask if needed + if is_causal: + causal_mask = torch.triu( + torch.ones((sequence_length, sequence_length), dtype=dtype, device=device) * min_value, + diagonal=1, + ) + causal_mask = causal_mask.view(1, 1, sequence_length, sequence_length) + + # Combine with padding mask if it exists + if attention_mask is not None: + attention_mask = torch.minimum(attention_mask, causal_mask) + else: + attention_mask = causal_mask + + return attention_mask + + @staticmethod + def _timesfm_masked_mean_std(inputs: torch.Tensor, padding: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Calculates mean and standard deviation of `inputs` across axis 1. + + It excludes values where `padding` is 1. + + Args: + inputs: A PyTorch tensor of shape [b, n, p]. + padding: A PyTorch tensor of shape [b, n, p] with values 0 or 1. + + Returns: + A tuple containing the mean and standard deviation. + We return the statistics of the first patch with more than three non-padded values. + """ + + # Selecting the first patch with more than 3 unpadded values. + def _get_patch_index(arr: torch.Tensor): + indices = torch.argmax((arr >= 3).to(torch.int32), dim=1) + row_sum = (arr >= 3).to(torch.int32).sum(dim=1) + return torch.where(row_sum == 0, arr.shape[1] - 1, indices) + + pad_sum = torch.sum(1 - padding, dim=2) + patch_indices = _get_patch_index(pad_sum) + bidxs = torch.arange(inputs.shape[0]) + + arr = inputs[bidxs, patch_indices, :] + pad = padding[bidxs, patch_indices, :] + + # Create a mask where padding is 0 + mask = 1 - pad + + # Calculate the number of valid elements + num_valid_elements = torch.sum(mask, dim=1) + num_valid_elements = torch.where( + num_valid_elements == 0, + torch.tensor(1, dtype=num_valid_elements.dtype, device=num_valid_elements.device), + num_valid_elements, + ) + + # Calculate the masked sum and squared sum + masked_sum = torch.sum(arr * mask, dim=1) + masked_squared_sum = torch.sum((arr * mask) ** 2, dim=1) + + # Calculate the masked mean and standard deviation + masked_mean = masked_sum / num_valid_elements + masked_var = masked_squared_sum / num_valid_elements - masked_mean**2 + masked_var = torch.where( + masked_var < 0.0, + torch.tensor(0.0, dtype=masked_var.dtype, device=masked_var.device), + masked_var, + ) + masked_std = torch.sqrt(masked_var) + + return masked_mean, masked_std + + @staticmethod + def _timesfm_shift_padded_seq(mask: torch.Tensor, seq: torch.Tensor) -> torch.Tensor: + """Shifts rows of seq based on the first 0 in each row of the mask. + + Args: + mask: mask tensor of shape [B, N] + seq: seq tensor of shape [B, N, P] + + Returns: + The shifted sequence. + """ + batch_size, num_seq, feature_dim = seq.shape + + new_mask: torch.BoolTensor = mask == 0 + + # Use argmax to find the first True value in each row + indices = new_mask.to(torch.int32).argmax(dim=1) + + # Handle rows with all zeros + indices[~new_mask.any(dim=1)] = -1 + + # Create index ranges for each sequence in the batch + idx_range = torch.arange(num_seq, device=seq.device).view(1, -1, 1).expand(batch_size, -1, feature_dim) + + # Calculate shifted indices for each element in each sequence + shifted_idx = (idx_range - indices[:, None, None]) % num_seq + + # Gather values from seq using shifted indices + shifted_seq = seq.gather(1, shifted_idx) + + return shifted_seq + + +class TimesFmModelForPrediction(TimesFmPreTrainedModel): + """TimesFM model for quantile and mean prediction.""" + + def __init__(self, config: TimesFmConfig): + super().__init__(config) + + self.config = config + self.context_len = config.context_length + self.horizon_len = config.horizon_length + + self.decoder = TimesFmModel(config) + + # quantile and mean output + self.horizon_ff_layer = TimesFmResidualBlock( + input_dims=config.hidden_size, + output_dims=config.horizon_length * (1 + len(config.quantiles)), + hidden_dims=config.intermediate_size, + ) + + # Initialize weights and apply final processing + self.post_init() + + def _preprocess( + self, inputs: Sequence[torch.Tensor], freq: Sequence[int] + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Formats and pads raw inputs to feed into the model. + + This function both pads each time series to match the context length, and + pads the inputs to meet the SPMD shape requirement. + + Args: + inputs: A list of 1d Tensors. Each Tensor is the context time series of + a single forecast task. + freq: list of frequencies + + Returns: + A tuple of: + - the padded input time series to meet the model required context. + - the padding indicator. + - the number of padded examples for SPMD so that each core has the same + number (a multiple of `batch_size`) of examples. + """ + input_ts, input_padding, inp_freq = [], [], [] + + for i, ts in enumerate(inputs): + input_len = ts.shape[0] + padding = torch.zeros(input_len + self.horizon_len, dtype=ts.dtype, device=ts.device) + if input_len < self.context_len: + num_front_pad = self.context_len - input_len + ts = torch.cat([torch.zeros(num_front_pad, dtype=ts.dtype, device=ts.device), ts], dim=0) + padding = torch.cat([torch.ones(num_front_pad, dtype=ts.dtype, device=padding.device), padding], dim=0) + elif input_len > self.context_len: + ts = ts[-self.context_len :] + padding = padding[-(self.context_len + self.horizon_len) :] + + input_ts.append(ts) + input_padding.append(padding) + inp_freq.append(freq[i]) + + return ( + torch.stack(input_ts, dim=0), + torch.stack(input_padding, dim=0), + torch.tensor(inp_freq, dtype=torch.int32).reshape(-1, 1), + ) + + def _postprocess_output( + self, model_output: torch.Tensor, stats: tuple[torch.Tensor, torch.Tensor] + ) -> torch.Tensor: + """Postprocess output of stacked transformer.""" + + # B x N x (H.Q) + output_ts = self.horizon_ff_layer(model_output) + + # Reshape using view + b, n, _ = output_ts.shape + output_ts = output_ts.view(b, n, self.config.horizon_length, len(self.config.quantiles) + 1) + + mu, sigma = stats + return output_ts * sigma[:, None, None, None] + mu[:, None, None, None] + + def _quantile_loss(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: + losses = [] + for i, q in enumerate(self.config.quantiles): + errors = targets - predictions[..., i] + loss = torch.max((q - 1) * errors, q * errors) + losses.append(loss.mean()) + return torch.stack(losses).mean() + + @can_return_tuple + @add_start_docstrings_to_model_forward(TIMESFM_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TimesFmOutputForPrediction, config_class=_CONFIG_FOR_DOC) + def forward( + self, + past_values: Sequence[torch.Tensor], + freq: Optional[Sequence[Union[torch.Tensor, int]]] = None, + window_size: Optional[int] = None, + future_values: Optional[torch.Tensor] = None, + forecast_context_len: Optional[int] = None, + return_forecast_on_context: bool = False, + truncate_negative: bool = False, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> TimesFmOutputForPrediction: + r""" + window_size (`int`, *optional*): + Window size of trend + residual decomposition. If None then we do not do decomposition. + future_values (`torch.Tensor`, *optional*): + Optional future time series values to be used for loss computation. + forecast_context_len (`int`, *optional*): + Optional max context length. + return_forecast_on_context (`bool`, *optional*): + True to return the forecast on the context when available, i.e. after the first input patch. + truncate_negative (`bool`, *optional*): + Truncate to only non-negative values if any of the contexts have non-negative values, + otherwise do nothing. + output_attentions (`bool`, *optional*): + Whether to output the attentions. + output_hidden_states (`bool`, *optional*): + Whether to output the hidden states. + + Returns: + + Example: + + ```python + >>> from transformers import TimesFmModelForPrediction + + >>> model = TimesFmModelForPrediction.from_pretrained("google/timesfm-2.0-500m-pytorch") + + >>> forecast_input = [torch.linspace(0, 20, 100).sin(), torch.linspace(0, 20, 200).sin(), torch.linspace(0, 20, 400).sin()] + >>> frequency_input = torch.tensor([0, 1, 2], dtype=torch.long) + + >>> # Generate + >>> with torch.no_grad(): + >>> outputs = model(past_values=forecast_input, freq=frequency_input, return_dict=True) + >>> point_forecast_conv = outputs.mean_predictions + >>> quantile_forecast_conv = outputs.full_predictions + ``` + """ + if forecast_context_len is None: + fcontext_len = self.context_len + else: + fcontext_len = forecast_context_len + + # Get device from first input tensor + device = past_values[0].device + + # Truncate inputs to forecast_context_len + inputs = [ts[-fcontext_len:] for ts in past_values] + inp_min = torch.min(torch.stack([torch.min(ts) for ts in inputs])) + + if window_size is not None: + new_inputs = [] + new_freqs = [] + for i, ts in enumerate(inputs): + new_inputs.extend(self._timesfm_moving_average(ts, window_size)) + if freq is not None: + new_freqs.extend([freq[i]] * 2) + inputs = new_inputs + if freq is not None: + freq = new_freqs + + if freq is None: + logger.info("No frequency provided via `freq`. Default to high (0).") + freq = [0] * len(inputs) + + if output_attentions is None: + output_attentions = self.config.output_attentions + if output_hidden_states is None: + output_hidden_states = self.config.output_hidden_states + + input_ts, input_padding, inp_freq = self._preprocess(inputs, freq) + # Move tensors to the same device as input + input_ts = input_ts.to(device) + input_padding = input_padding.to(device) + inp_freq = inp_freq.to(device) + + final_out = input_ts + context_len = final_out.shape[1] + full_outputs = [] + + if input_padding.shape[1] != final_out.shape[1] + self.horizon_len: + raise ValueError( + "Length of paddings must match length of input + horizon_len:" + f" {input_padding.shape[1]} != {final_out.shape[1]} + {self.horizon_len}" + ) + output_patch_len = self.config.horizon_length + + num_decode_patches = (self.horizon_len + output_patch_len - 1) // output_patch_len + for step_index in range(num_decode_patches): + current_padding = input_padding[:, 0 : final_out.shape[1]] + input_ts = final_out[:, -fcontext_len:] + input_padding = current_padding[:, -fcontext_len:] + decoder_output = self.decoder( + past_values=input_ts, + past_values_padding=input_padding, + freq=inp_freq, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + fprop_outputs = self._postprocess_output( + decoder_output.last_hidden_state, + (decoder_output.loc, decoder_output.scale), + ) + + if return_forecast_on_context and step_index == 0: + # For the first decodings step, collect the model forecast on the + # context except the unavailable first input batch forecast. + new_full_ts = fprop_outputs[:, :-1, : self.config.patch_length, :] + # We have to use reshape and not view for non-contiguous memory + new_full_ts = new_full_ts.reshape(new_full_ts.size(0), -1, new_full_ts.size(3)) + + full_outputs.append(new_full_ts) + + # (full batch, last patch, output_patch_len, index of mean forecast = 0) + new_ts = fprop_outputs[:, -1, :output_patch_len, 0] + new_full_ts = fprop_outputs[:, -1, :output_patch_len, :] + # (full batch, last patch, output_patch_len, all output indices) + full_outputs.append(new_full_ts) + final_out = torch.concatenate([final_out, new_ts], axis=-1) + + if return_forecast_on_context: + # `full_outputs` indexing starts at after the first input patch. + full_outputs = torch.concatenate(full_outputs, axis=1)[ + :, : (context_len - self.config.patch_length + self.horizon_len), : + ] + else: + # `full_outputs` indexing starts at the forecast horizon. + full_outputs = torch.concatenate(full_outputs, axis=1)[:, 0 : self.horizon_len, :] + + mean_outputs = full_outputs[:, :, 0] + if window_size is not None: + mean_outputs = mean_outputs[0::2, ...] + mean_outputs[1::2, ...] + full_outputs = full_outputs[0::2, ...] + full_outputs[1::2, ...] + if inp_min >= 0 and truncate_negative: + mean_outputs = torch.maximum(mean_outputs, 0.0) + full_outputs = torch.maximum(full_outputs, 0.0) + + loss = None + if future_values is not None: + mse_loss = F.mse_loss(mean_outputs, future_values) + quantile_loss = self._quantile_loss(full_outputs[:, :, 1:], future_values) + loss = mse_loss + quantile_loss + + return TimesFmOutputForPrediction( + last_hidden_state=decoder_output.last_hidden_state, + attentions=decoder_output.attentions if output_attentions else None, + hidden_states=decoder_output.hidden_states if output_hidden_states else None, + mean_predictions=mean_outputs, + full_predictions=full_outputs, + loss=loss, + ) + + @staticmethod + def _timesfm_moving_average(arr: torch.Tensor, window_size: int) -> list[torch.Tensor]: + """Calculates the moving average using PyTorch's convolution function.""" + # Pad with zeros to handle initial window positions + arr_padded = F.pad(arr, (window_size - 1, 0), "constant", 0) + # Create a convolution kernel + kernel = torch.ones(window_size, dtype=arr.dtype, device=arr.device) / window_size + # Apply convolution to calculate the moving average + smoothed_arr = F.conv1d(arr_padded.view(1, 1, -1), kernel.view(1, 1, -1)).squeeze() + return [smoothed_arr, arr - smoothed_arr] + + +__all__ = ["TimesFmModelForPrediction", "TimesFmPreTrainedModel", "TimesFmModel"] diff --git a/docs/transformers/src/transformers/models/timesfm/modular_timesfm.py b/docs/transformers/src/transformers/models/timesfm/modular_timesfm.py new file mode 100644 index 0000000000000000000000000000000000000000..e285a4a19e2006373f84399c1805057876c6b5dc --- /dev/null +++ b/docs/transformers/src/transformers/models/timesfm/modular_timesfm.py @@ -0,0 +1,866 @@ +# coding=utf-8 +# Copyright 2025 Google LLC and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch TimesFM model.""" + +import math +from dataclasses import dataclass +from typing import Callable, Optional, Sequence, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import BaseModelOutput +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + can_return_tuple, + logging, + replace_return_docstrings, +) +from ..llama.modeling_llama import LlamaRMSNorm +from ..phi4_multimodal.modeling_phi4_multimodal import simple_eager_attention_forward +from .configuration_timesfm import TimesFmConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "google/timesfm-2.0-500m-pytorch" +_CONFIG_FOR_DOC = "TimesFmConfig" + + +@dataclass +class TimesFmOutput(BaseModelOutput): + """ + Args: + loc (`torch.Tensor` of shape `(batch_size, )`): + The mean of the time series inputs. + scale (`torch.Tensor` of shape `(batch_size,)`): + The scale of the time series inputs. + """ + + loc: Optional[torch.Tensor] = None + scale: Optional[torch.Tensor] = None + + +@dataclass +class TimesFmOutputForPrediction(BaseModelOutput): + """ + Args: + mean_predictions (`torch.Tensor` of shape `(batch_size, sequence_length)`): + The mean predictions of the time series. + full_predictions (`torch.Tensor` of shape `(batch_size, sequence_length)`): + The full predictions of the time series including the mean and the quantiles. + loss (`torch.Tensor` of shape `(1,)`, *optional*, returned when `future_values` is provided): + The loss of the TimesFM model. + """ + + mean_predictions: Optional[torch.Tensor] = None + full_predictions: Optional[torch.Tensor] = None + loss: Optional[Union[torch.Tensor, float]] = None + + +class TimesFmMLP(nn.Module): + """Pax MLP in pytorch.""" + + def __init__(self, config: TimesFmConfig): + super().__init__() + hidden_size = config.hidden_size + intermediate_size = config.intermediate_size + + self.gate_proj = nn.Linear(hidden_size, intermediate_size) + self.down_proj = nn.Linear(intermediate_size, hidden_size) + self.layer_norm = nn.LayerNorm(normalized_shape=hidden_size, eps=1e-6) + + def forward(self, x, paddings=None): + gate_inp = self.layer_norm(x) + gate = self.gate_proj(gate_inp) + gate = F.relu(gate) + outputs = self.down_proj(gate) + if paddings is not None: + outputs = outputs * (1.0 - paddings[:, :, None]) + return outputs + x + + +class TimesFmResidualBlock(nn.Module): + """TimesFM residual block.""" + + def __init__(self, input_dims, hidden_dims, output_dims): + super().__init__() + self.input_dims = input_dims + self.hidden_dims = hidden_dims + self.output_dims = output_dims + + self.input_layer = nn.Linear(input_dims, hidden_dims) + self.activation = nn.SiLU() + self.output_layer = nn.Linear(hidden_dims, output_dims) + self.residual_layer = nn.Linear(input_dims, output_dims) + + def forward(self, x): + hidden = self.input_layer(x) + hidden = self.activation(hidden) + output = self.output_layer(hidden) + residual = self.residual_layer(x) + return output + residual + + +class TimesFmRMSNorm(LlamaRMSNorm): + pass + + +class TimesFmPositionalEmbedding(nn.Module): + """Generates position embedding for a given 1-d sequence.""" + + def __init__(self, config: TimesFmConfig): + super().__init__() + min_timescale = config.min_timescale + max_timescale = config.max_timescale + self.embedding_dims = config.hidden_size + + num_timescales = self.embedding_dims // 2 + log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / max(num_timescales - 1, 1) + self.register_buffer( + "inv_timescales", + min_timescale * torch.exp(torch.arange(num_timescales, dtype=torch.float32) * -log_timescale_increment), + ) + + def forward(self, seq_length=None, position=None): + """Generates a Tensor of sinusoids with different frequencies. + + Args: + seq_length: an optional Python int defining the output sequence length. + if the `position` argument is specified. + position: [B, seq_length], optional position for each token in the + sequence, only required when the sequence is packed. + + Returns: + [B, seqlen, D] if `position` is specified, else [1, seqlen, D] + """ + if position is None and seq_length is None: + raise ValueError("Either position or seq_length must be provided") + + if position is None: + # [1, seqlen] + position = torch.arange(seq_length, dtype=torch.float32, device=self.inv_timescales.device).unsqueeze(0) + elif position.ndim != 2: + raise ValueError(f"position must be 2-dimensional, got shape {position.shape}") + + scaled_time = position.view(*position.shape, 1) * self.inv_timescales.view(1, 1, -1) + signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2) + + # Padding to ensure correct embedding dimension + signal = F.pad(signal, (0, 0, 0, self.embedding_dims % 2)) + return signal + + +class TimesFmAttention(nn.Module): + """Implements the attention used in TimesFM. One key difference is that there is _per_dim_scaling of the query.""" + + def __init__(self, config: TimesFmConfig, layer_idx: int): + super().__init__() + self.config = config + self.is_causal = True + self.attention_dropout = config.attention_dropout + self.layer_idx = layer_idx + + self.num_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.head_dim = config.head_dim + + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_heads * self.head_dim + self.scaling = nn.Parameter(torch.empty((self.head_dim,))) + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim) + self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim) + self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size) + + def _scale_query(self, query: torch.Tensor) -> torch.Tensor: + scale = F.softplus(self.scaling).mul(1.442695041 / math.sqrt(self.head_dim)) + return query * scale[None, None, None, :] + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + query_states = self._scale_query(query_states) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + attention_interface: Callable = simple_eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=1.0, + **kwargs, + ) + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class TimesFmDecoderLayer(nn.Module): + """Transformer layer.""" + + def __init__(self, config: TimesFmConfig, layer_idx: int): + super().__init__() + + self.self_attn = TimesFmAttention(config, layer_idx=layer_idx) + self.mlp = TimesFmMLP(config) + self.input_layernorm = TimesFmRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + paddings: torch.Tensor, + output_attentions: bool = False, + ) -> tuple[Optional[torch.Tensor], torch.Tensor]: + # Self Attention + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states, scores = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + # MLP + hidden_states = self.mlp(hidden_states, paddings=paddings) + + return scores, hidden_states + + +TIMESFM_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`TimesFmConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare TimesFM Model outputting raw hidden-states without any specific head on top.", + TIMESFM_START_DOCSTRING, +) +class TimesFmPreTrainedModel(PreTrainedModel): + """handles the loading for all models.""" + + config_class = TimesFmConfig + base_model_prefix = "timesfm" + _no_split_modules = ["TimesFmDecoderLayer"] + main_input_name = "past_values" + _supports_sdpa = True + + def _init_weights(self, module): + if isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0, std=self.config.initializer_range) + + elif isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0, std=self.config.initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + + elif isinstance(module, nn.LayerNorm): + nn.init.ones_(module.weight) + nn.init.zeros_(module.bias) + + elif isinstance(module, TimesFmRMSNorm): + nn.init.zeros_(module.weight) + + elif isinstance(module, TimesFmAttention): + # Initialize scaling parameter + nn.init.ones_(module.scaling) + + +TIMESFM_INPUTS_DOCSTRING = r""" + Args: + past_values: list of time series forecast contexts. Each context time series + can be a torch Tensor of potentially different context lengths. + freq: frequency of each context time series in the inputs. 0 for high frequency + (default), 1 for medium, and 2 for low. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. +""" + + +@add_start_docstrings( + "The bare TimesFM Model outputting raw hidden-states without any specific head on top.", + TIMESFM_START_DOCSTRING, +) +class TimesFmModel(TimesFmPreTrainedModel): + """Patched time-series decoder without any specific output layer.""" + + def __init__(self, config: TimesFmConfig): + super().__init__(config) + + self.config = config + self.input_ff_layer = TimesFmResidualBlock( + input_dims=2 * config.patch_length, + output_dims=config.hidden_size, + hidden_dims=config.intermediate_size, + ) + self.freq_emb = nn.Embedding(num_embeddings=config.freq_size, embedding_dim=config.hidden_size) + self.layers = nn.ModuleList( + [TimesFmDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + if self.config.use_positional_embedding: + self.position_emb = TimesFmPositionalEmbedding(config=config) + + # Initialize weights and apply final processing + self.post_init() + + def _forward_transform( + self, inputs: torch.Tensor, patched_pads: torch.Tensor + ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + """Input is of shape [B, N, P].""" + mu, sigma = self._timesfm_masked_mean_std(inputs, patched_pads) + sigma = torch.where( + sigma < self.config.tolerance, + torch.tensor(1.0, dtype=sigma.dtype, device=sigma.device), + sigma, + ) + + # Normalize each patch + outputs = (inputs - mu[:, None, None]) / sigma[:, None, None] + outputs = torch.where( + torch.abs(inputs - self.config.pad_val) < self.config.tolerance, + torch.tensor(self.config.pad_val, dtype=outputs.dtype, device=outputs.device), + outputs, + ) + return outputs, (mu, sigma) + + @can_return_tuple + @add_start_docstrings_to_model_forward(TIMESFM_INPUTS_DOCSTRING) + def forward( + self, + past_values: torch.Tensor, + past_values_padding: torch.LongTensor, + freq: torch.Tensor, + output_attentions: bool = False, + output_hidden_states: bool = False, + ) -> TimesFmOutput: + """ + past_values_padding (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The padding indicator of the time series. + """ + # Reshape into patches (using view for efficiency) + bsize = past_values.shape[0] + patched_inputs = past_values.view(bsize, -1, self.config.patch_length) + patched_pads = past_values_padding.view(bsize, -1, self.config.patch_length) + + patched_inputs = torch.where( + torch.abs(patched_pads - 1.0) < self.config.tolerance, + torch.tensor(0.0, dtype=patched_inputs.dtype, device=patched_inputs.device), + patched_inputs, + ) + patched_pads = torch.where( + torch.abs(patched_inputs - self.config.pad_val) < self.config.tolerance, + torch.tensor(1.0, dtype=patched_pads.dtype, device=patched_pads.device), + patched_pads, + ) + patched_inputs, stats = self._forward_transform(patched_inputs, patched_pads) + + # B x N x D + patched_inputs = patched_inputs * (1.0 - patched_pads) + concat_inputs = torch.cat([patched_inputs, patched_pads], dim=-1) + model_input = self.input_ff_layer(concat_inputs) + + # A patch should not be padded even if there is at least one zero. + patched_padding = torch.min(patched_pads, dim=-1)[0] # Get the values from the min result + if self.config.use_positional_embedding: + pos_emb = self.position_emb(model_input.shape[1]) + pos_emb = torch.concat([pos_emb] * model_input.shape[0], dim=0) + pos_emb = self._timesfm_shift_padded_seq(patched_padding, pos_emb) + model_input += pos_emb + + f_emb = self.freq_emb(freq) # B x 1 x D + model_input += f_emb + + # Convert paddings to attention mask and combine with causal mask + hidden_states = model_input + attention_mask = self._prepare_4d_attention_mask( + attention_mask=patched_padding, + sequence_length=hidden_states.shape[1], + dtype=hidden_states.dtype, + device=hidden_states.device, + is_causal=True, + ) + + all_attentions = [] + all_hidden_states = [] + + for layer in self.layers[: self.config.num_hidden_layers]: + scores, hidden_states = layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + paddings=patched_padding, + output_attentions=output_attentions, + ) + if output_attentions: + all_attentions.append(scores) + if output_hidden_states: + all_hidden_states.append(hidden_states) + + if output_hidden_states: + all_hidden_states = [model_input] + all_hidden_states + else: + all_hidden_states = None + + return TimesFmOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_attentions if output_attentions else None, + loc=stats[0], + scale=stats[1], + ) + + @staticmethod + def _prepare_4d_attention_mask( + attention_mask: Optional[torch.Tensor], + sequence_length: int, + dtype: torch.dtype, + device: torch.device, + is_causal: bool = True, + ) -> Optional[torch.Tensor]: + """ + Creates 4D attention mask and combines causal and padding masks if needed. + + Args: + attention_mask: Optional tensor of shape (batch_size, seq_length) containing padding mask + sequence_length: Length of the sequence + dtype: Data type of the mask + device: Device of the mask + is_causal: Whether to apply causal masking + + Returns: + 4D attention mask of shape (batch_size, 1, seq_length, seq_length) + """ + # Get minimum value for the dtype + min_value = torch.finfo(dtype).min if dtype.is_floating_point else torch.iinfo(dtype).min + + # Handle padding mask + if attention_mask is not None: + # Convert 2D padding mask to 4D attention mask + attention_mask = attention_mask.view(attention_mask.shape[0], 1, 1, -1) + attention_mask = attention_mask * min_value + + # Create causal mask if needed + if is_causal: + causal_mask = torch.triu( + torch.ones((sequence_length, sequence_length), dtype=dtype, device=device) * min_value, + diagonal=1, + ) + causal_mask = causal_mask.view(1, 1, sequence_length, sequence_length) + + # Combine with padding mask if it exists + if attention_mask is not None: + attention_mask = torch.minimum(attention_mask, causal_mask) + else: + attention_mask = causal_mask + + return attention_mask + + @staticmethod + def _timesfm_masked_mean_std(inputs: torch.Tensor, padding: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Calculates mean and standard deviation of `inputs` across axis 1. + + It excludes values where `padding` is 1. + + Args: + inputs: A PyTorch tensor of shape [b, n, p]. + padding: A PyTorch tensor of shape [b, n, p] with values 0 or 1. + + Returns: + A tuple containing the mean and standard deviation. + We return the statistics of the first patch with more than three non-padded values. + """ + + # Selecting the first patch with more than 3 unpadded values. + def _get_patch_index(arr: torch.Tensor): + indices = torch.argmax((arr >= 3).to(torch.int32), dim=1) + row_sum = (arr >= 3).to(torch.int32).sum(dim=1) + return torch.where(row_sum == 0, arr.shape[1] - 1, indices) + + pad_sum = torch.sum(1 - padding, dim=2) + patch_indices = _get_patch_index(pad_sum) + bidxs = torch.arange(inputs.shape[0]) + + arr = inputs[bidxs, patch_indices, :] + pad = padding[bidxs, patch_indices, :] + + # Create a mask where padding is 0 + mask = 1 - pad + + # Calculate the number of valid elements + num_valid_elements = torch.sum(mask, dim=1) + num_valid_elements = torch.where( + num_valid_elements == 0, + torch.tensor(1, dtype=num_valid_elements.dtype, device=num_valid_elements.device), + num_valid_elements, + ) + + # Calculate the masked sum and squared sum + masked_sum = torch.sum(arr * mask, dim=1) + masked_squared_sum = torch.sum((arr * mask) ** 2, dim=1) + + # Calculate the masked mean and standard deviation + masked_mean = masked_sum / num_valid_elements + masked_var = masked_squared_sum / num_valid_elements - masked_mean**2 + masked_var = torch.where( + masked_var < 0.0, + torch.tensor(0.0, dtype=masked_var.dtype, device=masked_var.device), + masked_var, + ) + masked_std = torch.sqrt(masked_var) + + return masked_mean, masked_std + + @staticmethod + def _timesfm_shift_padded_seq(mask: torch.Tensor, seq: torch.Tensor) -> torch.Tensor: + """Shifts rows of seq based on the first 0 in each row of the mask. + + Args: + mask: mask tensor of shape [B, N] + seq: seq tensor of shape [B, N, P] + + Returns: + The shifted sequence. + """ + batch_size, num_seq, feature_dim = seq.shape + + new_mask: torch.BoolTensor = mask == 0 + + # Use argmax to find the first True value in each row + indices = new_mask.to(torch.int32).argmax(dim=1) + + # Handle rows with all zeros + indices[~new_mask.any(dim=1)] = -1 + + # Create index ranges for each sequence in the batch + idx_range = torch.arange(num_seq, device=seq.device).view(1, -1, 1).expand(batch_size, -1, feature_dim) + + # Calculate shifted indices for each element in each sequence + shifted_idx = (idx_range - indices[:, None, None]) % num_seq + + # Gather values from seq using shifted indices + shifted_seq = seq.gather(1, shifted_idx) + + return shifted_seq + + +class TimesFmModelForPrediction(TimesFmPreTrainedModel): + """TimesFM model for quantile and mean prediction.""" + + def __init__(self, config: TimesFmConfig): + super().__init__(config) + + self.config = config + self.context_len = config.context_length + self.horizon_len = config.horizon_length + + self.decoder = TimesFmModel(config) + + # quantile and mean output + self.horizon_ff_layer = TimesFmResidualBlock( + input_dims=config.hidden_size, + output_dims=config.horizon_length * (1 + len(config.quantiles)), + hidden_dims=config.intermediate_size, + ) + + # Initialize weights and apply final processing + self.post_init() + + def _preprocess( + self, inputs: Sequence[torch.Tensor], freq: Sequence[int] + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Formats and pads raw inputs to feed into the model. + + This function both pads each time series to match the context length, and + pads the inputs to meet the SPMD shape requirement. + + Args: + inputs: A list of 1d Tensors. Each Tensor is the context time series of + a single forecast task. + freq: list of frequencies + + Returns: + A tuple of: + - the padded input time series to meet the model required context. + - the padding indicator. + - the number of padded examples for SPMD so that each core has the same + number (a multiple of `batch_size`) of examples. + """ + input_ts, input_padding, inp_freq = [], [], [] + + for i, ts in enumerate(inputs): + input_len = ts.shape[0] + padding = torch.zeros(input_len + self.horizon_len, dtype=ts.dtype, device=ts.device) + if input_len < self.context_len: + num_front_pad = self.context_len - input_len + ts = torch.cat([torch.zeros(num_front_pad, dtype=ts.dtype, device=ts.device), ts], dim=0) + padding = torch.cat([torch.ones(num_front_pad, dtype=ts.dtype, device=padding.device), padding], dim=0) + elif input_len > self.context_len: + ts = ts[-self.context_len :] + padding = padding[-(self.context_len + self.horizon_len) :] + + input_ts.append(ts) + input_padding.append(padding) + inp_freq.append(freq[i]) + + return ( + torch.stack(input_ts, dim=0), + torch.stack(input_padding, dim=0), + torch.tensor(inp_freq, dtype=torch.int32).reshape(-1, 1), + ) + + def _postprocess_output( + self, model_output: torch.Tensor, stats: tuple[torch.Tensor, torch.Tensor] + ) -> torch.Tensor: + """Postprocess output of stacked transformer.""" + + # B x N x (H.Q) + output_ts = self.horizon_ff_layer(model_output) + + # Reshape using view + b, n, _ = output_ts.shape + output_ts = output_ts.view(b, n, self.config.horizon_length, len(self.config.quantiles) + 1) + + mu, sigma = stats + return output_ts * sigma[:, None, None, None] + mu[:, None, None, None] + + def _quantile_loss(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: + losses = [] + for i, q in enumerate(self.config.quantiles): + errors = targets - predictions[..., i] + loss = torch.max((q - 1) * errors, q * errors) + losses.append(loss.mean()) + return torch.stack(losses).mean() + + @can_return_tuple + @add_start_docstrings_to_model_forward(TIMESFM_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TimesFmOutputForPrediction, config_class=_CONFIG_FOR_DOC) + def forward( + self, + past_values: Sequence[torch.Tensor], + freq: Optional[Sequence[Union[torch.Tensor, int]]] = None, + window_size: Optional[int] = None, + future_values: Optional[torch.Tensor] = None, + forecast_context_len: Optional[int] = None, + return_forecast_on_context: bool = False, + truncate_negative: bool = False, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> TimesFmOutputForPrediction: + r""" + window_size (`int`, *optional*): + Window size of trend + residual decomposition. If None then we do not do decomposition. + future_values (`torch.Tensor`, *optional*): + Optional future time series values to be used for loss computation. + forecast_context_len (`int`, *optional*): + Optional max context length. + return_forecast_on_context (`bool`, *optional*): + True to return the forecast on the context when available, i.e. after the first input patch. + truncate_negative (`bool`, *optional*): + Truncate to only non-negative values if any of the contexts have non-negative values, + otherwise do nothing. + output_attentions (`bool`, *optional*): + Whether to output the attentions. + output_hidden_states (`bool`, *optional*): + Whether to output the hidden states. + + Returns: + + Example: + + ```python + >>> from transformers import TimesFmModelForPrediction + + >>> model = TimesFmModelForPrediction.from_pretrained("google/timesfm-2.0-500m-pytorch") + + >>> forecast_input = [torch.linspace(0, 20, 100).sin(), torch.linspace(0, 20, 200).sin(), torch.linspace(0, 20, 400).sin()] + >>> frequency_input = torch.tensor([0, 1, 2], dtype=torch.long) + + >>> # Generate + >>> with torch.no_grad(): + >>> outputs = model(past_values=forecast_input, freq=frequency_input, return_dict=True) + >>> point_forecast_conv = outputs.mean_predictions + >>> quantile_forecast_conv = outputs.full_predictions + ``` + """ + if forecast_context_len is None: + fcontext_len = self.context_len + else: + fcontext_len = forecast_context_len + + # Get device from first input tensor + device = past_values[0].device + + # Truncate inputs to forecast_context_len + inputs = [ts[-fcontext_len:] for ts in past_values] + inp_min = torch.min(torch.stack([torch.min(ts) for ts in inputs])) + + if window_size is not None: + new_inputs = [] + new_freqs = [] + for i, ts in enumerate(inputs): + new_inputs.extend(self._timesfm_moving_average(ts, window_size)) + if freq is not None: + new_freqs.extend([freq[i]] * 2) + inputs = new_inputs + if freq is not None: + freq = new_freqs + + if freq is None: + logger.info("No frequency provided via `freq`. Default to high (0).") + freq = [0] * len(inputs) + + if output_attentions is None: + output_attentions = self.config.output_attentions + if output_hidden_states is None: + output_hidden_states = self.config.output_hidden_states + + input_ts, input_padding, inp_freq = self._preprocess(inputs, freq) + # Move tensors to the same device as input + input_ts = input_ts.to(device) + input_padding = input_padding.to(device) + inp_freq = inp_freq.to(device) + + final_out = input_ts + context_len = final_out.shape[1] + full_outputs = [] + + if input_padding.shape[1] != final_out.shape[1] + self.horizon_len: + raise ValueError( + "Length of paddings must match length of input + horizon_len:" + f" {input_padding.shape[1]} != {final_out.shape[1]} + {self.horizon_len}" + ) + output_patch_len = self.config.horizon_length + + num_decode_patches = (self.horizon_len + output_patch_len - 1) // output_patch_len + for step_index in range(num_decode_patches): + current_padding = input_padding[:, 0 : final_out.shape[1]] + input_ts = final_out[:, -fcontext_len:] + input_padding = current_padding[:, -fcontext_len:] + decoder_output = self.decoder( + past_values=input_ts, + past_values_padding=input_padding, + freq=inp_freq, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + fprop_outputs = self._postprocess_output( + decoder_output.last_hidden_state, + (decoder_output.loc, decoder_output.scale), + ) + + if return_forecast_on_context and step_index == 0: + # For the first decodings step, collect the model forecast on the + # context except the unavailable first input batch forecast. + new_full_ts = fprop_outputs[:, :-1, : self.config.patch_length, :] + # We have to use reshape and not view for non-contiguous memory + new_full_ts = new_full_ts.reshape(new_full_ts.size(0), -1, new_full_ts.size(3)) + + full_outputs.append(new_full_ts) + + # (full batch, last patch, output_patch_len, index of mean forecast = 0) + new_ts = fprop_outputs[:, -1, :output_patch_len, 0] + new_full_ts = fprop_outputs[:, -1, :output_patch_len, :] + # (full batch, last patch, output_patch_len, all output indices) + full_outputs.append(new_full_ts) + final_out = torch.concatenate([final_out, new_ts], axis=-1) + + if return_forecast_on_context: + # `full_outputs` indexing starts at after the first input patch. + full_outputs = torch.concatenate(full_outputs, axis=1)[ + :, : (context_len - self.config.patch_length + self.horizon_len), : + ] + else: + # `full_outputs` indexing starts at the forecast horizon. + full_outputs = torch.concatenate(full_outputs, axis=1)[:, 0 : self.horizon_len, :] + + mean_outputs = full_outputs[:, :, 0] + if window_size is not None: + mean_outputs = mean_outputs[0::2, ...] + mean_outputs[1::2, ...] + full_outputs = full_outputs[0::2, ...] + full_outputs[1::2, ...] + if inp_min >= 0 and truncate_negative: + mean_outputs = torch.maximum(mean_outputs, 0.0) + full_outputs = torch.maximum(full_outputs, 0.0) + + loss = None + if future_values is not None: + mse_loss = F.mse_loss(mean_outputs, future_values) + quantile_loss = self._quantile_loss(full_outputs[:, :, 1:], future_values) + loss = mse_loss + quantile_loss + + return TimesFmOutputForPrediction( + last_hidden_state=decoder_output.last_hidden_state, + attentions=decoder_output.attentions if output_attentions else None, + hidden_states=decoder_output.hidden_states if output_hidden_states else None, + mean_predictions=mean_outputs, + full_predictions=full_outputs, + loss=loss, + ) + + @staticmethod + def _timesfm_moving_average(arr: torch.Tensor, window_size: int) -> list[torch.Tensor]: + """Calculates the moving average using PyTorch's convolution function.""" + # Pad with zeros to handle initial window positions + arr_padded = F.pad(arr, (window_size - 1, 0), "constant", 0) + # Create a convolution kernel + kernel = torch.ones(window_size, dtype=arr.dtype, device=arr.device) / window_size + # Apply convolution to calculate the moving average + smoothed_arr = F.conv1d(arr_padded.view(1, 1, -1), kernel.view(1, 1, -1)).squeeze() + return [smoothed_arr, arr - smoothed_arr] + + +__all__ = ["TimesFmModelForPrediction", "TimesFmPreTrainedModel", "TimesFmModel"] diff --git a/docs/transformers/src/transformers/models/timesformer/__init__.py b/docs/transformers/src/transformers/models/timesformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c0cc9d3d6a68569b0e0e94c94573d7191b153a68 --- /dev/null +++ b/docs/transformers/src/transformers/models/timesformer/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_timesformer import * + from .modeling_timesformer import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/timesformer/configuration_timesformer.py b/docs/transformers/src/transformers/models/timesformer/configuration_timesformer.py new file mode 100644 index 0000000000000000000000000000000000000000..edb69af230f901c1d8c237c24354a8fde9c5c909 --- /dev/null +++ b/docs/transformers/src/transformers/models/timesformer/configuration_timesformer.py @@ -0,0 +1,129 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TimeSformer model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class TimesformerConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`TimesformerModel`]. It is used to instantiate a + TimeSformer model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the TimeSformer + [facebook/timesformer-base-finetuned-k600](https://huggingface.co/facebook/timesformer-base-finetuned-k600) + architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch. + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + num_frames (`int`, *optional*, defaults to 8): + The number of frames in each video. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + qkv_bias (`bool`, *optional*, defaults to `True`): + Whether to add a bias to the queries, keys and values. + attention_type (`str`, *optional*, defaults to `"divided_space_time"`): + The attention type to use. Must be one of `"divided_space_time"`, `"space_only"`, `"joint_space_time"`. + drop_path_rate (`float`, *optional*, defaults to 0): + The dropout ratio for stochastic depth. + + Example: + + ```python + >>> from transformers import TimesformerConfig, TimesformerModel + + >>> # Initializing a TimeSformer timesformer-base style configuration + >>> configuration = TimesformerConfig() + + >>> # Initializing a model from the configuration + >>> model = TimesformerModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "timesformer" + + def __init__( + self, + image_size=224, + patch_size=16, + num_channels=3, + num_frames=8, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + initializer_range=0.02, + layer_norm_eps=1e-6, + qkv_bias=True, + attention_type="divided_space_time", + drop_path_rate=0, + **kwargs, + ): + super().__init__(**kwargs) + + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_frames = num_frames + + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.qkv_bias = qkv_bias + + self.attention_type = attention_type + self.drop_path_rate = drop_path_rate + + +__all__ = ["TimesformerConfig"] diff --git a/docs/transformers/src/transformers/models/timesformer/convert_timesformer_to_pytorch.py b/docs/transformers/src/transformers/models/timesformer/convert_timesformer_to_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..cda9b0c182726781f25eb7d983dd36c93517e98a --- /dev/null +++ b/docs/transformers/src/transformers/models/timesformer/convert_timesformer_to_pytorch.py @@ -0,0 +1,253 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert TimeSformer checkpoints from the original repository: https://github.com/MCG-NJU/TimeSformer""" + +import argparse +import json + +import gdown +import numpy as np +import torch +from huggingface_hub import hf_hub_download + +from transformers import TimesformerConfig, TimesformerForVideoClassification, VideoMAEImageProcessor + + +def get_timesformer_config(model_name): + config = TimesformerConfig() + + if "large" in model_name: + config.num_frames = 96 + + if "hr" in model_name: + config.num_frames = 16 + config.image_size = 448 + + repo_id = "huggingface/label-files" + if "k400" in model_name: + config.num_labels = 400 + filename = "kinetics400-id2label.json" + elif "k600" in model_name: + config.num_labels = 600 + filename = "kinetics600-id2label.json" + elif "ssv2" in model_name: + config.num_labels = 174 + filename = "something-something-v2-id2label.json" + else: + raise ValueError("Model name should either contain 'k400', 'k600' or 'ssv2'.") + id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r")) + id2label = {int(k): v for k, v in id2label.items()} + config.id2label = id2label + config.label2id = {v: k for k, v in id2label.items()} + + return config + + +def rename_key(name): + if "encoder." in name: + name = name.replace("encoder.", "") + if "cls_token" in name: + name = name.replace("cls_token", "timesformer.embeddings.cls_token") + if "pos_embed" in name: + name = name.replace("pos_embed", "timesformer.embeddings.position_embeddings") + if "time_embed" in name: + name = name.replace("time_embed", "timesformer.embeddings.time_embeddings") + if "patch_embed.proj" in name: + name = name.replace("patch_embed.proj", "timesformer.embeddings.patch_embeddings.projection") + if "patch_embed.norm" in name: + name = name.replace("patch_embed.norm", "timesformer.embeddings.norm") + if "blocks" in name: + name = name.replace("blocks", "timesformer.encoder.layer") + if "attn.proj" in name: + name = name.replace("attn.proj", "attention.output.dense") + if "attn" in name and "bias" not in name and "temporal" not in name: + name = name.replace("attn", "attention.self") + if "attn" in name and "temporal" not in name: + name = name.replace("attn", "attention.attention") + if "temporal_norm1" in name: + name = name.replace("temporal_norm1", "temporal_layernorm") + if "temporal_attn.proj" in name: + name = name.replace("temporal_attn", "temporal_attention.output.dense") + if "temporal_fc" in name: + name = name.replace("temporal_fc", "temporal_dense") + if "norm1" in name and "temporal" not in name: + name = name.replace("norm1", "layernorm_before") + if "norm2" in name: + name = name.replace("norm2", "layernorm_after") + if "mlp.fc1" in name: + name = name.replace("mlp.fc1", "intermediate.dense") + if "mlp.fc2" in name: + name = name.replace("mlp.fc2", "output.dense") + if "norm.weight" in name and "fc" not in name and "temporal" not in name: + name = name.replace("norm.weight", "timesformer.layernorm.weight") + if "norm.bias" in name and "fc" not in name and "temporal" not in name: + name = name.replace("norm.bias", "timesformer.layernorm.bias") + if "head" in name: + name = name.replace("head", "classifier") + + return name + + +def convert_state_dict(orig_state_dict, config): + for key in orig_state_dict.copy().keys(): + val = orig_state_dict.pop(key) + + if key.startswith("model."): + key = key.replace("model.", "") + + if "qkv" in key: + key_split = key.split(".") + layer_num = int(key_split[1]) + prefix = "timesformer.encoder.layer." + if "temporal" in key: + postfix = ".temporal_attention.attention.qkv." + else: + postfix = ".attention.attention.qkv." + if "weight" in key: + orig_state_dict[f"{prefix}{layer_num}{postfix}weight"] = val + else: + orig_state_dict[f"{prefix}{layer_num}{postfix}bias"] = val + else: + orig_state_dict[rename_key(key)] = val + + return orig_state_dict + + +# We will verify our results on a video of eating spaghetti +# Frame indices used: [164 168 172 176 181 185 189 193 198 202 206 210 215 219 223 227] +def prepare_video(): + file = hf_hub_download( + repo_id="hf-internal-testing/spaghetti-video", filename="eating_spaghetti.npy", repo_type="dataset" + ) + video = np.load(file) + return list(video) + + +def convert_timesformer_checkpoint(checkpoint_url, pytorch_dump_folder_path, model_name, push_to_hub): + config = get_timesformer_config(model_name) + + model = TimesformerForVideoClassification(config) + + # download original checkpoint, hosted on Google Drive + output = "pytorch_model.bin" + gdown.cached_download(checkpoint_url, output, quiet=False) + files = torch.load(output, map_location="cpu", weights_only=True) + if "model" in files: + state_dict = files["model"] + elif "module" in files: + state_dict = files["module"] + else: + state_dict = files["model_state"] + new_state_dict = convert_state_dict(state_dict, config) + + model.load_state_dict(new_state_dict) + model.eval() + + # verify model on basic input + image_processor = VideoMAEImageProcessor(image_mean=[0.5, 0.5, 0.5], image_std=[0.5, 0.5, 0.5]) + video = prepare_video() + inputs = image_processor(video[:8], return_tensors="pt") + + outputs = model(**inputs) + logits = outputs.logits + + model_names = [ + # Kinetics-400 checkpoints (hr = high resolution input of 448px instead of 224px) + "timesformer-base-finetuned-k400", + "timesformer-large-finetuned-k400", + "timesformer-hr-finetuned-k400", + # Kinetics-600 checkpoints (hr = high resolution input of 448px instead of 224px) + "timesformer-base-finetuned-k600", + "timesformer-large-finetuned-k600", + "timesformer-hr-finetuned-k600", + # Something-Something-v2 checkpoints (hr = high resolution input of 448px instead of 224px) + "timesformer-base-finetuned-ssv2", + "timesformer-large-finetuned-ssv2", + "timesformer-hr-finetuned-ssv2", + ] + + # NOTE: logits were tested with image_mean and image_std equal to [0.5, 0.5, 0.5] and [0.5, 0.5, 0.5] + if model_name == "timesformer-base-finetuned-k400": + expected_shape = torch.Size([1, 400]) + expected_slice = torch.tensor([-0.3016, -0.7713, -0.4205]) + elif model_name == "timesformer-base-finetuned-k600": + expected_shape = torch.Size([1, 600]) + expected_slice = torch.tensor([-0.7267, -0.7466, 3.2404]) + elif model_name == "timesformer-base-finetuned-ssv2": + expected_shape = torch.Size([1, 174]) + expected_slice = torch.tensor([-0.9059, 0.6433, -3.1457]) + elif model_name == "timesformer-large-finetuned-k400": + expected_shape = torch.Size([1, 400]) + expected_slice = torch.tensor([0, 0, 0]) + elif model_name == "timesformer-large-finetuned-k600": + expected_shape = torch.Size([1, 600]) + expected_slice = torch.tensor([0, 0, 0]) + elif model_name == "timesformer-large-finetuned-ssv2": + expected_shape = torch.Size([1, 174]) + expected_slice = torch.tensor([0, 0, 0]) + elif model_name == "timesformer-hr-finetuned-k400": + expected_shape = torch.Size([1, 400]) + expected_slice = torch.tensor([-0.9617, -3.7311, -3.7708]) + elif model_name == "timesformer-hr-finetuned-k600": + expected_shape = torch.Size([1, 600]) + expected_slice = torch.tensor([2.5273, 0.7127, 1.8848]) + elif model_name == "timesformer-hr-finetuned-ssv2": + expected_shape = torch.Size([1, 174]) + expected_slice = torch.tensor([-3.6756, -0.7513, 0.7180]) + else: + raise ValueError(f"Model name not supported. Should be one of {model_names}") + + # verify logits + assert logits.shape == expected_shape + assert torch.allclose(logits[0, :3], expected_slice, atol=1e-4) + print("Logits ok!") + + if pytorch_dump_folder_path is not None: + print(f"Saving model and image processor to {pytorch_dump_folder_path}") + image_processor.save_pretrained(pytorch_dump_folder_path) + model.save_pretrained(pytorch_dump_folder_path) + + if push_to_hub: + print("Pushing to the hub...") + model.push_to_hub(f"fcakyon/{model_name}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--checkpoint_url", + default="https://drive.google.com/u/1/uc?id=17yvuYp9L4mn-HpIcK5Zo6K3UoOy1kA5l&export=download", + type=str, + help=( + "URL of the original PyTorch checkpoint (on Google Drive) you'd like to convert. Should be a direct" + " download link." + ), + ) + parser.add_argument( + "--pytorch_dump_folder_path", + default="", + type=str, + help="Path to the output PyTorch model directory.", + ) + parser.add_argument("--model_name", default="timesformer-base-finetuned-k400", type=str, help="Name of the model.") + parser.add_argument( + "--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub." + ) + + args = parser.parse_args() + convert_timesformer_checkpoint( + args.checkpoint_url, args.pytorch_dump_folder_path, args.model_name, args.push_to_hub + ) diff --git a/docs/transformers/src/transformers/models/timesformer/modeling_timesformer.py b/docs/transformers/src/transformers/models/timesformer/modeling_timesformer.py new file mode 100644 index 0000000000000000000000000000000000000000..0bd79a6cec0425991dba299e1886dd4120a1dd16 --- /dev/null +++ b/docs/transformers/src/transformers/models/timesformer/modeling_timesformer.py @@ -0,0 +1,816 @@ +# coding=utf-8 +# Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch TimeSformer model.""" + +import collections +from typing import Optional, Tuple, Union + +import torch +import torch.nn.functional +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput +from ...modeling_utils import PreTrainedModel +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from .configuration_timesformer import TimesformerConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "TimesformerConfig" +_CHECKPOINT_FOR_DOC = "facebook/timesformer" + + +# Adapted from https://github.com/facebookresearch/TimeSformer/blob/a5ef29a7b7264baff199a30b3306ac27de901133/timesformer/models/vit.py#L155 +class TimesformerPatchEmbeddings(nn.Module): + """Image to Patch Embedding""" + + def __init__(self, config): + super().__init__() + + image_size = config.image_size + patch_size = config.patch_size + + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.image_size = image_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.projection = nn.Conv2d(config.num_channels, config.hidden_size, kernel_size=patch_size, stride=patch_size) + + def forward(self, pixel_values): + batch_size, num_frames, num_channels, height, width = pixel_values.shape + pixel_values = pixel_values.reshape(batch_size * num_frames, num_channels, height, width) + + embeddings = self.projection(pixel_values) + patch_width = embeddings.size(-1) + embeddings = embeddings.flatten(2).transpose(1, 2) + return embeddings, num_frames, patch_width + + +class TimesformerEmbeddings(nn.Module): + """ + Construct the patch and position embeddings. + """ + + def __init__(self, config): + super().__init__() + + embed_dim = config.hidden_size + num_frames = config.num_frames + drop_rate = config.hidden_dropout_prob + attention_type = config.attention_type + + self.attention_type = attention_type + self.patch_embeddings = TimesformerPatchEmbeddings(config) + self.num_patches = self.patch_embeddings.num_patches + + # Positional Embeddings + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.position_embeddings = nn.Parameter(torch.zeros(1, self.num_patches + 1, embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + if attention_type != "space_only": + self.time_embeddings = nn.Parameter(torch.zeros(1, num_frames, embed_dim)) + self.time_drop = nn.Dropout(p=drop_rate) + + def forward(self, pixel_values): + batch_size = pixel_values.shape[0] + + # create patch embeddings + embeddings, num_frames, patch_width = self.patch_embeddings(pixel_values) + + cls_tokens = self.cls_token.expand(embeddings.size(0), -1, -1) + embeddings = torch.cat((cls_tokens, embeddings), dim=1) + + # resizing the positional embeddings in case they don't match the input at inference + if embeddings.size(1) != self.position_embeddings.size(1): + position_embeddings = self.position_embeddings + cls_pos_embed = position_embeddings[0, 0, :].unsqueeze(0).unsqueeze(1) + other_pos_embed = position_embeddings[0, 1:, :].unsqueeze(0).transpose(1, 2) + patch_num = int(other_pos_embed.size(2) ** 0.5) + patch_height = embeddings.size(1) // patch_width + other_pos_embed = other_pos_embed.reshape(1, embeddings.size(2), patch_num, patch_num) + new_pos_embed = nn.functional.interpolate( + other_pos_embed, size=(patch_height, patch_width), mode="nearest" + ) + new_pos_embed = new_pos_embed.flatten(2) + new_pos_embed = new_pos_embed.transpose(1, 2) + new_pos_embed = torch.cat((cls_pos_embed, new_pos_embed), 1) + embeddings = embeddings + new_pos_embed + else: + embeddings = embeddings + self.position_embeddings + embeddings = self.pos_drop(embeddings) + + # Time Embeddings + if self.attention_type != "space_only": + cls_tokens = embeddings[:batch_size, 0, :].unsqueeze(1) + embeddings = embeddings[:, 1:] + _, patch_height, patch_width = embeddings.shape + embeddings = ( + embeddings.reshape(batch_size, num_frames, patch_height, patch_width) + .permute(0, 2, 1, 3) + .reshape(batch_size * patch_height, num_frames, patch_width) + ) + # Resizing time embeddings in case they don't match + if num_frames != self.time_embeddings.size(1): + time_embeddings = self.time_embeddings.transpose(1, 2) + new_time_embeddings = nn.functional.interpolate(time_embeddings, size=(num_frames), mode="nearest") + new_time_embeddings = new_time_embeddings.transpose(1, 2) + embeddings = embeddings + new_time_embeddings + else: + embeddings = embeddings + self.time_embeddings + embeddings = self.time_drop(embeddings) + embeddings = embeddings.view(batch_size, patch_height, num_frames, patch_width).reshape( + batch_size, patch_height * num_frames, patch_width + ) + embeddings = torch.cat((cls_tokens, embeddings), dim=1) + + return embeddings + + +# Copied from transformers.models.beit.modeling_beit.drop_path +def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, + however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the + layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the + argument. + """ + if drop_prob == 0.0 or not training: + return input + keep_prob = 1 - drop_prob + shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) + random_tensor.floor_() # binarize + output = input.div(keep_prob) * random_tensor + return output + + +# Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->TimeSformer +class TimeSformerDropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: Optional[float] = None) -> None: + super().__init__() + self.drop_prob = drop_prob + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return drop_path(hidden_states, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return "p={}".format(self.drop_prob) + + +# Adapted from https://github.com/facebookresearch/TimeSformer/blob/a5ef29a7b7264baff199a30b3306ac27de901133/timesformer/models/vit.py#L57 +class TimesformerSelfAttention(nn.Module): + def __init__(self, config: TimesformerConfig): + super().__init__() + + num_heads = config.num_attention_heads + qkv_bias = config.qkv_bias + attention_dropout_prob = config.attention_probs_dropout_prob + + self.num_heads = num_heads + head_dim = config.hidden_size // num_heads + self.scale = head_dim**-0.5 + self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attention_dropout_prob) + + def forward(self, hidden_states, output_attentions: bool = False): + batch_size, hidden_size, num_channels = hidden_states.shape + qkv = ( + self.qkv(hidden_states) + .reshape(batch_size, hidden_size, 3, self.num_heads, num_channels // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + query, key, value = qkv[0], qkv[1], qkv[2] + + attention_probs = (query @ key.transpose(-2, -1)) * self.scale + attention_probs = attention_probs.softmax(dim=-1) + attention_probs = self.attn_drop(attention_probs) + + context_layer = (attention_probs @ value).transpose(1, 2).reshape(batch_size, hidden_size, num_channels) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +class TimesformerSelfOutput(nn.Module): + """ + The residual connection is defined in TimesformerLayer instead of here (as is the case with other models), due to + the layernorm applied before each block. + """ + + def __init__(self, config: TimesformerConfig) -> None: + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + return hidden_states + + +class TimeSformerAttention(nn.Module): + def __init__(self, config: TimesformerConfig) -> None: + super().__init__() + self.attention = TimesformerSelfAttention(config) + self.output = TimesformerSelfOutput(config) + + def forward( + self, + hidden_states: torch.Tensor, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + self_outputs = self.attention(hidden_states, output_attentions) + + attention_output = self.output(self_outputs[0]) + + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +# Adapted from https://github.com/facebookresearch/TimeSformer/blob/a5ef29a7b7264baff199a30b3306ac27de901133/timesformer/models/vit.py#L39 +class TimesformerIntermediate(nn.Module): + def __init__(self, config: TimesformerConfig) -> None: + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + hidden_states = self.dropout(hidden_states) + + return hidden_states + + +class TimesformerOutput(nn.Module): + def __init__(self, config: TimesformerConfig) -> None: + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + return hidden_states + + +# Adapted from https://github.com/facebookresearch/TimeSformer/blob/a5ef29a7b7264baff199a30b3306ac27de901133/timesformer/models/vit.py#L89 +class TimesformerLayer(nn.Module): + def __init__(self, config: TimesformerConfig, layer_index: int) -> None: + super().__init__() + + attention_type = config.attention_type + + drop_path_rates = [ + x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers, device="cpu") + ] # stochastic depth decay rule + drop_path_rate = drop_path_rates[layer_index] + + self.drop_path = TimeSformerDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() + self.attention = TimeSformerAttention(config) + self.intermediate = TimesformerIntermediate(config) + self.output = TimesformerOutput(config) + self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.config = config + self.attention_type = attention_type + if attention_type not in ["divided_space_time", "space_only", "joint_space_time"]: + raise ValueError("Unknown attention type: {}".format(attention_type)) + + # Temporal Attention Parameters + if self.attention_type == "divided_space_time": + self.temporal_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.temporal_attention = TimeSformerAttention(config) + self.temporal_dense = nn.Linear(config.hidden_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor, output_attentions: bool = False): + num_frames = self.config.num_frames + num_patch_width = self.config.image_size // self.config.patch_size + batch_size = hidden_states.shape[0] + num_spatial_tokens = (hidden_states.size(1) - 1) // num_frames + num_patch_height = num_spatial_tokens // num_patch_width + + if self.attention_type in ["space_only", "joint_space_time"]: + self_attention_outputs = self.attention( + self.layernorm_before(hidden_states), output_attentions=output_attentions + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + hidden_states = hidden_states + self.drop_path(attention_output) + + layer_output = self.layernorm_after(hidden_states) + layer_output = self.intermediate(layer_output) + layer_output = self.output(layer_output) + layer_output = hidden_states + self.drop_path(layer_output) + + outputs = (layer_output,) + outputs + + return outputs + + elif self.attention_type == "divided_space_time": + # Temporal + temporal_embedding = hidden_states[:, 1:, :] + temporal_embedding = temporal_embedding.reshape( + batch_size, num_patch_height, num_patch_width, num_frames, temporal_embedding.shape[2] + ).reshape(batch_size * num_patch_height * num_patch_width, num_frames, temporal_embedding.shape[2]) + + temporal_attention_outputs = self.temporal_attention( + self.temporal_layernorm(temporal_embedding), + ) + attention_output = temporal_attention_outputs[0] + + residual_temporal = self.drop_path(attention_output) + + residual_temporal = residual_temporal.reshape( + batch_size, num_patch_height, num_patch_width, num_frames, residual_temporal.shape[2] + ).reshape(batch_size, num_patch_height * num_patch_width * num_frames, residual_temporal.shape[2]) + residual_temporal = self.temporal_dense(residual_temporal) + temporal_embedding = hidden_states[:, 1:, :] + residual_temporal + + # Spatial + init_cls_token = hidden_states[:, 0, :].unsqueeze(1) + cls_token = init_cls_token.repeat(1, num_frames, 1) + cls_token = cls_token.reshape(batch_size * num_frames, 1, cls_token.shape[2]) + spatial_embedding = temporal_embedding + spatial_embedding = ( + spatial_embedding.reshape( + batch_size, num_patch_height, num_patch_width, num_frames, spatial_embedding.shape[2] + ) + .permute(0, 3, 1, 2, 4) + .reshape(batch_size * num_frames, num_patch_height * num_patch_width, spatial_embedding.shape[2]) + ) + spatial_embedding = torch.cat((cls_token, spatial_embedding), 1) + + spatial_attention_outputs = self.attention( + self.layernorm_before(spatial_embedding), output_attentions=output_attentions + ) + attention_output = spatial_attention_outputs[0] + outputs = spatial_attention_outputs[1:] # add self attentions if we output attention weights + + residual_spatial = self.drop_path(attention_output) + + # Taking care of CLS token + cls_token = residual_spatial[:, 0, :] + cls_token = cls_token.reshape(batch_size, num_frames, cls_token.shape[1]) + cls_token = torch.mean(cls_token, 1, True) # averaging for every frame + residual_spatial = residual_spatial[:, 1:, :] + residual_spatial = ( + residual_spatial.reshape( + batch_size, num_frames, num_patch_height, num_patch_width, residual_spatial.shape[2] + ) + .permute(0, 2, 3, 1, 4) + .reshape(batch_size, num_patch_height * num_patch_width * num_frames, residual_spatial.shape[2]) + ) + residual = residual_spatial + hidden_states = temporal_embedding + + # Mlp + hidden_states = torch.cat((init_cls_token, hidden_states), 1) + torch.cat((cls_token, residual), 1) + layer_output = self.layernorm_after(hidden_states) + layer_output = self.intermediate(layer_output) + layer_output = self.output(layer_output) + layer_output = hidden_states + self.drop_path(layer_output) + + outputs = (layer_output,) + outputs + + return outputs + + +class TimesformerEncoder(nn.Module): + def __init__(self, config: TimesformerConfig) -> None: + super().__init__() + self.config = config + self.layer = nn.ModuleList([TimesformerLayer(config, ind) for ind in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ) -> Union[tuple, BaseModelOutput]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.__call__, + hidden_states, + output_attentions, + ) + else: + layer_outputs = layer_module(hidden_states, output_attentions) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class TimesformerPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = TimesformerConfig + base_model_prefix = "timesformer" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + _no_split_modules = ["TimesformerLayer"] + + def _init_weights(self, module): + if isinstance(module, (nn.Linear, nn.Conv2d)): + nn.init.trunc_normal_(module.weight, std=self.config.initializer_range) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + elif isinstance(module, nn.LayerNorm): + nn.init.constant_(module.bias, 0) + nn.init.constant_(module.weight, 1.0) + elif isinstance(module, TimesformerEmbeddings): + nn.init.trunc_normal_(module.cls_token, std=self.config.initializer_range) + nn.init.trunc_normal_(module.position_embeddings, std=self.config.initializer_range) + module.patch_embeddings.apply(self._init_weights) + + +TIMESFORMER_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it + as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`TimesformerConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +TIMESFORMER_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See + [`VideoMAEImageProcessor.preprocess`] for details. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare TimeSformer Model transformer outputting raw hidden-states without any specific head on top.", + TIMESFORMER_START_DOCSTRING, +) +class TimesformerModel(TimesformerPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.config = config + + self.embeddings = TimesformerEmbeddings(config) + self.encoder = TimesformerEncoder(config) + + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.patch_embeddings + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(TIMESFORMER_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: torch.FloatTensor, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]: + r""" + Returns: + + Examples: + + ```python + >>> import av + >>> import numpy as np + + >>> from transformers import AutoImageProcessor, TimesformerModel + >>> from huggingface_hub import hf_hub_download + + >>> np.random.seed(0) + + + >>> def read_video_pyav(container, indices): + ... ''' + ... Decode the video with PyAV decoder. + ... Args: + ... container (`av.container.input.InputContainer`): PyAV container. + ... indices (`List[int]`): List of frame indices to decode. + ... Returns: + ... result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3). + ... ''' + ... frames = [] + ... container.seek(0) + ... start_index = indices[0] + ... end_index = indices[-1] + ... for i, frame in enumerate(container.decode(video=0)): + ... if i > end_index: + ... break + ... if i >= start_index and i in indices: + ... frames.append(frame) + ... return np.stack([x.to_ndarray(format="rgb24") for x in frames]) + + + >>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len): + ... ''' + ... Sample a given number of frame indices from the video. + ... Args: + ... clip_len (`int`): Total number of frames to sample. + ... frame_sample_rate (`int`): Sample every n-th frame. + ... seg_len (`int`): Maximum allowed index of sample's last frame. + ... Returns: + ... indices (`List[int]`): List of sampled frame indices + ... ''' + ... converted_len = int(clip_len * frame_sample_rate) + ... end_idx = np.random.randint(converted_len, seg_len) + ... start_idx = end_idx - converted_len + ... indices = np.linspace(start_idx, end_idx, num=clip_len) + ... indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64) + ... return indices + + + >>> # video clip consists of 300 frames (10 seconds at 30 FPS) + >>> file_path = hf_hub_download( + ... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset" + ... ) + >>> container = av.open(file_path) + + >>> # sample 8 frames + >>> indices = sample_frame_indices(clip_len=8, frame_sample_rate=4, seg_len=container.streams.video[0].frames) + >>> video = read_video_pyav(container, indices) + + >>> image_processor = AutoImageProcessor.from_pretrained("MCG-NJU/videomae-base") + >>> model = TimesformerModel.from_pretrained("facebook/timesformer-base-finetuned-k400") + + >>> # prepare video for the model + >>> inputs = image_processor(list(video), return_tensors="pt") + + >>> # forward pass + >>> outputs = model(**inputs) + >>> last_hidden_states = outputs.last_hidden_state + >>> list(last_hidden_states.shape) + [1, 1569, 768] + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + embedding_output = self.embeddings(pixel_values) + + encoder_outputs = self.encoder( + embedding_output, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + if self.layernorm is not None: + sequence_output = self.layernorm(sequence_output) + + if not return_dict: + return (sequence_output,) + encoder_outputs[1:] + + return BaseModelOutput( + last_hidden_state=sequence_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + """TimeSformer Model transformer with a video classification head on top (a linear layer on top of the final hidden state +of the [CLS] token) e.g. for ImageNet.""", + TIMESFORMER_START_DOCSTRING, +) +class TimesformerForVideoClassification(TimesformerPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.num_labels = config.num_labels + self.timesformer = TimesformerModel(config) + + # Classifier head + self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity() + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(TIMESFORMER_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=ImageClassifierOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, ImageClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + Returns: + + Examples: + + ```python + >>> import av + >>> import torch + >>> import numpy as np + + >>> from transformers import AutoImageProcessor, TimesformerForVideoClassification + >>> from huggingface_hub import hf_hub_download + + >>> np.random.seed(0) + + + >>> def read_video_pyav(container, indices): + ... ''' + ... Decode the video with PyAV decoder. + ... Args: + ... container (`av.container.input.InputContainer`): PyAV container. + ... indices (`List[int]`): List of frame indices to decode. + ... Returns: + ... result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3). + ... ''' + ... frames = [] + ... container.seek(0) + ... start_index = indices[0] + ... end_index = indices[-1] + ... for i, frame in enumerate(container.decode(video=0)): + ... if i > end_index: + ... break + ... if i >= start_index and i in indices: + ... frames.append(frame) + ... return np.stack([x.to_ndarray(format="rgb24") for x in frames]) + + + >>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len): + ... ''' + ... Sample a given number of frame indices from the video. + ... Args: + ... clip_len (`int`): Total number of frames to sample. + ... frame_sample_rate (`int`): Sample every n-th frame. + ... seg_len (`int`): Maximum allowed index of sample's last frame. + ... Returns: + ... indices (`List[int]`): List of sampled frame indices + ... ''' + ... converted_len = int(clip_len * frame_sample_rate) + ... end_idx = np.random.randint(converted_len, seg_len) + ... start_idx = end_idx - converted_len + ... indices = np.linspace(start_idx, end_idx, num=clip_len) + ... indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64) + ... return indices + + + >>> # video clip consists of 300 frames (10 seconds at 30 FPS) + >>> file_path = hf_hub_download( + ... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset" + ... ) + >>> container = av.open(file_path) + + >>> # sample 8 frames + >>> indices = sample_frame_indices(clip_len=8, frame_sample_rate=1, seg_len=container.streams.video[0].frames) + >>> video = read_video_pyav(container, indices) + + >>> image_processor = AutoImageProcessor.from_pretrained("MCG-NJU/videomae-base-finetuned-kinetics") + >>> model = TimesformerForVideoClassification.from_pretrained("facebook/timesformer-base-finetuned-k400") + + >>> inputs = image_processor(list(video), return_tensors="pt") + + >>> with torch.no_grad(): + ... outputs = model(**inputs) + ... logits = outputs.logits + + >>> # model predicts one of the 400 Kinetics-400 classes + >>> predicted_label = logits.argmax(-1).item() + >>> print(model.config.id2label[predicted_label]) + eating spaghetti + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.timesformer( + pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0][:, 0] + + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = ["TimesformerModel", "TimesformerForVideoClassification", "TimesformerPreTrainedModel"] diff --git a/docs/transformers/src/transformers/models/timm_backbone/__init__.py b/docs/transformers/src/transformers/models/timm_backbone/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4a58f096b1e48227518b04a11250d66635c65701 --- /dev/null +++ b/docs/transformers/src/transformers/models/timm_backbone/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_timm_backbone import * + from .modeling_timm_backbone import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/timm_backbone/configuration_timm_backbone.py b/docs/transformers/src/transformers/models/timm_backbone/configuration_timm_backbone.py new file mode 100644 index 0000000000000000000000000000000000000000..6000698c92488d046ef482eb664920e5bfc58ec5 --- /dev/null +++ b/docs/transformers/src/transformers/models/timm_backbone/configuration_timm_backbone.py @@ -0,0 +1,86 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Configuration for Backbone models""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class TimmBackboneConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration for a timm backbone [`TimmBackbone`]. + + It is used to instantiate a timm backbone model according to the specified arguments, defining the model. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + backbone (`str`, *optional*): + The timm checkpoint to load. + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + features_only (`bool`, *optional*, defaults to `True`): + Whether to output only the features or also the logits. + use_pretrained_backbone (`bool`, *optional*, defaults to `True`): + Whether to use a pretrained backbone. + out_indices (`List[int]`, *optional*): + If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how + many stages the model has). Will default to the last stage if unset. + freeze_batch_norm_2d (`bool`, *optional*, defaults to `False`): + Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. + + Example: + ```python + >>> from transformers import TimmBackboneConfig, TimmBackbone + + >>> # Initializing a timm backbone + >>> configuration = TimmBackboneConfig("resnet50") + + >>> # Initializing a model from the configuration + >>> model = TimmBackbone(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "timm_backbone" + + def __init__( + self, + backbone=None, + num_channels=3, + features_only=True, + use_pretrained_backbone=True, + out_indices=None, + freeze_batch_norm_2d=False, + **kwargs, + ): + super().__init__(**kwargs) + self.backbone = backbone + self.num_channels = num_channels + self.features_only = features_only + self.use_pretrained_backbone = use_pretrained_backbone + self.use_timm_backbone = True + self.out_indices = out_indices if out_indices is not None else [-1] + self.freeze_batch_norm_2d = freeze_batch_norm_2d + + +__all__ = ["TimmBackboneConfig"] diff --git a/docs/transformers/src/transformers/models/timm_backbone/modeling_timm_backbone.py b/docs/transformers/src/transformers/models/timm_backbone/modeling_timm_backbone.py new file mode 100644 index 0000000000000000000000000000000000000000..60a8b40463569dc909275d49f7b5a386ecf31c73 --- /dev/null +++ b/docs/transformers/src/transformers/models/timm_backbone/modeling_timm_backbone.py @@ -0,0 +1,161 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple, Union + +import torch + +from ...modeling_outputs import BackboneOutput +from ...modeling_utils import PreTrainedModel +from ...utils import is_timm_available, is_torch_available, requires_backends +from ...utils.backbone_utils import BackboneMixin +from .configuration_timm_backbone import TimmBackboneConfig + + +if is_timm_available(): + import timm + + +if is_torch_available(): + from torch import Tensor + + +class TimmBackbone(PreTrainedModel, BackboneMixin): + """ + Wrapper class for timm models to be used as backbones. This enables using the timm models interchangeably with the + other models in the library keeping the same API. + """ + + main_input_name = "pixel_values" + supports_gradient_checkpointing = False + config_class = TimmBackboneConfig + + def __init__(self, config, **kwargs): + requires_backends(self, "timm") + super().__init__(config) + self.config = config + + if config.backbone is None: + raise ValueError("backbone is not set in the config. Please set it to a timm model name.") + + if hasattr(config, "out_features") and config.out_features is not None: + raise ValueError("out_features is not supported by TimmBackbone. Please use out_indices instead.") + + pretrained = getattr(config, "use_pretrained_backbone", None) + if pretrained is None: + raise ValueError("use_pretrained_backbone is not set in the config. Please set it to True or False.") + + # We just take the final layer by default. This matches the default for the transformers models. + out_indices = config.out_indices if getattr(config, "out_indices", None) is not None else (-1,) + + in_chans = kwargs.pop("in_chans", config.num_channels) + self._backbone = timm.create_model( + config.backbone, + pretrained=pretrained, + # This is currently not possible for transformer architectures. + features_only=config.features_only, + in_chans=in_chans, + out_indices=out_indices, + **kwargs, + ) + + # Converts all `BatchNorm2d` and `SyncBatchNorm` or `BatchNormAct2d` and `SyncBatchNormAct2d` layers of provided module into `FrozenBatchNorm2d` or `FrozenBatchNormAct2d` respectively + if getattr(config, "freeze_batch_norm_2d", False): + self.freeze_batch_norm_2d() + + # These are used to control the output of the model when called. If output_hidden_states is True, then + # return_layers is modified to include all layers. + self._return_layers = { + layer["module"]: str(layer["index"]) for layer in self._backbone.feature_info.get_dicts() + } + self._all_layers = {layer["module"]: str(i) for i, layer in enumerate(self._backbone.feature_info.info)} + super()._init_backbone(config) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + requires_backends(cls, ["vision", "timm"]) + from ...models.timm_backbone import TimmBackboneConfig + + config = kwargs.pop("config", TimmBackboneConfig()) + + use_timm = kwargs.pop("use_timm_backbone", True) + if not use_timm: + raise ValueError("use_timm_backbone must be True for timm backbones") + + num_channels = kwargs.pop("num_channels", config.num_channels) + features_only = kwargs.pop("features_only", config.features_only) + use_pretrained_backbone = kwargs.pop("use_pretrained_backbone", config.use_pretrained_backbone) + out_indices = kwargs.pop("out_indices", config.out_indices) + config = TimmBackboneConfig( + backbone=pretrained_model_name_or_path, + num_channels=num_channels, + features_only=features_only, + use_pretrained_backbone=use_pretrained_backbone, + out_indices=out_indices, + ) + return super()._from_config(config, **kwargs) + + def freeze_batch_norm_2d(self): + timm.utils.model.freeze_batch_norm_2d(self._backbone) + + def unfreeze_batch_norm_2d(self): + timm.utils.model.unfreeze_batch_norm_2d(self._backbone) + + def _init_weights(self, module): + """ + Empty init weights function to ensure compatibility of the class in the library. + """ + pass + + def forward( + self, + pixel_values: torch.FloatTensor, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[BackboneOutput, Tuple[Tensor, ...]]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + if output_attentions: + raise ValueError("Cannot output attentions for timm backbones at the moment") + + if output_hidden_states: + # We modify the return layers to include all the stages of the backbone + self._backbone.return_layers = self._all_layers + hidden_states = self._backbone(pixel_values, **kwargs) + self._backbone.return_layers = self._return_layers + feature_maps = tuple(hidden_states[i] for i in self.out_indices) + else: + feature_maps = self._backbone(pixel_values, **kwargs) + hidden_states = None + + feature_maps = tuple(feature_maps) + hidden_states = tuple(hidden_states) if hidden_states is not None else None + + if not return_dict: + output = (feature_maps,) + if output_hidden_states: + output = output + (hidden_states,) + return output + + return BackboneOutput(feature_maps=feature_maps, hidden_states=hidden_states, attentions=None) + + +__all__ = ["TimmBackbone"] diff --git a/docs/transformers/src/transformers/models/timm_wrapper/__init__.py b/docs/transformers/src/transformers/models/timm_wrapper/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9fbc4150412a73ec6d77a09db13c105e95e50c06 --- /dev/null +++ b/docs/transformers/src/transformers/models/timm_wrapper/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_timm_wrapper import * + from .modeling_timm_wrapper import * + from .processing_timm_wrapper import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/timm_wrapper/configuration_timm_wrapper.py b/docs/transformers/src/transformers/models/timm_wrapper/configuration_timm_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..4313ad854add5955150268d43d4f3a6b97e73ed1 --- /dev/null +++ b/docs/transformers/src/transformers/models/timm_wrapper/configuration_timm_wrapper.py @@ -0,0 +1,117 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Configuration for TimmWrapper models""" + +from typing import Any, Dict + +from ...configuration_utils import PretrainedConfig +from ...utils import is_timm_available, logging, requires_backends + + +if is_timm_available(): + from timm.data import ImageNetInfo, infer_imagenet_subset + + +logger = logging.get_logger(__name__) + + +class TimmWrapperConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration for a timm backbone [`TimmWrapper`]. + + It is used to instantiate a timm model according to the specified arguments, defining the model. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Config loads imagenet label descriptions and stores them in `id2label` attribute, `label2id` attribute for default + imagenet models is set to `None` due to occlusions in the label descriptions. + + Args: + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + do_pooling (`bool`, *optional*, defaults to `True`): + Whether to do pooling for the last_hidden_state in `TimmWrapperModel` or not. + + Example: + ```python + >>> from transformers import TimmWrapperModel + + >>> # Initializing a timm model + >>> model = TimmWrapperModel.from_pretrained("timm/resnet18.a1_in1k") + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "timm_wrapper" + + def __init__(self, initializer_range: float = 0.02, do_pooling: bool = True, **kwargs): + self.initializer_range = initializer_range + self.do_pooling = do_pooling + super().__init__(**kwargs) + + @classmethod + def from_dict(cls, config_dict: Dict[str, Any], **kwargs): + label_names = config_dict.get("label_names", None) + is_custom_model = "num_labels" in kwargs or "id2label" in kwargs + + # if no labels added to config, use imagenet labeller in timm + if label_names is None and not is_custom_model: + requires_backends(cls, ["timm"]) + imagenet_subset = infer_imagenet_subset(config_dict) + if imagenet_subset: + dataset_info = ImageNetInfo(imagenet_subset) + synsets = dataset_info.label_names() + label_descriptions = dataset_info.label_descriptions(as_dict=True) + label_names = [label_descriptions[synset] for synset in synsets] + + if label_names is not None and not is_custom_model: + kwargs["id2label"] = dict(enumerate(label_names)) + + # if all label names are unique, create label2id mapping as well + if len(set(label_names)) == len(label_names): + kwargs["label2id"] = {name: i for i, name in enumerate(label_names)} + else: + kwargs["label2id"] = None + + # timm config stores the `num_classes` attribute in both the root of config and in the "pretrained_cfg" dict. + # We are removing these attributes in order to have the native `transformers` num_labels attribute in config + # and to avoid duplicate attributes + num_labels_in_kwargs = kwargs.pop("num_labels", None) + num_labels_in_dict = config_dict.pop("num_classes", None) + + # passed num_labels has priority over num_classes in config_dict + kwargs["num_labels"] = num_labels_in_kwargs or num_labels_in_dict + + # pop num_classes from "pretrained_cfg", + # it is not necessary to have it, only root one is used in timm + if "pretrained_cfg" in config_dict and "num_classes" in config_dict["pretrained_cfg"]: + config_dict["pretrained_cfg"].pop("num_classes", None) + + return super().from_dict(config_dict, **kwargs) + + def to_dict(self) -> Dict[str, Any]: + output = super().to_dict() + output["num_classes"] = self.num_labels + output["label_names"] = list(self.id2label.values()) + output.pop("id2label", None) + output.pop("label2id", None) + return output + + +__all__ = ["TimmWrapperConfig"] diff --git a/docs/transformers/src/transformers/models/timm_wrapper/image_processing_timm_wrapper.py b/docs/transformers/src/transformers/models/timm_wrapper/image_processing_timm_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..de54f8a936cda645e622664864612fd56f26407b --- /dev/null +++ b/docs/transformers/src/transformers/models/timm_wrapper/image_processing_timm_wrapper.py @@ -0,0 +1,139 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Any, Dict, Optional, Tuple, Union + +import torch + +from ...image_processing_utils import BaseImageProcessor, BatchFeature +from ...image_transforms import to_pil_image +from ...image_utils import ImageInput, make_list_of_images +from ...utils import TensorType, logging, requires_backends +from ...utils.import_utils import is_timm_available, is_torch_available, requires + + +if is_timm_available(): + import timm + +if is_torch_available(): + import torch + + +logger = logging.get_logger(__name__) + + +@requires(backends=("torch", "timm", "torchvision")) +class TimmWrapperImageProcessor(BaseImageProcessor): + """ + Wrapper class for timm models to be used within transformers. + + Args: + pretrained_cfg (`Dict[str, Any]`): + The configuration of the pretrained model used to resolve evaluation and + training transforms. + architecture (`Optional[str]`, *optional*): + Name of the architecture of the model. + """ + + main_input_name = "pixel_values" + + def __init__( + self, + pretrained_cfg: Dict[str, Any], + architecture: Optional[str] = None, + **kwargs, + ): + requires_backends(self, "timm") + super().__init__(architecture=architecture) + + self.data_config = timm.data.resolve_data_config(pretrained_cfg, model=None, verbose=False) + self.val_transforms = timm.data.create_transform(**self.data_config, is_training=False) + + # useful for training, see examples/pytorch/image-classification/run_image_classification.py + self.train_transforms = timm.data.create_transform(**self.data_config, is_training=True) + + # If `ToTensor` is in the transforms, then the input should be numpy array or PIL image. + # Otherwise, the input can be a tensor. In later timm versions, `MaybeToTensor` is used + # which can handle both numpy arrays / PIL images and tensors. + self._not_supports_tensor_input = any( + transform.__class__.__name__ == "ToTensor" for transform in self.val_transforms.transforms + ) + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes this instance to a Python dictionary. + """ + output = super().to_dict() + output.pop("train_transforms", None) + output.pop("val_transforms", None) + output.pop("_not_supports_tensor_input", None) + return output + + @classmethod + def get_image_processor_dict( + cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """ + Get the image processor dict for the model. + """ + image_processor_filename = kwargs.pop("image_processor_filename", "config.json") + return super().get_image_processor_dict( + pretrained_model_name_or_path, image_processor_filename=image_processor_filename, **kwargs + ) + + def preprocess( + self, + images: ImageInput, + return_tensors: Optional[Union[str, TensorType]] = "pt", + ) -> BatchFeature: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. + """ + if return_tensors != "pt": + raise ValueError(f"return_tensors for TimmWrapperImageProcessor must be 'pt', but got {return_tensors}") + + if self._not_supports_tensor_input and isinstance(images, torch.Tensor): + images = images.cpu().numpy() + + # If the input is a torch tensor, then no conversion is needed + # Otherwise, we need to pass in a list of PIL images + if isinstance(images, torch.Tensor): + images = self.val_transforms(images) + # Add batch dimension if a single image + images = images.unsqueeze(0) if images.ndim == 3 else images + else: + images = make_list_of_images(images) + images = [to_pil_image(image) for image in images] + images = torch.stack([self.val_transforms(image) for image in images]) + + return BatchFeature({"pixel_values": images}, tensor_type=return_tensors) + + def save_pretrained(self, *args, **kwargs): + # disable it to make checkpoint the same as in `timm` library. + logger.warning_once( + "The `save_pretrained` method is disabled for TimmWrapperImageProcessor. " + "The image processor configuration is saved directly in `config.json` when " + "`save_pretrained` is called for saving the model." + ) + + +__all__ = ["TimmWrapperImageProcessor"] diff --git a/docs/transformers/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py b/docs/transformers/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..48c1d85825335ba66f4328d1db9fa5e191f46ecc --- /dev/null +++ b/docs/transformers/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py @@ -0,0 +1,367 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch +from torch import Tensor, nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...modeling_outputs import ImageClassifierOutput, ModelOutput +from ...modeling_utils import PreTrainedModel +from ...utils import ( + add_start_docstrings_to_model_forward, + is_timm_available, + replace_return_docstrings, + requires_backends, +) +from .configuration_timm_wrapper import TimmWrapperConfig + + +if is_timm_available(): + import timm + + +@dataclass +class TimmWrapperModelOutput(ModelOutput): + """ + Output class for models TimmWrapperModel, containing the last hidden states, an optional pooled output, + and optional hidden states. + + Args: + last_hidden_state (`torch.FloatTensor`): + The last hidden state of the model, output before applying the classification head. + pooler_output (`torch.FloatTensor`, *optional*): + The pooled output derived from the last hidden state, if applicable. + hidden_states (`tuple(torch.FloatTensor)`, *optional*): + A tuple containing the intermediate hidden states of the model at the output of each layer or specified layers. + Returned if `output_hidden_states=True` is set or if `config.output_hidden_states=True`. + attentions (`tuple(torch.FloatTensor)`, *optional*): + A tuple containing the intermediate attention weights of the model at the output of each layer. + Returned if `output_attentions=True` is set or if `config.output_attentions=True`. + Note: Currently, Timm models do not support attentions output. + """ + + last_hidden_state: torch.FloatTensor + pooler_output: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +TIMM_WRAPPER_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`TimmWrapperImageProcessor.preprocess`] + for details. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. Not compatible with timm wrapped models. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. Not compatible with timm wrapped models. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + **kwargs: + Additional keyword arguments passed along to the `timm` model forward. +""" + + +class TimmWrapperPreTrainedModel(PreTrainedModel): + main_input_name = "pixel_values" + config_class = TimmWrapperConfig + _no_split_modules = [] + model_tags = ["timm"] + + # used in Trainer to avoid passing `loss_kwargs` to model forward + accepts_loss_kwargs = False + + def __init__(self, *args, **kwargs): + requires_backends(self, ["vision", "timm"]) + super().__init__(*args, **kwargs) + + @staticmethod + def _fix_state_dict_key_on_load(key) -> Tuple[str, bool]: + """ + Overrides original method that renames `gamma` and `beta` to `weight` and `bias`. + We don't want this behavior for timm wrapped models. Instead, this method adds a + "timm_model." prefix to enable loading official timm Hub checkpoints. + """ + if "timm_model." not in key: + return f"timm_model.{key}", True + return key, False + + def _fix_state_dict_key_on_save(self, key): + """ + Overrides original method to remove "timm_model." prefix from state_dict keys. + Makes the saved checkpoint compatible with the `timm` library. + """ + return key.replace("timm_model.", ""), True + + def load_state_dict(self, state_dict, *args, **kwargs): + """ + Override original method to fix state_dict keys on load for cases when weights are loaded + without using the `from_pretrained` method (e.g., in Trainer to resume from checkpoint). + """ + state_dict = {self._fix_state_dict_key_on_load(k)[0]: v for k, v in state_dict.items()} + return super().load_state_dict(state_dict, *args, **kwargs) + + def _init_weights(self, module): + """ + Initialize weights function to properly initialize Linear layer weights. + Since model architectures may vary, we assume only the classifier requires + initialization, while all other weights should be loaded from the checkpoint. + """ + if isinstance(module, (nn.Linear)): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + + +class TimmWrapperModel(TimmWrapperPreTrainedModel): + """ + Wrapper class for timm models to be used in transformers. + """ + + def __init__(self, config: TimmWrapperConfig): + super().__init__(config) + # using num_classes=0 to avoid creating classification head + self.timm_model = timm.create_model(config.architecture, pretrained=False, num_classes=0) + self.post_init() + + @add_start_docstrings_to_model_forward(TIMM_WRAPPER_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TimmWrapperModelOutput, config_class=TimmWrapperConfig) + def forward( + self, + pixel_values: torch.FloatTensor, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[Union[bool, List[int]]] = None, + return_dict: Optional[bool] = None, + do_pooling: Optional[bool] = None, + **kwargs, + ) -> Union[TimmWrapperModelOutput, Tuple[Tensor, ...]]: + r""" + do_pooling (`bool`, *optional*): + Whether to do pooling for the last_hidden_state in `TimmWrapperModel` or not. If `None` is passed, the + `do_pooling` value from the config is used. + + Returns: + + Examples: + ```python + >>> import torch + >>> from PIL import Image + >>> from urllib.request import urlopen + >>> from transformers import AutoModel, AutoImageProcessor + + >>> # Load image + >>> image = Image.open(urlopen( + ... 'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png' + ... )) + + >>> # Load model and image processor + >>> checkpoint = "timm/resnet50.a1_in1k" + >>> image_processor = AutoImageProcessor.from_pretrained(checkpoint) + >>> model = AutoModel.from_pretrained(checkpoint).eval() + + >>> # Preprocess image + >>> inputs = image_processor(image) + + >>> # Forward pass + >>> with torch.no_grad(): + ... outputs = model(**inputs) + + >>> # Get pooled output + >>> pooled_output = outputs.pooler_output + + >>> # Get last hidden state + >>> last_hidden_state = outputs.last_hidden_state + ``` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + do_pooling = do_pooling if do_pooling is not None else self.config.do_pooling + + if output_attentions: + raise ValueError("Cannot set `output_attentions` for timm models.") + + if output_hidden_states and not hasattr(self.timm_model, "forward_intermediates"): + raise ValueError( + "The 'output_hidden_states' option cannot be set for this timm model. " + "To enable this feature, the 'forward_intermediates' method must be implemented " + "in the timm model (available in timm versions > 1.*). Please consider using a " + "different architecture or updating the timm package to a compatible version." + ) + + pixel_values = pixel_values.to(self.device, self.dtype) + + if output_hidden_states: + # to enable hidden states selection + if isinstance(output_hidden_states, (list, tuple)): + kwargs["indices"] = output_hidden_states + last_hidden_state, hidden_states = self.timm_model.forward_intermediates(pixel_values, **kwargs) + else: + last_hidden_state = self.timm_model.forward_features(pixel_values, **kwargs) + hidden_states = None + + if do_pooling: + # classification head is not created, applying pooling only + pooler_output = self.timm_model.forward_head(last_hidden_state) + else: + pooler_output = None + + if not return_dict: + outputs = (last_hidden_state, pooler_output, hidden_states) + outputs = tuple(output for output in outputs if output is not None) + return outputs + + return TimmWrapperModelOutput( + last_hidden_state=last_hidden_state, + pooler_output=pooler_output, + hidden_states=hidden_states, + ) + + +class TimmWrapperForImageClassification(TimmWrapperPreTrainedModel): + """ + Wrapper class for timm models to be used in transformers for image classification. + """ + + def __init__(self, config: TimmWrapperConfig): + super().__init__(config) + + if config.num_labels == 0: + raise ValueError( + "You are trying to load weights into `TimmWrapperForImageClassification` from a checkpoint with no classifier head. " + "Please specify the number of classes, e.g. `model = TimmWrapperForImageClassification.from_pretrained(..., num_labels=10)`, " + "or use `TimmWrapperModel` for feature extraction." + ) + + self.timm_model = timm.create_model(config.architecture, pretrained=False, num_classes=config.num_labels) + self.num_labels = config.num_labels + self.post_init() + + @add_start_docstrings_to_model_forward(TIMM_WRAPPER_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=ImageClassifierOutput, config_class=TimmWrapperConfig) + def forward( + self, + pixel_values: torch.FloatTensor, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[Union[bool, List[int]]] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[ImageClassifierOutput, Tuple[Tensor, ...]]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + Returns: + + Examples: + ```python + >>> import torch + >>> from PIL import Image + >>> from urllib.request import urlopen + >>> from transformers import AutoModelForImageClassification, AutoImageProcessor + + >>> # Load image + >>> image = Image.open(urlopen( + ... 'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png' + ... )) + + >>> # Load model and image processor + >>> checkpoint = "timm/resnet50.a1_in1k" + >>> image_processor = AutoImageProcessor.from_pretrained(checkpoint) + >>> model = AutoModelForImageClassification.from_pretrained(checkpoint).eval() + + >>> # Preprocess image + >>> inputs = image_processor(image) + + >>> # Forward pass + >>> with torch.no_grad(): + ... logits = model(**inputs).logits + + >>> # Get top 5 predictions + >>> top5_probabilities, top5_class_indices = torch.topk(logits.softmax(dim=1) * 100, k=5) + ``` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + if output_attentions: + raise ValueError("Cannot set `output_attentions` for timm models.") + + if output_hidden_states and not hasattr(self.timm_model, "forward_intermediates"): + raise ValueError( + "The 'output_hidden_states' option cannot be set for this timm model. " + "To enable this feature, the 'forward_intermediates' method must be implemented " + "in the timm model (available in timm versions > 1.*). Please consider using a " + "different architecture or updating the timm package to a compatible version." + ) + + pixel_values = pixel_values.to(self.device, self.dtype) + + if output_hidden_states: + # to enable hidden states selection + if isinstance(output_hidden_states, (list, tuple)): + kwargs["indices"] = output_hidden_states + last_hidden_state, hidden_states = self.timm_model.forward_intermediates(pixel_values, **kwargs) + logits = self.timm_model.forward_head(last_hidden_state) + else: + logits = self.timm_model(pixel_values, **kwargs) + hidden_states = None + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + outputs = (loss, logits, hidden_states) + outputs = tuple(output for output in outputs if output is not None) + return outputs + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=hidden_states, + ) + + +__all__ = ["TimmWrapperPreTrainedModel", "TimmWrapperModel", "TimmWrapperForImageClassification"] diff --git a/docs/transformers/src/transformers/models/trocr/__init__.py b/docs/transformers/src/transformers/models/trocr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..96ae1927c1b56002e6d3cf20f1fea145854f53f3 --- /dev/null +++ b/docs/transformers/src/transformers/models/trocr/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_trocr import * + from .modeling_trocr import * + from .processing_trocr import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/trocr/configuration_trocr.py b/docs/transformers/src/transformers/models/trocr/configuration_trocr.py new file mode 100644 index 0000000000000000000000000000000000000000..6c3aabbe1958700cf64b30154f9924a56a95831f --- /dev/null +++ b/docs/transformers/src/transformers/models/trocr/configuration_trocr.py @@ -0,0 +1,146 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TrOCR model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class TrOCRConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`TrOCRForCausalLM`]. It is used to instantiate an + TrOCR model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the TrOCR + [microsoft/trocr-base-handwritten](https://huggingface.co/microsoft/trocr-base-handwritten) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50265): + Vocabulary size of the TrOCR model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`TrOCRForCausalLM`]. + d_model (`int`, *optional*, defaults to 1024): + Dimensionality of the layers and the pooler layer. + decoder_layers (`int`, *optional*, defaults to 12): + Number of decoder layers. + decoder_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer decoder. + decoder_ffn_dim (`int`, *optional*, defaults to 4096): + Dimensionality of the "intermediate" (often named feed-forward) layer in decoder. + activation_function (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the pooler. If string, `"gelu"`, `"relu"`, + `"silu"` and `"gelu_new"` are supported. + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, and pooler. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + activation_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for activations inside the fully connected layer. + init_std (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + decoder_layerdrop (`float`, *optional*, defaults to 0.0): + The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) + for more details. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + scale_embedding (`bool`, *optional*, defaults to `False`): + Whether or not to scale the word embeddings by sqrt(d_model). + use_learned_position_embeddings (`bool`, *optional*, defaults to `True`): + Whether or not to use learned position embeddings. If not, sinusoidal position embeddings will be used. + layernorm_embedding (`bool`, *optional*, defaults to `True`): + Whether or not to use a layernorm after the word + position embeddings. + + Example: + + ```python + >>> from transformers import TrOCRConfig, TrOCRForCausalLM + + >>> # Initializing a TrOCR-base style configuration + >>> configuration = TrOCRConfig() + + >>> # Initializing a model (with random weights) from the TrOCR-base style configuration + >>> model = TrOCRForCausalLM(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "trocr" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = { + "num_attention_heads": "decoder_attention_heads", + "hidden_size": "d_model", + "num_hidden_layers": "decoder_layers", + } + + def __init__( + self, + vocab_size=50265, + d_model=1024, + decoder_layers=12, + decoder_attention_heads=16, + decoder_ffn_dim=4096, + activation_function="gelu", + max_position_embeddings=512, + dropout=0.1, + attention_dropout=0.0, + activation_dropout=0.0, + decoder_start_token_id=2, + init_std=0.02, + decoder_layerdrop=0.0, + use_cache=True, + scale_embedding=False, + use_learned_position_embeddings=True, + layernorm_embedding=True, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + **kwargs, + ): + self.vocab_size = vocab_size + self.d_model = d_model + self.decoder_layers = decoder_layers + self.decoder_attention_heads = decoder_attention_heads + self.decoder_ffn_dim = decoder_ffn_dim + self.activation_function = activation_function + self.max_position_embeddings = max_position_embeddings + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.init_std = init_std + self.decoder_layerdrop = decoder_layerdrop + self.use_cache = use_cache + self.scale_embedding = scale_embedding + self.use_learned_position_embeddings = use_learned_position_embeddings + self.layernorm_embedding = layernorm_embedding + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + decoder_start_token_id=decoder_start_token_id, + **kwargs, + ) + + +__all__ = ["TrOCRConfig"] diff --git a/docs/transformers/src/transformers/models/trocr/convert_trocr_unilm_to_pytorch.py b/docs/transformers/src/transformers/models/trocr/convert_trocr_unilm_to_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..a787932b76946cdf1d56dc0f1edacab364214deb --- /dev/null +++ b/docs/transformers/src/transformers/models/trocr/convert_trocr_unilm_to_pytorch.py @@ -0,0 +1,237 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert TrOCR checkpoints from the unilm repository.""" + +import argparse +from pathlib import Path + +import requests +import torch +from PIL import Image + +from transformers import ( + RobertaTokenizer, + TrOCRConfig, + TrOCRForCausalLM, + TrOCRProcessor, + VisionEncoderDecoderModel, + ViTConfig, + ViTImageProcessor, + ViTModel, +) +from transformers.utils import logging + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + + +# here we list all keys to be renamed (original name on the left, our name on the right) +def create_rename_keys(encoder_config, decoder_config): + rename_keys = [] + for i in range(encoder_config.num_hidden_layers): + # encoder layers: output projection, 2 feedforward neural networks and 2 layernorms + rename_keys.append( + (f"encoder.deit.blocks.{i}.norm1.weight", f"encoder.encoder.layer.{i}.layernorm_before.weight") + ) + rename_keys.append((f"encoder.deit.blocks.{i}.norm1.bias", f"encoder.encoder.layer.{i}.layernorm_before.bias")) + rename_keys.append( + (f"encoder.deit.blocks.{i}.attn.proj.weight", f"encoder.encoder.layer.{i}.attention.output.dense.weight") + ) + rename_keys.append( + (f"encoder.deit.blocks.{i}.attn.proj.bias", f"encoder.encoder.layer.{i}.attention.output.dense.bias") + ) + rename_keys.append( + (f"encoder.deit.blocks.{i}.norm2.weight", f"encoder.encoder.layer.{i}.layernorm_after.weight") + ) + rename_keys.append((f"encoder.deit.blocks.{i}.norm2.bias", f"encoder.encoder.layer.{i}.layernorm_after.bias")) + rename_keys.append( + (f"encoder.deit.blocks.{i}.mlp.fc1.weight", f"encoder.encoder.layer.{i}.intermediate.dense.weight") + ) + rename_keys.append( + (f"encoder.deit.blocks.{i}.mlp.fc1.bias", f"encoder.encoder.layer.{i}.intermediate.dense.bias") + ) + rename_keys.append( + (f"encoder.deit.blocks.{i}.mlp.fc2.weight", f"encoder.encoder.layer.{i}.output.dense.weight") + ) + rename_keys.append((f"encoder.deit.blocks.{i}.mlp.fc2.bias", f"encoder.encoder.layer.{i}.output.dense.bias")) + + # cls token, position embeddings and patch embeddings of encoder + rename_keys.extend( + [ + ("encoder.deit.cls_token", "encoder.embeddings.cls_token"), + ("encoder.deit.pos_embed", "encoder.embeddings.position_embeddings"), + ("encoder.deit.patch_embed.proj.weight", "encoder.embeddings.patch_embeddings.projection.weight"), + ("encoder.deit.patch_embed.proj.bias", "encoder.embeddings.patch_embeddings.projection.bias"), + ("encoder.deit.norm.weight", "encoder.layernorm.weight"), + ("encoder.deit.norm.bias", "encoder.layernorm.bias"), + ] + ) + + return rename_keys + + +# we split up the matrix of each encoder layer into queries, keys and values +def read_in_q_k_v(state_dict, encoder_config): + for i in range(encoder_config.num_hidden_layers): + # queries, keys and values (only weights, no biases) + in_proj_weight = state_dict.pop(f"encoder.deit.blocks.{i}.attn.qkv.weight") + + state_dict[f"encoder.encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[ + : encoder_config.hidden_size, : + ] + state_dict[f"encoder.encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[ + encoder_config.hidden_size : encoder_config.hidden_size * 2, : + ] + state_dict[f"encoder.encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[ + -encoder_config.hidden_size :, : + ] + + +def rename_key(dct, old, new): + val = dct.pop(old) + dct[new] = val + + +# We will verify our results on an image of the IAM Handwriting Database +def prepare_img(checkpoint_url): + if "handwritten" in checkpoint_url: + url = "https://fki.tic.heia-fr.ch/static/img/a01-122-02-00.jpg" # industry + # url = "https://fki.tic.heia-fr.ch/static/img/a01-122-02-12.jpg" # have + # url = "https://fki.tic.heia-fr.ch/static/img/a01-122-02-10.jpg" # let + # url = "https://fki.tic.heia-fr.ch/static/img/a01-122-02.jpg" # + # url = "https://fki.tic.heia-fr.ch/static/img/a01-122.jpg" + elif "printed" in checkpoint_url or "stage1" in checkpoint_url: + url = "https://www.researchgate.net/profile/Dinh-Sang/publication/338099565/figure/fig8/AS:840413229350922@1577381536857/An-receipt-example-in-the-SROIE-2019-dataset_Q640.jpg" + im = Image.open(requests.get(url, stream=True).raw).convert("RGB") + return im + + +@torch.no_grad() +def convert_tr_ocr_checkpoint(checkpoint_url, pytorch_dump_folder_path): + """ + Copy/paste/tweak model's weights to our VisionEncoderDecoderModel structure. + """ + # define encoder and decoder configs based on checkpoint_url + encoder_config = ViTConfig(image_size=384, qkv_bias=False) + decoder_config = TrOCRConfig() + + # size of the architecture + if "base" in checkpoint_url: + decoder_config.encoder_hidden_size = 768 + elif "large" in checkpoint_url: + # use ViT-large encoder + encoder_config.hidden_size = 1024 + encoder_config.intermediate_size = 4096 + encoder_config.num_hidden_layers = 24 + encoder_config.num_attention_heads = 16 + decoder_config.encoder_hidden_size = 1024 + else: + raise ValueError("Should either find 'base' or 'large' in checkpoint URL") + + # the large-printed + stage1 checkpoints uses sinusoidal position embeddings, no layernorm afterwards + if "large-printed" in checkpoint_url or "stage1" in checkpoint_url: + decoder_config.tie_word_embeddings = False + decoder_config.activation_function = "relu" + decoder_config.max_position_embeddings = 1024 + decoder_config.scale_embedding = True + decoder_config.use_learned_position_embeddings = False + decoder_config.layernorm_embedding = False + + # load HuggingFace model + encoder = ViTModel(encoder_config, add_pooling_layer=False) + decoder = TrOCRForCausalLM(decoder_config) + model = VisionEncoderDecoderModel(encoder=encoder, decoder=decoder) + model.eval() + + # load state_dict of original model, rename some keys + state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu", check_hash=True)["model"] + + rename_keys = create_rename_keys(encoder_config, decoder_config) + for src, dest in rename_keys: + rename_key(state_dict, src, dest) + read_in_q_k_v(state_dict, encoder_config) + + # remove parameters we don't need + del state_dict["encoder.deit.head.weight"] + del state_dict["encoder.deit.head.bias"] + del state_dict["decoder.version"] + + # add prefix to decoder keys + for key, val in state_dict.copy().items(): + val = state_dict.pop(key) + if key.startswith("decoder") and "output_projection" not in key: + state_dict["decoder.model." + key] = val + else: + state_dict[key] = val + + # load state dict + model.load_state_dict(state_dict) + + # Check outputs on an image + image_processor = ViTImageProcessor(size=encoder_config.image_size) + tokenizer = RobertaTokenizer.from_pretrained("FacebookAI/roberta-large") + processor = TrOCRProcessor(image_processor, tokenizer) + + pixel_values = processor(images=prepare_img(checkpoint_url), return_tensors="pt").pixel_values + + # verify logits + decoder_input_ids = torch.tensor([[model.config.decoder.decoder_start_token_id]]) + outputs = model(pixel_values=pixel_values, decoder_input_ids=decoder_input_ids) + logits = outputs.logits + + expected_shape = torch.Size([1, 1, 50265]) + if "trocr-base-handwritten" in checkpoint_url: + expected_slice = torch.tensor( + [-1.4502, -4.6683, -0.5347, -2.9291, 9.1435, -3.0571, 8.9764, 1.7560, 8.7358, -1.5311] + ) + elif "trocr-large-handwritten" in checkpoint_url: + expected_slice = torch.tensor( + [-2.6437, -1.3129, -2.2596, -5.3455, 6.3539, 1.7604, 5.4991, 1.4702, 5.6113, 2.0170] + ) + elif "trocr-base-printed" in checkpoint_url: + expected_slice = torch.tensor( + [-5.6816, -5.8388, 1.1398, -6.9034, 6.8505, -2.4393, 1.2284, -1.0232, -1.9661, -3.9210] + ) + elif "trocr-large-printed" in checkpoint_url: + expected_slice = torch.tensor( + [-6.0162, -7.0959, 4.4155, -5.1063, 7.0468, -3.1631, 2.6466, -0.3081, -0.8106, -1.7535] + ) + + if "stage1" not in checkpoint_url: + assert logits.shape == expected_shape, "Shape of logits not as expected" + assert torch.allclose(logits[0, 0, :10], expected_slice, atol=1e-3), "First elements of logits not as expected" + + Path(pytorch_dump_folder_path).mkdir(exist_ok=True) + print(f"Saving model to {pytorch_dump_folder_path}") + model.save_pretrained(pytorch_dump_folder_path) + print(f"Saving processor to {pytorch_dump_folder_path}") + processor.save_pretrained(pytorch_dump_folder_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--checkpoint_url", + default="https://layoutlm.blob.core.windows.net/trocr/model_zoo/fairseq/trocr-base-handwritten.pt", + type=str, + help="URL to the original PyTorch checkpoint (.pth file).", + ) + parser.add_argument( + "--pytorch_dump_folder_path", default=None, type=str, help="Path to the folder to output PyTorch model." + ) + args = parser.parse_args() + convert_tr_ocr_checkpoint(args.checkpoint_url, args.pytorch_dump_folder_path) diff --git a/docs/transformers/src/transformers/models/trocr/modeling_trocr.py b/docs/transformers/src/transformers/models/trocr/modeling_trocr.py new file mode 100644 index 0000000000000000000000000000000000000000..fcc0de5b6d8a88197db9997aa09f8461c7abb53e --- /dev/null +++ b/docs/transformers/src/transformers/models/trocr/modeling_trocr.py @@ -0,0 +1,958 @@ +# coding=utf-8 +# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch TrOCR decoder model (based on RoBERTa).""" + +import copy +import math +from typing import Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss + +from ...activations import ACT2FN +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask +from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions +from ...modeling_utils import PreTrainedModel +from ...utils import add_start_docstrings, logging, replace_return_docstrings +from .configuration_trocr import TrOCRConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "TrOCRConfig" +_CHECKPOINT_FOR_DOC = "microsoft/trocr-base-handwritten" + + +# Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->TrOCR +class TrOCRLearnedPositionalEmbedding(nn.Embedding): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int): + # TrOCR is set up so that if padding_idx is specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. Other models don't have this hack + self.offset = 2 + super().__init__(num_embeddings + self.offset, embedding_dim) + + def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0): + """`input_ids' shape is expected to be [bsz x seqlen].""" + + bsz, seq_len = input_ids.shape[:2] + positions = torch.arange( + past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device + ).expand(bsz, -1) + + return super().forward(positions + self.offset) + + +# Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->TrOCR +class TrOCRScaledWordEmbedding(nn.Embedding): + """ + This module overrides nn.Embeddings' forward by multiplying with embeddings scale. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0): + super().__init__(num_embeddings, embedding_dim, padding_idx) + self.embed_scale = embed_scale + + def forward(self, input_ids: torch.Tensor): + return super().forward(input_ids) * self.embed_scale + + +class TrOCRSinusoidalPositionalEmbedding(nn.Module): + """This module produces sinusoidal positional embeddings of any length.""" + + def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None): + super().__init__() + self.offset = 2 + self.embedding_dim = embedding_dim + self.padding_idx = padding_idx + self.weights = self.get_embedding(num_positions, embedding_dim, padding_idx) + self.register_buffer("_float_tensor", torch.FloatTensor(1)) + + @staticmethod + def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None): + """ + Build sinusoidal embeddings. This matches the implementation in tensor2tensor, but differs slightly from the + description in Section 3.5 of "Attention Is All You Need". + """ + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb) + emb = torch.arange(num_embeddings, dtype=torch.int64).float().unsqueeze(1) * emb.unsqueeze(0) + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1) + if embedding_dim % 2 == 1: + # zero pad + emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) + if padding_idx is not None: + emb[padding_idx, :] = 0 + + return emb.to(torch.get_default_dtype()) + + @torch.no_grad() + def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0): + bsz, seq_len = input_ids.size() + # Create the position ids from the input token ids. Any padded tokens remain padded. + position_ids = self.create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length).to( + input_ids.device + ) + + # expand embeddings if needed + max_pos = self.padding_idx + 1 + seq_len + if self.weights is None or max_pos > self.weights.size(0): + # recompute/expand embeddings if needed + self.weights = self.get_embedding(max_pos, self.embedding_dim, self.padding_idx) + self.weights = self.weights.to(self._float_tensor) + + x = self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, -1).detach() + + return x + + def create_position_ids_from_input_ids( + self, input_ids: torch.Tensor, padding_idx: int, past_key_values_length: Optional[int] = 0 + ): + """ + Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding + symbols are ignored. This is modified from fairseq's `utils.make_positions`. + """ + # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. + mask = input_ids.ne(padding_idx).int() + incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask + return incremental_indices.long() + padding_idx + + +class TrOCRAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper.""" + + def __init__( + self, + config, + embed_dim: int, + num_heads: int, + kdim: Optional[int] = None, + vdim: Optional[int] = None, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + is_cross_attention: bool = False, + ): + super().__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + if not (self.head_dim * num_heads == self.embed_dim): + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + + self.k_proj = nn.Linear(self.kdim, embed_dim, bias=bias) + self.v_proj = nn.Linear(self.vdim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + bsz, tgt_len, embed_dim = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value + + +class TrOCRDecoderLayer(nn.Module): + def __init__(self, config: TrOCRConfig): + super().__init__() + self.embed_dim = config.hidden_size + + self.self_attn = TrOCRAttention( + config, + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + ) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + + if config.is_decoder: + self.encoder_attn = TrOCRAttention( + config, + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + kdim=config.cross_attention_hidden_size, + vdim=config.cross_attention_hidden_size, + dropout=config.attention_dropout, + is_decoder=True, + is_cross_attention=True, + ) + self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) + + self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) + self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + cross_attn_layer_head_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = True, + ): + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + encoder_hidden_states (`torch.FloatTensor`): + cross attention input to the layer of shape `(batch, seq_len, embed_dim)` + encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of + size *(decoder_attention_heads,)*. + past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + # Self Attention + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + # add present self-attn cache to positions 1,2 of present_key_value tuple + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=self_attn_past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Cross-Attention Block + cross_attn_present_key_value = None + cross_attn_weights = None + + if encoder_hidden_states is not None: + residual = hidden_states + + # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=cross_attn_past_key_value, + output_attentions=output_attentions, + ) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + # add cross-attn to positions 3,4 of present_key_value tuple + present_key_value = present_key_value + cross_attn_present_key_value + + # Fully Connected + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class TrOCRPreTrainedModel(PreTrainedModel): + config_class = TrOCRConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["TrOCRDecoderLayer"] + + def _init_weights(self, module): + std = self.config.init_std + if isinstance(module, (nn.Linear, nn.Conv1d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +TROCR_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`TrOCRConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +class TrOCRDecoder(TrOCRPreTrainedModel): + """ + Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`TrOCRDecoderLayer`] + + Args: + config: TrOCRConfig + """ + + def __init__(self, config: TrOCRConfig): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.decoder_layerdrop + self.padding_idx = config.pad_token_id + embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0 + + self.embed_tokens = TrOCRScaledWordEmbedding( + config.vocab_size, config.hidden_size, self.padding_idx, embed_scale=embed_scale + ) + + if config.use_learned_position_embeddings: + self.embed_positions = TrOCRLearnedPositionalEmbedding(config.max_position_embeddings, config.hidden_size) + else: + self.embed_positions = TrOCRSinusoidalPositionalEmbedding( + config.max_position_embeddings + self.padding_idx + 1, + config.hidden_size, + self.padding_idx, + ) + + if config.layernorm_embedding: + self.layernorm_embedding = nn.LayerNorm(config.hidden_size) + else: + self.layernorm_embedding = None + + self.layers = nn.ModuleList([TrOCRDecoderLayer(config) for _ in range(config.decoder_layers)]) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + head_mask=None, + cross_attn_head_mask=None, + past_key_values=None, + inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): + Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values + selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention + on hidden heads. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input = input_ids + input_ids = input_ids.view(-1, input.shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + input = inputs_embeds[:, :, -1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if self.config.use_learned_position_embeddings: + embed_pos = self.embed_positions(input, past_key_values_length=past_key_values_length) + else: + embed_pos = self.embed_positions(input_ids, past_key_values_length=past_key_values_length) + + hidden_states = inputs_embeds + embed_pos + + if self.layernorm_embedding is not None: + hidden_states = self.layernorm_embedding(hidden_states) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + input_shape = input.shape + + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + next_decoder_cache = () if use_cache else None + + # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): + if attn_mask is not None: + if attn_mask.size()[0] != (len(self.layers)): + raise ValueError( + f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: + continue + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + head_mask[idx] if head_mask is not None else None, + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, + None, + output_attentions, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=( + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None + ), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + +@add_start_docstrings( + "The TrOCR Model with a language modeling head. Can be used for summarization.", + TROCR_START_DOCSTRING, +) +class TrOCRDecoderWrapper(TrOCRPreTrainedModel): + """ + This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is + used in combination with the [`EncoderDecoderModel`] framework. + """ + + def __init__(self, config): + super().__init__(config) + self.decoder = TrOCRDecoder(config) + + def forward(self, *args, **kwargs): + return self.decoder(*args, **kwargs) + + +@add_start_docstrings( + "The TrOCR Decoder with a language modeling head. Can be used as the decoder part of [`EncoderDecoderModel`] and" + " [`VisionEncoderDecoder`].", + TROCR_START_DOCSTRING, +) +class TrOCRForCausalLM(TrOCRPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["output_projection.weight"] + + def __init__(self, config): + config = copy.deepcopy(config) + config.is_decoder = True + config.is_encoder_decoder = False + super().__init__(config) + self.model = TrOCRDecoderWrapper(config) + + self.output_projection = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.model.decoder.embed_tokens = value + + def get_output_embeddings(self): + return self.output_projection + + def set_output_embeddings(self, new_embeddings): + self.output_projection = new_embeddings + + def set_decoder(self, decoder): + self.model.decoder = decoder + + def get_decoder(self): + return self.model.decoder + + @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + if the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used + in the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional + tensors are only required when the model is used as a decoder in a Sequence to Sequence model. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + + Returns: + + Example: + + ```python + >>> from transformers import ( + ... TrOCRConfig, + ... TrOCRProcessor, + ... TrOCRForCausalLM, + ... ViTConfig, + ... ViTModel, + ... VisionEncoderDecoderModel, + ... ) + >>> import requests + >>> from PIL import Image + + >>> # TrOCR is a decoder model and should be used within a VisionEncoderDecoderModel + >>> # init vision2text model with random weights + >>> encoder = ViTModel(ViTConfig()) + >>> decoder = TrOCRForCausalLM(TrOCRConfig()) + >>> model = VisionEncoderDecoderModel(encoder=encoder, decoder=decoder) + + >>> # If you want to start from the pretrained model, load the checkpoint with `VisionEncoderDecoderModel` + >>> processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten") + >>> model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten") + + >>> # load image from the IAM dataset + >>> url = "https://fki.tic.heia-fr.ch/static/img/a01-122-02.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw).convert("RGB") + >>> pixel_values = processor(image, return_tensors="pt").pixel_values + >>> text = "industry, ' Mr. Brown commented icily. ' Let us have a" + + >>> # training + >>> model.config.decoder_start_token_id = processor.tokenizer.eos_token_id + >>> model.config.pad_token_id = processor.tokenizer.pad_token_id + >>> model.config.vocab_size = model.config.decoder.vocab_size + + >>> labels = processor.tokenizer(text, return_tensors="pt").input_ids + >>> outputs = model(pixel_values, labels=labels) + >>> loss = outputs.loss + >>> round(loss.item(), 2) + 5.30 + + >>> # inference + >>> generated_ids = model.generate(pixel_values) + >>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + >>> generated_text + 'industry, " Mr. Brown commented icily. " Let us have a' + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + logits = self.output_projection(outputs[0]) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + + +__all__ = ["TrOCRForCausalLM", "TrOCRPreTrainedModel"] diff --git a/docs/transformers/src/transformers/models/trocr/processing_trocr.py b/docs/transformers/src/transformers/models/trocr/processing_trocr.py new file mode 100644 index 0000000000000000000000000000000000000000..6fb5f281ecdfa81994e0e45cbcfb49e7f65df30d --- /dev/null +++ b/docs/transformers/src/transformers/models/trocr/processing_trocr.py @@ -0,0 +1,159 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Processor class for TrOCR. +""" + +import warnings +from contextlib import contextmanager +from typing import List, Union + +from ...image_processing_utils import BatchFeature +from ...image_utils import ImageInput +from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack +from ...tokenization_utils_base import PreTokenizedInput, TextInput + + +class TrOCRProcessorKwargs(ProcessingKwargs, total=False): + _defaults = {} + + +class TrOCRProcessor(ProcessorMixin): + r""" + Constructs a TrOCR processor which wraps a vision image processor and a TrOCR tokenizer into a single processor. + + [`TrOCRProcessor`] offers all the functionalities of [`ViTImageProcessor`/`DeiTImageProcessor`] and + [`RobertaTokenizer`/`XLMRobertaTokenizer`]. See the [`~TrOCRProcessor.__call__`] and [`~TrOCRProcessor.decode`] for + more information. + + Args: + image_processor ([`ViTImageProcessor`/`DeiTImageProcessor`], *optional*): + An instance of [`ViTImageProcessor`/`DeiTImageProcessor`]. The image processor is a required input. + tokenizer ([`RobertaTokenizer`/`XLMRobertaTokenizer`], *optional*): + An instance of [`RobertaTokenizer`/`XLMRobertaTokenizer`]. The tokenizer is a required input. + """ + + attributes = ["image_processor", "tokenizer"] + image_processor_class = "AutoImageProcessor" + tokenizer_class = "AutoTokenizer" + + def __init__(self, image_processor=None, tokenizer=None, **kwargs): + feature_extractor = None + if "feature_extractor" in kwargs: + warnings.warn( + "The `feature_extractor` argument is deprecated and will be removed in v5, use `image_processor`" + " instead.", + FutureWarning, + ) + feature_extractor = kwargs.pop("feature_extractor") + + image_processor = image_processor if image_processor is not None else feature_extractor + if image_processor is None: + raise ValueError("You need to specify an `image_processor`.") + if tokenizer is None: + raise ValueError("You need to specify a `tokenizer`.") + + super().__init__(image_processor, tokenizer) + self.current_processor = self.image_processor + self._in_target_context_manager = False + + def __call__( + self, + images: ImageInput = None, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + audio=None, + videos=None, + **kwargs: Unpack[TrOCRProcessorKwargs], + ) -> BatchFeature: + """ + When used in normal mode, this method forwards all its arguments to AutoImageProcessor's + [`~AutoImageProcessor.__call__`] and returns its output. If used in the context + [`~TrOCRProcessor.as_target_processor`] this method forwards all its arguments to TrOCRTokenizer's + [`~TrOCRTokenizer.__call__`]. Please refer to the docstring of the above two methods for more information. + """ + # For backward compatibility + if self._in_target_context_manager: + return self.current_processor(images, **kwargs) + + if images is None and text is None: + raise ValueError("You need to specify either an `images` or `text` input to process.") + + output_kwargs = self._merge_kwargs( + TrOCRProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + + if images is not None: + inputs = self.image_processor(images, **output_kwargs["images_kwargs"]) + if text is not None: + encodings = self.tokenizer(text, **output_kwargs["text_kwargs"]) + + if text is None: + return inputs + elif images is None: + return encodings + else: + inputs["labels"] = encodings["input_ids"] + return inputs + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to TrOCRTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please refer + to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to TrOCRTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to the + docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + @contextmanager + def as_target_processor(self): + """ + Temporarily sets the tokenizer for processing the input. Useful for encoding the labels when fine-tuning TrOCR. + """ + warnings.warn( + "`as_target_processor` is deprecated and will be removed in v5 of Transformers. You can process your " + "labels by using the argument `text` of the regular `__call__` method (either in the same call as " + "your images inputs, or in a separate call." + ) + self._in_target_context_manager = True + self.current_processor = self.tokenizer + yield + self.current_processor = self.image_processor + self._in_target_context_manager = False + + @property + def feature_extractor_class(self): + warnings.warn( + "`feature_extractor_class` is deprecated and will be removed in v5. Use `image_processor_class` instead.", + FutureWarning, + ) + return self.image_processor_class + + @property + def feature_extractor(self): + warnings.warn( + "`feature_extractor` is deprecated and will be removed in v5. Use `image_processor` instead.", + FutureWarning, + ) + return self.image_processor + + +__all__ = ["TrOCRProcessor"] diff --git a/docs/transformers/src/transformers/models/tvp/__init__.py b/docs/transformers/src/transformers/models/tvp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7388075201e03c4ee182310085686cef1c6b0256 --- /dev/null +++ b/docs/transformers/src/transformers/models/tvp/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_tvp import * + from .image_processing_tvp import * + from .modeling_tvp import * + from .processing_tvp import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/tvp/configuration_tvp.py b/docs/transformers/src/transformers/models/tvp/configuration_tvp.py new file mode 100644 index 0000000000000000000000000000000000000000..803acb220d081fe3140dafd8843491c61d396004 --- /dev/null +++ b/docs/transformers/src/transformers/models/tvp/configuration_tvp.py @@ -0,0 +1,201 @@ +# coding=utf-8 +# Copyright 2023 The Intel AIA Team Authors, and HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License=, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing=, software +# distributed under the License is distributed on an "AS IS" BASIS=, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND=, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TVP model configuration""" + +import copy + +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ...utils.backbone_utils import verify_backbone_config_arguments +from ..auto import CONFIG_MAPPING + + +logger = logging.get_logger(__name__) + + +class TvpConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`TvpModel`]. It is used to instantiate an Tvp + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the Tvp + [Intel/tvp-base](https://huggingface.co/Intel/tvp-base) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + backbone_config (`PretrainedConfig` or `dict`, *optional*): + The configuration of the backbone model. + backbone (`str`, *optional*): + Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this + will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone` + is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights. + use_pretrained_backbone (`bool`, *optional*, defaults to `False`): + Whether to use pretrained weights for the backbone. + use_timm_backbone (`bool`, *optional*, defaults to `False`): + Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers + library. + backbone_kwargs (`dict`, *optional*): + Keyword arguments to be passed to AutoBackbone when loading from a checkpoint + e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set. + distance_loss_weight (`float`, *optional*, defaults to 1.0): + The weight of distance loss. + duration_loss_weight (`float`, *optional*, defaults to 0.1): + The weight of duration loss. + visual_prompter_type (`str`, *optional*, defaults to `"framepad"`): + Visual prompt type. The type of padding. Framepad means padding on each frame. Should be one of "framepad" + or "framedownpad" + visual_prompter_apply (`str`, *optional*, defaults to `"replace"`): + The way of applying visual prompt. Replace means use the value of prompt to change the original value in + visual inputs. Should be one of "replace", or "add", or "remove". + visual_prompt_size (`int`, *optional*, defaults to 96): + The size of visual prompt. + max_img_size (`int`, *optional*, defaults to 448): + The maximum size of frame. + num_frames (`int`, *optional*, defaults to 48): + The number of frames extracted from a video. + vocab_size (`int`, *optional*, defaults to 30522): + Vocabulary size of the Tvp text model. Defines the number of different tokens that can be represented by + the `inputs_ids` passed when calling [`TvpModel`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + max_grid_col_position_embeddings (`int`, *optional*, defaults to 100): + The largest number of horizontal patches from a video frame. + max_grid_row_position_embeddings (`int`, *optional*, defaults to 100): + The largest number of vertical patches from a video frame. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability of hidden layers. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability of attention layers. + """ + + model_type = "tvp" + + def __init__( + self, + backbone_config=None, + backbone=None, + use_pretrained_backbone=False, + use_timm_backbone=False, + backbone_kwargs=None, + distance_loss_weight=1.0, + duration_loss_weight=0.1, + visual_prompter_type="framepad", + visual_prompter_apply="replace", + visual_prompt_size=96, + max_img_size=448, + num_frames=48, + vocab_size=30522, + hidden_size=768, + intermediate_size=3072, + num_hidden_layers=12, + num_attention_heads=12, + max_position_embeddings=512, + max_grid_col_position_embeddings=100, + max_grid_row_position_embeddings=100, + hidden_dropout_prob=0.1, + hidden_act="gelu", + layer_norm_eps=1e-12, + initializer_range=0.02, + attention_probs_dropout_prob=0.1, + **kwargs, + ): + super().__init__(**kwargs) + if backbone_config is None and backbone is None: + logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.") + backbone_config = CONFIG_MAPPING["resnet"](out_features=["stage4"]) + elif isinstance(backbone_config, dict): + backbone_model_type = backbone_config.get("model_type") + config_class = CONFIG_MAPPING[backbone_model_type] + backbone_config = config_class.from_dict(backbone_config) + + verify_backbone_config_arguments( + use_timm_backbone=use_timm_backbone, + use_pretrained_backbone=use_pretrained_backbone, + backbone=backbone, + backbone_config=backbone_config, + backbone_kwargs=backbone_kwargs, + ) + + self.backbone_config = backbone_config + self.backbone = backbone + self.use_pretrained_backbone = use_pretrained_backbone + self.use_timm_backbone = use_timm_backbone + self.backbone_kwargs = backbone_kwargs + self.distance_loss_weight = distance_loss_weight + self.duration_loss_weight = duration_loss_weight + self.visual_prompter_type = visual_prompter_type + self.visual_prompter_apply = visual_prompter_apply + self.visual_prompt_size = visual_prompt_size + self.max_img_size = max_img_size + self.num_frames = num_frames + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.max_position_embeddings = max_position_embeddings + self.max_grid_col_position_embeddings = max_grid_col_position_embeddings + self.max_grid_row_position_embeddings = max_grid_row_position_embeddings + self.layer_norm_eps = layer_norm_eps + self.hidden_dropout_prob = hidden_dropout_prob + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.attention_probs_dropout_prob = attention_probs_dropout_prob + + @classmethod + def from_backbone_config(cls, backbone_config: PretrainedConfig, **kwargs): + """Instantiate a [`TvpConfig`] (or a derived class) from a pre-trained backbone model configuration. + + Args: + backbone_config ([`PretrainedConfig`]): + The backbone configuration. + Returns: + [`TvpConfig`]: An instance of a configuration object + """ + return cls(backbone_config=backbone_config, **kwargs) + + def to_dict(self): + """ + Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. + + Returns: + `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, + """ + output = copy.deepcopy(self.__dict__) + if output["backbone_config"] is not None: + output["backbone_config"] = self.backbone_config.to_dict() + output["model_type"] = self.__class__.model_type + return output + + +__all__ = ["TvpConfig"] diff --git a/docs/transformers/src/transformers/models/tvp/image_processing_tvp.py b/docs/transformers/src/transformers/models/tvp/image_processing_tvp.py new file mode 100644 index 0000000000000000000000000000000000000000..85b120206217cb04ee43b847c75ac85f9aa2be2a --- /dev/null +++ b/docs/transformers/src/transformers/models/tvp/image_processing_tvp.py @@ -0,0 +1,481 @@ +# coding=utf-8 +# Copyright 2023 The Intel AIA Team Authors, and HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License=, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing=, software +# distributed under the License is distributed on an "AS IS" BASIS=, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND=, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for TVP.""" + +from typing import Dict, Iterable, List, Optional, Tuple, Union + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import ( + PaddingMode, + flip_channel_order, + pad, + resize, + to_channel_dimension_format, +) +from ...image_utils import ( + IMAGENET_STANDARD_MEAN, + IMAGENET_STANDARD_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + get_image_size, + is_valid_image, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging + + +if is_vision_available(): + import PIL + + +logger = logging.get_logger(__name__) + + +# Copied from transformers.models.vivit.image_processing_vivit.make_batched +def make_batched(videos) -> List[List[ImageInput]]: + if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]): + return videos + + elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]): + return [videos] + + elif is_valid_image(videos): + return [[videos]] + + raise ValueError(f"Could not make batched video from {videos}") + + +def get_resize_output_image_size( + input_image: np.ndarray, + max_size: int = 448, + input_data_format: Optional[Union[str, ChannelDimension]] = None, +) -> Tuple[int, int]: + height, width = get_image_size(input_image, input_data_format) + if height >= width: + ratio = width * 1.0 / height + new_height = max_size + new_width = new_height * ratio + else: + ratio = height * 1.0 / width + new_width = max_size + new_height = new_width * ratio + size = (int(new_height), int(new_width)) + + return size + + +class TvpImageProcessor(BaseImageProcessor): + r""" + Constructs a Tvp image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the + `do_resize` parameter in the `preprocess` method. + size (`Dict[str, int]` *optional*, defaults to `{"longest_edge": 448}`): + Size of the output image after resizing. The longest edge of the image will be resized to + `size["longest_edge"]` while maintaining the aspect ratio of the original image. Can be overriden by + `size` in the `preprocess` method. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`): + Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the + `preprocess` method. + do_center_crop (`bool`, *optional*, defaults to `True`): + Whether to center crop the image to the specified `crop_size`. Can be overridden by the `do_center_crop` + parameter in the `preprocess` method. + crop_size (`Dict[str, int]`, *optional*, defaults to `{"height": 448, "width": 448}`): + Size of the image after applying the center crop. Can be overridden by the `crop_size` parameter in the + `preprocess` method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale` + parameter in the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Defines the scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter + in the `preprocess` method. + do_pad (`bool`, *optional*, defaults to `True`): + Whether to pad the image. Can be overridden by the `do_pad` parameter in the `preprocess` method. + pad_size (`Dict[str, int]`, *optional*, defaults to `{"height": 448, "width": 448}`): + Size of the image after applying the padding. Can be overridden by the `pad_size` parameter in the + `preprocess` method. + constant_values (`Union[float, Iterable[float]]`, *optional*, defaults to 0): + The fill value to use when padding the image. + pad_mode (`PaddingMode`, *optional*, defaults to `PaddingMode.CONSTANT`): + Use what kind of mode in padding. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess` + method. + do_flip_channel_order (`bool`, *optional*, defaults to `True`): + Whether to flip the color channels from RGB to BGR. Can be overridden by the `do_flip_channel_order` + parameter in the `preprocess` method. + image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize: bool = True, + size: Dict[str, int] = None, + resample: PILImageResampling = PILImageResampling.BILINEAR, + do_center_crop: bool = True, + crop_size: Dict[str, int] = None, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_pad: bool = True, + pad_size: Dict[str, int] = None, + constant_values: Union[float, Iterable[float]] = 0, + pad_mode: PaddingMode = PaddingMode.CONSTANT, + do_normalize: bool = True, + do_flip_channel_order: bool = True, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"longest_edge": 448} + crop_size = crop_size if crop_size is not None else {"height": 448, "width": 448} + pad_size = pad_size if pad_size is not None else {"height": 448, "width": 448} + + self.do_resize = do_resize + self.size = size + self.do_center_crop = do_center_crop + self.crop_size = crop_size + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_pad = do_pad + self.pad_size = pad_size + self.constant_values = constant_values + self.pad_mode = pad_mode + self.do_normalize = do_normalize + self.do_flip_channel_order = do_flip_channel_order + self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN + self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD + + def resize( + self, + image: np.ndarray, + size: Dict[str, int], + resample: PILImageResampling = PILImageResampling.BILINEAR, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize an image. + + Args: + image (`np.ndarray`): + Image to resize. + size (`Dict[str, int]`): + Size of the output image. If `size` is of the form `{"height": h, "width": w}`, the output image will + have the size `(h, w)`. If `size` is of the form `{"longest_edge": s}`, the output image will have its + longest edge of length `s` while keeping the aspect ratio of the original image. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`): + Resampling filter to use when resiizing the image. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be the same as the input image. + input_data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + size = get_size_dict(size, default_to_square=False) + if "height" in size and "width" in size: + output_size = (size["height"], size["width"]) + elif "longest_edge" in size: + output_size = get_resize_output_image_size(image, size["longest_edge"], input_data_format) + else: + raise ValueError(f"Size must have 'height' and 'width' or 'longest_edge' as keys. Got {size.keys()}") + + return resize( + image, + size=output_size, + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + def pad_image( + self, + image: np.ndarray, + pad_size: Dict[str, int] = None, + constant_values: Union[float, Iterable[float]] = 0, + pad_mode: PaddingMode = PaddingMode.CONSTANT, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ): + """ + Pad an image with zeros to the given size. + + Args: + image (`np.ndarray`): + Image to pad. + pad_size (`Dict[str, int]`) + Size of the output image with pad. + constant_values (`Union[float, Iterable[float]]`) + The fill value to use when padding the image. + pad_mode (`PaddingMode`) + The pad mode, default to PaddingMode.CONSTANT + data_format (`ChannelDimension` or `str`, *optional*) + The channel dimension format of the image. If not provided, it will be the same as the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + height, width = get_image_size(image, channel_dim=input_data_format) + max_height = pad_size.get("height", height) + max_width = pad_size.get("width", width) + + pad_right, pad_bottom = max_width - width, max_height - height + if pad_right < 0 or pad_bottom < 0: + raise ValueError("The padding size must be greater than image size") + + padding = ((0, pad_bottom), (0, pad_right)) + padded_image = pad( + image, + padding, + mode=pad_mode, + constant_values=constant_values, + data_format=data_format, + input_data_format=input_data_format, + ) + + return padded_image + + def _preprocess_image( + self, + image: ImageInput, + do_resize: Optional[bool] = None, + size: Dict[str, int] = None, + resample: PILImageResampling = None, + do_center_crop: Optional[bool] = None, + crop_size: Dict[str, int] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_pad: bool = True, + pad_size: Dict[str, int] = None, + constant_values: Union[float, Iterable[float]] = None, + pad_mode: PaddingMode = None, + do_normalize: Optional[bool] = None, + do_flip_channel_order: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """Preprocesses a single image.""" + + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_pad=do_pad, + size_divisibility=pad_size, # here the pad() method simply requires the pad_size argument. + do_center_crop=do_center_crop, + crop_size=crop_size, + do_resize=do_resize, + size=size, + resample=resample, + ) + + # All transformations expect numpy arrays. + image = to_numpy_array(image) + + if do_resize: + image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format) + + if do_center_crop: + image = self.center_crop(image, size=crop_size, input_data_format=input_data_format) + + if do_rescale: + image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + + if do_normalize: + image = self.normalize( + image=image.astype(np.float32), mean=image_mean, std=image_std, input_data_format=input_data_format + ) + + if do_pad: + image = self.pad_image( + image=image, + pad_size=pad_size, + constant_values=constant_values, + pad_mode=pad_mode, + input_data_format=input_data_format, + ) + + # the pretrained checkpoints assume images are BGR, not RGB + if do_flip_channel_order: + image = flip_channel_order(image=image, input_data_format=input_data_format) + + image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + + return image + + @filter_out_non_signature_kwargs() + def preprocess( + self, + videos: Union[ImageInput, List[ImageInput], List[List[ImageInput]]], + do_resize: Optional[bool] = None, + size: Dict[str, int] = None, + resample: PILImageResampling = None, + do_center_crop: Optional[bool] = None, + crop_size: Dict[str, int] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_pad: Optional[bool] = None, + pad_size: Dict[str, int] = None, + constant_values: Union[float, Iterable[float]] = None, + pad_mode: PaddingMode = None, + do_normalize: Optional[bool] = None, + do_flip_channel_order: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: ChannelDimension = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> PIL.Image.Image: + """ + Preprocess an image or batch of images. + + Args: + videos (`ImageInput` or `List[ImageInput]` or `List[List[ImageInput]]`): + Frames to preprocess. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Size of the image after applying resize. + resample (`PILImageResampling`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only + has an effect if `do_resize` is set to `True`. + do_center_crop (`bool`, *optional*, defaults to `self.do_centre_crop`): + Whether to centre crop the image. + crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`): + Size of the image after applying the centre crop. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image values between [0 - 1]. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_pad (`bool`, *optional*, defaults to `True`): + Whether to pad the image. Can be overridden by the `do_pad` parameter in the `preprocess` method. + pad_size (`Dict[str, int]`, *optional*, defaults to `{"height": 448, "width": 448}`): + Size of the image after applying the padding. Can be overridden by the `pad_size` parameter in the + `preprocess` method. + constant_values (`Union[float, Iterable[float]]`, *optional*, defaults to 0): + The fill value to use when padding the image. + pad_mode (`PaddingMode`, *optional*, defaults to "PaddingMode.CONSTANT"): + Use what kind of mode in padding. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + do_flip_channel_order (`bool`, *optional*, defaults to `self.do_flip_channel_order`): + Whether to flip the channel order of the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the inferred channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + do_resize = do_resize if do_resize is not None else self.do_resize + resample = resample if resample is not None else self.resample + do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_pad = do_pad if do_pad is not None else self.do_pad + pad_size = pad_size if pad_size is not None else self.pad_size + constant_values = constant_values if constant_values is not None else self.constant_values + pad_mode = pad_mode if pad_mode else self.pad_mode + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + do_flip_channel_order = ( + do_flip_channel_order if do_flip_channel_order is not None else self.do_flip_channel_order + ) + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + + size = size if size is not None else self.size + size = get_size_dict(size, default_to_square=False) + crop_size = crop_size if crop_size is not None else self.crop_size + crop_size = get_size_dict(crop_size, param_name="crop_size") + + if not valid_images(videos): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + videos = make_batched(videos) + + videos = [ + np.array( + [ + self._preprocess_image( + image=img, + do_resize=do_resize, + size=size, + resample=resample, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_pad=do_pad, + pad_size=pad_size, + constant_values=constant_values, + pad_mode=pad_mode, + do_normalize=do_normalize, + do_flip_channel_order=do_flip_channel_order, + image_mean=image_mean, + image_std=image_std, + data_format=data_format, + input_data_format=input_data_format, + ) + for img in video + ] + ) + for video in videos + ] + + data = {"pixel_values": videos} + return BatchFeature(data=data, tensor_type=return_tensors) + + +__all__ = ["TvpImageProcessor"] diff --git a/docs/transformers/src/transformers/models/tvp/modeling_tvp.py b/docs/transformers/src/transformers/models/tvp/modeling_tvp.py new file mode 100644 index 0000000000000000000000000000000000000000..9e93e6e9001b7d8c94aa4cb1aea2d8c15f6c5be7 --- /dev/null +++ b/docs/transformers/src/transformers/models/tvp/modeling_tvp.py @@ -0,0 +1,985 @@ +# coding=utf-8 +# Copyright 2023 The Intel AIA Team Authors, and HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License=, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing=, software +# distributed under the License is distributed on an "AS IS" BASIS=, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND=, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch TVP Model""" + +import math +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +import torch.utils.checkpoint +from torch import nn + +from ...activations import ACT2FN +from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ModelOutput +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import prune_linear_layer +from ...utils import logging +from ...utils.backbone_utils import load_backbone +from .configuration_tvp import TvpConfig + + +logger = logging.get_logger(__name__) + + +@dataclass +class TvpVideoGroundingOutput(ModelOutput): + """ + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`): + Temporal-Distance IoU loss for video grounding. + logits (`torch.FloatTensor` of shape `(batch_size, 2)`): + Contains start_time/duration and end_time/duration. It is the time slot of the videos corresponding to the + input texts. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of + the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +class TvpLoss(nn.Module): + """ + This class computes the losses for `TvpForVideoGrounding`. The process happens in two steps: 1) we compute + hungarian assignment between ground truth boxes and the outputs of the model 2) we supervise each pair of matched + ground-truth / prediction (supervise class and box). + + Args: + losses (`List[str]`): + List of all the losses to be applied. + """ + + def __init__(self, losses): + super().__init__() + self.loss_map = { + "iou": self.loss_iou, + "distance": self.loss_distance, + "duration": self.loss_duration, + } + for loss in losses: + if loss not in self.loss_map: + raise ValueError(f"Loss {loss} not supported") + + self.losses = losses + + def loss_iou(self, start_time, end_time, candidates_start_time, candidates_end_time, duration): + """ + Measure the intersection over union. + """ + inter = torch.min(candidates_end_time, end_time) - torch.max(candidates_start_time, start_time) + union = torch.max(candidates_end_time, end_time) - torch.min(candidates_start_time, start_time) + iou = 1 - inter.clamp(min=0) / union + + return iou + + def loss_distance(self, start_time, end_time, candidates_start_time, candidates_end_time, duration): + """ + Measure the distance of mid points. + """ + mid_candidates = torch.div(torch.add(candidates_start_time, candidates_end_time), 2.0) + mid_groundtruth = torch.div(torch.add(start_time, end_time), 2.0) + distance_diff = torch.div( + torch.max(mid_candidates, mid_groundtruth) - torch.min(mid_candidates, mid_groundtruth), duration + ).clamp(min=0.2) + + return distance_diff + + def loss_duration(self, start_time, end_time, candidates_start_time, candidates_end_time, duration): + """ + Measure the difference of duration. + """ + duration_candidates = torch.sub(candidates_end_time, candidates_start_time) + duration_groundtruth = torch.sub(end_time, start_time) + duration_diff = torch.square(torch.div(torch.sub(duration_candidates, duration_groundtruth), duration)) + duration_diff = duration_diff.clamp(min=0.4) + + return duration_diff + + def forward(self, logits, labels): + """ + This performs the loss computation. + + Args: + logits (`torch.FloatTensor`): + The output logits of head module. + labels (`List[torch.FloatTensor]`): + List of tensors ([start, end, duration]), which contains start time, end time of the video corresponding to the text, and also the duration. + """ + duration, start_time, end_time = labels + candidates = torch.mul(logits, duration) + candidates_start_time, candidates_end_time = candidates[:, 0].float(), candidates[:, 1].float() + + losses_dict = {} + for loss in self.losses: + losses_dict.update( + {loss: self.loss_map[loss](start_time, end_time, candidates_start_time, candidates_end_time, duration)} + ) + + return losses_dict + + +class TvpVisionModel(nn.Module): + def __init__(self, config): + super().__init__() + self.backbone = load_backbone(config) + + if config.backbone_config is not None: + in_channels = config.backbone_config.hidden_sizes[-1] + elif hasattr(self.backbone, "config") and hasattr(self.backbone.config, "hidden_sizes"): + in_channels = self.backbone.config.hidden_sizes[-1] + elif hasattr(self.backbone, "config") and hasattr(self.backbone.config, "hidden_size"): + in_channels = self.backbone.config.hidden_size + else: + raise ValueError("Backbone config not found") + + self.grid_encoder_conv = nn.Conv2d( + in_channels, + config.hidden_size, + kernel_size=3, + stride=1, + padding=1, + groups=1, + bias=False, + ) + + def forward(self, pixel_values): + batch_size, num_frames, num_channels, height, width = pixel_values.shape + # (batch_size * num_frames, num_channels, height, width) + pixel_values = pixel_values.view(batch_size * num_frames, num_channels, height, width) + grid_feat_outputs = self.backbone(pixel_values)["feature_maps"][0] + grid = self.grid_encoder_conv(grid_feat_outputs) + grid = nn.functional.max_pool2d(grid, kernel_size=2, stride=2) + grid = nn.functional.relu(grid, inplace=True) + new_channel, new_height, new_width = grid.shape[-3:] + # (batch_size, num_frames, num_channels, height, width) + grid = grid.view(batch_size, num_frames, new_channel, new_height, new_width) + # (batch_size, num_frames, height, width, num_channels) + grid = grid.permute(0, 1, 3, 4, 2) + return grid + + +class TvpVisualInputEmbedding(nn.Module): + """ + Takes input of both image and video (multi-frame) + """ + + def __init__(self, config): + super().__init__() + # sequence embedding + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.row_position_embeddings = nn.Embedding(config.max_grid_row_position_embeddings, config.hidden_size) + self.col_position_embeddings = nn.Embedding(config.max_grid_col_position_embeddings, config.hidden_size) + self.token_type_embeddings = nn.Embedding(1, config.hidden_size) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.max_grid_row_position_embeddings = config.max_grid_row_position_embeddings + self.max_grid_col_position_embeddings = config.max_grid_col_position_embeddings + + def interpolate_pos_encoding(self, embedding: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained pad weights , to be able to use the model on collection of high + resolution images (high resolution videos). + + """ + h0 = w0 = 1 + # if height dimension is to be interpolated + if height > self.max_grid_row_position_embeddings: + h0 = height / self.max_grid_row_position_embeddings + # if width dimension is to be interpolated + if width > self.max_grid_col_position_embeddings: + w0 = width / self.max_grid_col_position_embeddings + embedding = embedding.permute(0, 3, 1, 2) # (batch_size, hidden_dim, height, width) + embedding = nn.functional.interpolate( + embedding, + scale_factor=(h0, w0), + mode="bicubic", + align_corners=False, + ) + embedding = embedding.permute(0, 2, 3, 1) # (batch_size, height, width, hidden_dim) + return embedding + + def add_2d_positional_embeddings(self, grid, interpolate_pos_encoding: bool = False): + """ + Args: + grid: (batch_size, height, width, hidden_dim) + interpolate_pos_encoding: (`bool`, *optional*, defaults to `False`): + Whether to interpolate the pre-trained position encodings. + Returns: + grid + col_position_embeddings.view(*col_shape): (batch_size, *, height, width, hidden_dim) + """ + batch_size, height, width, hidden_dim = grid.shape + + # add row-wise position embeddings + # (height, ) + row_height = min(self.max_grid_row_position_embeddings, height) + row_position_ids = torch.arange(row_height, dtype=torch.long, device=grid.device) + # (height, hidden_dim) + row_position_embeddings = self.row_position_embeddings(row_position_ids) + row_shape = (1,) * (len(grid.shape) - 3) + (row_height, 1, hidden_dim) + # (batch_size, height, 1, hidden_dim) + row_position_embeddings = row_position_embeddings.view(*row_shape) + + # add column-wise position embeddings + row_width = min(self.max_grid_col_position_embeddings, width) + col_position_ids = torch.arange(row_width, dtype=torch.long, device=grid.device) + # (width, hidden_dim) + col_position_embeddings = self.col_position_embeddings(col_position_ids) + col_shape = (batch_size, 1, row_width, hidden_dim) + # (batch_size, 1, width, hidden_dim) + col_position_embeddings = col_position_embeddings.view(*col_shape) + # (batch_size, height, width, hidden_dim) + positional_embeddings = row_position_embeddings + col_position_embeddings + + # This interpolation gets triggered ONLY when the input image dim is larger in any dimenstion than the original position embeddings + if interpolate_pos_encoding and ( + height > self.max_grid_row_position_embeddings or width > self.max_grid_col_position_embeddings + ): + grid = grid + self.interpolate_pos_encoding(positional_embeddings, height, width) + else: + grid = grid + positional_embeddings + return grid + + def forward(self, grid, interpolate_pos_encoding: bool = False): + """ + Args: + grid: Array of shape (batch_size, num_frames, height, width, num_channels). + It contains processed frames extracted from videos, and is generated by Tvp image preprocessor. Note, + num_frames can be 1 + interpolate_pos_encoding: (bool, *optional*, defaults to `False`): + Whether to interpolate the pre-trained position encodings. + + Returns: + embeddings: The embedding of grid with size (batch_size, height*width, num_channels) + + """ + batch_size, num_frames, height, width, num_channels = grid.shape + # temporal mean pooling, (batch_size, height, width, hidden_size) + grid = grid.mean(1) + grid = self.add_2d_positional_embeddings(grid, interpolate_pos_encoding=interpolate_pos_encoding) + # image token sequence, (batch_size, height*width, num_channels) + visual_tokens = grid.view(batch_size, -1, num_channels) + visual_tokens_shape = visual_tokens.shape[:-1] + device = visual_tokens.device + + # image token type embeddings. + token_type_ids = torch.zeros(visual_tokens_shape, dtype=torch.long, device=device) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = visual_tokens + token_type_embeddings + embeddings = self.layer_norm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class TvpTextInputEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + device = input_ids.device if input_ids is not None else inputs_embeds.device + if position_ids is None: + position_ids = torch.arange(seq_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).expand(input_shape) + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + position_embeddings + token_type_embeddings + embeddings = self.layer_norm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class TvpAttention(nn.Module): + def __init__(self, config): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size {config.hidden_size} is not a multiple of the number of attention heads {config.num_attention_heads}" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + self.attn_dropout = nn.Dropout(config.attention_probs_dropout_prob) + + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + mask = torch.ones(self.num_attention_heads, self.attention_head_size) + heads = set(heads) - self.pruned_heads # Convert to set and remove already pruned heads + for head in heads: + # Compute how many pruned heads are before the head and move the index accordingly + head = head - sum(1 if h < head else 0 for h in self.pruned_heads) + mask[head] = 0 + mask = mask.view(-1).contiguous().eq(1) + index = torch.arange(len(mask))[mask].long() + + # Prune linear layers + self.query = prune_linear_layer(self.query, index) + self.key = prune_linear_layer(self.key, index) + self.value = prune_linear_layer(self.value, index) + self.dense = prune_linear_layer(self.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.num_attention_heads = self.num_attention_heads - len(heads) + self.all_head_size = self.attention_head_size * self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def _reshape(self, tensor: torch.Tensor, sequence_length: int, batch_size: int): + return ( + tensor.view(batch_size, sequence_length, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + .contiguous() + ) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + output_attentions: Optional[bool] = None, + ): + batch_size, sequence_length = hidden_states.shape[:2] + mixed_query_layer = self.query(hidden_states) + + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) + + query_layer = self._reshape(mixed_query_layer, sequence_length, batch_size) + key_layer = self._reshape(mixed_key_layer, sequence_length, batch_size) + value_layer = self._reshape(mixed_value_layer, sequence_length, batch_size) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.attn_dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + attn_output = torch.matmul(attention_probs, value_layer) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(batch_size, sequence_length, self.all_head_size) + + attn_output = self.dense(attn_output) + attn_output = self.dropout(attn_output) + attn_output = self.layer_norm(attn_output + hidden_states) + # add attentions if we output them + outputs = (attn_output, attention_probs) if output_attentions else (attn_output,) + return outputs + + +# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->Tvp +class TvpIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class TvpOutputLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.layer_norm(hidden_states + input_tensor) + return hidden_states + + +class TvpEncodeLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.attention = TvpAttention(config) + self.intermediate = TvpIntermediate(config) + self.output = TvpOutputLayer(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + output_attentions: Optional[bool] = None, + ): + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + outputs = (layer_output,) + outputs + return outputs + + +class TvpEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList([TvpEncodeLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + return_dict = return_dict if return_dict is not None else self.config.return_dict + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + all_hidden_states = () + all_attentions = () + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.__call__, + hidden_states, + attention_mask, + (head_mask[i] if head_mask is not None else None), + output_attentions, + ) + else: + layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i], output_attentions) + + hidden_states = layer_outputs[0] + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + outputs = (hidden_states,) + if output_hidden_states: + outputs = outputs + (all_hidden_states,) + if output_attentions: + outputs = outputs + (all_attentions,) + return outputs # last-layer hidden state, (all hidden states), (all attentions) + + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states if output_hidden_states else None, + attentions=all_attentions if output_attentions else None, + ) + + +# Copied from transformers.models.bert.modeling_bert.BertPooler with Bert->Tvp +class TvpPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class TvpPreTrainedModel(PreTrainedModel): + """An abstract class to handle weights initialization and + a simple interface for downloading and loading pretrained models. + """ + + config_class = TvpConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + if isinstance(module, nn.Conv2d): + nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + +TVP_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it + as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`TvpConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +TVP_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input + IDs?](../glossary#input-ids) + + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`TvpImageProcessor`]. See [`TvpImageProcessor.__call__`] + for details. + + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + + interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): + Whether to interpolate the pre-trained image pad prompter encodings and positional encodings. +""" + + +class TvpFrameDownPadPrompter(nn.Module): + """ + Pad frames extracted from videos only at the bottom. + """ + + def __init__(self, config): + if config.visual_prompter_apply not in ("add", "replace", "remove"): + raise ValueError("`visual_prompter_apply` must be in (add, replace, remove)") + + super().__init__() + self.visual_prompt_size = config.visual_prompt_size + self.frame_num = config.frame_num + self.max_img_size = config.max_img_size + self.visual_prompter_apply = config.visual_prompter_apply + + self.pad_down = nn.Parameter( + torch.randn([1, config.frame_num, 3, config.visual_prompt_size, config.max_img_size]) + ) + + def forward(self, pixel_values): + if self.visual_prompter_apply != "add": + visual_prompt_mask = torch.ones( + [self.max_img_size, self.max_img_size], dtype=pixel_values.dtype, device=pixel_values.device + ) + visual_prompt_mask[self.max_img_size - self.visual_prompt_size : self.max_img_size, :] = 0.0 + pixel_values *= visual_prompt_mask + if self.visual_prompter_apply != "remove": + prompt = torch.zeros( + [pixel_values.shape[0], pixel_values.shape[1], 3, self.max_img_size, self.max_img_size], + device=pixel_values.device, + ) + start_point = self.max_img_size - self.visual_prompt_size + prompt[:, :, :, start_point : self.max_img_size, :] = self.pad_down + pixel_values += prompt.to(pixel_values.dtype) + return pixel_values + + +class TvpFramePadPrompter(nn.Module): + """ + Pad frames extracted from videos in the surroundings. + """ + + def __init__(self, config): + if config.visual_prompter_apply not in ("add", "replace", "remove"): + raise ValueError("`visual_prompter_apply` must be in (add, replace, remove)") + + super().__init__() + self.num_frames = config.num_frames + self.max_img_size = config.max_img_size + self.visual_prompter_apply = config.visual_prompter_apply + self.base_size = config.max_img_size - config.visual_prompt_size * 2 + self.pad_up = nn.Parameter( + torch.randn([1, config.num_frames, 3, config.visual_prompt_size, config.max_img_size]) + ) + self.pad_down = nn.Parameter( + torch.randn([1, config.num_frames, 3, config.visual_prompt_size, config.max_img_size]) + ) + self.pad_left = nn.Parameter( + torch.randn( + [ + 1, + config.num_frames, + 3, + config.max_img_size - config.visual_prompt_size * 2, + config.visual_prompt_size, + ] + ) + ) + self.pad_right = nn.Parameter( + torch.randn( + [ + 1, + config.num_frames, + 3, + config.max_img_size - config.visual_prompt_size * 2, + config.visual_prompt_size, + ] + ) + ) + + def interpolate_pad_encoding(self, prompt: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained pad weights, to be able to use the model on collection of high + resolution images (high resolution videos). + + """ + + # creates scale factor from height and width of original image wrt to the config.max_img_size + h0, w0 = height / self.max_img_size, width / self.max_img_size + + batch, num_frames, channels, prompt_height, prompt_width = prompt.shape + + # reshaping the batch and num_frames dimension into a single one (i.e (b,frames,c,h,w)-->(b*frames,c,h,w)), to apply bicubic interpolation + prompt = prompt.reshape(batch * num_frames, channels, prompt_height, prompt_width) + prompt = nn.functional.interpolate( + prompt, + scale_factor=(h0, w0), + mode="bicubic", + align_corners=False, + ) + # reversing back to (batch,frames,channels,height,width), where height and width is the new interpolated height and width + prompt = prompt.reshape(batch, num_frames, channels, height, width) + return prompt + + def forward(self, pixel_values, interpolate_pad_encoding: bool = False): + height, width = ( + (pixel_values.shape[-2], pixel_values.shape[-1]) + if interpolate_pad_encoding + else (self.max_img_size, self.max_img_size) + ) + if self.visual_prompter_apply not in ("add", "remove", "replace"): + raise ValueError(f"Invalid visual_prompter_apply value {self.visual_prompter_apply}") + if self.visual_prompter_apply in ("replace", "remove"): + visual_prompt_mask = torch.ones([height, width], dtype=pixel_values.dtype, device=pixel_values.device) + pixel_values *= visual_prompt_mask + if self.visual_prompter_apply in ("replace", "add"): + base = torch.zeros(1, self.num_frames, 3, self.base_size, self.base_size, device=pixel_values.device) + + prompt = torch.cat([self.pad_left, base, self.pad_right], dim=4) + prompt = torch.cat([self.pad_up, prompt, self.pad_down], dim=3) + prompt = torch.cat(pixel_values.size(0) * [prompt]) + if interpolate_pad_encoding: + prompt = self.interpolate_pad_encoding(prompt, height, width) + pixel_values = pixel_values + prompt.to(pixel_values.dtype) + return pixel_values + + +TVP_PROMPTER_CLASSES_MAPPING = { + "framedownpad": TvpFrameDownPadPrompter, + "framepad": TvpFramePadPrompter, +} + + +@add_start_docstrings( + "The bare Tvp Model transformer outputting BaseModelOutputWithPooling object without any specific head on top.", + TVP_START_DOCSTRING, +) +class TvpModel(TvpPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.config = config + self.vision_model = TvpVisionModel(config) + self.embeddings = TvpTextInputEmbeddings(config) + self.visual_embeddings = TvpVisualInputEmbedding(config) + self.encoder = TvpEncoder(config) + self.pooler = TvpPooler(config) + self.text_prompt = nn.Parameter(torch.randn([1, 10, config.hidden_size])) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + if config.visual_prompter_type not in TVP_PROMPTER_CLASSES_MAPPING: + raise ValueError("`visual_prompter_type` must be in (framedownpad, framepad)") + self.visual_prompter = TVP_PROMPTER_CLASSES_MAPPING[config.visual_prompter_type](config) + + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """Prunes heads of the model. + heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(TVP_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=TvpConfig) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + ): + r""" + Returns: + + Examples: + ```python + >>> import torch + >>> from transformers import AutoConfig, AutoTokenizer, TvpModel + + >>> model = TvpModel.from_pretrained("Jiqing/tiny-random-tvp") + + >>> tokenizer = AutoTokenizer.from_pretrained("Jiqing/tiny-random-tvp") + + >>> pixel_values = torch.rand(1, 1, 3, 448, 448) + >>> text_inputs = tokenizer("This is an example input", return_tensors="pt") + >>> output = model(text_inputs.input_ids, pixel_values, text_inputs.attention_mask) + ```""" + return_dict = return_dict if return_dict is not None else self.config.return_dict + # Add visual prompt, it compensates for the spatiotemporal information loss in 2D visual features. + pixel_values = self.vision_model( + self.visual_prompter(pixel_values, interpolate_pad_encoding=interpolate_pos_encoding) + ) + # (batch_size, sequence_length, hidden_size) + text_embedding_output = self.embeddings(input_ids=input_ids) + # (batch_size, visual_sequence_length, hidden_size) + visual_embedding_output = self.visual_embeddings( + pixel_values, interpolate_pos_encoding=interpolate_pos_encoding + ) + + if attention_mask is not None: + # (batch_size, visual_sequence_length) + visual_attention_mask = attention_mask.new_ones(visual_embedding_output.shape[:2]) + pt_mask = torch.ones(attention_mask.shape[0], 10).to( + device=attention_mask.device, dtype=attention_mask.dtype + ) + attention_mask = torch.cat([pt_mask, attention_mask, visual_attention_mask], dim=-1) + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + attention_mask = self.get_extended_attention_mask(attention_mask, input_ids.size()).to(input_ids.device) + text_prompt = self.text_prompt.expand(text_embedding_output.shape[0], -1, -1) + # (batch_size, sequence_length + visual_sequence_length, hidden_size) + embedding_output = torch.cat([text_prompt, text_embedding_output, visual_embedding_output], dim=1) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=attention_mask, + head_mask=self.get_head_mask(head_mask, self.config.num_hidden_layers), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + last_hidden_state = encoder_outputs.last_hidden_state if return_dict else encoder_outputs[0] + pooled_output = self.pooler(last_hidden_state) + last_hidden_state = self.dropout(last_hidden_state) + pooled_output = self.dropout(pooled_output) + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class TvpVideoGroundingHead(nn.Module): + def __init__(self, config): + super().__init__() + self.layer_0 = nn.Linear(config.hidden_size, config.hidden_size * 2) + self.layer_1 = nn.Linear(config.hidden_size * 2, 2) + self.activation_0 = nn.ReLU() + self.activation_1 = nn.Sigmoid() + + def forward(self, pooler_output): + logits = self.activation_0(self.layer_0(pooler_output)) + logits = self.activation_1(self.layer_1(logits)) + return logits + + +@add_start_docstrings( + """ + Tvp Model with a video grounding head on top computing IoU, distance, and duration loss. + """, + TVP_START_DOCSTRING, +) +class TvpForVideoGrounding(TvpPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.config = config + self.model = TvpModel(config) + self.video_grounding_head = TvpVideoGroundingHead(config) + + self.post_init() + + @add_start_docstrings_to_model_forward(TVP_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TvpVideoGroundingOutput, config_class=TvpConfig) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + labels: Tuple[torch.Tensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + ): + r""" + labels (`torch.FloatTensor` of shape `(batch_size, 3)`, *optional*): + The labels contains duration, start time, and end time of the video corresponding to the text. + Returns: + + Examples: + ```python + >>> import torch + >>> from transformers import AutoConfig, AutoTokenizer, TvpForVideoGrounding + + >>> model = TvpForVideoGrounding.from_pretrained("Jiqing/tiny-random-tvp") + + >>> tokenizer = AutoTokenizer.from_pretrained("Jiqing/tiny-random-tvp") + + >>> pixel_values = torch.rand(1, 1, 3, 448, 448) + >>> text_inputs = tokenizer("This is an example input", return_tensors="pt") + >>> output = model(text_inputs.input_ids, pixel_values, text_inputs.attention_mask) + ```""" + return_dict = return_dict if return_dict is not None else self.config.return_dict + outputs = self.model( + input_ids, + pixel_values, + attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + interpolate_pos_encoding=interpolate_pos_encoding, + ) + pooler_output = outputs[1] + logits = self.video_grounding_head(pooler_output) + + loss = None + if labels is not None: + criterion = TvpLoss(["iou", "distance", "duration"]) + criterion.to(self.device) + loss_dict = criterion(logits, labels) + loss = ( + loss_dict["iou"] + + self.config.distance_loss_weight * loss_dict["distance"] + + self.config.duration_loss_weight * loss_dict["duration"] + ) + if not return_dict: + outputs = (logits,) + outputs[2:] + if loss is not None: + outputs = (loss,) + outputs + return outputs + + return TvpVideoGroundingOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = ["TvpModel", "TvpPreTrainedModel", "TvpForVideoGrounding"] diff --git a/docs/transformers/src/transformers/models/tvp/processing_tvp.py b/docs/transformers/src/transformers/models/tvp/processing_tvp.py new file mode 100644 index 0000000000000000000000000000000000000000..76baae91346cfe21ddc506d2bd61347cdc0d82c3 --- /dev/null +++ b/docs/transformers/src/transformers/models/tvp/processing_tvp.py @@ -0,0 +1,156 @@ +# coding=utf-8 +# Copyright 2023 The Intel AIA Team Authors, and HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License=, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing=, software +# distributed under the License is distributed on an "AS IS" BASIS=, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND=, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Processor class for TVP. +""" + +from ...processing_utils import ProcessorMixin +from ...tokenization_utils_base import BatchEncoding + + +class TvpProcessor(ProcessorMixin): + r""" + Constructs an TVP processor which wraps a TVP image processor and a Bert tokenizer into a single processor. + + [`TvpProcessor`] offers all the functionalities of [`TvpImageProcessor`] and [`BertTokenizerFast`]. See the + [`~TvpProcessor.__call__`] and [`~TvpProcessor.decode`] for more information. + + Args: + image_processor ([`TvpImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`BertTokenizerFast`], *optional*): + The tokenizer is a required input. + """ + + attributes = ["image_processor", "tokenizer"] + image_processor_class = "TvpImageProcessor" + tokenizer_class = ("BertTokenizer", "BertTokenizerFast") + + def __init__(self, image_processor=None, tokenizer=None, **kwargs): + if image_processor is None: + raise ValueError("You need to specify an `image_processor`.") + if tokenizer is None: + raise ValueError("You need to specify a `tokenizer`.") + + super().__init__(image_processor, tokenizer) + + def __call__(self, text=None, videos=None, return_tensors=None, **kwargs): + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` + and `kwargs` arguments to BertTokenizerFast's [`~BertTokenizerFast.__call__`] if `text` is not `None` to encode + the text. To prepare the image(s), this method forwards the `videos` and `kwargs` arguments to + TvpImageProcessor's [`~TvpImageProcessor.__call__`] if `videos` is not `None`. Please refer to the docstring of + the above two methods for more information. + + Args: + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + videos (`List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`, `List[List[PIL.Image.Image]]`, `List[List[np.ndarray]]`,: + `List[List[torch.Tensor]]`): The video or batch of videos to be prepared. Each video should be a list + of frames, which can be either PIL images or NumPy arrays. In case of NumPy arrays/PyTorch tensors, + each frame should be of shape (H, W, C), where H and W are frame height and width, and C is a number of + channels. + + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchEncoding`]: A [`BatchEncoding`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `videos` is not `None`. + """ + + max_text_length = kwargs.pop("max_text_length", None) + + if text is None and videos is None: + raise ValueError("You have to specify either text or videos. Both cannot be none.") + + encoding = {} + if text is not None: + textual_input = self.tokenizer.batch_encode_plus( + text, + truncation=True, + padding="max_length", + max_length=max_text_length, + pad_to_max_length=True, + return_tensors=return_tensors, + return_token_type_ids=False, + **kwargs, + ) + encoding.update(textual_input) + + if videos is not None: + image_features = self.image_processor(videos, return_tensors=return_tensors, **kwargs) + encoding.update(image_features) + + return BatchEncoding(data=encoding, tensor_type=return_tensors) + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to BertTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to BertTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + def post_process_video_grounding(self, logits, video_durations): + """ + Compute the time of the video. + + Args: + logits (`torch.Tensor`): + The logits output of TvpForVideoGrounding. + video_durations (`float`): + The video's duration. + + Returns: + start (`float`): + The start time of the video. + end (`float`): + The end time of the video. + """ + start, end = ( + round(logits.tolist()[0][0] * video_durations, 1), + round(logits.tolist()[0][1] * video_durations, 1), + ) + + return start, end + + @property + # Copied from transformers.models.blip.processing_blip.BlipProcessor.model_input_names + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) + + +__all__ = ["TvpProcessor"] diff --git a/docs/transformers/src/transformers/models/udop/__init__.py b/docs/transformers/src/transformers/models/udop/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cf4c36f6363f5f9704526b0cb302952abe6f87dc --- /dev/null +++ b/docs/transformers/src/transformers/models/udop/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_udop import * + from .modeling_udop import * + from .processing_udop import * + from .tokenization_udop import * + from .tokenization_udop_fast import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/udop/configuration_udop.py b/docs/transformers/src/transformers/models/udop/configuration_udop.py new file mode 100644 index 0000000000000000000000000000000000000000..6ed28c78c8fb4c90c172deed21ba1a8048168e5e --- /dev/null +++ b/docs/transformers/src/transformers/models/udop/configuration_udop.py @@ -0,0 +1,160 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""UDOP model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class UdopConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`UdopForConditionalGeneration`]. It is used to + instantiate a UDOP model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the UDOP + [microsoft/udop-large](https://huggingface.co/microsoft/udop-large) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Arguments: + vocab_size (`int`, *optional*, defaults to 33201): + Vocabulary size of the UDOP model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`UdopForConditionalGeneration`]. + d_model (`int`, *optional*, defaults to 1024): + Size of the encoder layers and the pooler layer. + d_kv (`int`, *optional*, defaults to 64): + Size of the key, query, value projections per attention head. The `inner_dim` of the projection layer will + be defined as `num_heads * d_kv`. + d_ff (`int`, *optional*, defaults to 4096): + Size of the intermediate feed forward layer in each `UdopBlock`. + num_layers (`int`, *optional*, defaults to 24): + Number of hidden layers in the Transformer encoder and decoder. + num_decoder_layers (`int`, *optional*): + Number of hidden layers in the Transformer decoder. Will use the same value as `num_layers` if not set. + num_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder and decoder. + relative_attention_num_buckets (`int`, *optional*, defaults to 32): + The number of buckets to use for each attention layer. + relative_attention_max_distance (`int`, *optional*, defaults to 128): + The maximum distance of the longer sequences for the bucket separation. + relative_bias_args (`List[dict]`, *optional*, defaults to `[{'type': '1d'}, {'type': 'horizontal'}, {'type': 'vertical'}]`): + A list of dictionaries containing the arguments for the relative bias layers. + dropout_rate (`float`, *optional*, defaults to 0.1): + The ratio for all dropout layers. + layer_norm_epsilon (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + initializer_factor (`float`, *optional*, defaults to 1.0): + A factor for initializing all weight matrices (should be kept to 1, used internally for initialization + testing). + feed_forward_proj (`string`, *optional*, defaults to `"relu"`): + Type of feed forward layer to be used. Should be one of `"relu"` or `"gated-gelu"`. Udopv1.1 uses the + `"gated-gelu"` feed forward projection. Original Udop uses `"relu"`. + is_encoder_decoder (`bool`, *optional*, defaults to `True`): + Whether the model should behave as an encoder/decoder or not. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + pad_token_id (`int`, *optional*, defaults to 0): + The id of the padding token in the vocabulary. + eos_token_id (`int`, *optional*, defaults to 1): + The id of the end-of-sequence token in the vocabulary. + max_2d_position_embeddings (`int`, *optional*, defaults to 1024): + The maximum absolute position embeddings for relative position encoding. + image_size (`int`, *optional*, defaults to 224): + The size of the input images. + patch_size (`int`, *optional*, defaults to 16): + The patch size used by the vision encoder. + num_channels (`int`, *optional*, defaults to 3): + The number of channels in the input images. + """ + + model_type = "udop" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"} + + def __init__( + self, + vocab_size=33201, + d_model=1024, + d_kv=64, + d_ff=4096, + num_layers=24, + num_decoder_layers=None, + num_heads=16, + relative_attention_num_buckets=32, + relative_attention_max_distance=128, + relative_bias_args=[{"type": "1d"}, {"type": "horizontal"}, {"type": "vertical"}], + dropout_rate=0.1, + layer_norm_epsilon=1e-6, + initializer_factor=1.0, + feed_forward_proj="relu", + is_encoder_decoder=True, + use_cache=True, + pad_token_id=0, + eos_token_id=1, + max_2d_position_embeddings=1024, + image_size=224, + patch_size=16, + num_channels=3, + **kwargs, + ): + self.vocab_size = vocab_size + self.d_model = d_model + self.d_kv = d_kv + self.d_ff = d_ff + self.num_layers = num_layers + self.num_decoder_layers = ( + num_decoder_layers if num_decoder_layers is not None else self.num_layers + ) # default = symmetry + self.num_heads = num_heads + self.relative_attention_num_buckets = relative_attention_num_buckets + self.relative_attention_max_distance = relative_attention_max_distance + self.dropout_rate = dropout_rate + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_factor = initializer_factor + self.feed_forward_proj = feed_forward_proj + self.use_cache = use_cache + + # UDOP attributes + self.max_2d_position_embeddings = max_2d_position_embeddings + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + if not isinstance(relative_bias_args, list): + raise TypeError("`relative_bias_args` should be a list of dictionaries.") + self.relative_bias_args = relative_bias_args + + act_info = self.feed_forward_proj.split("-") + self.dense_act_fn = act_info[-1] + self.is_gated_act = act_info[0] == "gated" + + if len(act_info) > 1 and act_info[0] != "gated" or len(act_info) > 2: + raise ValueError( + f"`feed_forward_proj`: {feed_forward_proj} is not a valid activation function of the dense layer." + "Please make sure `feed_forward_proj` is of the format `gated-{ACT_FN}` or `{ACT_FN}`, e.g. " + "'gated-gelu' or 'relu'" + ) + + super().__init__( + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + is_encoder_decoder=is_encoder_decoder, + **kwargs, + ) + + +__all__ = ["UdopConfig"] diff --git a/docs/transformers/src/transformers/models/udop/convert_udop_to_hf.py b/docs/transformers/src/transformers/models/udop/convert_udop_to_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..8ba0de55df78f45808131d6eb538a96006fddf4d --- /dev/null +++ b/docs/transformers/src/transformers/models/udop/convert_udop_to_hf.py @@ -0,0 +1,224 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert UDOP checkpoints from the original repository. URL: https://github.com/microsoft/i-Code/tree/main/i-Code-Doc""" + +import argparse + +import torch +from huggingface_hub import hf_hub_download +from PIL import Image +from torchvision import transforms as T + +from transformers import ( + LayoutLMv3ImageProcessor, + UdopConfig, + UdopForConditionalGeneration, + UdopProcessor, + UdopTokenizer, +) +from transformers.image_utils import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD + + +def original_transform(image, image_size=224): + transform = T.Compose( + [ + T.Resize([image_size, image_size]), + T.ToTensor(), + T.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + ] + ) + + image = transform(image) + return image + + +def get_image(): + filepath = hf_hub_download( + repo_id="hf-internal-testing/fixtures_docvqa", filename="document_2.png", repo_type="dataset" + ) + image = Image.open(filepath).convert("RGB") + + return image + + +def prepare_dummy_inputs(tokenizer, image_processor): + prompt = "Question answering. What is the name of the company?" + prompt = "Question answering. In which year is the report made?" + prompt_ids = tokenizer.encode(prompt, add_special_tokens=False) + + image = get_image() + # words, boxes = apply_tesseract(image, lang=None) + # fmt: off + words = ['7', 'ITC', 'Limited', 'REPORT', 'AND', 'ACCOUNTS', '2013', 'ITC’s', 'Brands:', 'An', 'Asset', 'for', 'the', 'Nation', 'The', 'consumer', 'needs', 'and', 'aspirations', 'they', 'fulfil,', 'the', 'benefit', 'they', 'generate', 'for', 'millions', 'across', 'ITC’s', 'value', 'chains,', 'the', 'future-ready', 'capabilities', 'that', 'support', 'them,', 'and', 'the', 'value', 'that', 'they', 'create', 'for', 'the', 'country,', 'have', 'made', 'ITC’s', 'brands', 'national', 'assets,', 'adding', 'to', 'India’s', 'competitiveness.', 'It', 'is', 'ITC’s', 'aspiration', 'to', 'be', 'the', 'No', '1', 'FMCG', 'player', 'in', 'the', 'country,', 'driven', 'by', 'its', 'new', 'FMCG', 'businesses.', 'A', 'recent', 'Nielsen', 'report', 'has', 'highlighted', 'that', "ITC's", 'new', 'FMCG', 'businesses', 'are', 'the', 'fastest', 'growing', 'among', 'the', 'top', 'consumer', 'goods', 'companies', 'operating', 'in', 'India.', 'ITC', 'takes', 'justifiable', 'pride', 'that,', 'along', 'with', 'generating', 'economic', 'value,', 'these', 'celebrated', 'Indian', 'brands', 'also', 'drive', 'the', 'creation', 'of', 'larger', 'societal', 'capital', 'through', 'the', 'virtuous', 'cycle', 'of', 'sustainable', 'and', 'inclusive', 'growth.', 'DI', 'WILLS', '*', ';', 'LOVE', 'DELIGHTFULLY', 'SOFT', 'SKIN?', 'aia', 'Ans', 'Source:', 'https://www.industrydocuments.ucsf.edu/docs/snbx0223'] + boxes = [[0, 45, 67, 80], [72, 56, 109, 67], [116, 56, 189, 67], [198, 59, 253, 66], [257, 59, 285, 66], [289, 59, 365, 66], [372, 59, 407, 66], [74, 136, 161, 158], [175, 137, 306, 158], [318, 137, 363, 158], [374, 137, 472, 158], [483, 136, 529, 158], [540, 137, 593, 158], [608, 137, 717, 158], [73, 194, 100, 203], [106, 196, 177, 203], [183, 194, 227, 203], [233, 194, 259, 203], [265, 194, 344, 205], [74, 211, 104, 222], [109, 210, 141, 221], [147, 211, 169, 220], [175, 210, 223, 220], [229, 211, 259, 222], [265, 211, 329, 222], [334, 210, 352, 220], [74, 227, 127, 236], [133, 229, 180, 236], [187, 227, 221, 236], [226, 227, 264, 236], [270, 227, 320, 237], [327, 227, 349, 236], [74, 243, 161, 254], [166, 243, 249, 254], [254, 243, 281, 252], [286, 244, 342, 254], [74, 260, 112, 270], [119, 260, 145, 269], [151, 260, 174, 269], [179, 260, 217, 269], [222, 260, 249, 269], [254, 260, 285, 271], [290, 260, 335, 269], [340, 259, 359, 269], [74, 276, 95, 284], [101, 276, 156, 287], [164, 276, 198, 284], [203, 276, 244, 284], [251, 275, 285, 284], [291, 276, 340, 284], [74, 292, 129, 301], [135, 292, 185, 302], [192, 292, 242, 303], [248, 292, 261, 301], [267, 292, 312, 301], [74, 308, 195, 319], [75, 335, 82, 344], [88, 335, 98, 344], [105, 335, 138, 344], [144, 335, 214, 346], [220, 336, 233, 344], [239, 335, 256, 344], [262, 335, 283, 344], [290, 335, 309, 344], [316, 335, 320, 344], [74, 351, 119, 360], [126, 352, 170, 362], [176, 352, 186, 360], [192, 352, 214, 360], [220, 352, 276, 362], [282, 352, 326, 360], [333, 352, 349, 362], [74, 368, 89, 377], [95, 370, 124, 377], [129, 367, 175, 377], [181, 368, 266, 377], [272, 368, 283, 376], [289, 368, 333, 377], [74, 384, 126, 393], [134, 385, 175, 395], [181, 384, 206, 393], [212, 384, 292, 395], [298, 384, 325, 393], [330, 384, 366, 393], [74, 403, 103, 409], [109, 400, 154, 409], [161, 401, 241, 409], [247, 403, 269, 409], [275, 401, 296, 409], [302, 400, 349, 409], [74, 417, 131, 428], [137, 419, 186, 428], [192, 417, 214, 426], [219, 417, 242, 428], [248, 419, 319, 426], [74, 433, 119, 444], [125, 433, 204, 444], [210, 433, 278, 444], [285, 433, 295, 441], [302, 433, 340, 442], [75, 449, 98, 458], [104, 449, 142, 458], [146, 449, 215, 460], [221, 449, 258, 460], [263, 449, 293, 459], [300, 449, 339, 460], [74, 466, 101, 474], [108, 466, 185, 476], [191, 466, 261, 474], [267, 466, 309, 476], [315, 466, 354, 474], [74, 482, 151, 491], [158, 482, 201, 491], [208, 482, 258, 491], [263, 482, 292, 491], [298, 482, 333, 491], [338, 482, 360, 491], [74, 498, 131, 507], [137, 498, 150, 507], [156, 498, 197, 509], [202, 498, 257, 507], [263, 498, 310, 509], [74, 515, 128, 525], [134, 515, 156, 523], [161, 515, 218, 523], [223, 515, 261, 525], [267, 514, 280, 523], [74, 531, 156, 540], [162, 531, 188, 540], [195, 531, 257, 540], [263, 531, 315, 542], [871, 199, 878, 202], [883, 199, 908, 202], [894, 251, 904, 257], [841, 268, 841, 270], [784, 373, 811, 378], [816, 373, 896, 378], [784, 381, 811, 387], [815, 381, 847, 387], [645, 908, 670, 915], [692, 908, 712, 915], [220, 984, 285, 993], [293, 983, 779, 996]] + # fmt: on + text_list = [] + bbox_list = [] + for text, box in zip(words, boxes): + if text == "": + continue + sub_tokens = tokenizer.tokenize(text) + for sub_token in sub_tokens: + text_list.append(sub_token) + bbox_list.append(box) + + input_ids = tokenizer.convert_tokens_to_ids(text_list) + + input_ids = prompt_ids + input_ids + bbox = [[0, 0, 0, 0]] * len(prompt_ids) + bbox_list + + pixel_values = image_processor(image, return_tensors="pt").pixel_values + original_pixel_values = original_transform(image, image_size=image_processor.size["height"]).unsqueeze(0) + # verify pixel values + assert torch.allclose(original_pixel_values, pixel_values) + print("Pixel values are ok!") + + return torch.tensor(input_ids).unsqueeze(0), torch.tensor(bbox).unsqueeze(0).float(), pixel_values + + +def convert_udop_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_hub=False): + # model_name to checkpoint_path + name_to_checkpoint_path = { + "udop-large": "/Users/nielsrogge/Documents/UDOP/udop-unimodel-large-224/pytorch_model.bin", + "udop-large-512": "/Users/nielsrogge/Documents/UDOP/udop-unimodel-large-512/pytorch_model.bin", + "udop-large-512-300k": "/Users/nielsrogge/Documents/UDOP/udop-unimodel-large-512-300k-steps/pytorch_model.bin", + } + + # load original state dict + checkpoint_path = name_to_checkpoint_path[model_name] + state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True) + + print("Checkpoint path:", checkpoint_path) + + # create HF model + image_size = 512 if "512" in model_name else 224 + config = UdopConfig(decoder_start_token_id=0, image_size=image_size) + model = UdopForConditionalGeneration(config) + model.eval() + + # rename keys + state_dict = {k.replace("cell2dembedding", "cell_2d_embedding"): v for k, v in state_dict.items()} + + # load weights + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) + print("Missing keys:", missing_keys) + print("Unexpected keys:", unexpected_keys) + assert missing_keys == ["encoder.embed_patches.proj.weight", "encoder.embed_patches.proj.bias"] + assert unexpected_keys == ["pos_embed"] + + # Add extra_ids to the special token list + # NOTE special tokens have a unique order + # see https://github.com/huggingface/transformers/issues/29591 for details + # fmt: off + additional_special_tokens = ['', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', ''] + # fmt: on + + tokenizer = UdopTokenizer.from_pretrained( + "/Users/nielsrogge/Documents/UDOP/udop-unimodel-large-512", + legacy=True, + additional_special_tokens=additional_special_tokens, + ) + size = {"height": image_size, "width": image_size} + image_processor = LayoutLMv3ImageProcessor( + image_mean=IMAGENET_DEFAULT_MEAN, image_std=IMAGENET_DEFAULT_STD, size=size + ) + processor = UdopProcessor(image_processor=image_processor, tokenizer=tokenizer) + + # prepare dummy inputs + input_ids, bbox, image = prepare_dummy_inputs(tokenizer, image_processor) + prompt = "Question answering. In which year is the report made?" + encoding = processor(images=get_image(), text=prompt, return_tensors="pt") + + input_ids = encoding.input_ids + try: + EXPECTED_INPUT_IDS = torch.tensor([[11860, 18243, 5, 86, 84, 215, 19, 8, 934, 263, 58, 1, 489, 27, 3838, 7363, 4083, 14536, 3430, 5686, 5911, 17161, 134, 2038, 27, 3838, 22, 7, 4688, 7, 10, 389, 18202, 21, 8, 11046, 37, 3733, 523, 11, 38, 2388, 1628, 3, 13133, 23334, 6, 8, 1656, 79, 3806, 21, 4040, 640, 27, 3838, 22, 7, 701, 16534, 6, 8, 3, 76, 2693, 18, 23015, 5644, 24, 380, 3, 6015, 6, 11, 8, 701, 24, 79, 482, 21, 3, 88, 684, 6, 43, 263, 27, 3838, 22, 7, 3635, 1157, 4089, 6, 2651, 12, 1547, 22, 7, 3265, 655, 5, 19, 27, 3838, 22, 7, 38, 2388, 257, 12, 36, 8, 465, 209, 13409, 12150, 1959, 16, 8, 684, 6, 6737, 57, 165, 126, 13409, 12150, 1623, 5, 71, 1100, 30298, 934, 65, 12566, 24, 27, 3838, 31, 7, 126, 13409, 12150, 1623, 33, 8, 10391, 1710, 859, 8, 420, 3733, 4968, 688, 2699, 16, 1547, 5, 27, 3838, 1217, 131, 99, 23, 179, 6064, 24, 6, 590, 28, 3, 11600, 1456, 701, 6, 175, 9443, 2557, 3635, 92, 1262, 8, 3409, 13, 2186, 3, 27908, 1784, 190, 8, 3, 5771, 17, 13281, 4005, 13, 5086, 11, 13066, 1170, 5, 10826, 16309, 134, 3, 2, 276, 26, 3, 55, 391, 13570, 5, 10315, 309, 3577, 19114, 371, 4254, 5121, 5055, 6245, 3, 10047, 3162, 58, 3, 9, 61, 1713, 2703, 476, 667, 25158, 301, 6058, 6038, 476, 3765, 9149, 10, 4893, 1303, 1986, 5, 13580, 7, 8224, 28244, 7, 5, 76, 75, 7, 89, 5, 15, 1259, 87, 7171, 7, 87, 7, 29, 115, 226, 4305, 2773, 1]]) # fmt: skip + torch.testing.assert_close(EXPECTED_INPUT_IDS, input_ids) + bbox = encoding.bbox.float() + pixel_values = encoding.pixel_values + except Exception: + print("Input_ids don't match, preparing dummy inputs") + input_ids, bbox, pixel_values = prepare_dummy_inputs(tokenizer, image_processor) + + # Verify single forward pass + print("Testing single forward pass..") + with torch.no_grad(): + decoder_input_ids = torch.tensor([[101]]) + outputs = model(input_ids=input_ids, bbox=bbox, pixel_values=pixel_values, decoder_input_ids=decoder_input_ids) + print("Shape of logits:", outputs.logits.shape) + print("First values of logits:", outputs.logits[0, :3, :3]) + + # tensor([[-18.5262, 1.5087, -15.7051]]) on linux + # tensor([[-19.4976, 0.8515, -17.1873]]) on mac + try: + assert torch.allclose(outputs.logits[0, :3, :3], torch.tensor([[-18.5262, 1.5087, -15.7051]]), atol=1e-4) + print("Looks ok!") + except Exception: + print("logits don't match let's try to generate") + + # Verify autoregressive decoding + print("Testing generation...") + model_kwargs = {"bbox": bbox, "pixel_values": pixel_values} + outputs = model.generate(input_ids=input_ids, **model_kwargs, max_new_tokens=20) + + print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True)) + + # autoregressive decoding with original input data + print("Testing generation with original inputs...") + filepath = hf_hub_download(repo_id="nielsr/test-image", filename="input_ids_udop.pt", repo_type="dataset") + input_ids = torch.load(filepath, weights_only=True) + filepath = hf_hub_download(repo_id="nielsr/test-image", filename="bbox_udop.pt", repo_type="dataset") + bbox = torch.load(filepath, weights_only=True) + pixel_values_filename = "pixel_values_udop_512.pt" if "512" in model_name else "pixel_values_udop_224.pt" + filepath = hf_hub_download(repo_id="nielsr/test-image", filename=pixel_values_filename, repo_type="dataset") + pixel_values = torch.load(filepath, weights_only=True) + + print("Decoded input ids:", tokenizer.decode(input_ids[0], skip_special_tokens=True)) + print("Bbox shape:", bbox.shape) + + model_kwargs = {"bbox": bbox, "pixel_values": pixel_values} + outputs = model.generate(input_ids=input_ids, **model_kwargs, max_new_tokens=20) + generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] + print("Generated:", generated_text) + + if pytorch_dump_folder_path is not None: + model.save_pretrained(pytorch_dump_folder_path) + processor.save_pretrained(pytorch_dump_folder_path) + + if push_to_hub: + model.push_to_hub(f"microsoft/{model_name}") + processor.push_to_hub(f"microsoft/{model_name}") + # BIG note here: to save the fast tokenizer files in the repo on the hub, you need to do the following: + # see https://discuss.huggingface.co/t/convert-slow-xlmrobertatokenizer-to-fast-one/20876 + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--model_name", + default="udop-large", + type=str, + choices=["udop-large", "udop-large-512", "udop-large-512-300k"], + help=("Name of the UDOP model you'd like to convert."), + ) + parser.add_argument( + "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory." + ) + parser.add_argument( + "--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub." + ) + + args = parser.parse_args() + convert_udop_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub) diff --git a/docs/transformers/src/transformers/models/udop/modeling_udop.py b/docs/transformers/src/transformers/models/udop/modeling_udop.py new file mode 100644 index 0000000000000000000000000000000000000000..9cd7f984237381ca8ac2d9fa263339fca4a8affb --- /dev/null +++ b/docs/transformers/src/transformers/models/udop/modeling_udop.py @@ -0,0 +1,2171 @@ +# coding=utf-8 +# Copyright 2024 Microsoft Research and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch UDOP model.""" + +import collections +import logging +import math +import random +from abc import ABC, abstractmethod +from copy import deepcopy +from dataclasses import dataclass +from typing import Any, Dict, Optional, Sequence, Tuple, Union + +import torch +from torch import Tensor, nn +from torch.nn import CrossEntropyLoss + +from transformers import UdopConfig +from transformers.modeling_outputs import ( + Seq2SeqLMOutput, + Seq2SeqModelOutput, +) + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_torch_flex_attn_available, + is_torchdynamo_compiling, + replace_return_docstrings, +) + + +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + +logger = logging.getLogger(__name__) + + +_CONFIG_FOR_DOC = "UdopConfig" + + +UDOP_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Args: + config ([`UdopConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +UDOP_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. UDOP is a model with relative position embeddings so + you should be able to pad the inputs on both the right and the left. Indices can be obtained using + [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for detail. + [What are input IDs?](../glossary#input-ids) + + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + + bbox (`torch.LongTensor` of shape `({0}, 4)`, *optional*): + Bounding boxes of each input sequence tokens. Selected in the range `[0, + config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1) + format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1, + y1) represents the position of the lower right corner. + + Note that `sequence_length = token_sequence_length + patch_sequence_length + 1` where `1` is for [CLS] + token. See `pixel_values` for `patch_sequence_length`. + + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Batch of document images. Each image is divided into patches of shape `(num_channels, config.patch_size, + config.patch_size)` and the total number of patches (=`patch_sequence_length`) equals to `((height / + config.patch_size) * (width / config.patch_size))`. + + visual_bbox (`torch.LongTensor` of shape `(batch_size, patch_sequence_length, 4)`, *optional*): + Bounding boxes of each patch in the image. If not provided, bounding boxes are created in the model. + + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. Indices can be obtained using + [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. + [What are decoder input IDs?](../glossary#decoder-input-ids) T5 uses the `pad_token_id` as the starting + token for `decoder_input_ids` generation. If `past_key_values` is used, optionally only the last + `decoder_input_ids` have to be input (see `past_key_values`). To know more on how to prepare + `decoder_input_ids` for pretraining take a look at [T5 Training](./t5#training). + decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in `[0, + 1]`: + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0, + 1]`: + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in + `[0, 1]`: + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): + Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, `optional`: *attentions*) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a sequence of hidden states at + the output of the last layer of the encoder. Used in the cross-attention of the decoder. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded + representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be + input (see `past_key_values`). This is useful if you want more control over how to convert + `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. If + `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value of + `inputs_embeds`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. +""" + + +UDOP_ENCODER_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you + should be able to pad the inputs on both the right and the left. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for detail. + + To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training). + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + bbox (`torch.LongTensor` of shape `({0}, 4)`, *optional*): + Bounding boxes of each input sequence tokens. Selected in the range `[0, + config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1) + format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1, + y1) represents the position of the lower right corner. + + Note that `sequence_length = token_sequence_length + patch_sequence_length + 1` where `1` is for [CLS] + token. See `pixel_values` for `patch_sequence_length`. + + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Batch of document images. Each image is divided into patches of shape `(num_channels, config.patch_size, + config.patch_size)` and the total number of patches (=`patch_sequence_length`) equals to `((height / + config.patch_size) * (width / config.patch_size))`. + + visual_bbox (`torch.LongTensor` of shape `(batch_size, patch_sequence_length, 4)`, *optional*): + Bounding boxes of each patch in the image. If not provided, bounding boxes are created in the model. + + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@dataclass +class BaseModelOutputWithAttentionMask(ModelOutput): + """ + Class for the model's outputs that may also contain a past key/values (to speed up sequential decoding). Includes + an additional attention mask. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. If `past_key_values` is used only + the last hidden-state of the sequences of shape `(batch_size, 1, hidden_size)` is output. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or + when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if + `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, + encoder_sequence_length, embed_size_per_head)`. Contains pre-computed hidden-states (key and values in the + self-attention blocks and optionally if `config.is_encoder_decoder=True` in the cross-attention blocks) + that can be used (see `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or + when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of + the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when + `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in + the self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and + `config.add_cross_attention=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax, + used to compute the weighted average in the cross-attention heads. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + attention_mask: Optional[torch.FloatTensor] = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + + +def get_visual_bbox(image_size=224, patch_size=16): + image_feature_pool_shape = [image_size // patch_size, image_size // patch_size] + visual_bbox_x = torch.arange(0, 1.0 * (image_feature_pool_shape[1] + 1), 1.0) + visual_bbox_x /= image_feature_pool_shape[1] + + visual_bbox_y = torch.arange(0, 1.0 * (image_feature_pool_shape[0] + 1), 1.0) + visual_bbox_y /= image_feature_pool_shape[0] + + visual_bbox_input = torch.stack( + [ + visual_bbox_x[:-1].repeat(image_feature_pool_shape[0], 1), + visual_bbox_y[:-1].repeat(image_feature_pool_shape[1], 1).transpose(0, 1), + visual_bbox_x[1:].repeat(image_feature_pool_shape[0], 1), + visual_bbox_y[1:].repeat(image_feature_pool_shape[1], 1).transpose(0, 1), + ], + dim=-1, + ) + + visual_bbox_input = visual_bbox_input.view(-1, 4) + + return visual_bbox_input + + +def pad_sequence(seq, target_len, pad_value=0): + if isinstance(seq, torch.Tensor): + n = seq.shape[0] + else: + n = len(seq) + seq = torch.tensor(seq) + m = target_len - n + if m > 0: + ret = torch.stack([pad_value] * m).to(seq) + seq = torch.cat([seq, ret], dim=0) + return seq[:target_len] + + +def combine_image_text_embeddings( + image_embeddings, + inputs_embeds, + bbox, + visual_bbox, + attention_mask=None, + num_patches=14, + max_len=0, + image_size=224, + patch_size=16, +): + """ + Combine the image and text embeddings for the input to the encoder/decoder of UDOP. + + First, the image embeddings are created by checking for each visual patch if it is inside the bounding box of a + token. If it is, the visual patch is combined with the token embedding. Then, the visual bounding boxes are combined + with the text bounding boxes. Finally, the visual bounding boxes are combined with the text attention mask. + """ + + sequence_length = num_patches + ocr_points_x = torch.clip( + torch.floor((bbox[:, :, 0] + bbox[:, :, 2]) / 2.0 * sequence_length).long(), 0, sequence_length - 1 + ) + ocr_points_y = ( + torch.clip(torch.floor((bbox[:, :, 1] + bbox[:, :, 3]) / 2.0 * sequence_length).long(), 0, sequence_length - 1) + * sequence_length + ) + ocr_points = ocr_points_x + ocr_points_y + # make sure bounding boxes are of type float to calculate means + bbox = bbox.to(torch.float64) + target_seg = (bbox.mean(-1) == 0.0) | (bbox.mean(-1) == 1.0) + repeated_vision_embeds = torch.gather( + image_embeddings, 1, ocr_points.unsqueeze(-1).repeat(1, 1, image_embeddings.size(-1)) + ) + repeated_vision_embeds[target_seg] = 0.0 + inputs_embeds += repeated_vision_embeds + + patch_inds = torch.full_like(image_embeddings[:, :, 0], True).bool() + ind = torch.cat( + [ + torch.arange(len(ocr_points))[:, None].repeat(1, ocr_points.size(-1))[:, :, None].to(ocr_points), + ocr_points[:, :, None], + ], + dim=-1, + ) + ind = ind.flatten(0, 1) + rows, cols = zip(*ind) + patch_inds[rows, cols] = False + + input_vision_patches = [image_embeddings[i][patch_inds[i]] for i in range(len(patch_inds))] + + if visual_bbox is None: + visual_bbox = get_visual_bbox(image_size=image_size, patch_size=patch_size) + visual_bbox = visual_bbox.unsqueeze(0).repeat(image_embeddings.size(0), 1, 1) + visual_bbox = visual_bbox.to(image_embeddings.device) + + visual_bbox = [visual_bbox[i][patch_inds[i]] for i in range(len(patch_inds))] + if attention_mask is not None: + visual_attention_mask = [torch.tensor([1] * len(item)).to(attention_mask) for item in visual_bbox] + + if max_len == 0: + max_len = image_embeddings.size(1) + else: + max_len = max_len - inputs_embeds.size(1) + inputs_vision_patches = torch.stack( + [pad_sequence(item, max_len, torch.zeros_like(image_embeddings[0, 0])) for item in input_vision_patches] + ) + visual_bbox = torch.stack([pad_sequence(item, max_len, torch.zeros_like(bbox[0, 0])) for item in visual_bbox]) + if attention_mask is not None: + visual_attention_mask = torch.stack( + [pad_sequence(item, max_len, torch.zeros_like(attention_mask[0, 0])) for item in visual_attention_mask] + ) + + inputs_embeds = torch.cat([inputs_embeds, inputs_vision_patches], 1) + bbox = torch.cat([bbox, visual_bbox], 1) + if attention_mask is not None: + attention_mask = torch.cat([attention_mask, visual_attention_mask], 1) + return inputs_embeds, bbox, attention_mask + + +class UdopPatchEmbeddings(nn.Module): + """2D Image to Patch Embeddings""" + + def __init__(self, config): + super().__init__() + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.hidden_size + + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + + self.proj = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) + + def forward(self, pixel_values): + batch_size, num_channels, height, width = pixel_values.shape + if height != self.image_size[0] or width != self.image_size[1]: + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})." + ) + embeddings = self.proj(pixel_values) + embeddings = embeddings.flatten(2).transpose(1, 2) + return embeddings + + +class UdopPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. Based on `T5PreTrainedModel`. + """ + + config_class = UdopConfig + base_model_prefix = "transformer" + supports_gradient_checkpointing = True + _supports_cache_class = True + _supports_static_cache = False + _keep_in_fp32_modules = ["wo"] + + def _init_weights(self, module): + """Initialize the weights""" + factor = self.config.initializer_factor # Used for testing weights initialization + if isinstance(module, UdopLayerNorm): + module.weight.data.fill_(factor * 1.0) + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=factor) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.Conv2d): + # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid + # `trunc_normal_cpu` not implemented in `half` issues + module.weight.data = nn.init.trunc_normal_(module.weight.data.to(torch.float32), mean=0.0, std=factor).to( + module.weight.dtype + ) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, RelativePositionBiasBase): + factor = self.config.initializer_factor + d_model = self.config.d_model + module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) + elif isinstance(module, UdopModel): + # Mesh TensorFlow embeddings initialization + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624 + module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0) + elif isinstance(module, UdopForConditionalGeneration): + if hasattr(module, "lm_head") and not self.config.tie_word_embeddings: + module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0) + elif isinstance(module, UdopDenseActDense): + # Mesh TensorFlow FF initialization + # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56 + # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89 + module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.wi, "bias") and module.wi.bias is not None: + module.wi.bias.data.zero_() + module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + if hasattr(module.wo, "bias") and module.wo.bias is not None: + module.wo.bias.data.zero_() + elif isinstance(module, UdopDenseGatedActDense): + module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None: + module.wi_0.bias.data.zero_() + module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None: + module.wi_1.bias.data.zero_() + module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + if hasattr(module.wo, "bias") and module.wo.bias is not None: + module.wo.bias.data.zero_() + elif isinstance(module, UdopAttention): + # Mesh TensorFlow attention initialization to avoid scaling before softmax + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136 + d_model = self.config.d_model + key_value_proj_dim = self.config.d_kv + n_heads = self.config.num_heads + module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) + module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) + module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) + module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) + if module.has_relative_attention_bias: + module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) + + # Copied from transformers.models.prophetnet.modeling_prophetnet.ProphetNetPreTrainedModel._shift_right with ProphetNet->Udop + def _shift_right(self, input_ids): + decoder_start_token_id = self.config.decoder_start_token_id + pad_token_id = self.config.pad_token_id + + assert decoder_start_token_id is not None, ( + "self.model.config.decoder_start_token_id has to be defined. In Udop it is usually set to the" + " pad_token_id. See Udop docs for more information" + ) + + # shift inputs to the right + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() + shifted_input_ids[..., 0] = decoder_start_token_id + + assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + assert torch.all(shifted_input_ids >= 0).item(), "Verify that `shifted_input_ids` has only positive values" + + return shifted_input_ids + + +# Copied from transformers.models.t5.modeling_t5.T5LayerNorm with T5->Udop +class UdopLayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Construct a layernorm module in the Udop style. No bias and no subtraction of mean. + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + # Udop uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean + # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated + # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for + # half-precision inputs is done in fp32 + + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states + + +# Copied from transformers.models.t5.modeling_t5.T5DenseActDense with T5->Udop +class UdopDenseActDense(nn.Module): + def __init__(self, config: UdopConfig): + super().__init__() + self.wi = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) + self.dropout = nn.Dropout(config.dropout_rate) + self.act = ACT2FN[config.dense_act_fn] + + def forward(self, hidden_states): + hidden_states = self.wi(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.dropout(hidden_states) + if ( + isinstance(self.wo.weight, torch.Tensor) + and hidden_states.dtype != self.wo.weight.dtype + and self.wo.weight.dtype != torch.int8 + ): + hidden_states = hidden_states.to(self.wo.weight.dtype) + hidden_states = self.wo(hidden_states) + return hidden_states + + +# Copied from transformers.models.t5.modeling_t5.T5DenseGatedActDense with T5->Udop +class UdopDenseGatedActDense(nn.Module): + def __init__(self, config: UdopConfig): + super().__init__() + self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) + self.dropout = nn.Dropout(config.dropout_rate) + self.act = ACT2FN[config.dense_act_fn] + + def forward(self, hidden_states): + hidden_gelu = self.act(self.wi_0(hidden_states)) + hidden_linear = self.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = self.dropout(hidden_states) + + # To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32. + # See https://github.com/huggingface/transformers/issues/20287 + # we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None`` + if ( + isinstance(self.wo.weight, torch.Tensor) + and hidden_states.dtype != self.wo.weight.dtype + and self.wo.weight.dtype != torch.int8 + ): + hidden_states = hidden_states.to(self.wo.weight.dtype) + + hidden_states = self.wo(hidden_states) + return hidden_states + + +# Copied from transformers.models.t5.modeling_t5.T5LayerFF with T5->Udop +class UdopLayerFF(nn.Module): + def __init__(self, config: UdopConfig): + super().__init__() + if config.is_gated_act: + self.DenseReluDense = UdopDenseGatedActDense(config) + else: + self.DenseReluDense = UdopDenseActDense(config) + + self.layer_norm = UdopLayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward(self, hidden_states): + forwarded_states = self.layer_norm(hidden_states) + forwarded_states = self.DenseReluDense(forwarded_states) + hidden_states = hidden_states + self.dropout(forwarded_states) + return hidden_states + + +# Copied from transformers.models.t5.modeling_t5.T5Attention with T5->Udop +class UdopAttention(nn.Module): + def __init__( + self, + config: UdopConfig, + has_relative_attention_bias=False, + layer_idx: Optional[int] = None, + ): + super().__init__() + self.is_decoder = config.is_decoder + self.has_relative_attention_bias = has_relative_attention_bias + self.relative_attention_num_buckets = config.relative_attention_num_buckets + self.relative_attention_max_distance = config.relative_attention_max_distance + self.d_model = config.d_model + self.key_value_proj_dim = config.d_kv + self.n_heads = config.num_heads + self.dropout = config.dropout_rate + self.inner_dim = self.n_heads * self.key_value_proj_dim + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + # Mesh TensorFlow initialization to avoid scaling before softmax + self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.k = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.v = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.o = nn.Linear(self.inner_dim, self.d_model, bias=False) + + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads) + self.pruned_heads = set() + self.gradient_checkpointing = False + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads + ) + # Prune linear layers + self.q = prune_linear_layer(self.q, index) + self.k = prune_linear_layer(self.k, index) + self.v = prune_linear_layer(self.v, index) + self.o = prune_linear_layer(self.o, index, dim=1) + # Update hyper params + self.n_heads = self.n_heads - len(heads) + self.inner_dim = self.key_value_proj_dim * self.n_heads + self.pruned_heads = self.pruned_heads.union(heads) + + @staticmethod + def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_position_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_position_if_large = torch.min( + relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) + return relative_buckets + + def compute_bias(self, query_length, key_length, device=None, cache_position=None): + """Compute binned relative position bias""" + if device is None: + device = self.relative_attention_bias.weight.device + if cache_position is None: + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + else: + context_position = cache_position[:, None].to(device) + memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] + relative_position = memory_position - context_position # shape (query_length, key_length) + relative_position_bucket = self._relative_position_bucket( + relative_position, # shape (query_length, key_length) + bidirectional=(not self.is_decoder), + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads) + values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) + return values + + def forward( + self, + hidden_states, + mask=None, + key_value_states=None, + position_bias=None, + past_key_value=None, + layer_head_mask=None, + query_length=None, + use_cache=False, + output_attentions=False, + cache_position=None, + ): + """ + Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). + """ + # Input is (batch_size, seq_length, dim) + # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder) + batch_size, seq_length = hidden_states.shape[:2] + + # if key_value_states are provided this layer is used as a cross-attention layer for the decoder + is_cross_attention = key_value_states is not None + + query_states = self.q(hidden_states) + query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + + if past_key_value is not None: + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: + # reuse k,v, cross_attentions + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] + else: + key_states = self.k(current_states) + value_states = self.v(current_states) + key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True + + # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + scores = torch.matmul(query_states, key_states.transpose(3, 2)) + + if position_bias is None: + key_length = key_states.shape[-2] + # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past) + real_seq_length = query_length if query_length is not None else cache_position[-1] + 1 + if not self.has_relative_attention_bias: + position_bias = torch.zeros( + (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype + ) + if self.gradient_checkpointing and self.training: + position_bias.requires_grad = True + else: + position_bias = self.compute_bias( + real_seq_length, key_length, device=scores.device, cache_position=cache_position + ) + position_bias = position_bias[:, :, -seq_length:, :] + + if mask is not None: + causal_mask = mask[:, :, :, : key_states.shape[-2]] + position_bias = position_bias + causal_mask + + if self.pruned_heads: + mask = torch.ones(position_bias.shape[1]) + mask[list(self.pruned_heads)] = 0 + position_bias_masked = position_bias[:, mask.bool()] + else: + position_bias_masked = position_bias + + scores += position_bias_masked + + # (batch_size, n_heads, seq_length, key_length) + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + # Mask heads if we want to + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask + + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(batch_size, -1, self.inner_dim) + attn_output = self.o(attn_output) + + outputs = (attn_output, past_key_value, position_bias) + + if output_attentions: + outputs = outputs + (attn_weights,) + return outputs + + +# Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5->Udop +class UdopLayerSelfAttention(nn.Module): + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): + super().__init__() + self.SelfAttention = UdopAttention( + config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx + ) + self.layer_norm = UdopLayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + use_cache=False, + output_attentions=False, + cache_position=None, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.SelfAttention( + normed_hidden_states, + mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + ) + hidden_states = hidden_states + self.dropout(attention_output[0]) + outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5->Udop +class UdopLayerCrossAttention(nn.Module): + def __init__(self, config, layer_idx: Optional[int] = None): + super().__init__() + self.EncDecAttention = UdopAttention(config, has_relative_attention_bias=False, layer_idx=layer_idx) + self.layer_norm = UdopLayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + hidden_states, + key_value_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + use_cache=False, + query_length=None, + output_attentions=False, + cache_position=None, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.EncDecAttention( + normed_hidden_states, + mask=attention_mask, + key_value_states=key_value_states, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + query_length=query_length, + output_attentions=output_attentions, + cache_position=cache_position, + ) + layer_output = hidden_states + self.dropout(attention_output[0]) + outputs = (layer_output,) + attention_output[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.t5.modeling_t5.T5Block with T5->Udop +class UdopBlock(nn.Module): + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): + super().__init__() + self.is_decoder = config.is_decoder + self.layer = nn.ModuleList() + self.layer.append( + UdopLayerSelfAttention( + config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx + ) + ) + if self.is_decoder: + self.layer.append(UdopLayerCrossAttention(config, layer_idx=layer_idx)) + + self.layer.append(UdopLayerFF(config)) + + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + encoder_decoder_position_bias=None, + layer_head_mask=None, + cross_attn_layer_head_mask=None, + past_key_value=None, + use_cache=False, + output_attentions=False, + return_dict=True, + cache_position=None, + ): + self_attention_outputs = self.layer[0]( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + ) + hidden_states, past_key_value = self_attention_outputs[:2] + attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(hidden_states).any(), + torch.finfo(hidden_states.dtype).max - 1000, + torch.finfo(hidden_states.dtype).max, + ) + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + do_cross_attention = self.is_decoder and encoder_hidden_states is not None + if do_cross_attention: + cross_attention_outputs = self.layer[1]( + hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + position_bias=encoder_decoder_position_bias, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=past_key_value, + query_length=cache_position[-1] + 1, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states, past_key_value = cross_attention_outputs[:2] + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(hidden_states).any(), + torch.finfo(hidden_states.dtype).max - 1000, + torch.finfo(hidden_states.dtype).max, + ) + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + # Keep cross-attention outputs and relative position weights + attention_outputs = attention_outputs + cross_attention_outputs[2:] + + # Apply Feed Forward layer + hidden_states = self.layer[-1](hidden_states) + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(hidden_states).any(), + torch.finfo(hidden_states.dtype).max - 1000, + torch.finfo(hidden_states.dtype).max, + ) + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if use_cache: + outputs = outputs + (past_key_value,) + attention_outputs + else: + outputs = outputs + attention_outputs + + return outputs # hidden-states, past_key_value, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + + +class UdopCellEmbeddings(nn.Module): + def __init__(self, max_2d_position_embeddings=501, hidden_size=1024): + super(UdopCellEmbeddings, self).__init__() + self.max_2d_position_embeddings = max_2d_position_embeddings + + self.x_position_embeddings = nn.Embedding(max_2d_position_embeddings, hidden_size) + self.y_position_embeddings = nn.Embedding(max_2d_position_embeddings, hidden_size) + + def forward(self, bbox): + bbox = torch.clip(bbox, 0.0, 1.0) + bbox = (bbox * (self.max_2d_position_embeddings - 1)).long() + left_position_embeddings = self.x_position_embeddings(bbox[:, :, 0]) + upper_position_embeddings = self.y_position_embeddings(bbox[:, :, 1]) + right_position_embeddings = self.x_position_embeddings(bbox[:, :, 2]) + lower_position_embeddings = self.y_position_embeddings(bbox[:, :, 3]) + + embeddings = ( + left_position_embeddings + + upper_position_embeddings + + right_position_embeddings + + lower_position_embeddings + ) + + return embeddings + + +# get function for bucket computation +# protected member access seems to be lesser evil than copy paste whole function +get_relative_position_bucket = UdopAttention._relative_position_bucket +AUGMENTATION_RANGE = (0.80, 1.25) + + +class RelativePositionBiasBase(nn.Module, ABC): + """ + Base class of relative biases. + + Args: + num_heads (`int`): + Number of attention heads in the model, it will create embeddings of size `num_heads`, which will be added to the scores of each token pair. + relative_attention_num_buckets (`int`, *optional*, defaults to 32): + Pair token metric (distance in the sequence, distance in pixels etc.) will be bucketed, parameter is defining number of such + buckets. + bidirectional (`bool`, *optional*, defaults to `True`): + Whether the distance should be bidirectional for a pair of tokens. If `False`, then distance(tok1, tok2) == distance(tok2, tok1). + scaling_factor (`int`, *optional*, defaults to 1): + Defining factor which will be used to scale relative distance. + max_distance (`int`, *optional*, defaults to 128): + All distances above this value will end up in the one/same bucket. + augmentation (`bool`, *optional*, defaults to `False`): + Whether to multiply relative distances by a random scalar. + expand (`bool`, *optional*, defaults to `False`): + Whether to expand an existing pretrained model with subsequent additions of prefix_bucket. + """ + + def __init__( + self, + num_heads=None, + relative_attention_num_buckets=32, + bidirectional=True, + scaling_factor=1, + max_distance=128, + level="tokens", + augmentation=False, + prefix_bucket=False, + expand=False, + ): + super(RelativePositionBiasBase, self).__init__() + self.prefix_bucket = prefix_bucket + self.augmentation = augmentation + self.level = level + self.max_distance = max_distance + self.scaling_factor = scaling_factor + self.bidirectional = bidirectional + self.num_heads = num_heads + self.expand = expand + self.relative_attention_num_buckets = relative_attention_num_buckets + extra_head = 2 if prefix_bucket and not self.expand else 0 + self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets + extra_head, self.num_heads) + + @abstractmethod + def prepare_input( + self, + attention_mask: Optional[Tensor] = None, + bbox: Optional[Dict[str, Any]] = None, + ) -> Tensor: + pass + + def get_bucket(self, attention_mask: Optional[Tensor] = None, bbox: Optional[Dict[str, Any]] = None) -> Tensor: + relative_position = self.prepare_input(attention_mask, bbox) + rp_bucket: Tensor = get_relative_position_bucket( + relative_position, + bidirectional=self.bidirectional, + num_buckets=self.relative_attention_num_buckets, + max_distance=self.max_distance, + ) + return rp_bucket + + def get_relative_position(self, positions): + context_position = positions[:, :, None] + memory_position = positions[:, None, :] + relative_position = memory_position - context_position + if self.augmentation and self.training: + relative_position *= random.uniform(*AUGMENTATION_RANGE) + relative_position *= self.scaling_factor + + return relative_position.to(torch.long) + + def forward(self, attention_mask: Optional[Tensor] = None, bbox: Optional[Dict[str, Any]] = None) -> Tensor: + # re-using pretrained model with subsequent addition of prefix_bucket + if self.expand and self.prefix_bucket: + new_bias = nn.Embedding(self.relative_attention_num_buckets + 2, self.num_heads) + new_bias.weight.data[: self.relative_attention_num_buckets] = self.relative_attention_bias.weight.data + new_bias.weight.data[self.relative_attention_num_buckets :] = 0.1 + self.relative_attention_bias = new_bias + self.expand = False + + rp_bucket = self.get_bucket(attention_mask, bbox) + + if self.prefix_bucket: + if rp_bucket.size(0) == 1 and attention_mask.size(0) > 1: + rp_bucket = rp_bucket.repeat(attention_mask.size(0), 1, 1) + # based on assumption that prefix bboxes are negative + is_prefix = bbox[:, :, 1] < 0 + num_prefix = is_prefix.sum(-1) + for idx, num_prefix_row in enumerate(num_prefix.cpu().numpy()): + rp_bucket[idx, :num_prefix_row, num_prefix_row:] = self.relative_attention_num_buckets + rp_bucket[idx, num_prefix_row:, :num_prefix_row] = self.relative_attention_num_buckets + 1 + + values: Tensor = self.relative_attention_bias(rp_bucket) + if values.dim() != 4: + raise ValueError("Wrong dimension of values tensor") + values = values.permute([0, 3, 1, 2]) + + return values + + +class RelativePositionBias1D(RelativePositionBiasBase): + def __init__(self, scaling_factor=1, max_distance=128, **kwargs): + """ + Reimplementation of T5 relative position bias. Distance between given tokens is their distance in the sequence. + Parameters are the same as in base class + """ + super().__init__(scaling_factor=scaling_factor, max_distance=max_distance, **kwargs) + + def prepare_input(self, attention_mask: Optional[Tensor] = None, bbox: Optional[Dict[str, Any]] = None) -> Tensor: + if self.scaling_factor != 1: + raise ValueError("No need to scale 1d features") + relative_position = self.get_relative_position( + torch.arange(attention_mask.size(1), dtype=torch.long, device=attention_mask.device)[None, :] + ) + + return relative_position + + +class RelativePositionBiasHorizontal(RelativePositionBiasBase): + def __init__(self, scaling_factor=100, max_distance=100, **kwargs): + """ + Represents in the bucket embeddings horizontal distance between two tokens. Parameters are the same as in base + class + """ + super().__init__(scaling_factor=scaling_factor, max_distance=max_distance, **kwargs) + + def prepare_input(self, attention_mask: Optional[Tensor] = None, bbox: Optional[Dict[str, Any]] = None) -> Tensor: + if not self.scaling_factor > 1.0: + raise ValueError("Need to scale the values of bboxes, as there are in small (0,1) range") + if bbox is None: + raise ValueError("Bbox is required for horizontal relative position bias") + # get x positions of left point of bbox + horizontal_position: Tensor = bbox[:, :, [0, 2]].mean(dim=-1) + + return self.get_relative_position(horizontal_position) + + +class RelativePositionBiasVertical(RelativePositionBiasBase): + def __init__(self, scaling_factor=100, max_distance=100, **kwargs): + """ + Represents in the bucket embeddings vertical distance between two tokens. Parameters are the same as in base + class + """ + super().__init__(scaling_factor=scaling_factor, max_distance=max_distance, **kwargs) + + def prepare_input(self, attention_mask: Optional[Tensor] = None, bbox: Optional[Dict[str, Any]] = None) -> Tensor: + if not self.scaling_factor > 1.0: + raise ValueError("Need to scale the values of bboxes, as there are in small (0,1) range") + if bbox is None: + raise ValueError("Bbox is required for vertical relative position bias") + # get y positions of middle of bbox + vertical_position: Tensor = bbox[:, :, [1, 3]].mean(dim=-1) + + return self.get_relative_position(vertical_position) + + +class RelativePositionBiasAggregated(nn.Module): + def __init__(self, modules: Sequence[RelativePositionBiasBase]): + """ + Class which sums up various computed biases. + + Args: + modules (Sequence[RelativePositionBiasBase]): + List of relative bias modules. + """ + super().__init__() + self.biases = nn.ModuleList(modules) + + def forward( + self, attention_mask: Optional[Tensor] = None, bbox: Optional[Dict[str, Any]] = None + ) -> Union[float, Tensor]: + output = 0.0 + for bias in self.biases: # type: ignore + output = bias(attention_mask, bbox) + output + + return output + + +BIAS_CLASSES = { + "1d": RelativePositionBias1D, + "horizontal": RelativePositionBiasHorizontal, + "vertical": RelativePositionBiasVertical, +} + + +def create_relative_bias(config: UdopConfig) -> Sequence[RelativePositionBiasBase]: + """ + Creates empty list or one/multiple relative biases. + + :param config: Model's configuration :return: Sequence with created bias modules. + """ + bias_list = [] + if hasattr(config, "relative_bias_args"): + for bias_kwargs_org in config.relative_bias_args: + bias_kwargs = deepcopy(bias_kwargs_org) + bias_type = bias_kwargs.pop("type") + model_num_heads = config.num_heads if hasattr(config, "num_heads") else config.num_attention_heads + if "num_heads" in bias_kwargs: + if bias_kwargs["num_heads"] != model_num_heads: + raise ValueError("Number of heads must match num of heads in the model") + else: + bias_kwargs["num_heads"] = model_num_heads + bias_list.append(BIAS_CLASSES[bias_type](**bias_kwargs)) # type: ignore + + return bias_list + + +class UdopStack(UdopPreTrainedModel): + """ + This class is based on `T5Stack`, but modified to take into account the image modality as well as 2D position + embeddings. + """ + + def __init__(self, config, embed_tokens=None, embed_patches=None): + super().__init__(config) + + self.embed_tokens = embed_tokens + self.embed_patches = embed_patches + self.is_decoder = config.is_decoder + self._max_length = config.max_length + self.num_layers = config.num_layers + + self.block = nn.ModuleList( + [UdopBlock(config, has_relative_attention_bias=bool(i == 0), layer_idx=i) for i in range(self.num_layers)] + ) + self.final_layer_norm = UdopLayerNorm(config.d_model, eps=config.layer_norm_epsilon) + + self.dropout = nn.Dropout(config.dropout_rate) + + if not self.is_decoder: + self.cell_2d_embedding = UdopCellEmbeddings(config.max_2d_position_embeddings, config.hidden_size) + + # get weights from encoder position bias + self.relative_bias = self._get_relative_bias(config) + + def _tie_weights(self): + for bias in self.relative_bias.biases: + if isinstance(bias, RelativePositionBias1D): + self._tie_or_clone_weights( + bias.relative_attention_bias, self.block[0].layer[0].SelfAttention.relative_attention_bias + ) + + @staticmethod + def _get_relative_bias(config: UdopConfig) -> RelativePositionBiasAggregated: + relative_bias_list = create_relative_bias(config) + return RelativePositionBiasAggregated(relative_bias_list) + + def get_input_embeddings(self): + return self.embed_tokens + + def get_output_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, new_embeddings): + self.embed_tokens = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + bbox=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + inputs_embeds=None, + pixel_values=None, + visual_bbox=None, + image_embeddings=None, + position_bias=None, + head_mask=None, + cross_attn_head_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + cache_position=None, + ): + use_cache = use_cache if use_cache is not None else self.config.use_cache + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # input embeddings processing + + if input_ids is not None and inputs_embeds is not None: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError( + f"You cannot specify both {err_msg_prefix}inputs and {err_msg_prefix}inputs_embeds at the same time" + ) + elif input_ids is not None and torch.numel(input_ids) > 0: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is None and input_ids is not None and torch.numel(input_ids) == 0: + input_ids = torch.full((4, 1024), self.config.pad_token_id, device=input_ids.device, dtype=input_ids.dtype) + attention_mask = torch.zeros((4, 1024), device=input_ids.device, dtype=input_ids.dtype) + bbox = torch.zeros((4, 1024, 4), device=input_ids.device, dtype=input_ids.dtype) + input_shape = input_ids.size() + position_bias = torch.zeros_like(self.get_extended_attention_mask(attention_mask, input_shape)) + # encoder_attention_mask = attention_mask + logger.warning("Empty batch") + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError(f"You have to specify either {err_msg_prefix}inputs or {err_msg_prefix}inputs_embeds") + + if inputs_embeds is None: + if self.embed_tokens is None: + raise ValueError("You have to intialize the model with valid token embeddings") + inputs_embeds = self.embed_tokens(input_ids) + + if pixel_values is not None: + image_embeddings = self.embed_patches(pixel_values) + + if image_embeddings is not None: + # combine visual and OCR text embeddings + num_patches = self.config.image_size // self.config.patch_size + inputs_embeds, bbox, attention_mask = combine_image_text_embeddings( + image_embeddings, + inputs_embeds, + bbox, + visual_bbox, + attention_mask, + num_patches, + 0, + self.config.image_size, + self.config.patch_size, + ) + input_shape = inputs_embeds.size()[:-1] + + if not self.is_decoder and bbox is not None: + inputs_embeds += self.cell_2d_embedding(bbox) + + batch_size, seq_length = input_shape + + if use_cache is True: + assert self.is_decoder, "`use_cache` can only be set to `True` if {} is used as a decoder".format(self) + + # initialize past_key_values + return_legacy_cache = False + return_self_attention_cache = False + if self.is_decoder and (use_cache or past_key_values is not None): + if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache): + return_self_attention_cache = True + past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) + elif not isinstance(past_key_values, EncoderDecoderCache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + elif past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + elif not self.is_decoder: + # do not pass cache object down the line for encoder stack + # it messes indexing later in decoder-stack because cache object is modified in-place + past_key_values = None + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device + ) + + if attention_mask is None and not is_torchdynamo_compiling(): + # required mask seq length can be calculated via length of past cache + mask_seq_length = past_key_values_length + seq_length + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + + if self.config.is_decoder: + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + past_key_values.self_attention_cache if past_key_values is not None else None, + output_attentions, + ) + else: + causal_mask = attention_mask[:, None, None, :] + causal_mask = causal_mask.to(dtype=inputs_embeds.dtype) + causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min + + if self.is_decoder and encoder_attention_mask is not None: + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + head_mask = self.get_head_mask(head_mask, self.num_layers) + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_cross_attentions = () if (output_attentions and self.is_decoder) else None + + if self.is_decoder: # modified lines + position_bias = None + else: + position_bias = self.relative_bias(attention_mask=attention_mask, bbox=bbox) + position_bias = position_bias + causal_mask + encoder_decoder_position_bias = None + + hidden_states = inputs_embeds + + hidden_states = self.dropout(hidden_states) + + for i, layer_module in enumerate(self.block): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_outputs = layer_module( + hidden_states, + attention_mask=causal_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + layer_head_mask=head_mask[i], + past_key_value=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + ) + # layer_outputs is a tuple with: + # hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) + if use_cache is False: # MP fixes + layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] + hidden_states, next_decoder_cache = layer_outputs[:2] + + # We share the position biases between the layers - the first layer store them + # layer_outputs = hidden-states, key-value-states (self-attention weights), + # (self-attention position bias), (cross-attention weights), (cross-attention position bias) + + position_bias = layer_outputs[2] + if self.is_decoder and encoder_hidden_states is not None: + encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[2],) # We keep only self-attention weights for now + if self.is_decoder: + all_cross_attentions = all_cross_attentions + (layer_outputs[5],) + + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_self_attention_cache: + next_cache = past_key_values.self_attention_cache + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + attention_mask, + next_cache, + all_hidden_states, + all_attentions, + all_cross_attentions, + ] + if v is not None + ) + + return BaseModelOutputWithAttentionMask( + last_hidden_state=hidden_states, + attention_mask=attention_mask, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ) + + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, "BlockMask"], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to place the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +@add_start_docstrings( + "The bare UDOP encoder-decoder Transformer outputting raw hidden-states without any specific head on top.", + UDOP_START_DOCSTRING, +) +class UdopModel(UdopPreTrainedModel): + _tied_weights_keys = [ + "encoder.embed_tokens.weight", + "decoder.embed_tokens.weight", + "encoder.embed_patches.proj.weight", + "encoder.embed_patches.proj.bias", + "encoder.relative_bias.biases.0.relative_attention_bias.weight", + "decoder.relative_bias.biases.0.relative_attention_bias.weight", + ] + + def __init__(self, config): + super(UdopModel, self).__init__(config) + + # text and image embeddings + self.shared = nn.Embedding(config.vocab_size, config.d_model) + self.patch_embed = UdopPatchEmbeddings(config) + + encoder_config = deepcopy(config) + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = UdopStack(encoder_config, self.shared, self.patch_embed) + + decoder_config = deepcopy(config) + decoder_config.is_decoder = True + decoder_config.is_encoder_decoder = False + decoder_config.num_layers = config.num_decoder_layers + self.decoder = UdopStack(decoder_config, self.shared) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + self.decoder.set_input_embeddings(new_embeddings) + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + @add_start_docstrings_to_model_forward(UDOP_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[Tensor] = None, + attention_mask: Optional[Tensor] = None, + bbox: Dict[str, Any] = None, + pixel_values: Optional[Tensor] = None, + visual_bbox: Dict[str, Any] = None, + decoder_input_ids: Optional[Tensor] = None, + decoder_attention_mask: Optional[Tensor] = None, + inputs_embeds: Optional[Tensor] = None, + encoder_outputs: Optional[Tensor] = None, + past_key_values: Optional[Tensor] = None, + head_mask: Optional[Tensor] = None, + decoder_inputs_embeds: Optional[Tensor] = None, + decoder_head_mask: Optional[Tensor] = None, + cross_attn_head_mask: Optional[Tensor] = None, + use_cache=True, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[Tensor, ...]: + r""" + Returns: + + Example: + + ```python + >>> from transformers import AutoProcessor, AutoModel + >>> from datasets import load_dataset + >>> import torch + + >>> # load model and processor + >>> # in this case, we already have performed OCR ourselves + >>> # so we initialize the processor with `apply_ocr=False` + >>> processor = AutoProcessor.from_pretrained("microsoft/udop-large", apply_ocr=False) + >>> model = AutoModel.from_pretrained("microsoft/udop-large") + + >>> # load an example image, along with the words and coordinates + >>> # which were extracted using an OCR engine + >>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train", trust_remote_code=True) + >>> example = dataset[0] + >>> image = example["image"] + >>> words = example["tokens"] + >>> boxes = example["bboxes"] + >>> inputs = processor(image, words, boxes=boxes, return_tensors="pt") + + >>> decoder_input_ids = torch.tensor([[model.config.decoder_start_token_id]]) + + >>> # forward pass + >>> outputs = model(**inputs, decoder_input_ids=decoder_input_ids) + >>> last_hidden_states = outputs.last_hidden_state + >>> list(last_hidden_states.shape) + [1, 1, 1024] + ```""" + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + bbox=bbox, + pixel_values=pixel_values, + visual_bbox=visual_bbox, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = encoder_outputs[0] + encoder_attention_mask = encoder_outputs.attention_mask if return_dict else encoder_outputs[1] + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=hidden_states, + encoder_attention_mask=encoder_attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + if not return_dict: + # we filter out the attention mask + decoder_outputs = tuple(value for idx, value in enumerate(decoder_outputs) if idx != 1) + encoder_outputs = tuple(value for idx, value in enumerate(encoder_outputs) if idx != 1) + return decoder_outputs + encoder_outputs + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + """The UDOP encoder-decoder Transformer with a language modeling head on top, enabling to generate text given document + images and an optional prompt. + + This class is based on [`T5ForConditionalGeneration`], extended to deal with images and layout (2D) data.""", + UDOP_START_DOCSTRING, +) +class UdopForConditionalGeneration(UdopPreTrainedModel, GenerationMixin): + _tied_weights_keys = [ + "encoder.embed_tokens.weight", + "decoder.embed_tokens.weight", + "encoder.embed_patches.proj.weight", + "encoder.embed_patches.proj.bias", + "encoder.relative_bias.biases.0.relative_attention_bias.weight", + "decoder.relative_bias.biases.0.relative_attention_bias.weight", + "lm_head.weight", + ] + + def __init__(self, config): + super(UdopForConditionalGeneration, self).__init__(config) + + # text and image embeddings + self.shared = nn.Embedding(config.vocab_size, config.d_model) + self.patch_embed = UdopPatchEmbeddings(config) + + encoder_config = deepcopy(config) + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = UdopStack(encoder_config, self.shared, self.patch_embed) + + decoder_config = deepcopy(config) + decoder_config.is_decoder = True + decoder_config.is_encoder_decoder = False + decoder_config.num_layers = config.num_decoder_layers + self.decoder = UdopStack(decoder_config, self.shared) + + # The weights of the language modeling head are shared with those of the encoder and decoder + self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + self.decoder.set_input_embeddings(new_embeddings) + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_output_embeddings(self): + return self.lm_head + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + @add_start_docstrings_to_model_forward(UDOP_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[Tensor] = None, + attention_mask: Optional[Tensor] = None, + bbox: Dict[str, Any] = None, + pixel_values: Optional[Tensor] = None, + visual_bbox: Dict[str, Any] = None, + decoder_input_ids: Optional[Tensor] = None, + decoder_attention_mask: Optional[Tensor] = None, + inputs_embeds: Optional[Tensor] = None, + encoder_outputs: Optional[Tensor] = None, + past_key_values: Optional[Tensor] = None, + head_mask: Optional[Tensor] = None, + decoder_inputs_embeds: Optional[Tensor] = None, + decoder_head_mask: Optional[Tensor] = None, + cross_attn_head_mask: Optional[Tensor] = None, + use_cache=True, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: Optional[Tensor] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[Tensor, ...]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the language modeling loss. Indices should be in `[-100, 0, ..., config.vocab_size - + 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., + config.vocab_size]`. + + Returns: + + Examples: + + ```python + >>> from transformers import AutoProcessor, UdopForConditionalGeneration + >>> from datasets import load_dataset + + >>> # load model and processor + >>> # in this case, we already have performed OCR ourselves + >>> # so we initialize the processor with `apply_ocr=False` + >>> processor = AutoProcessor.from_pretrained("microsoft/udop-large", apply_ocr=False) + >>> model = UdopForConditionalGeneration.from_pretrained("microsoft/udop-large") + + >>> # load an example image, along with the words and coordinates + >>> # which were extracted using an OCR engine + >>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train", trust_remote_code=True) + >>> example = dataset[0] + >>> image = example["image"] + >>> words = example["tokens"] + >>> boxes = example["bboxes"] + + >>> # one can use the various task prefixes (prompts) used during pre-training + >>> # e.g. the task prefix for DocVQA is "Question answering. " + >>> question = "Question answering. What is the date on the form?" + >>> encoding = processor(image, question, text_pair=words, boxes=boxes, return_tensors="pt") + + >>> # autoregressive generation + >>> predicted_ids = model.generate(**encoding) + >>> print(processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]) + 9/30/92 + ```""" + + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if decoder_input_ids is None and labels is not None: + decoder_input_ids = self._shift_right(labels) + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + bbox=bbox, + visual_bbox=visual_bbox, + pixel_values=pixel_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = encoder_outputs[0] + encoder_attention_mask = encoder_outputs.attention_mask if return_dict else encoder_outputs[1] + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=hidden_states, + encoder_attention_mask=encoder_attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + sequence_output = decoder_outputs[0] + + if self.config.tie_word_embeddings: + # Rescale output before projecting on vocab + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + sequence_output = sequence_output * (self.config.d_model**-0.5) + + lm_logits = self.lm_head(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss(ignore_index=-100) + loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) + + if not return_dict: + output = (lm_logits,) + decoder_outputs[2:] + (encoder_outputs[0],) + encoder_outputs[2:] + return ((loss,) + output) if loss is not None else output + + return Seq2SeqLMOutput( + loss=loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration._reorder_cache + def _reorder_cache(self, past_key_values, beam_idx): + # if decoder past is not included in output + # speedy decoding is disabled and no need to reorder + if past_key_values is None: + logger.warning("You might want to consider setting `use_cache=True` to speed up decoding") + return past_key_values + + reordered_decoder_past = () + for layer_past_states in past_key_values: + # get the correct batch idx from layer past batch dim + # batch dim of `past` is at 2nd position + reordered_layer_past_states = () + for layer_past_state in layer_past_states: + # need to set correct `past` for each of the four key / value states + reordered_layer_past_states = reordered_layer_past_states + ( + layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)), + ) + + if reordered_layer_past_states[0].shape != layer_past_states[0].shape: + raise ValueError( + f"reordered_layer_past_states[0] shape {reordered_layer_past_states[0].shape} and layer_past_states[0] shape {layer_past_states[0].shape} mismatched" + ) + if len(reordered_layer_past_states) != len(layer_past_states): + raise ValueError( + f"length of reordered_layer_past_states {len(reordered_layer_past_states)} and length of layer_past_states {len(layer_past_states)} mismatched" + ) + + reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,) + return reordered_decoder_past + + +@add_start_docstrings( + "The bare UDOP Model transformer outputting encoder's raw hidden-states without any specific head on top.", + UDOP_START_DOCSTRING, +) +class UdopEncoderModel(UdopPreTrainedModel): + _tied_weights_keys = [ + "encoder.embed_tokens.weight", + "encoder.embed_patches.proj.weight", + "encoder.embed_patches.proj.bias", + "encoder.relative_bias.biases.0.relative_attention_bias.weight", + ] + + def __init__(self, config: UdopConfig): + super().__init__(config) + + # text and image embeddings + self.shared = nn.Embedding(config.vocab_size, config.d_model) + self.patch_embed = UdopPatchEmbeddings(config) + + encoder_config = deepcopy(config) + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = UdopStack(encoder_config, self.shared, self.patch_embed) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + + def get_encoder(self): + return self.encoder + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.block[layer].layer[0].SelfAttention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(UDOP_ENCODER_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithAttentionMask, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[Tensor] = None, + bbox: Dict[str, Any] = None, + attention_mask: Optional[Tensor] = None, + pixel_values: Optional[Tensor] = None, + visual_bbox: Dict[str, Any] = None, + head_mask: Optional[Tensor] = None, + inputs_embeds: Optional[Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.FloatTensor], BaseModelOutputWithAttentionMask]: + r""" + Returns: + + Example: + + ```python + >>> from transformers import AutoProcessor, UdopEncoderModel + >>> from huggingface_hub import hf_hub_download + >>> from datasets import load_dataset + + >>> # load model and processor + >>> # in this case, we already have performed OCR ourselves + >>> # so we initialize the processor with `apply_ocr=False` + >>> processor = AutoProcessor.from_pretrained("microsoft/udop-large", apply_ocr=False) + >>> model = UdopEncoderModel.from_pretrained("microsoft/udop-large") + + >>> # load an example image, along with the words and coordinates + >>> # which were extracted using an OCR engine + >>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train", trust_remote_code=True) + >>> example = dataset[0] + >>> image = example["image"] + >>> words = example["tokens"] + >>> boxes = example["bboxes"] + >>> encoding = processor(image, words, boxes=boxes, return_tensors="pt") + + >>> outputs = model(**encoding) + >>> last_hidden_states = outputs.last_hidden_state + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_outputs = self.encoder( + input_ids=input_ids, + bbox=bbox, + visual_bbox=visual_bbox, + pixel_values=pixel_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + return encoder_outputs + + +__all__ = ["UdopForConditionalGeneration", "UdopPreTrainedModel", "UdopModel", "UdopEncoderModel"] diff --git a/docs/transformers/src/transformers/models/udop/processing_udop.py b/docs/transformers/src/transformers/models/udop/processing_udop.py new file mode 100644 index 0000000000000000000000000000000000000000..ce8847a5c3ac6b8b7062588dcc90dff560159f01 --- /dev/null +++ b/docs/transformers/src/transformers/models/udop/processing_udop.py @@ -0,0 +1,216 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Processor class for UDOP. +""" + +from typing import List, Optional, Union + +from transformers import logging + +from ...image_processing_utils import BatchFeature +from ...image_utils import ImageInput +from ...processing_utils import ProcessingKwargs, ProcessorMixin, TextKwargs, Unpack +from ...tokenization_utils_base import PreTokenizedInput, TextInput + + +logger = logging.get_logger(__name__) + + +class UdopTextKwargs(TextKwargs, total=False): + word_labels: Optional[Union[List[int], List[List[int]]]] + boxes: Union[List[List[int]], List[List[List[int]]]] + + +class UdopProcessorKwargs(ProcessingKwargs, total=False): + text_kwargs: UdopTextKwargs + _defaults = { + "text_kwargs": { + "add_special_tokens": True, + "padding": False, + "truncation": False, + "stride": 0, + "return_overflowing_tokens": False, + "return_special_tokens_mask": False, + "return_offsets_mapping": False, + "return_length": False, + "verbose": True, + }, + "images_kwargs": {}, + } + + +class UdopProcessor(ProcessorMixin): + r""" + Constructs a UDOP processor which combines a LayoutLMv3 image processor and a UDOP tokenizer into a single processor. + + [`UdopProcessor`] offers all the functionalities you need to prepare data for the model. + + It first uses [`LayoutLMv3ImageProcessor`] to resize, rescale and normalize document images, and optionally applies OCR + to get words and normalized bounding boxes. These are then provided to [`UdopTokenizer`] or [`UdopTokenizerFast`], + which turns the words and bounding boxes into token-level `input_ids`, `attention_mask`, `token_type_ids`, `bbox`. + Optionally, one can provide integer `word_labels`, which are turned into token-level `labels` for token + classification tasks (such as FUNSD, CORD). + + Additionally, it also supports passing `text_target` and `text_pair_target` to the tokenizer, which can be used to + prepare labels for language modeling tasks. + + Args: + image_processor (`LayoutLMv3ImageProcessor`): + An instance of [`LayoutLMv3ImageProcessor`]. The image processor is a required input. + tokenizer (`UdopTokenizer` or `UdopTokenizerFast`): + An instance of [`UdopTokenizer`] or [`UdopTokenizerFast`]. The tokenizer is a required input. + """ + + attributes = ["image_processor", "tokenizer"] + image_processor_class = "LayoutLMv3ImageProcessor" + tokenizer_class = ("UdopTokenizer", "UdopTokenizerFast") + # For backward compatibility. See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details. + optional_call_args = ["text_pair"] + + def __init__(self, image_processor, tokenizer): + super().__init__(image_processor, tokenizer) + + def __call__( + self, + images: Optional[ImageInput] = None, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + # The following is to capture `text_pair` argument that may be passed as a positional argument. + # See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details, + # or this conversation for more context: https://github.com/huggingface/transformers/pull/32544#discussion_r1720208116 + # This behavior is only needed for backward compatibility and will be removed in future versions. + # + *args, + audio=None, + videos=None, + **kwargs: Unpack[UdopProcessorKwargs], + ) -> BatchFeature: + """ + This method first forwards the `images` argument to [`~UdopImageProcessor.__call__`]. In case + [`UdopImageProcessor`] was initialized with `apply_ocr` set to `True`, it passes the obtained words and + bounding boxes along with the additional arguments to [`~UdopTokenizer.__call__`] and returns the output, + together with the prepared `pixel_values`. In case [`UdopImageProcessor`] was initialized with `apply_ocr` set + to `False`, it passes the words (`text`/``text_pair`) and `boxes` specified by the user along with the + additional arguments to [`~UdopTokenizer.__call__`] and returns the output, together with the prepared + `pixel_values`. + + Alternatively, one can pass `text_target` and `text_pair_target` to prepare the targets of UDOP. + + Please refer to the docstring of the above two methods for more information. + """ + # verify input + output_kwargs = self._merge_kwargs( + UdopProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + **self.prepare_and_validate_optional_call_args(*args), + ) + + boxes = output_kwargs["text_kwargs"].pop("boxes", None) + word_labels = output_kwargs["text_kwargs"].pop("word_labels", None) + text_pair = output_kwargs["text_kwargs"].pop("text_pair", None) + return_overflowing_tokens = output_kwargs["text_kwargs"].get("return_overflowing_tokens", False) + return_offsets_mapping = output_kwargs["text_kwargs"].get("return_offsets_mapping", False) + text_target = output_kwargs["text_kwargs"].get("text_target", None) + + if self.image_processor.apply_ocr and (boxes is not None): + raise ValueError( + "You cannot provide bounding boxes if you initialized the image processor with apply_ocr set to True." + ) + + if self.image_processor.apply_ocr and (word_labels is not None): + raise ValueError( + "You cannot provide word labels if you initialized the image processor with apply_ocr set to True." + ) + + if return_overflowing_tokens and not return_offsets_mapping: + raise ValueError("You cannot return overflowing tokens without returning the offsets mapping.") + + if text_target is not None: + # use the processor to prepare the targets of UDOP + return self.tokenizer( + **output_kwargs["text_kwargs"], + ) + + else: + # use the processor to prepare the inputs of UDOP + # first, apply the image processor + features = self.image_processor(images=images, **output_kwargs["images_kwargs"]) + features_words = features.pop("words", None) + features_boxes = features.pop("boxes", None) + + output_kwargs["text_kwargs"].pop("text_target", None) + output_kwargs["text_kwargs"].pop("text_pair_target", None) + output_kwargs["text_kwargs"]["text_pair"] = text_pair + output_kwargs["text_kwargs"]["boxes"] = boxes if boxes is not None else features_boxes + output_kwargs["text_kwargs"]["word_labels"] = word_labels + + # second, apply the tokenizer + if text is not None and self.image_processor.apply_ocr and text_pair is None: + if isinstance(text, str): + text = [text] # add batch dimension (as the image processor always adds a batch dimension) + output_kwargs["text_kwargs"]["text_pair"] = features_words + + encoded_inputs = self.tokenizer( + text=text if text is not None else features_words, + **output_kwargs["text_kwargs"], + ) + + # add pixel values + if return_overflowing_tokens is True: + features["pixel_values"] = self.get_overflowing_images( + features["pixel_values"], encoded_inputs["overflow_to_sample_mapping"] + ) + features.update(encoded_inputs) + + return features + + # Copied from transformers.models.layoutlmv3.processing_layoutlmv3.LayoutLMv3Processor.get_overflowing_images + def get_overflowing_images(self, images, overflow_to_sample_mapping): + # in case there's an overflow, ensure each `input_ids` sample is mapped to its corresponding image + images_with_overflow = [] + for sample_idx in overflow_to_sample_mapping: + images_with_overflow.append(images[sample_idx]) + + if len(images_with_overflow) != len(overflow_to_sample_mapping): + raise ValueError( + "Expected length of images to be the same as the length of `overflow_to_sample_mapping`, but got" + f" {len(images_with_overflow)} and {len(overflow_to_sample_mapping)}" + ) + + return images_with_overflow + + # Copied from transformers.models.layoutlmv3.processing_layoutlmv3.LayoutLMv3Processor.batch_decode + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + # Copied from transformers.models.layoutlmv3.processing_layoutlmv3.LayoutLMv3Processor.decode + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer + to the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + @property + def model_input_names(self): + return ["pixel_values", "input_ids", "bbox", "attention_mask"] + + +__all__ = ["UdopProcessor"] diff --git a/docs/transformers/src/transformers/models/udop/tokenization_udop.py b/docs/transformers/src/transformers/models/udop/tokenization_udop.py new file mode 100644 index 0000000000000000000000000000000000000000..08eccaec7ba8371c13a15189979cde660fb15e83 --- /dev/null +++ b/docs/transformers/src/transformers/models/udop/tokenization_udop.py @@ -0,0 +1,1490 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License +"""Tokenization classes for UDOP model.""" + +import os +import re +import warnings +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple, Union + +import sentencepiece as spm + +from ...tokenization_utils import PreTrainedTokenizer +from ...tokenization_utils_base import ( + AddedToken, + BatchEncoding, + EncodedInput, + PreTokenizedInput, + TextInput, + TextInputPair, + TruncationStrategy, +) +from ...utils import PaddingStrategy, TensorType, add_end_docstrings, logging +from ...utils.import_utils import requires + + +logger = logging.get_logger(__name__) + + +SPIECE_UNDERLINE = "▁" + + +UDOP_ENCODE_KWARGS_DOCSTRING = r""" + add_special_tokens (`bool`, *optional*, defaults to `True`): + Whether or not to encode the sequences with the special tokens relative to their model. + padding (`bool`, `str` or [`~file_utils.PaddingStrategy`], *optional*, defaults to `False`): + Activates and controls padding. Accepts the following values: + + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + truncation (`bool`, `str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*, defaults to `False`): + Activates and controls truncation. Accepts the following values: + + - `True` or `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or + to the maximum acceptable input length for the model if that argument is not provided. This will + truncate token by token, removing a token from the longest sequence in the pair if a pair of + sequences (or a batch of pairs) is provided. + - `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the + maximum acceptable input length for the model if that argument is not provided. This will only + truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + - `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the + maximum acceptable input length for the model if that argument is not provided. This will only + truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + - `False` or `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths + greater than the model maximum admissible input size). + max_length (`int`, *optional*): + Controls the maximum length to use by one of the truncation/padding parameters. + + If left unset or set to `None`, this will use the predefined model maximum length if a maximum length + is required by one of the truncation/padding parameters. If the model has no specific maximum input + length (like XLNet) truncation/padding to a maximum length will be deactivated. + stride (`int`, *optional*, defaults to 0): + If set to a number along with `max_length`, the overflowing tokens returned when + `return_overflowing_tokens=True` will contain some tokens from the end of the truncated sequence + returned to provide some overlap between truncated and overflowing sequences. The value of this + argument defines the number of overlapping tokens. + pad_to_multiple_of (`int`, *optional*): + If set will pad the sequence to a multiple of the provided value. This is especially useful to enable + the use of Tensor Cores on NVIDIA hardware with compute capability `>= 7.5` (Volta). + return_tensors (`str` or [`~file_utils.TensorType`], *optional*): + If set, will return tensors instead of list of python integers. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return Numpy `np.ndarray` objects. + return_token_type_ids (`bool`, *optional*): + Whether to return token type IDs. If left to the default, will return the token type IDs according to + the specific tokenizer's default, defined by the `return_outputs` attribute. + + [What are token type IDs?](../glossary#token-type-ids) + return_attention_mask (`bool`, *optional*): + Whether to return the attention mask. If left to the default, will return the attention mask according + to the specific tokenizer's default, defined by the `return_outputs` attribute. + + [What are attention masks?](../glossary#attention-mask) + return_overflowing_tokens (`bool`, *optional*, defaults to `False`): + Whether or not to return overflowing token sequences. If a pair of sequences of input ids (or a batch + of pairs) is provided with `truncation_strategy = longest_first` or `True`, an error is raised instead + of returning overflowing tokens. + return_special_tokens_mask (`bool`, *optional*, defaults to `False`): + Whether or not to return special tokens mask information. + return_offsets_mapping (`bool`, *optional*, defaults to `False`): + Whether or not to return `(char_start, char_end)` for each token. + + This is only available on fast tokenizers inheriting from [`PreTrainedTokenizerFast`], if using + Python's tokenizer, this method will raise `NotImplementedError`. + return_length (`bool`, *optional*, defaults to `False`): + Whether or not to return the lengths of the encoded inputs. + verbose (`bool`, *optional*, defaults to `True`): + Whether or not to print more information and warnings. + **kwargs: passed to the `self.tokenize()` method + + Return: + [`BatchEncoding`]: A [`BatchEncoding`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. + + [What are input IDs?](../glossary#input-ids) + + - **bbox** -- List of bounding boxes to be fed to a model. + + - **token_type_ids** -- List of token type ids to be fed to a model (when `return_token_type_ids=True` or + if *"token_type_ids"* is in `self.model_input_names`). + + [What are token type IDs?](../glossary#token-type-ids) + + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names`). + + [What are attention masks?](../glossary#attention-mask) + + - **labels** -- List of labels to be fed to a model. (when `word_labels` is specified). + - **overflowing_tokens** -- List of overflowing tokens sequences (when a `max_length` is specified and + `return_overflowing_tokens=True`). + - **num_truncated_tokens** -- Number of tokens truncated (when a `max_length` is specified and + `return_overflowing_tokens=True`). + - **special_tokens_mask** -- List of 0s and 1s, with 1 specifying added special tokens and 0 specifying + regular sequence tokens (when `add_special_tokens=True` and `return_special_tokens_mask=True`). + - **length** -- The length of the inputs (when `return_length=True`). +""" + +VOCAB_FILES_NAMES = {"vocab_file": "spiece.model", "tokenizer_file": "tokenizer.json"} + + +@requires(backends=("sentencepiece",)) +class UdopTokenizer(PreTrainedTokenizer): + """ + Adapted from [`LayoutXLMTokenizer`] and [`T5Tokenizer`]. Based on + [SentencePiece](https://github.com/google/sentencepiece). + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + + sep_token (`str`, *optional*, defaults to `""`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + sep_token_box (`List[int]`, *optional*, defaults to `[1000, 1000, 1000, 1000]`): + The bounding box to use for the special [SEP] token. + pad_token_box (`List[int]`, *optional*, defaults to `[0, 0, 0, 0]`): + The bounding box to use for the special [PAD] token. + pad_token_label (`int`, *optional*, defaults to -100): + The label to use for padding tokens. Defaults to -100, which is the `ignore_index` of PyTorch's + CrossEntropyLoss. + only_label_first_subword (`bool`, *optional*, defaults to `True`): + Whether or not to only label the first subword, in case word labels are provided. + additional_special_tokens (`List[str]`, *optional*, defaults to `["NOTUSED", "NOTUSED"]`): + Additional special tokens used by the tokenizer. + + sp_model_kwargs (`dict`, *optional*): + Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for + SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things, + to set: + + - `enable_sampling`: Enable subword regularization. + - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout. + + - `nbest_size = {0,1}`: No sampling is performed. + - `nbest_size > 1`: samples from the nbest_size results. + - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice) + using forward-filtering-and-backward-sampling algorithm. + + - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for + BPE-dropout. + legacy (`bool`, *optional*, defaults to `True`): + Whether or not the `legacy` behaviour of the tokenizer should be used. Legacy is before the merge of #24622 + which includes fixes to properly handle tokens that appear after special tokens. A simple example: + - `legacy=True`: + ```python + >>> from transformers import T5Tokenizer + + >>> tokenizer = T5Tokenizer.from_pretrained("t5-base", legacy=True) + >>> tokenizer.encode("Hello .") + [8774, 32099, 3, 5, 1] + ``` + - `legacy=False`: + ```python + >>> from transformers import T5Tokenizer + + >>> tokenizer = T5Tokenizer.from_pretrained("t5-base", legacy=False) + >>> tokenizer.encode("Hello .") # the extra space `[3]` is no longer here + [8774, 32099, 5, 1] + ``` + Checkout the pull request and the issue [here](https://github.com/huggingface/transformers/pull/24565) for + more details. + add_prefix_space (`bool`, *optional*, defaults to `True`): + Whether or not to add an initial space to the input. This allows to treat the leading word just as any + other word. + + + Attributes: + sp_model (`SentencePieceProcessor`): + The *SentencePiece* processor that is used for every conversion (string, tokens and IDs). + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + eos_token="", + unk_token="", + sep_token="", + pad_token="", + sep_token_box=[1000, 1000, 1000, 1000], + pad_token_box=[0, 0, 0, 0], + pad_token_label=-100, + only_label_first_subword=True, + additional_special_tokens=None, + sp_model_kwargs: Optional[Dict[str, Any]] = None, + legacy=True, + add_prefix_space=True, + **kwargs, + ) -> None: + eos_token = AddedToken(eos_token, special=True) if isinstance(eos_token, str) else eos_token + unk_token = AddedToken(unk_token, special=True) if isinstance(unk_token, str) else unk_token + sep_token = AddedToken(sep_token, special=True) if isinstance(sep_token, str) else sep_token + pad_token = AddedToken(pad_token, special=True) if isinstance(pad_token, str) else pad_token + + self.legacy = legacy + self.add_prefix_space = add_prefix_space + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + + self.vocab_file = vocab_file + + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(vocab_file) + + # additional properties + self.sep_token_box = sep_token_box + self.pad_token_box = pad_token_box + self.pad_token_label = pad_token_label + self.only_label_first_subword = only_label_first_subword + + super().__init__( + eos_token=eos_token, + unk_token=unk_token, + sep_token=sep_token, + pad_token=pad_token, + sep_token_box=sep_token_box, + pad_token_box=pad_token_box, + pad_token_label=pad_token_label, + only_label_first_subword=only_label_first_subword, + additional_special_tokens=additional_special_tokens, + sp_model_kwargs=self.sp_model_kwargs, + legacy=legacy, + add_prefix_space=add_prefix_space, + **kwargs, + ) + + @property + def vocab_size(self): + return len(self.sp_model) + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_vocab + def get_vocab(self): + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_special_tokens_mask + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + # normal case: some special tokens + if token_ids_1 is None: + return ([0] * len(token_ids_0)) + [1] + return ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_sentinel_tokens + def get_sentinel_tokens(self): + return list( + set(filter(lambda x: bool(re.search(r"", x)) is not None, self.additional_special_tokens)) + ) + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_sentinel_token_ids + def get_sentinel_token_ids(self): + return [self.convert_tokens_to_ids(token) for token in self.get_sentinel_tokens()] + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._add_eos_if_not_present + def _add_eos_if_not_present(self, token_ids: List[int]) -> List[int]: + """Do not add eos again if user already added it.""" + if len(token_ids) > 0 and token_ids[-1] == self.eos_token_id: + warnings.warn( + f"This sequence already has {self.eos_token}. In future versions this behavior may lead to duplicated" + " eos tokens being added." + ) + return token_ids + else: + return token_ids + [self.eos_token_id] + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.create_token_type_ids_from_sequences + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make + use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of zeros. + """ + eos = [self.eos_token_id] + + if token_ids_1 is None: + return len(token_ids_0 + eos) * [0] + return len(token_ids_0 + eos + token_ids_1 + eos) * [0] + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.build_inputs_with_special_tokens + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A sequence has the following format: + + - single sequence: `X ` + - pair of sequences: `A B ` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + token_ids_0 = self._add_eos_if_not_present(token_ids_0) + if token_ids_1 is None: + return token_ids_0 + else: + token_ids_1 = self._add_eos_if_not_present(token_ids_1) + return token_ids_0 + token_ids_1 + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.__getstate__ + def __getstate__(self): + state = self.__dict__.copy() + state["sp_model"] = None + return state + + def __setstate__(self, d): + self.__dict__.update(d) + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(self.vocab_file) + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.tokenize + def tokenize(self, text: "TextInput", **kwargs) -> List[str]: + """ + Converts a string to a list of tokens. If `self.legacy` is set to `False`, a prefix token is added unless the + first token is special. + """ + if self.legacy or len(text) == 0: + return super().tokenize(text, **kwargs) + + text = text.replace(SPIECE_UNDERLINE, " ") + if self.add_prefix_space: + text = SPIECE_UNDERLINE + text + + tokens = super().tokenize(text, **kwargs) + + if len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens: + tokens = tokens[1:] + return tokens + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._tokenize + def _tokenize(self, text, **kwargs): + """ + Returns a tokenized string. + + We de-activated the `add_dummy_prefix` option, thus the sentencepiece internals will always strip any + SPIECE_UNDERLINE. For example: `self.sp_model.encode(f"{SPIECE_UNDERLINE}Hey", out_type = str)` will give + `['H', 'e', 'y']` instead of `['▁He', 'y']`. Thus we always encode `f"{unk_token}text"` and strip the + `unk_token`. Here is an example with `unk_token = ""` and `unk_token_length = 4`. + `self.tokenizer.sp_model.encode(" Hey", out_type = str)[4:]`. + """ + if self.legacy or not text.startswith((SPIECE_UNDERLINE, " ")): + return self.sp_model.encode(text, out_type=str) + + # 1. Encode string + prefix ex: " Hey" + tokens = self.sp_model.encode(self.unk_token + text, out_type=str) + # 2. Remove self.unk_token from ['<','unk','>', '▁Hey'] + return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.sp_model.piece_to_id(token) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.sp_model.IdToPiece(index) + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.convert_tokens_to_string + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + # since we manually add the prefix space, we have to remove it when decoding + if tokens[0].startswith(SPIECE_UNDERLINE) and self.add_prefix_space: + tokens[0] = tokens[0][1:] + + current_sub_tokens = [] + out_string = "" + prev_is_special = False + for token in tokens: + # make sure that special tokens are not decoded using sentencepiece model + if token in self.all_special_tokens: + if not prev_is_special: + out_string += " " + out_string += self.sp_model.decode(current_sub_tokens) + token + prev_is_special = True + current_sub_tokens = [] + else: + current_sub_tokens.append(token) + prev_is_special = False + out_string += self.sp_model.decode(current_sub_tokens) + return out_string.strip() + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.save_vocabulary + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): + copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + + return (out_vocab_file,) + + @add_end_docstrings(UDOP_ENCODE_KWARGS_DOCSTRING) + def __call__( + self, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + text_pair: Optional[Union[PreTokenizedInput, List[PreTokenizedInput]]] = None, + boxes: Union[List[List[int]], List[List[List[int]]]] = None, + word_labels: Optional[Union[List[int], List[List[int]]]] = None, + text_target: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + text_pair_target: Optional[ + Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] + ] = None, + **kwargs, + ) -> BatchEncoding: + if text is None and text_target is None: + raise ValueError("You need to specify either `text` or `text_target`.") + if text is not None: + # The context manager will send the inputs as normal texts and not text_target, but we shouldn't change the + # input mode in this case. + if not self._in_target_context_manager: + self._switch_to_input_mode() + encodings = self.call_boxes(text=text, text_pair=text_pair, boxes=boxes, word_labels=word_labels, **kwargs) + if text_target is not None: + self._switch_to_target_mode() + target_encodings = self._call_one(text=text_target, text_pair=text_pair_target, **kwargs) + # Leave back tokenizer in input mode + self._switch_to_input_mode() + + if text_target is None: + return encodings + elif text is None: + return target_encodings + else: + encodings["labels"] = target_encodings["input_ids"] + return encodings + + def call_boxes( + self, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]], + text_pair: Optional[Union[PreTokenizedInput, List[PreTokenizedInput]]] = None, + boxes: Union[List[List[int]], List[List[List[int]]]] = None, + word_labels: Optional[Union[List[int], List[List[int]]]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + """ + Main method to tokenize and prepare for the model one or several sequence(s) or one or several pair(s) of + sequences with word-level normalized bounding boxes and optional labels. + + Args: + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string, a list of strings + (words of a single example or questions of a batch of examples) or a list of list of strings (batch of + words). + text_pair (`List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence should be a list of strings + (pretokenized string). + boxes (`List[List[int]]`, `List[List[List[int]]]`): + Word-level bounding boxes. Each bounding box should be normalized to be on a 0-1000 scale. + word_labels (`List[int]`, `List[List[int]]`, *optional*): + Word-level integer labels (for token classification tasks such as FUNSD, CORD). + """ + + # Input type checking for clearer error + def _is_valid_text_input(t): + if isinstance(t, str): + # Strings are fine + return True + elif isinstance(t, (list, tuple)): + # List are fine as long as they are... + if len(t) == 0: + # ... empty + return True + elif isinstance(t[0], str): + # ... list of strings + return True + elif isinstance(t[0], (list, tuple)): + # ... list with an empty list or with a list of strings + return len(t[0]) == 0 or isinstance(t[0][0], str) + else: + return False + else: + return False + + if text_pair is not None: + # in case text + text_pair are provided, text = questions, text_pair = words + if not _is_valid_text_input(text): + raise ValueError("text input must of type `str` (single example) or `List[str]` (batch of examples). ") + if not isinstance(text_pair, (list, tuple)): + raise ValueError( + "words must of type `List[str]` (single pretokenized example), " + "or `List[List[str]]` (batch of pretokenized examples)." + ) + else: + # in case only text is provided => must be words + if not isinstance(text, (list, tuple)): + raise ValueError( + "Words must of type `List[str]` (single pretokenized example), " + "or `List[List[str]]` (batch of pretokenized examples)." + ) + + if text_pair is not None: + is_batched = isinstance(text, (list, tuple)) + else: + is_batched = isinstance(text, (list, tuple)) and text and isinstance(text[0], (list, tuple)) + + words = text if text_pair is None else text_pair + if boxes is None: + raise ValueError("You must provide corresponding bounding boxes") + if is_batched: + if len(words) != len(boxes): + raise ValueError("You must provide words and boxes for an equal amount of examples") + for words_example, boxes_example in zip(words, boxes): + if len(words_example) != len(boxes_example): + raise ValueError("You must provide as many words as there are bounding boxes") + else: + if len(words) != len(boxes): + raise ValueError("You must provide as many words as there are bounding boxes") + + if is_batched: + if text_pair is not None and len(text) != len(text_pair): + raise ValueError( + f"batch length of `text`: {len(text)} does not match batch length of `text_pair`:" + f" {len(text_pair)}." + ) + batch_text_or_text_pairs = list(zip(text, text_pair)) if text_pair is not None else text + is_pair = bool(text_pair is not None) + return self.batch_encode_plus_boxes( + batch_text_or_text_pairs=batch_text_or_text_pairs, + is_pair=is_pair, + boxes=boxes, + word_labels=word_labels, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + else: + return self.encode_plus_boxes( + text=text, + text_pair=text_pair, + boxes=boxes, + word_labels=word_labels, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + + def batch_encode_plus_boxes( + self, + batch_text_or_text_pairs: Union[ + List[TextInput], + List[TextInputPair], + List[PreTokenizedInput], + ], + is_pair: Optional[bool] = None, + boxes: Optional[List[List[List[int]]]] = None, + word_labels: Optional[List[List[int]]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length: Optional[int] = None, + stride: int = 0, + is_split_into_words: bool = False, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + """ + Tokenize and prepare for the model a list of sequences or a list of pairs of sequences. + + Args: + batch_text_or_text_pairs (`List[str]`, `List[Tuple[str, str]]`, `List[List[str]]`, `List[Tuple[List[str], List[str]]]`, and for not-fast tokenizers, also `List[List[int]]`, `List[Tuple[List[int], List[int]]]`): + Batch of sequences or pair of sequences to be encoded. This can be a list of + string/string-sequences/int-sequences or a list of pair of string/string-sequences/int-sequence (see + details in `encode_plus`). + """ + + # Backward compatibility for 'truncation_strategy', 'pad_to_max_length' + padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies( + padding=padding, + truncation=truncation, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + verbose=verbose, + **kwargs, + ) + + return self._batch_encode_plus_boxes( + batch_text_or_text_pairs=batch_text_or_text_pairs, + is_pair=is_pair, + boxes=boxes, + word_labels=word_labels, + add_special_tokens=add_special_tokens, + padding_strategy=padding_strategy, + truncation_strategy=truncation_strategy, + max_length=max_length, + stride=stride, + is_split_into_words=is_split_into_words, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + + def encode_boxes( + self, + text: Union[TextInput, PreTokenizedInput, EncodedInput], + text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None, + boxes: Optional[List[List[int]]] = None, + word_labels: Optional[List[List[int]]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length: Optional[int] = None, + stride: int = 0, + return_tensors: Optional[Union[str, TensorType]] = None, + **kwargs, + ) -> List[int]: + """ + Args: + Converts a string to a sequence of ids (integer), using the tokenizer and vocabulary. Same as doing + `self.convert_tokens_to_ids(self.tokenize(text))`. + text (`str`, `List[str]` or `List[int]`): + The first sequence to be encoded. This can be a string, a list of strings (tokenized string using the + `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids` + method). + text_pair (`str`, `List[str]` or `List[int]`, *optional*): + Optional second sequence to be encoded. This can be a string, a list of strings (tokenized string using + the `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids` + method). + """ + encoded_inputs = self.encode_plus_boxes( + text, + text_pair=text_pair, + boxes=boxes, + word_labels=word_labels, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + stride=stride, + return_tensors=return_tensors, + **kwargs, + ) + + return encoded_inputs["input_ids"] + + def encode_plus_boxes( + self, + text: Union[TextInput, PreTokenizedInput], + text_pair: Optional[PreTokenizedInput] = None, + boxes: Optional[List[List[int]]] = None, + word_labels: Optional[List[List[int]]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length: Optional[int] = None, + stride: int = 0, + is_split_into_words: bool = False, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + """ + Tokenize and prepare for the model a sequence or a pair of sequences. + + + + This method is deprecated, `__call__` should be used instead. + + + + Args: + text (`str`, `List[str]` or (for non-fast tokenizers) `List[int]`): + The first sequence to be encoded. This can be a string, a list of strings (tokenized string using the + `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids` + method). + text_pair (`str`, `List[str]` or `List[int]`, *optional*): + Optional second sequence to be encoded. This can be a string, a list of strings (tokenized string using + the `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids` + method). + """ + + # Backward compatibility for 'truncation_strategy', 'pad_to_max_length' + padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies( + padding=padding, + truncation=truncation, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + verbose=verbose, + **kwargs, + ) + + return self._encode_plus_boxes( + text=text, + text_pair=text_pair, + boxes=boxes, + word_labels=word_labels, + add_special_tokens=add_special_tokens, + padding_strategy=padding_strategy, + truncation_strategy=truncation_strategy, + max_length=max_length, + stride=stride, + is_split_into_words=is_split_into_words, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + + def _batch_encode_plus_boxes( + self, + batch_text_or_text_pairs: Union[ + List[TextInput], + List[TextInputPair], + List[PreTokenizedInput], + ], + is_pair: Optional[bool] = None, + boxes: Optional[List[List[List[int]]]] = None, + word_labels: Optional[List[List[int]]] = None, + add_special_tokens: bool = True, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + if return_offsets_mapping: + raise NotImplementedError( + "return_offset_mapping is not available when using Python tokenizers. " + "To use this feature, change your tokenizer to one deriving from " + "transformers.PreTrainedTokenizerFast." + ) + + batch_outputs = self._batch_prepare_for_model_boxes( + batch_text_or_text_pairs=batch_text_or_text_pairs, + is_pair=is_pair, + boxes=boxes, + word_labels=word_labels, + add_special_tokens=add_special_tokens, + padding_strategy=padding_strategy, + truncation_strategy=truncation_strategy, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_attention_mask=return_attention_mask, + return_token_type_ids=return_token_type_ids, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_length=return_length, + return_tensors=return_tensors, + verbose=verbose, + ) + + return BatchEncoding(batch_outputs) + + @add_end_docstrings(UDOP_ENCODE_KWARGS_DOCSTRING) + def _batch_prepare_for_model_boxes( + self, + batch_text_or_text_pairs, + is_pair: Optional[bool] = None, + boxes: Optional[List[List[int]]] = None, + word_labels: Optional[List[List[int]]] = None, + add_special_tokens: bool = True, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[str] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_length: bool = False, + verbose: bool = True, + ) -> BatchEncoding: + """ + Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. It + adds special tokens, truncates sequences if overflowing while taking into account the special tokens and + manages a moving window (with user defined stride) for overflowing tokens + + Args: + batch_ids_pairs: list of tokenized input ids or input ids pairs + """ + + batch_outputs = {} + for idx, example in enumerate(zip(batch_text_or_text_pairs, boxes)): + batch_text_or_text_pair, boxes_example = example + outputs = self.prepare_for_model_boxes( + batch_text_or_text_pair[0] if is_pair else batch_text_or_text_pair, + batch_text_or_text_pair[1] if is_pair else None, + boxes_example, + word_labels=word_labels[idx] if word_labels is not None else None, + add_special_tokens=add_special_tokens, + padding=PaddingStrategy.DO_NOT_PAD.value, # we pad in batch afterward + truncation=truncation_strategy.value, + max_length=max_length, + stride=stride, + pad_to_multiple_of=None, # we pad in batch afterward + padding_side=None, # we pad in batch afterward + return_attention_mask=False, # we pad in batch afterward + return_token_type_ids=return_token_type_ids, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_length=return_length, + return_tensors=None, # We convert the whole batch to tensors at the end + prepend_batch_axis=False, + verbose=verbose, + ) + + for key, value in outputs.items(): + if key not in batch_outputs: + batch_outputs[key] = [] + batch_outputs[key].append(value) + + batch_outputs = self.pad( + batch_outputs, + padding=padding_strategy.value, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_attention_mask=return_attention_mask, + ) + + batch_outputs = BatchEncoding(batch_outputs, tensor_type=return_tensors) + + return batch_outputs + + def _encode_plus_boxes( + self, + text: Union[TextInput, PreTokenizedInput], + text_pair: Optional[PreTokenizedInput] = None, + boxes: Optional[List[List[int]]] = None, + word_labels: Optional[List[int]] = None, + add_special_tokens: bool = True, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + if return_offsets_mapping: + raise NotImplementedError( + "return_offset_mapping is not available when using Python tokenizers. " + "To use this feature, change your tokenizer to one deriving from " + "transformers.PreTrainedTokenizerFast. " + "More information on available tokenizers at " + "https://github.com/huggingface/transformers/pull/2674" + ) + + return self.prepare_for_model_boxes( + text=text, + text_pair=text_pair, + boxes=boxes, + word_labels=word_labels, + add_special_tokens=add_special_tokens, + padding=padding_strategy.value, + truncation=truncation_strategy.value, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_tensors=return_tensors, + prepend_batch_axis=True, + return_attention_mask=return_attention_mask, + return_token_type_ids=return_token_type_ids, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_length=return_length, + verbose=verbose, + ) + + @add_end_docstrings(UDOP_ENCODE_KWARGS_DOCSTRING) + def prepare_for_model_boxes( + self, + text: Union[TextInput, PreTokenizedInput], + text_pair: Optional[PreTokenizedInput] = None, + boxes: Optional[List[List[int]]] = None, + word_labels: Optional[List[int]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + prepend_batch_axis: bool = False, + **kwargs, + ) -> BatchEncoding: + """ + Prepares a sequence or a pair of sequences so that it can be used by the model. It adds special tokens, + truncates sequences if overflowing while taking into account the special tokens and manages a moving window + (with user defined stride) for overflowing tokens. + + Word-level `boxes` are turned into token-level `bbox`. If provided, word-level `word_labels` are turned into + token-level `labels`. The word label is used for the first token of the word, while remaining tokens are + labeled with -100, such that they will be ignored by the loss function. + + Args: + text (`str`, `List[str]`, `List[List[str]]`): + The first sequence to be encoded. This can be a string, a list of strings or a list of list of strings. + text_pair (`List[str]` or `List[int]`, *optional*): + Optional second sequence to be encoded. This can be a list of strings (words of a single example) or a + list of list of strings (words of a batch of examples). + """ + + # Backward compatibility for 'truncation_strategy', 'pad_to_max_length' + padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies( + padding=padding, + truncation=truncation, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + verbose=verbose, + **kwargs, + ) + + tokens = [] + pair_tokens = [] + token_boxes = [] + pair_token_boxes = [] + labels = [] + + if text_pair is None: + if word_labels is None: + # CASE 1: document image classification (training + inference) + CASE 2: token classification (inference) + for word, box in zip(text, boxes): + if len(word) < 1: # skip empty words + continue + word_tokens = self.tokenize(word) + tokens.extend(word_tokens) + token_boxes.extend([box] * len(word_tokens)) + else: + # CASE 2: token classification (training) + for word, box, label in zip(text, boxes, word_labels): + if len(word) < 1: # skip empty words + continue + word_tokens = self.tokenize(word) + tokens.extend(word_tokens) + token_boxes.extend([box] * len(word_tokens)) + if self.only_label_first_subword: + # Use the real label id for the first token of the word, and padding ids for the remaining tokens + labels.extend([label] + [self.pad_token_label] * (len(word_tokens) - 1)) + else: + labels.extend([label] * len(word_tokens)) + else: + # CASE 3: document visual question answering (inference) + # text = question + # text_pair = words + tokens = self.tokenize(text) + token_boxes = [self.pad_token_box for _ in range(len(tokens))] + + for word, box in zip(text_pair, boxes): + if len(word) < 1: # skip empty words + continue + word_tokens = self.tokenize(word) + pair_tokens.extend(word_tokens) + pair_token_boxes.extend([box] * len(word_tokens)) + + # Create ids + pair_ids + ids = self.convert_tokens_to_ids(tokens) + pair_ids = self.convert_tokens_to_ids(pair_tokens) if pair_tokens else None + + # Compute the total size of the returned encodings + pair = bool(pair_ids is not None) + len_ids = len(ids) + len_pair_ids = len(pair_ids) if pair else 0 + total_len = len_ids + len_pair_ids + (self.num_special_tokens_to_add(pair=pair) if add_special_tokens else 0) + + # Truncation: Handle max sequence length + overflowing_tokens = [] + overflowing_token_boxes = [] + overflowing_labels = [] + if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and max_length and total_len > max_length: + ( + ids, + token_boxes, + pair_ids, + pair_token_boxes, + labels, + overflowing_tokens, + overflowing_token_boxes, + overflowing_labels, + ) = self.truncate_sequences( + ids, + token_boxes, + pair_ids=pair_ids, + pair_token_boxes=pair_token_boxes, + labels=labels, + num_tokens_to_remove=total_len - max_length, + truncation_strategy=truncation_strategy, + stride=stride, + ) + + if return_token_type_ids and not add_special_tokens: + raise ValueError( + "Asking to return token_type_ids while setting add_special_tokens to False " + "results in an undefined behavior. Please set add_special_tokens to True or " + "set return_token_type_ids to None." + ) + + # Load from model defaults + if return_token_type_ids is None: + return_token_type_ids = "token_type_ids" in self.model_input_names + if return_attention_mask is None: + return_attention_mask = "attention_mask" in self.model_input_names + + encoded_inputs = {} + + if return_overflowing_tokens: + encoded_inputs["overflowing_tokens"] = overflowing_tokens + encoded_inputs["overflowing_token_boxes"] = overflowing_token_boxes + encoded_inputs["overflowing_labels"] = overflowing_labels + encoded_inputs["num_truncated_tokens"] = total_len - max_length + + # Add special tokens + if add_special_tokens: + sequence = self.build_inputs_with_special_tokens(ids, pair_ids) + token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids) + token_boxes = token_boxes + [self.sep_token_box] + if pair_token_boxes: + pair_token_boxes = pair_token_boxes + [self.sep_token_box] + if labels: + labels = labels + [self.pad_token_label] + else: + sequence = ids + pair_ids if pair else ids + token_type_ids = [0] * len(ids) + ([0] * len(pair_ids) if pair else []) + + # Build output dictionary + encoded_inputs["input_ids"] = sequence + encoded_inputs["bbox"] = token_boxes + pair_token_boxes + if return_token_type_ids: + encoded_inputs["token_type_ids"] = token_type_ids + if return_special_tokens_mask: + if add_special_tokens: + encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask(ids, pair_ids) + else: + encoded_inputs["special_tokens_mask"] = [0] * len(sequence) + + if labels: + encoded_inputs["labels"] = labels + + # Check lengths + self._eventual_warn_about_too_long_sequence(encoded_inputs["input_ids"], max_length, verbose) + + # Padding + if padding_strategy != PaddingStrategy.DO_NOT_PAD or return_attention_mask: + encoded_inputs = self.pad( + encoded_inputs, + max_length=max_length, + padding=padding_strategy.value, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_attention_mask=return_attention_mask, + ) + + if return_length: + encoded_inputs["length"] = len(encoded_inputs["input_ids"]) + + batch_outputs = BatchEncoding( + encoded_inputs, tensor_type=return_tensors, prepend_batch_axis=prepend_batch_axis + ) + + return batch_outputs + + # Copied from transformers.models.layoutxlm.tokenization_layoutxlm.LayoutXLMTokenizer.truncate_sequences + def truncate_sequences( + self, + ids: List[int], + token_boxes: List[List[int]], + pair_ids: Optional[List[int]] = None, + pair_token_boxes: Optional[List[List[int]]] = None, + labels: Optional[List[int]] = None, + num_tokens_to_remove: int = 0, + truncation_strategy: Union[str, TruncationStrategy] = "longest_first", + stride: int = 0, + ) -> Tuple[List[int], List[int], List[int]]: + """ + Truncates a sequence pair in-place following the strategy. + + Args: + ids (`List[int]`): + Tokenized input ids of the first sequence. Can be obtained from a string by chaining the `tokenize` and + `convert_tokens_to_ids` methods. + token_boxes (`List[List[int]]`): + Bounding boxes of the first sequence. + pair_ids (`List[int]`, *optional*): + Tokenized input ids of the second sequence. Can be obtained from a string by chaining the `tokenize` + and `convert_tokens_to_ids` methods. + pair_token_boxes (`List[List[int]]`, *optional*): + Bounding boxes of the second sequence. + labels (`List[int]`, *optional*): + Labels of the first sequence (for token classification tasks). + num_tokens_to_remove (`int`, *optional*, defaults to 0): + Number of tokens to remove using the truncation strategy. + truncation_strategy (`str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*, defaults to `False`): + The strategy to follow for truncation. Can be: + + - `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or to the + maximum acceptable input length for the model if that argument is not provided. This will truncate + token by token, removing a token from the longest sequence in the pair if a pair of sequences (or a + batch of pairs) is provided. + - `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the + maximum acceptable input length for the model if that argument is not provided. This will only + truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + - `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the + maximum acceptable input length for the model if that argument is not provided. This will only + truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + - `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths greater + than the model maximum admissible input size). + stride (`int`, *optional*, defaults to 0): + If set to a positive number, the overflowing tokens returned will contain some tokens from the main + sequence returned. The value of this argument defines the number of additional tokens. + + Returns: + `Tuple[List[int], List[int], List[int]]`: The truncated `ids`, the truncated `pair_ids` and the list of + overflowing tokens. + """ + if num_tokens_to_remove <= 0: + return ids, token_boxes, pair_ids, pair_token_boxes, labels, [], [], [] + + if not isinstance(truncation_strategy, TruncationStrategy): + truncation_strategy = TruncationStrategy(truncation_strategy) + + overflowing_tokens = [] + overflowing_token_boxes = [] + overflowing_labels = [] + if truncation_strategy == TruncationStrategy.LONGEST_FIRST: + for _ in range(num_tokens_to_remove): + if pair_ids is None or len(ids) > len(pair_ids): + if not overflowing_tokens: + window_len = min(len(ids), stride + 1) + else: + window_len = 1 + overflowing_tokens.extend(ids[-window_len:]) + overflowing_token_boxes.extend(token_boxes[-window_len:]) + overflowing_labels.extend(labels[-window_len:]) + ids = ids[:-1] + token_boxes = token_boxes[:-1] + labels = labels[:-1] + else: + if not overflowing_tokens: + window_len = min(len(pair_ids), stride + 1) + else: + window_len = 1 + overflowing_tokens.extend(pair_ids[-window_len:]) + overflowing_token_boxes.extend(pair_token_boxes[-window_len:]) + pair_ids = pair_ids[:-1] + pair_token_boxes = pair_token_boxes[:-1] + elif truncation_strategy == TruncationStrategy.ONLY_FIRST: + if len(ids) > num_tokens_to_remove: + window_len = min(len(ids), stride + num_tokens_to_remove) + overflowing_tokens = ids[-window_len:] + overflowing_token_boxes = token_boxes[-window_len:] + overflowing_labels = labels[-window_len:] + ids = ids[:-num_tokens_to_remove] + token_boxes = token_boxes[:-num_tokens_to_remove] + labels = labels[:-num_tokens_to_remove] + else: + logger.error( + f"We need to remove {num_tokens_to_remove} to truncate the input " + f"but the first sequence has a length {len(ids)}. " + f"Please select another truncation strategy than {truncation_strategy}, " + "for instance 'longest_first' or 'only_second'." + ) + elif truncation_strategy == TruncationStrategy.ONLY_SECOND and pair_ids is not None: + if len(pair_ids) > num_tokens_to_remove: + window_len = min(len(pair_ids), stride + num_tokens_to_remove) + overflowing_tokens = pair_ids[-window_len:] + overflowing_token_boxes = pair_token_boxes[-window_len:] + pair_ids = pair_ids[:-num_tokens_to_remove] + pair_token_boxes = pair_token_boxes[:-num_tokens_to_remove] + else: + logger.error( + f"We need to remove {num_tokens_to_remove} to truncate the input " + f"but the second sequence has a length {len(pair_ids)}. " + f"Please select another truncation strategy than {truncation_strategy}, " + "for instance 'longest_first' or 'only_first'." + ) + + return ( + ids, + token_boxes, + pair_ids, + pair_token_boxes, + labels, + overflowing_tokens, + overflowing_token_boxes, + overflowing_labels, + ) + + # Copied from transformers.models.layoutxlm.tokenization_layoutxlm.LayoutXLMTokenizer._pad + def _pad( + self, + encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding], + max_length: Optional[int] = None, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_attention_mask: Optional[bool] = None, + ) -> dict: + """ + Pad encoded inputs (on left/right and up to predefined length or max length in the batch) + + Args: + encoded_inputs: + Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`). + max_length: maximum length of the returned list and optionally padding length (see below). + Will truncate by taking into account the special tokens. + padding_strategy: PaddingStrategy to use for padding. + + - PaddingStrategy.LONGEST Pad to the longest sequence in the batch + - PaddingStrategy.MAX_LENGTH: Pad to the max length (default) + - PaddingStrategy.DO_NOT_PAD: Do not pad + The tokenizer padding sides are defined in self.padding_side: + + - 'left': pads on the left of the sequences + - 'right': pads on the right of the sequences + pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value. + This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability + `>= 7.5` (Volta). + padding_side (`str`, *optional*): + The side on which the model should have padding applied. Should be selected between ['right', 'left']. + Default value is picked from the class attribute of the same name. + return_attention_mask: + (optional) Set to False to avoid returning attention mask (default: set to model specifics) + """ + # Load from model defaults + if return_attention_mask is None: + return_attention_mask = "attention_mask" in self.model_input_names + + required_input = encoded_inputs[self.model_input_names[0]] + + if padding_strategy == PaddingStrategy.LONGEST: + max_length = len(required_input) + + if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): + max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of + + needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length + + # Initialize attention mask if not present. + if return_attention_mask and "attention_mask" not in encoded_inputs: + encoded_inputs["attention_mask"] = [1] * len(required_input) + + if needs_to_be_padded: + difference = max_length - len(required_input) + padding_side = padding_side if padding_side is not None else self.padding_side + if padding_side == "right": + if return_attention_mask: + encoded_inputs["attention_mask"] = encoded_inputs["attention_mask"] + [0] * difference + if "token_type_ids" in encoded_inputs: + encoded_inputs["token_type_ids"] = ( + encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference + ) + if "bbox" in encoded_inputs: + encoded_inputs["bbox"] = encoded_inputs["bbox"] + [self.pad_token_box] * difference + if "labels" in encoded_inputs: + encoded_inputs["labels"] = encoded_inputs["labels"] + [self.pad_token_label] * difference + if "special_tokens_mask" in encoded_inputs: + encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference + encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference + elif padding_side == "left": + if return_attention_mask: + encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"] + if "token_type_ids" in encoded_inputs: + encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[ + "token_type_ids" + ] + if "bbox" in encoded_inputs: + encoded_inputs["bbox"] = [self.pad_token_box] * difference + encoded_inputs["bbox"] + if "labels" in encoded_inputs: + encoded_inputs["labels"] = [self.pad_token_label] * difference + encoded_inputs["labels"] + if "special_tokens_mask" in encoded_inputs: + encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"] + encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input + else: + raise ValueError("Invalid padding strategy:" + str(padding_side)) + + return encoded_inputs + + +__all__ = ["UdopTokenizer"] diff --git a/docs/transformers/src/transformers/models/udop/tokenization_udop_fast.py b/docs/transformers/src/transformers/models/udop/tokenization_udop_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..9992da7bddc463c91ecb1ab8e50f20722c34f921 --- /dev/null +++ b/docs/transformers/src/transformers/models/udop/tokenization_udop_fast.py @@ -0,0 +1,1031 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License +"""Tokenization classes for UDOP model.""" + +import os +from shutil import copyfile +from typing import Dict, List, Optional, Tuple, Union + +from ...tokenization_utils_base import ( + BatchEncoding, + EncodedInput, + PreTokenizedInput, + TextInput, + TextInputPair, + TruncationStrategy, +) +from ...tokenization_utils_fast import PreTrainedTokenizerFast +from ...utils import PaddingStrategy, TensorType, add_end_docstrings, is_sentencepiece_available, logging + + +if is_sentencepiece_available(): + from .tokenization_udop import UdopTokenizer +else: + UdopTokenizer = None + + +VOCAB_FILES_NAMES = {"vocab_file": "spiece.model", "tokenizer_file": "tokenizer.json"} + + +logger = logging.get_logger(__name__) + +UDOP_ENCODE_KWARGS_DOCSTRING = r""" + add_special_tokens (`bool`, *optional*, defaults to `True`): + Whether or not to encode the sequences with the special tokens relative to their model. + padding (`bool`, `str` or [`~file_utils.PaddingStrategy`], *optional*, defaults to `False`): + Activates and controls padding. Accepts the following values: + + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + truncation (`bool`, `str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*, defaults to `False`): + Activates and controls truncation. Accepts the following values: + + - `True` or `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or + to the maximum acceptable input length for the model if that argument is not provided. This will + truncate token by token, removing a token from the longest sequence in the pair if a pair of + sequences (or a batch of pairs) is provided. + - `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the + maximum acceptable input length for the model if that argument is not provided. This will only + truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + - `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the + maximum acceptable input length for the model if that argument is not provided. This will only + truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + - `False` or `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths + greater than the model maximum admissible input size). + max_length (`int`, *optional*): + Controls the maximum length to use by one of the truncation/padding parameters. + + If left unset or set to `None`, this will use the predefined model maximum length if a maximum length + is required by one of the truncation/padding parameters. If the model has no specific maximum input + length (like XLNet) truncation/padding to a maximum length will be deactivated. + stride (`int`, *optional*, defaults to 0): + If set to a number along with `max_length`, the overflowing tokens returned when + `return_overflowing_tokens=True` will contain some tokens from the end of the truncated sequence + returned to provide some overlap between truncated and overflowing sequences. The value of this + argument defines the number of overlapping tokens. + pad_to_multiple_of (`int`, *optional*): + If set will pad the sequence to a multiple of the provided value. This is especially useful to enable + the use of Tensor Cores on NVIDIA hardware with compute capability `>= 7.5` (Volta). + return_tensors (`str` or [`~file_utils.TensorType`], *optional*): + If set, will return tensors instead of list of python integers. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return Numpy `np.ndarray` objects. + return_token_type_ids (`bool`, *optional*): + Whether to return token type IDs. If left to the default, will return the token type IDs according to + the specific tokenizer's default, defined by the `return_outputs` attribute. + + [What are token type IDs?](../glossary#token-type-ids) + return_attention_mask (`bool`, *optional*): + Whether to return the attention mask. If left to the default, will return the attention mask according + to the specific tokenizer's default, defined by the `return_outputs` attribute. + + [What are attention masks?](../glossary#attention-mask) + return_overflowing_tokens (`bool`, *optional*, defaults to `False`): + Whether or not to return overflowing token sequences. If a pair of sequences of input ids (or a batch + of pairs) is provided with `truncation_strategy = longest_first` or `True`, an error is raised instead + of returning overflowing tokens. + return_special_tokens_mask (`bool`, *optional*, defaults to `False`): + Whether or not to return special tokens mask information. + return_offsets_mapping (`bool`, *optional*, defaults to `False`): + Whether or not to return `(char_start, char_end)` for each token. + + This is only available on fast tokenizers inheriting from [`PreTrainedTokenizerFast`], if using + Python's tokenizer, this method will raise `NotImplementedError`. + return_length (`bool`, *optional*, defaults to `False`): + Whether or not to return the lengths of the encoded inputs. + verbose (`bool`, *optional*, defaults to `True`): + Whether or not to print more information and warnings. + **kwargs: passed to the `self.tokenize()` method + + Return: + [`BatchEncoding`]: A [`BatchEncoding`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. + + [What are input IDs?](../glossary#input-ids) + + - **bbox** -- List of bounding boxes to be fed to a model. + + - **token_type_ids** -- List of token type ids to be fed to a model (when `return_token_type_ids=True` or + if *"token_type_ids"* is in `self.model_input_names`). + + [What are token type IDs?](../glossary#token-type-ids) + + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names`). + + [What are attention masks?](../glossary#attention-mask) + + - **labels** -- List of labels to be fed to a model. (when `word_labels` is specified). + - **overflowing_tokens** -- List of overflowing tokens sequences (when a `max_length` is specified and + `return_overflowing_tokens=True`). + - **num_truncated_tokens** -- Number of tokens truncated (when a `max_length` is specified and + `return_overflowing_tokens=True`). + - **special_tokens_mask** -- List of 0s and 1s, with 1 specifying added special tokens and 0 specifying + regular sequence tokens (when `add_special_tokens=True` and `return_special_tokens_mask=True`). + - **length** -- The length of the inputs (when `return_length=True`). +""" + + +class UdopTokenizerFast(PreTrainedTokenizerFast): + """ + Construct a "fast" UDOP tokenizer (backed by HuggingFace's *tokenizers* library). Adapted from + [`LayoutXLMTokenizer`] and [`T5Tokenizer`]. Based on + [BPE](https://huggingface.co/docs/tokenizers/python/latest/components.html?highlight=BPE#models). + + This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should + refer to this superclass for more information regarding those methods. + + Args: + vocab_file (`str`, *optional*): + Path to the vocabulary file. + + tokenizer_file (`str`, *optional*): + Path to the tokenizer file. + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + sep_token (`str`, *optional*, defaults to `""`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + sep_token_box (`List[int]`, *optional*, defaults to `[1000, 1000, 1000, 1000]`): + The bounding box to use for the special [SEP] token. + pad_token_box (`List[int]`, *optional*, defaults to `[0, 0, 0, 0]`): + The bounding box to use for the special [PAD] token. + pad_token_label (`int`, *optional*, defaults to -100): + The label to use for padding tokens. Defaults to -100, which is the `ignore_index` of PyTorch's + CrossEntropyLoss. + only_label_first_subword (`bool`, *optional*, defaults to `True`): + Whether or not to only label the first subword, in case word labels are provided. + additional_special_tokens (`List[str]`, *optional*, defaults to `["NOTUSED", "NOTUSED"]`): + Additional special tokens used by the tokenizer. + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + slow_tokenizer_class = UdopTokenizer + + def __init__( + self, + vocab_file=None, + tokenizer_file=None, + eos_token="", + sep_token="", + unk_token="", + pad_token="", + sep_token_box=[1000, 1000, 1000, 1000], + pad_token_box=[0, 0, 0, 0], + pad_token_label=-100, + only_label_first_subword=True, + additional_special_tokens=None, + **kwargs, + ): + super().__init__( + vocab_file, + tokenizer_file=tokenizer_file, + eos_token=eos_token, + sep_token=sep_token, + unk_token=unk_token, + pad_token=pad_token, + sep_token_box=sep_token_box, + pad_token_box=pad_token_box, + pad_token_label=pad_token_label, + only_label_first_subword=only_label_first_subword, + additional_special_tokens=additional_special_tokens, + **kwargs, + ) + + self.vocab_file = vocab_file + + # additional properties + self.sep_token_box = sep_token_box + self.pad_token_box = pad_token_box + self.pad_token_label = pad_token_label + self.only_label_first_subword = only_label_first_subword + + @property + def can_save_slow_tokenizer(self) -> bool: + return os.path.isfile(self.vocab_file) if self.vocab_file else False + + @add_end_docstrings(UDOP_ENCODE_KWARGS_DOCSTRING) + def __call__( + self, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + text_pair: Optional[Union[PreTokenizedInput, List[PreTokenizedInput]]] = None, + boxes: Union[List[List[int]], List[List[List[int]]]] = None, + word_labels: Optional[Union[List[int], List[List[int]]]] = None, + text_target: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + text_pair_target: Optional[ + Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] + ] = None, + **kwargs, + ) -> BatchEncoding: + if text is None and text_target is None: + raise ValueError("You need to specify either `text` or `text_target`.") + if text is not None: + # The context manager will send the inputs as normal texts and not text_target, but we shouldn't change the + # input mode in this case. + if not self._in_target_context_manager: + self._switch_to_input_mode() + encodings = self.call_boxes(text=text, text_pair=text_pair, boxes=boxes, word_labels=word_labels, **kwargs) + if text_target is not None: + self._switch_to_target_mode() + target_encodings = self._call_one(text=text_target, text_pair=text_pair_target, **kwargs) + # Leave back tokenizer in input mode + self._switch_to_input_mode() + + if text_target is None: + return encodings + elif text is None: + return target_encodings + else: + encodings["labels"] = target_encodings["input_ids"] + return encodings + + @add_end_docstrings(UDOP_ENCODE_KWARGS_DOCSTRING) + def call_boxes( + self, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]], + text_pair: Optional[Union[PreTokenizedInput, List[PreTokenizedInput]]] = None, + boxes: Union[List[List[int]], List[List[List[int]]]] = None, + word_labels: Optional[Union[List[int], List[List[int]]]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + """ + Main method to tokenize and prepare for the model one or several sequence(s) or one or several pair(s) of + sequences with word-level normalized bounding boxes and optional labels. + + Args: + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string, a list of strings + (words of a single example or questions of a batch of examples) or a list of list of strings (batch of + words). + text_pair (`List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence should be a list of strings + (pretokenized string). + boxes (`List[List[int]]`, `List[List[List[int]]]`): + Word-level bounding boxes. Each bounding box should be normalized to be on a 0-1000 scale. + word_labels (`List[int]`, `List[List[int]]`, *optional*): + Word-level integer labels (for token classification tasks such as FUNSD, CORD). + """ + + # Input type checking for clearer error + def _is_valid_text_input(t): + if isinstance(t, str): + # Strings are fine + return True + elif isinstance(t, (list, tuple)): + # List are fine as long as they are... + if len(t) == 0: + # ... empty + return True + elif isinstance(t[0], str): + # ... list of strings + return True + elif isinstance(t[0], (list, tuple)): + # ... list with an empty list or with a list of strings + return len(t[0]) == 0 or isinstance(t[0][0], str) + else: + return False + else: + return False + + if text_pair is not None: + # in case text + text_pair are provided, text = questions, text_pair = words + if not _is_valid_text_input(text): + raise ValueError("text input must of type `str` (single example) or `List[str]` (batch of examples). ") + if not isinstance(text_pair, (list, tuple)): + raise ValueError( + "words must of type `List[str]` (single pretokenized example), " + "or `List[List[str]]` (batch of pretokenized examples)." + ) + else: + # in case only text is provided => must be words + if not isinstance(text, (list, tuple)): + raise ValueError( + "Words must of type `List[str]` (single pretokenized example), " + "or `List[List[str]]` (batch of pretokenized examples)." + ) + + if text_pair is not None: + is_batched = isinstance(text, (list, tuple)) + else: + is_batched = isinstance(text, (list, tuple)) and text and isinstance(text[0], (list, tuple)) + + words = text if text_pair is None else text_pair + if boxes is None: + raise ValueError("You must provide corresponding bounding boxes") + if is_batched: + if len(words) != len(boxes): + raise ValueError("You must provide words and boxes for an equal amount of examples") + for words_example, boxes_example in zip(words, boxes): + if len(words_example) != len(boxes_example): + raise ValueError("You must provide as many words as there are bounding boxes") + else: + if len(words) != len(boxes): + raise ValueError("You must provide as many words as there are bounding boxes") + + if is_batched: + if text_pair is not None and len(text) != len(text_pair): + raise ValueError( + f"batch length of `text`: {len(text)} does not match batch length of `text_pair`:" + f" {len(text_pair)}." + ) + batch_text_or_text_pairs = list(zip(text, text_pair)) if text_pair is not None else text + is_pair = bool(text_pair is not None) + return self.batch_encode_plus_boxes( + batch_text_or_text_pairs=batch_text_or_text_pairs, + is_pair=is_pair, + boxes=boxes, + word_labels=word_labels, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + else: + return self.encode_plus_boxes( + text=text, + text_pair=text_pair, + boxes=boxes, + word_labels=word_labels, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + + # Copied from transformers.models.layoutxlm.tokenization_layoutxlm_fast.LayoutXLMTokenizerFast.tokenize + def tokenize(self, text: str, pair: Optional[str] = None, add_special_tokens: bool = False, **kwargs) -> List[str]: + batched_input = [(text, pair)] if pair else [text] + + self._tokenizer.encode_special_tokens = kwargs.pop( + "split_special_tokens", self._tokenizer.encode_special_tokens + ) + + encodings = self._tokenizer.encode_batch( + batched_input, add_special_tokens=add_special_tokens, is_pretokenized=False, **kwargs + ) + + return encodings[0].tokens + + def batch_encode_plus_boxes( + self, + batch_text_or_text_pairs: Union[ + List[TextInput], + List[TextInputPair], + List[PreTokenizedInput], + ], + is_pair: Optional[bool] = None, + boxes: Optional[List[List[List[int]]]] = None, + word_labels: Optional[List[List[int]]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length: Optional[int] = None, + stride: int = 0, + is_split_into_words: bool = False, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + """ + Tokenize and prepare for the model a list of sequences or a list of pairs of sequences. + + + + This method is deprecated, `__call__` should be used instead. + + + + Args: + batch_text_or_text_pairs (`List[str]`, `List[Tuple[str, str]]`, `List[List[str]]`, `List[Tuple[List[str], List[str]]]`, and for not-fast tokenizers, also `List[List[int]]`, `List[Tuple[List[int], List[int]]]`): + Batch of sequences or pair of sequences to be encoded. This can be a list of + string/string-sequences/int-sequences or a list of pair of string/string-sequences/int-sequence (see + details in `encode_plus`). + """ + + # Backward compatibility for 'truncation_strategy', 'pad_to_max_length' + padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies( + padding=padding, + truncation=truncation, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + verbose=verbose, + **kwargs, + ) + + return self._batch_encode_plus_boxes( + batch_text_or_text_pairs=batch_text_or_text_pairs, + is_pair=is_pair, + boxes=boxes, + word_labels=word_labels, + add_special_tokens=add_special_tokens, + padding_strategy=padding_strategy, + truncation_strategy=truncation_strategy, + max_length=max_length, + stride=stride, + is_split_into_words=is_split_into_words, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + + def _batch_encode_plus_boxes( + self, + batch_text_or_text_pairs: Union[ + List[TextInput], + List[TextInputPair], + List[PreTokenizedInput], + ], + is_pair: Optional[bool] = None, + boxes: Optional[List[List[List[int]]]] = None, + word_labels: Optional[List[List[int]]] = None, + add_special_tokens: bool = True, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[str] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + if not isinstance(batch_text_or_text_pairs, list): + raise TypeError(f"batch_text_or_text_pairs has to be a list (got {type(batch_text_or_text_pairs)})") + + # Set the truncation and padding strategy and restore the initial configuration + self.set_truncation_and_padding( + padding_strategy=padding_strategy, + truncation_strategy=truncation_strategy, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + ) + + if is_pair: + batch_text_or_text_pairs = [(text.split(), text_pair) for text, text_pair in batch_text_or_text_pairs] + + encodings = self._tokenizer.encode_batch( + batch_text_or_text_pairs, + add_special_tokens=add_special_tokens, + is_pretokenized=True, # we set this to True as LayoutLMv2 always expects pretokenized inputs + ) + + # Convert encoding to dict + # `Tokens` has type: Tuple[ + # List[Dict[str, List[List[int]]]] or List[Dict[str, 2D-Tensor]], + # List[EncodingFast] + # ] + # with nested dimensions corresponding to batch, overflows, sequence length + tokens_and_encodings = [ + self._convert_encoding( + encoding=encoding, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=True + if word_labels is not None + else return_offsets_mapping, # we use offsets to create the labels + return_length=return_length, + verbose=verbose, + ) + for encoding in encodings + ] + + # Convert the output to have dict[list] from list[dict] and remove the additional overflows dimension + # From (variable) shape (batch, overflows, sequence length) to ~ (batch * overflows, sequence length) + # (we say ~ because the number of overflow varies with the example in the batch) + # + # To match each overflowing sample with the original sample in the batch + # we add an overflow_to_sample_mapping array (see below) + sanitized_tokens = {} + for key in tokens_and_encodings[0][0].keys(): + stack = [e for item, _ in tokens_and_encodings for e in item[key]] + sanitized_tokens[key] = stack + sanitized_encodings = [e for _, item in tokens_and_encodings for e in item] + + # If returning overflowing tokens, we need to return a mapping + # from the batch idx to the original sample + if return_overflowing_tokens: + overflow_to_sample_mapping = [] + for i, (toks, _) in enumerate(tokens_and_encodings): + overflow_to_sample_mapping += [i] * len(toks["input_ids"]) + sanitized_tokens["overflow_to_sample_mapping"] = overflow_to_sample_mapping + + for input_ids in sanitized_tokens["input_ids"]: + self._eventual_warn_about_too_long_sequence(input_ids, max_length, verbose) + + # create the token boxes + token_boxes = [] + for batch_index in range(len(sanitized_tokens["input_ids"])): + if return_overflowing_tokens: + original_index = sanitized_tokens["overflow_to_sample_mapping"][batch_index] + else: + original_index = batch_index + token_boxes_example = [] + for id, sequence_id, word_id in zip( + sanitized_tokens["input_ids"][batch_index], + sanitized_encodings[batch_index].sequence_ids, + sanitized_encodings[batch_index].word_ids, + ): + if word_id is not None: + if is_pair and sequence_id == 0: + token_boxes_example.append(self.pad_token_box) + else: + token_boxes_example.append(boxes[original_index][word_id]) + else: + if id == self.sep_token_id: + token_boxes_example.append(self.sep_token_box) + elif id == self.pad_token_id: + token_boxes_example.append(self.pad_token_box) + else: + raise ValueError("Id not recognized") + token_boxes.append(token_boxes_example) + + sanitized_tokens["bbox"] = token_boxes + + # optionally, create the labels + if word_labels is not None: + labels = [] + for batch_index in range(len(sanitized_tokens["input_ids"])): + if return_overflowing_tokens: + original_index = sanitized_tokens["overflow_to_sample_mapping"][batch_index] + else: + original_index = batch_index + labels_example = [] + previous_token_empty = False + for id, offset, word_id in zip( + sanitized_tokens["input_ids"][batch_index], + sanitized_tokens["offset_mapping"][batch_index], + sanitized_encodings[batch_index].word_ids, + ): + if word_id is not None: + if self.only_label_first_subword: + if offset[0] == 0 and not previous_token_empty: + # Use the real label id for the first token of the word, and padding ids for the remaining tokens + labels_example.append(word_labels[original_index][word_id]) + else: + labels_example.append(self.pad_token_label) + else: + labels_example.append(word_labels[original_index][word_id]) + if self.decode(id) == "": + previous_token_empty = True + else: + previous_token_empty = False + else: + labels_example.append(self.pad_token_label) + labels.append(labels_example) + + sanitized_tokens["labels"] = labels + # finally, remove offsets if the user didn't want them + if not return_offsets_mapping: + del sanitized_tokens["offset_mapping"] + + return BatchEncoding(sanitized_tokens, sanitized_encodings, tensor_type=return_tensors) + + def _encode_plus_boxes( + self, + text: Union[TextInput, PreTokenizedInput], + text_pair: Optional[PreTokenizedInput] = None, + boxes: Optional[List[List[int]]] = None, + word_labels: Optional[List[int]] = None, + add_special_tokens: bool = True, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[bool] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + # make it a batched input + # 2 options: + # 1) only text, in case text must be a list of str + # 2) text + text_pair, in which case text = str and text_pair a list of str + batched_input = [(text, text_pair)] if text_pair else [text] + batched_boxes = [boxes] + batched_word_labels = [word_labels] if word_labels is not None else None + batched_output = self._batch_encode_plus_boxes( + batched_input, + is_pair=bool(text_pair is not None), + boxes=batched_boxes, + word_labels=batched_word_labels, + add_special_tokens=add_special_tokens, + padding_strategy=padding_strategy, + truncation_strategy=truncation_strategy, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + + # Return tensor is None, then we can remove the leading batch axis + # Overflowing tokens are returned as a batch of output so we keep them in this case + if return_tensors is None and not return_overflowing_tokens: + batched_output = BatchEncoding( + { + key: value[0] if len(value) > 0 and isinstance(value[0], list) else value + for key, value in batched_output.items() + }, + batched_output.encodings, + ) + + self._eventual_warn_about_too_long_sequence(batched_output["input_ids"], max_length, verbose) + + return batched_output + + def encode_boxes( + self, + text: Union[TextInput, PreTokenizedInput, EncodedInput], + text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None, + boxes: Optional[List[List[int]]] = None, + word_labels: Optional[List[List[int]]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length: Optional[int] = None, + stride: int = 0, + return_tensors: Optional[Union[str, TensorType]] = None, + **kwargs, + ) -> List[int]: + """ + Args: + Converts a string to a sequence of ids (integer), using the tokenizer and vocabulary. Same as doing + `self.convert_tokens_to_ids(self.tokenize(text))`. + text (`str`, `List[str]` or `List[int]`): + The first sequence to be encoded. This can be a string, a list of strings (tokenized string using the + `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids` + method). + text_pair (`str`, `List[str]` or `List[int]`, *optional*): + Optional second sequence to be encoded. This can be a string, a list of strings (tokenized string using + the `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids` + method). + """ + encoded_inputs = self.encode_plus_boxes( + text, + text_pair=text_pair, + boxes=boxes, + word_labels=word_labels, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + stride=stride, + return_tensors=return_tensors, + **kwargs, + ) + + return encoded_inputs["input_ids"] + + def encode_plus_boxes( + self, + text: Union[TextInput, PreTokenizedInput], + text_pair: Optional[PreTokenizedInput] = None, + boxes: Optional[List[List[int]]] = None, + word_labels: Optional[List[List[int]]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length: Optional[int] = None, + stride: int = 0, + is_split_into_words: bool = False, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + """ + Tokenize and prepare for the model a sequence or a pair of sequences. + + + + This method is deprecated, `__call__` should be used instead. + + + + Args: + text (`str`, `List[str]` or (for non-fast tokenizers) `List[int]`): + The first sequence to be encoded. This can be a string, a list of strings (tokenized string using the + `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids` + method). + text_pair (`str`, `List[str]` or `List[int]`, *optional*): + Optional second sequence to be encoded. This can be a string, a list of strings (tokenized string using + the `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids` + method). + """ + + # Backward compatibility for 'truncation_strategy', 'pad_to_max_length' + padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies( + padding=padding, + truncation=truncation, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + verbose=verbose, + **kwargs, + ) + + return self._encode_plus_boxes( + text=text, + text_pair=text_pair, + boxes=boxes, + word_labels=word_labels, + add_special_tokens=add_special_tokens, + padding_strategy=padding_strategy, + truncation_strategy=truncation_strategy, + max_length=max_length, + stride=stride, + is_split_into_words=is_split_into_words, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + + # Copied from transformers.models.layoutxlm.tokenization_layoutxlm_fast.LayoutXLMTokenizerFast._pad + def _pad( + self, + encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding], + max_length: Optional[int] = None, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_attention_mask: Optional[bool] = None, + ) -> dict: + """ + Pad encoded inputs (on left/right and up to predefined length or max length in the batch) + + Args: + encoded_inputs: + Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`). + max_length: maximum length of the returned list and optionally padding length (see below). + Will truncate by taking into account the special tokens. + padding_strategy: PaddingStrategy to use for padding. + + - PaddingStrategy.LONGEST Pad to the longest sequence in the batch + - PaddingStrategy.MAX_LENGTH: Pad to the max length (default) + - PaddingStrategy.DO_NOT_PAD: Do not pad + The tokenizer padding sides are defined in self.padding_side: + + - 'left': pads on the left of the sequences + - 'right': pads on the right of the sequences + pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value. + This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability + `>= 7.5` (Volta). + padding_side (`str`, *optional*): + The side on which the model should have padding applied. Should be selected between ['right', 'left']. + Default value is picked from the class attribute of the same name. + return_attention_mask: + (optional) Set to False to avoid returning attention mask (default: set to model specifics) + """ + # Load from model defaults + if return_attention_mask is None: + return_attention_mask = "attention_mask" in self.model_input_names + + required_input = encoded_inputs[self.model_input_names[0]] + + if padding_strategy == PaddingStrategy.LONGEST: + max_length = len(required_input) + + if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): + max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of + + needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length + + # Initialize attention mask if not present. + if return_attention_mask and "attention_mask" not in encoded_inputs: + encoded_inputs["attention_mask"] = [1] * len(required_input) + + if needs_to_be_padded: + difference = max_length - len(required_input) + padding_side = padding_side if padding_side is not None else self.padding_side + if padding_side == "right": + if return_attention_mask: + encoded_inputs["attention_mask"] = encoded_inputs["attention_mask"] + [0] * difference + if "token_type_ids" in encoded_inputs: + encoded_inputs["token_type_ids"] = ( + encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference + ) + if "bbox" in encoded_inputs: + encoded_inputs["bbox"] = encoded_inputs["bbox"] + [self.pad_token_box] * difference + if "labels" in encoded_inputs: + encoded_inputs["labels"] = encoded_inputs["labels"] + [self.pad_token_label] * difference + if "special_tokens_mask" in encoded_inputs: + encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference + encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference + elif padding_side == "left": + if return_attention_mask: + encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"] + if "token_type_ids" in encoded_inputs: + encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[ + "token_type_ids" + ] + if "bbox" in encoded_inputs: + encoded_inputs["bbox"] = [self.pad_token_box] * difference + encoded_inputs["bbox"] + if "labels" in encoded_inputs: + encoded_inputs["labels"] = [self.pad_token_label] * difference + encoded_inputs["labels"] + if "special_tokens_mask" in encoded_inputs: + encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"] + encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input + else: + raise ValueError("Invalid padding strategy:" + str(padding_side)) + + return encoded_inputs + + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. An XLM-RoBERTa sequence has the following format: + + - single sequence: ` X ` + - pair of sequences: ` A B ` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + + if token_ids_1 is None: + return token_ids_0 + [self.sep_token_id] + sep = [self.sep_token_id] + return token_ids_0 + sep + token_ids_1 + sep + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. XLM-RoBERTa does + not make use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of zeros. + + """ + + sep = [self.sep_token_id] + + if token_ids_1 is None: + return len(token_ids_0 + sep) * [0] + return len(token_ids_0 + sep + token_ids_1 + sep) * [0] + + # Copied from transformers.models.layoutxlm.tokenization_layoutxlm_fast.LayoutXLMTokenizerFast.save_vocabulary + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + if not self.can_save_slow_tokenizer: + raise ValueError( + "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow " + "tokenizer." + ) + + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory.") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): + copyfile(self.vocab_file, out_vocab_file) + + return (out_vocab_file,) + + +__all__ = ["UdopTokenizerFast"] diff --git a/docs/transformers/src/transformers/models/umt5/__init__.py b/docs/transformers/src/transformers/models/umt5/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5dc2e041359770eca293bd2cce968e1b19d22bc2 --- /dev/null +++ b/docs/transformers/src/transformers/models/umt5/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_umt5 import * + from .modeling_umt5 import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/umt5/configuration_umt5.py b/docs/transformers/src/transformers/models/umt5/configuration_umt5.py new file mode 100644 index 0000000000000000000000000000000000000000..82e7dfeba5bb5527455775a089a0c498dd84adc3 --- /dev/null +++ b/docs/transformers/src/transformers/models/umt5/configuration_umt5.py @@ -0,0 +1,181 @@ +# coding=utf-8 +# Copyright 2023, The T5 Authors and HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""UMT5 model configuration""" + +from typing import Mapping + +from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxSeq2SeqConfigWithPast +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class UMT5Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`UMT5Model`]. It is used to instantiate a UMT5 + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the UMT5 + [google/umt5-small](https://huggingface.co/google/umt5-small) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Arguments: + vocab_size (`int`, *optional*, defaults to 250112): + Vocabulary size of the UMT5 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`UMT5Model`] or [`TFUMT5Model`]. + d_model (`int`, *optional*, defaults to 512): + Size of the encoder layers and the pooler layer. + d_kv (`int`, *optional*, defaults to 64): + Size of the key, query, value projections per attention head. `d_kv` has to be equal to `d_model // + num_heads`. + d_ff (`int`, *optional*, defaults to 1024): + Size of the intermediate feed forward layer in each `UMT5Block`. + num_layers (`int`, *optional*, defaults to 8): + Number of hidden layers in the Transformer encoder. + num_decoder_layers (`int`, *optional*): + Number of hidden layers in the Transformer decoder. Will use the same value as `num_layers` if not set. + num_heads (`int`, *optional*, defaults to 6): + Number of attention heads for each attention layer in the Transformer encoder. + relative_attention_num_buckets (`int`, *optional*, defaults to 32): + The number of buckets to use for each attention layer. + relative_attention_max_distance (`int`, *optional*, defaults to 128): + The maximum distance of the longer sequences for the bucket separation. + dropout_rate (`float`, *optional*, defaults to 0.1): + The ratio for all dropout layers. + classifier_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for classifier. + layer_norm_eps (`float`, *optional*, defaults to 1e-6): + The epsilon used by the layer normalization layers. + initializer_factor (`float`, *optional*, defaults to 1): + A factor for initializing all weight matrices (should be kept to 1, used internally for initialization + testing). + feed_forward_proj (`string`, *optional*, defaults to `"gated-gelu"`): + Type of feed forward layer to be used. Should be one of `"relu"` or `"gated-gelu"`. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + """ + + model_type = "umt5" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = { + "hidden_size": "d_model", + "num_attention_heads": "num_heads", + "num_hidden_layers": "num_layers", + "head_dim": "d_kv", + } + + def __init__( + self, + vocab_size=250112, + d_model=512, + d_kv=64, + d_ff=1024, + num_layers=8, + num_decoder_layers=None, + num_heads=6, + relative_attention_num_buckets=32, + relative_attention_max_distance=128, + dropout_rate=0.1, + layer_norm_epsilon=1e-6, + initializer_factor=1.0, + feed_forward_proj="gated-gelu", + is_encoder_decoder=True, + use_cache=True, + tokenizer_class="T5Tokenizer", + tie_word_embeddings=True, + pad_token_id=0, + eos_token_id=1, + decoder_start_token_id=0, + classifier_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.d_model = d_model + self.d_kv = d_kv + self.d_ff = d_ff + self.num_layers = num_layers + self.num_decoder_layers = ( + num_decoder_layers if num_decoder_layers is not None else self.num_layers + ) # default = symmetry + self.num_heads = num_heads + self.relative_attention_num_buckets = relative_attention_num_buckets + self.relative_attention_max_distance = relative_attention_max_distance + self.dropout_rate = dropout_rate + self.classifier_dropout = classifier_dropout + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_factor = initializer_factor + self.feed_forward_proj = feed_forward_proj + self.use_cache = use_cache + + act_info = self.feed_forward_proj.split("-") + self.dense_act_fn = act_info[-1] + self.is_gated_act = act_info[0] == "gated" + + if len(act_info) > 1 and act_info[0] != "gated" or len(act_info) > 2: + raise ValueError( + f"`feed_forward_proj`: {feed_forward_proj} is not a valid activation function of the dense layer. " + "Please make sure `feed_forward_proj` is of the format `gated-{ACT_FN}` or `{ACT_FN}`, e.g. " + "'gated-gelu' or 'relu'" + ) + + if feed_forward_proj == "gated-gelu": + self.dense_act_fn = "gelu_new" + + super().__init__( + is_encoder_decoder=is_encoder_decoder, + tokenizer_class=tokenizer_class, + tie_word_embeddings=tie_word_embeddings, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + decoder_start_token_id=decoder_start_token_id, + **kwargs, + ) + + +class UMT5OnnxConfig(OnnxSeq2SeqConfigWithPast): + @property + # Copied from transformers.models.t5.configuration_t5.T5OnnxConfig.inputs + def inputs(self) -> Mapping[str, Mapping[int, str]]: + common_inputs = { + "input_ids": {0: "batch", 1: "encoder_sequence"}, + "attention_mask": {0: "batch", 1: "encoder_sequence"}, + } + if self.use_past: + common_inputs["attention_mask"][1] = "past_encoder_sequence + sequence" + common_inputs["decoder_input_ids"] = {0: "batch"} + common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"} + else: + common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"} + common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"} + + if self.use_past: + self.fill_with_past_key_values_(common_inputs, direction="inputs") + + return common_inputs + + @property + # Copied from transformers.models.t5.configuration_t5.T5OnnxConfig.default_onnx_opset + def default_onnx_opset(self) -> int: + return 13 + + @property + def atol_for_validation(self) -> float: + return 5e-4 + + +__all__ = ["UMT5Config", "UMT5OnnxConfig"] diff --git a/docs/transformers/src/transformers/models/umt5/convert_umt5_checkpoint_to_pytorch.py b/docs/transformers/src/transformers/models/umt5/convert_umt5_checkpoint_to_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..848ca3c5660c5feeee02fa69c30f76aba8271497 --- /dev/null +++ b/docs/transformers/src/transformers/models/umt5/convert_umt5_checkpoint_to_pytorch.py @@ -0,0 +1,274 @@ +# coding=utf-8 +# Copyright 2023 Google LLC and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Convert T5X checkpoint to PyTorch + +Steps: +- Install gsutil according to https://cloud.google.com/storage/docs/gsutil_install +- Get a T5X checkpoint at https://github.com/google-research/t5x/blob/main/docs/models.md#t5-11-checkpoints Example: + `gsutil -m cp -r gs://t5-data/pretrained_models/t5x/t5_1_1_small $HOME/` +- Create or download a corresponding config for the downloaded model. E.g. for T5 v1.1 small, you can use + https://huggingface.co/google/t5-v1_1-small/blob/main/config.json +- Convert: + ``` + python3 convert_t5x_checkpoint_to_pytorch.py --t5x_checkpoint_path=$HOME/t5_1_1_small --config_file=config.json\ + --pytorch_dump_path=$HOME/t5_1_1_small_pt + ``` +""" + +import argparse +import collections + +import numpy as np +import torch +from flax import traverse_util +from t5x import checkpoints + +from transformers import MT5Config, UMT5EncoderModel, UMT5ForConditionalGeneration +from transformers.utils import logging + + +logging.set_verbosity_info() + + +def t5x_relpos_bias_lookup(params, i, prefix): + """Returns the Relative Position Bias parameters of a layer. Does not transpose.""" + return params[f"{prefix}/{prefix}/relpos_bias/rel_embedding"][:, i, :] + + +def t5x_attention_lookup(params, i, prefix, layer_name="attention"): + """Returns the KOQV parameters of (self-)attention. Does not transpose.""" + k_tmp = k_tmp = np.ascontiguousarray(params[f"{prefix}/{prefix}/{layer_name}/key/kernel"][:, i, :, :]) + k = k_tmp.reshape(k_tmp.shape[0], k_tmp.shape[1] * k_tmp.shape[2]) + o_tmp = np.ascontiguousarray(params[f"{prefix}/{prefix}/{layer_name}/out/kernel"][:, i, :, :]) + o = o_tmp.reshape(o_tmp.shape[0] * o_tmp.shape[1], o_tmp.shape[2]) + q_tmp = np.ascontiguousarray(params[f"{prefix}/{prefix}/{layer_name}/query/kernel"][:, i, :, :]) + q = q_tmp.reshape(q_tmp.shape[0], q_tmp.shape[1] * q_tmp.shape[2]) + v_tmp = np.ascontiguousarray(params[f"{prefix}/{prefix}/{layer_name}/value/kernel"][:, i, :, :]) + v = v_tmp.reshape(v_tmp.shape[0], v_tmp.shape[1] * v_tmp.shape[2]) + return k, o, q, v + + +def t5x_mlp_lookup(params, i, prefix, split_mlp_wi=False): + """Returns the MLP parameters of a layer. Does not transpose.""" + if split_mlp_wi: + wi_0 = params[f"{prefix}/{prefix}/mlp/wi_0/kernel"][:, i, :] + wi_1 = params[f"{prefix}/{prefix}/mlp/wi_1/kernel"][:, i, :] + wi = (wi_0, wi_1) + else: + wi = params[f"{prefix}/{prefix}/mlp/wi/kernel"][:, i, :] + + wo = params[f"{prefix}/{prefix}/mlp/wo/kernel"][:, i, :] + return wi, wo + + +def t5x_layer_norm_lookup(params, i, prefix, layer_name): + """Returns the layer norm param of a layer.""" + return params[f"{prefix}/{prefix}/{layer_name}/scale"][:, i] + + +def convert_t5x_to_pytorch( + variables: dict, *, num_layers: int, is_encoder_only: bool, scalable_attention: bool = False +): + """Converts the parameters from T5X-Flax to Transformers-PyTorch.""" + old = traverse_util.flatten_dict(variables["target"]) + old = {"/".join(k): v for k, v in old.items()} + + # v1.1 models have a gated GeLU with wi_0 and wi_1 instead of wi + split_mlp_wi = "encoder/encoder/mlp/wi_0/kernel" in old + print("Split MLP:", split_mlp_wi) + + new = collections.OrderedDict() + + # Shared embeddings. + new["shared.weight"] = old["token_embedder/embedding"] + + # Encoder. + for i in range(num_layers): + # Block i, layer 0 (Self Attention). + layer_norm = t5x_layer_norm_lookup(old, i, "encoder", "pre_attention_layer_norm") + k, o, q, v = t5x_attention_lookup(old, i, "encoder", "attention") + new[f"encoder.block.{i}.layer.0.layer_norm.weight"] = layer_norm + new[f"encoder.block.{i}.layer.0.SelfAttention.k.weight"] = k.T + new[f"encoder.block.{i}.layer.0.SelfAttention.o.weight"] = o.T + new[f"encoder.block.{i}.layer.0.SelfAttention.q.weight"] = q.T + new[f"encoder.block.{i}.layer.0.SelfAttention.v.weight"] = v.T + + # Block i, layer 1 (MLP). + layer_norm = t5x_layer_norm_lookup(old, i, "encoder", "pre_mlp_layer_norm") + wi, wo = t5x_mlp_lookup(old, i, "encoder", split_mlp_wi) + new[f"encoder.block.{i}.layer.1.layer_norm.weight"] = layer_norm + if split_mlp_wi: + new[f"encoder.block.{i}.layer.1.DenseReluDense.wi_0.weight"] = wi[0].T + new[f"encoder.block.{i}.layer.1.DenseReluDense.wi_1.weight"] = wi[1].T + else: + new[f"encoder.block.{i}.layer.1.DenseReluDense.wi.weight"] = wi.T + new[f"encoder.block.{i}.layer.1.DenseReluDense.wo.weight"] = wo.T + if scalable_attention: + # convert the rel_embedding of each layer + new[f"encoder.block.{i}.layer.0.SelfAttention.relative_attention_bias.weight"] = t5x_relpos_bias_lookup( + old, i, "encoder" + ).T + + new["encoder.final_layer_norm.weight"] = old["encoder/encoder_norm/scale"] + + if not scalable_attention: + new["encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"] = t5x_relpos_bias_lookup( + old, 0, "encoder" + ).T + new["decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"] = t5x_relpos_bias_lookup( + old, 0, "decoder" + ).T + + if not is_encoder_only: + # Decoder. + for i in range(num_layers): + # Block i, layer 0 (Self Attention). + layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_self_attention_layer_norm") + k, o, q, v = t5x_attention_lookup(old, i, "decoder", "self_attention") + new[f"decoder.block.{i}.layer.0.layer_norm.weight"] = layer_norm + new[f"decoder.block.{i}.layer.0.SelfAttention.k.weight"] = k.T + new[f"decoder.block.{i}.layer.0.SelfAttention.o.weight"] = o.T + new[f"decoder.block.{i}.layer.0.SelfAttention.q.weight"] = q.T + new[f"decoder.block.{i}.layer.0.SelfAttention.v.weight"] = v.T + + # Block i, layer 1 (Cross Attention). + layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_cross_attention_layer_norm") + k, o, q, v = t5x_attention_lookup(old, i, "decoder", "encoder_decoder_attention") + new[f"decoder.block.{i}.layer.1.layer_norm.weight"] = layer_norm + new[f"decoder.block.{i}.layer.1.EncDecAttention.k.weight"] = k.T + new[f"decoder.block.{i}.layer.1.EncDecAttention.o.weight"] = o.T + new[f"decoder.block.{i}.layer.1.EncDecAttention.q.weight"] = q.T + new[f"decoder.block.{i}.layer.1.EncDecAttention.v.weight"] = v.T + + # Block i, layer 2 (MLP). + layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_mlp_layer_norm") + wi, wo = t5x_mlp_lookup(old, i, "decoder", split_mlp_wi) + new[f"decoder.block.{i}.layer.2.layer_norm.weight"] = layer_norm + if split_mlp_wi: + new[f"decoder.block.{i}.layer.2.DenseReluDense.wi_0.weight"] = wi[0].T + new[f"decoder.block.{i}.layer.2.DenseReluDense.wi_1.weight"] = wi[1].T + else: + new[f"encoder.block.{i}.layer.2.DenseReluDense.wi.weight"] = wi.T + new[f"decoder.block.{i}.layer.2.DenseReluDense.wo.weight"] = wo.T + + if scalable_attention: + # convert the rel_embedding of each layer + new[f"decoder.block.{i}.layer.0.SelfAttention.relative_attention_bias.weight"] = ( + t5x_relpos_bias_lookup(old, i, "decoder").T + ) + + new["decoder.final_layer_norm.weight"] = old["decoder/decoder_norm/scale"] + + # LM Head (only in v1.1 checkpoints, in v1.0 embeddings are used instead) + if "decoder/logits_dense/kernel" in old: + new["lm_head.weight"] = old["decoder/logits_dense/kernel"].T + + return new + + +def make_state_dict(converted_params, is_encoder_only: bool): + """Prepares a state dict for the PyTorch model.""" + # Make a state dict with torch tensors. + state_dict = collections.OrderedDict([(k, torch.from_numpy(v.copy())) for (k, v) in converted_params.items()]) + + # Add what is missing. + if "encoder.embed_tokens.weight" not in state_dict: + state_dict["encoder.embed_tokens.weight"] = state_dict["shared.weight"] + + if not is_encoder_only: + if "decoder.embed_tokens.weight" not in state_dict: + state_dict["decoder.embed_tokens.weight"] = state_dict["shared.weight"] + + if "lm_head.weight" not in state_dict: # For old 1.0 models. + print("Using shared word embeddings as lm_head.") + state_dict["lm_head.weight"] = state_dict["shared.weight"] + + return state_dict + + +def load_t5x_weights_in_t5(model, config, t5x_checkpoint_path, is_encoder_only, scalable_attention): + """Replaces the params in model witht the T5X converted params.""" + variables = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path) + converted = convert_t5x_to_pytorch( + variables, num_layers=config.num_layers, is_encoder_only=is_encoder_only, scalable_attention=scalable_attention + ) + state_dict = make_state_dict(converted, is_encoder_only) + model.load_state_dict(state_dict, strict=True) + + +def convert_t5x_checkpoint_to_pytorch( + t5x_checkpoint_path, + config_file, + pytorch_dump_path, + is_encoder_only: bool = False, + scalable_attention: bool = False, +): + """Loads the config and model, converts the T5X checkpoint, and saves a PyTorch checkpoint.""" + # Initialise PyTorch model + config = MT5Config.from_json_file(config_file) + print(f"Building PyTorch model from configuration: {config}") + # Non-v1.1 checkpoints could also use T5Model, but this works for all. + # The v1.0 checkpoints will simply have an LM head that is the word embeddings. + if is_encoder_only: + model = UMT5EncoderModel(config) + else: + model = UMT5ForConditionalGeneration(config) + + # Load weights from tf checkpoint + load_t5x_weights_in_t5(model, config, t5x_checkpoint_path, is_encoder_only, scalable_attention) + + # Save pytorch-model + print(f"Save PyTorch model to {pytorch_dump_path}") + model.save_pretrained(pytorch_dump_path) + + # Verify that we can load the checkpoint. + model.from_pretrained(pytorch_dump_path) + print("Done") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Converts a native T5X checkpoint into a PyTorch checkpoint.") + # Required parameters + parser.add_argument( + "--t5x_checkpoint_path", default=None, type=str, required=True, help="Path to the T5X checkpoint." + ) + parser.add_argument( + "--config_file", + default=None, + type=str, + required=True, + help="The config json file corresponding to the pre-trained T5 model.\nThis specifies the model architecture.", + ) + parser.add_argument( + "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." + ) + parser.add_argument( + "--is_encoder_only", action="store_true", help="Check if the model is encoder-decoder model", default=False + ) + parser.add_argument( + "--scalable_attention", + action="store_true", + help="Whether the model uses scaled attention (umt5 model)", + default=False, + ) + args = parser.parse_args() + convert_t5x_checkpoint_to_pytorch( + args.t5x_checkpoint_path, + args.config_file, + args.pytorch_dump_path, + args.is_encoder_only, + args.scalable_attention, + ) diff --git a/docs/transformers/src/transformers/models/umt5/modeling_umt5.py b/docs/transformers/src/transformers/models/umt5/modeling_umt5.py new file mode 100644 index 0000000000000000000000000000000000000000..44586243c3fa01588f9f2205c3eaaa5470deb11b --- /dev/null +++ b/docs/transformers/src/transformers/models/umt5/modeling_umt5.py @@ -0,0 +1,2039 @@ +# coding=utf-8 +# Copyright 2023 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch UMT5 model.""" + +import copy +import math +from typing import List, Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, + Seq2SeqQuestionAnsweringModelOutput, + Seq2SeqSequenceClassifierOutput, + TokenClassifierOutput, +) +from ...modeling_utils import PreTrainedModel +from ...utils import ( + DUMMY_INPUTS, + DUMMY_MASK, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_torch_flex_attn_available, + is_torch_fx_proxy, + is_torchdynamo_compiling, + logging, + replace_return_docstrings, +) +from .configuration_umt5 import UMT5Config + + +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "UMT5Config" +_CHECKPOINT_FOR_DOC = "google/umt5-small" + + +# Copied from transformers.models.t5.modeling_t5.T5LayerNorm with T5->UMT5 +class UMT5LayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Construct a layernorm module in the UMT5 style. No bias and no subtraction of mean. + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + # UMT5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean + # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated + # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for + # half-precision inputs is done in fp32 + + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states + + +# Copied from transformers.models.t5.modeling_t5.T5DenseActDense with T5->UMT5 +class UMT5DenseActDense(nn.Module): + def __init__(self, config: UMT5Config): + super().__init__() + self.wi = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) + self.dropout = nn.Dropout(config.dropout_rate) + self.act = ACT2FN[config.dense_act_fn] + + def forward(self, hidden_states): + hidden_states = self.wi(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.dropout(hidden_states) + if ( + isinstance(self.wo.weight, torch.Tensor) + and hidden_states.dtype != self.wo.weight.dtype + and self.wo.weight.dtype != torch.int8 + ): + hidden_states = hidden_states.to(self.wo.weight.dtype) + hidden_states = self.wo(hidden_states) + return hidden_states + + +# Copied from transformers.models.t5.modeling_t5.T5DenseGatedActDense with T5->UMT5 +class UMT5DenseGatedActDense(nn.Module): + def __init__(self, config: UMT5Config): + super().__init__() + self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) + self.dropout = nn.Dropout(config.dropout_rate) + self.act = ACT2FN[config.dense_act_fn] + + def forward(self, hidden_states): + hidden_gelu = self.act(self.wi_0(hidden_states)) + hidden_linear = self.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = self.dropout(hidden_states) + + # To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32. + # See https://github.com/huggingface/transformers/issues/20287 + # we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None`` + if ( + isinstance(self.wo.weight, torch.Tensor) + and hidden_states.dtype != self.wo.weight.dtype + and self.wo.weight.dtype != torch.int8 + ): + hidden_states = hidden_states.to(self.wo.weight.dtype) + + hidden_states = self.wo(hidden_states) + return hidden_states + + +# Copied from transformers.models.t5.modeling_t5.T5LayerFF with T5->UMT5 +class UMT5LayerFF(nn.Module): + def __init__(self, config: UMT5Config): + super().__init__() + if config.is_gated_act: + self.DenseReluDense = UMT5DenseGatedActDense(config) + else: + self.DenseReluDense = UMT5DenseActDense(config) + + self.layer_norm = UMT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward(self, hidden_states): + forwarded_states = self.layer_norm(hidden_states) + forwarded_states = self.DenseReluDense(forwarded_states) + hidden_states = hidden_states + self.dropout(forwarded_states) + return hidden_states + + +class UMT5Attention(nn.Module): + """ + T5's attention using relative_attention_bias. + """ + + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): + super().__init__() + self.is_decoder = config.is_decoder + self.has_relative_attention_bias = has_relative_attention_bias + self.relative_attention_num_buckets = config.relative_attention_num_buckets + self.relative_attention_max_distance = config.relative_attention_max_distance + self.d_model = config.d_model + self.key_value_proj_dim = config.d_kv + self.n_heads = config.num_heads + self.dropout = config.dropout_rate + self.inner_dim = self.n_heads * self.key_value_proj_dim + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + # Mesh TensorFlow initialization to avoid scaling before softmax + self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.k = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.v = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.o = nn.Linear(self.inner_dim, self.d_model, bias=False) + + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads) + self.pruned_heads = set() + + def _shape(self, projection: torch.Tensor) -> torch.Tensor: + new_projection_shape = projection.size()[:-1] + (self.n_heads, self.key_value_proj_dim) + # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D) + new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3) + return new_projection + + def _relative_position_bucket(self, relative_position): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + num_buckets = self.relative_attention_num_buckets + max_distance = self.relative_attention_max_distance + if not self.is_decoder: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + log_ratio = torch.log(relative_position.float() / max_exact) / math.log(max_distance / max_exact) + log_ratio = log_ratio * (num_buckets - max_exact) + relative_position_if_large = max_exact + log_ratio.to(torch.long) + relative_position_if_large = torch.min( + relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) + return relative_buckets + + def compute_bias(self, query_length, key_length, device=None, cache_position=None): + """Compute binned relative position bias""" + if device is None: + device = self.relative_attention_bias.weight.device + if cache_position is None: + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + else: + context_position = cache_position[:, None] + memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] + relative_position = memory_position - context_position # shape (query_length, key_length) + relative_position_bucket = self._relative_position_bucket(relative_position) + values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads) + values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) + return values + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + cache_position: Optional[torch.Tensor] = None, + ): + batch_size, seq_length = hidden_states.shape[:2] + + # if encoder_hidden_states are provided this layer is used as a cross-attention layer for the decoder + is_cross_attention = encoder_hidden_states is not None + + query_states = self.q(hidden_states) + query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + + if past_key_value is not None: + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: + # reuse k,v, cross_attentions + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] + else: + key_states = self.k(current_states) + value_states = self.v(current_states) + key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True + + # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + scores = torch.matmul(query_states, key_states.transpose(3, 2)) + + # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past) + real_seq_length = seq_length + past_key_value.get_seq_length() if past_key_value is not None else seq_length + key_length = key_states.shape[-2] + if not self.has_relative_attention_bias: + position_bias = torch.zeros( + (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype + ) + else: + position_bias = self.compute_bias( + real_seq_length, key_length, device=scores.device, cache_position=cache_position + ) + position_bias = position_bias[:, :, -seq_length:, :] + + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + position_bias = position_bias + causal_mask + + if self.pruned_heads: + mask = torch.ones(position_bias.shape[1]) + mask[list(self.pruned_heads)] = 0 + position_bias_masked = position_bias[:, mask.bool()] + else: + position_bias_masked = position_bias + + scores += position_bias_masked + + # (batch_size, n_heads, seq_length, key_length) + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + # Mask heads if we want to + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask + + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(batch_size, seq_length, -1) + + attn_output = self.o(attn_output) + return attn_output, attn_weights, past_key_value + + +class UMT5LayerSelfAttention(nn.Module): + def __init__(self, config, layer_idx: Optional[int] = None): + super().__init__() + self.SelfAttention = UMT5Attention(config, has_relative_attention_bias=True, layer_idx=layer_idx) + self.layer_norm = UMT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + hidden_states, + attention_mask=None, + layer_head_mask=None, + past_key_value=None, + cache_position=None, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.SelfAttention( + normed_hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + cache_position=cache_position, + ) + hidden_states = hidden_states + self.dropout(attention_output[0]) + outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them + return outputs + + +class UMT5LayerCrossAttention(nn.Module): + def __init__(self, config, layer_idx: Optional[int] = None): + super().__init__() + self.EncDecAttention = UMT5Attention(config, has_relative_attention_bias=False, layer_idx=layer_idx) + self.layer_norm = UMT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + layer_head_mask=None, + past_key_value=None, + cache_position=None, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.EncDecAttention( + normed_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + cache_position=cache_position, + ) + layer_output = hidden_states + self.dropout(attention_output[0]) + outputs = (layer_output,) + attention_output[1:] # add attentions if we output them + return outputs + + +class UMT5Block(nn.Module): + def __init__(self, config, layer_idx: Optional[int] = None): + super().__init__() + self.is_decoder = config.is_decoder + self.layer = nn.ModuleList() + self.layer.append(UMT5LayerSelfAttention(config, layer_idx=layer_idx)) + if self.is_decoder: + self.layer.append(UMT5LayerCrossAttention(config, layer_idx=layer_idx)) + + self.layer.append(UMT5LayerFF(config)) + + def forward( + self, + hidden_states, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + layer_head_mask=None, + cross_attn_layer_head_mask=None, + past_key_value=None, + use_cache=False, + output_attentions=False, + cache_position=None, + ): + hidden_states, self_attn_weights, past_key_value = self.layer[0]( + hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + cache_position=cache_position, + ) + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16: + max_dtype = torch.finfo(hidden_states.dtype).max + clamp_value = torch.where(torch.isinf(hidden_states).any(), max_dtype - 1000, max_dtype) + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + # Cross-Attention Block + cross_attn_weights = None + do_cross_attention = self.is_decoder and encoder_hidden_states is not None + if do_cross_attention: + hidden_states, cross_attn_weights, past_key_value = self.layer[1]( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=past_key_value, + cache_position=cache_position, + ) + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16: + max_dtype = torch.finfo(hidden_states.dtype).max + clamp_value = torch.where(torch.isinf(hidden_states).any(), max_dtype - 1000, max_dtype) + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + # Apply Feed Forward layer + hidden_states = self.layer[-1](hidden_states) + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16: + max_dtype = torch.finfo(hidden_states.dtype).max + clamp_value = torch.where(torch.isinf(hidden_states).any(), max_dtype - 1000, max_dtype) + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = ( + hidden_states, + past_key_value, + ) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + return outputs + + +# Copied from transformers.models.t5.modeling_t5.T5ClassificationHead with T5->UMT5 +class UMT5ClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__(self, config: UMT5Config): + super().__init__() + self.dense = nn.Linear(config.d_model, config.d_model) + self.dropout = nn.Dropout(p=config.classifier_dropout) + self.out_proj = nn.Linear(config.d_model, config.num_labels) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dropout(hidden_states) + hidden_states = self.dense(hidden_states) + hidden_states = torch.tanh(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.out_proj(hidden_states) + return hidden_states + + +class UMT5PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = UMT5Config + base_model_prefix = "transformer" + supports_gradient_checkpointing = True + _supports_cache_class = True + _supports_static_cache = True + _no_split_modules = ["UMT5Block"] + _keep_in_fp32_modules = ["wo"] + + @property + def dummy_inputs(self): + input_ids = torch.tensor(DUMMY_INPUTS) + input_mask = torch.tensor(DUMMY_MASK) + dummy_inputs = { + "decoder_input_ids": input_ids, + "input_ids": input_ids, + "decoder_attention_mask": input_mask, + } + return dummy_inputs + + def _init_weights(self, module): + """Initialize the weights""" + factor = self.config.initializer_factor # Used for testing weights initialization + if isinstance(module, UMT5LayerNorm): + module.weight.data.fill_(factor * 1.0) + elif isinstance( + module, + ( + UMT5Model, + UMT5ForConditionalGeneration, + UMT5EncoderModel, + UMT5ForQuestionAnswering, + ), + ): + # Mesh TensorFlow embeddings initialization + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624 + module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0) + if hasattr(module, "lm_head") and not self.config.tie_word_embeddings: + module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0) + if hasattr(module, "qa_outputs"): + module.qa_outputs.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + module.qa_outputs.bias.data.zero_() + elif isinstance(module, UMT5ForTokenClassification): + if hasattr(module, "classifier"): + module.classifier.weight.data.normal_(mean=0.0, std=factor * 1.0) + module.classifier.bias.data.zero_() + elif isinstance(module, UMT5ClassificationHead): + module.dense.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.dense, "bias") and module.dense.bias is not None: + module.dense.bias.data.zero_() + module.out_proj.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.out_proj, "bias") and module.out_proj.bias is not None: + module.out_proj.bias.data.zero_() + elif isinstance(module, UMT5DenseActDense): + # Mesh TensorFlow FF initialization + # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56 + # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89 + module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.wi, "bias") and module.wi.bias is not None: + module.wi.bias.data.zero_() + module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + if hasattr(module.wo, "bias") and module.wo.bias is not None: + module.wo.bias.data.zero_() + elif isinstance(module, UMT5DenseGatedActDense): + module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None: + module.wi_0.bias.data.zero_() + module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None: + module.wi_1.bias.data.zero_() + module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + if hasattr(module.wo, "bias") and module.wo.bias is not None: + module.wo.bias.data.zero_() + elif isinstance(module, UMT5Attention): + # Mesh TensorFlow attention initialization to avoid scaling before softmax + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136 + d_model = self.config.d_model + key_value_proj_dim = self.config.d_kv + n_heads = self.config.num_heads + module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) + module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) + module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) + module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) + if module.has_relative_attention_bias: + module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) + + def _shift_right(self, input_ids): + decoder_start_token_id = self.config.decoder_start_token_id + pad_token_id = self.config.pad_token_id + + if decoder_start_token_id is None: + raise ValueError( + "self.model.config.decoder_start_token_id has to be defined. In UMT5 it is usually set to the pad_token_id. " + "See UMT5 docs for more information." + ) + + # shift inputs to the right + if is_torch_fx_proxy(input_ids): + # Item assignment is not supported natively for proxies. + shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id) + shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1) + else: + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() + shifted_input_ids[..., 0] = decoder_start_token_id + + if pad_token_id is None: + raise ValueError("self.model.config.pad_token_id has to be defined.") + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +class UMT5Stack(UMT5PreTrainedModel): + def __init__(self, config, embed_tokens=None): + super().__init__(config) + self.embed_tokens = embed_tokens + self.is_decoder = config.is_decoder + self.block = nn.ModuleList([UMT5Block(config, layer_idx=i) for i in range(config.num_layers)]) + self.final_layer_norm = UMT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + # Initialize weights and apply final processing + self.gradient_checkpointing = False + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, new_embeddings): + self.embed_tokens = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + inputs_embeds=None, + head_mask=None, + cross_attn_head_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + cache_position=None, + ): + use_cache = use_cache if use_cache is not None else self.config.use_cache + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError( + f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time" + ) + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if inputs_embeds is None: + if self.embed_tokens is None: + raise ValueError("You have to initialize the model with valid token embeddings") + inputs_embeds = self.embed_tokens(input_ids) + + batch_size, seq_length = input_shape + + if use_cache is True: + if not self.is_decoder: + raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder") + + # initialize past_key_values + return_legacy_cache = False + return_self_attention_cache = False + if self.is_decoder and (use_cache or past_key_values is not None): + if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache): + return_self_attention_cache = True + past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) + elif not isinstance(past_key_values, EncoderDecoderCache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + elif past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + elif not self.is_decoder: + # do not pass cache object down the line for encoder stack + # it messes indexing later in decoder-stack because cache object is modified in-place + past_key_values = None + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device + ) + + if attention_mask is None and not is_torchdynamo_compiling(): + # required mask seq length can be calculated via length of past cache + mask_seq_length = past_key_values_length + seq_length + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + + if self.is_decoder: + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + past_key_values.self_attention_cache if past_key_values is not None else None, + output_attentions, + ) + elif attention_mask is not None: + causal_mask = attention_mask[:, None, None, :] + causal_mask = causal_mask.to(dtype=inputs_embeds.dtype) + causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min + else: + causal_mask = None + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + head_mask = self.get_head_mask(head_mask, self.config.num_layers) + cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.is_decoder else None + + hidden_states = self.dropout(inputs_embeds) + + for i, layer_module in enumerate(self.block): + layer_head_mask = head_mask[i] + cross_attn_layer_head_mask = cross_attn_head_mask[i] + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.forward, + hidden_states, + causal_mask, + encoder_hidden_states, + encoder_extended_attention_mask, + layer_head_mask, + cross_attn_layer_head_mask, + None, # past_key_value is always None with gradient checkpointing + use_cache, + output_attentions, + cache_position, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask=causal_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + layer_head_mask=layer_head_mask, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, + past_key_value=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[1] + + if output_attentions: + all_attentions += (layer_outputs[2],) + if self.is_decoder: + all_cross_attentions += (layer_outputs[3],) + + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_self_attention_cache: + next_cache = past_key_values.self_attention_cache + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_cache, + all_hidden_states, + all_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ) + + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, "BlockMask"], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to place the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +UMT5_START_DOCSTRING = r""" + + The UMT5 model was proposed in [Exploring the Limits of Transfer Learning with a Unified Text-to-Text + Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan + Narang, Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder transformer pre-trained in a + text-to-text denoising generative setting. + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`UMT5Config`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +UMT5_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. UMT5 is a model with relative position embeddings so + you should be able to pad the inputs on both the right and the left. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for detail. + + [What are input IDs?](../glossary#input-ids) + + To know more on how to prepare `input_ids` for pretraining take a look a [UMT5 Training](./umt5#training). + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + UMT5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values` + is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). + + To know more on how to prepare `decoder_input_ids` for pretraining take a look at [UMT5 + Training](./umt5#training). + decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in `[0, + 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0, + 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in + `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): + Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, `optional`: *attentions*) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a sequence of hidden states at + the output of the last layer of the encoder. Used in the cross-attention of the decoder. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded + representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be + input (see `past_key_values`). This is useful if you want more control over how to convert + `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. + + If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value + of `inputs_embeds`. + + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. +""" + +UMT5_ENCODER_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. UMT5 is a model with relative position embeddings so + you should be able to pad the inputs on both the right and the left. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for detail. + + To know more on how to prepare `input_ids` for pretraining take a look a [UMT5 Training](./umt5#training). + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare UMT5 Model transformer outputting raw hidden-states without any specific head on top.", + UMT5_START_DOCSTRING, +) +class UMT5Model(UMT5PreTrainedModel): + r""" + Examples: + + ```python + >>> from transformers import UMT5Model, AutoTokenizer + + >>> model = UMT5Model.from_pretrained("google/umt5-small") + >>> tokenizer = AutoTokenizer.from_pretrained("google/umt5-small") + >>> noisy_text = "UN Offizier sagt, dass weiter werden muss in Syrien." + >>> label = " verhandelt" + >>> inputs = tokenizer(inputs, return_tensors="pt") + >>> labels = tokenizer(label=label, return_tensors="pt") + + >>> outputs = model(input_ids=inputs["input_ids"], decoder_input_ids=labels["input_ids"]) + >>> hidden_states = outputs.last_hidden_state + ```""" + + model_type = "umt5" + config_class = UMT5Config + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + + def __init__(self, config): + super().__init__(config) + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = UMT5Stack(encoder_config, self.shared) + + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.is_encoder_decoder = False + decoder_config.num_layers = config.num_decoder_layers + self.decoder = UMT5Stack(decoder_config, self.shared) + + # Initialize weights and apply final processing + self.post_init() + + # Copied from transformers.models.t5.modeling_t5.T5Model.get_input_embeddings + def get_input_embeddings(self): + return self.shared + + # Copied from transformers.models.t5.modeling_t5.T5Model.set_input_embeddings + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + self.decoder.set_input_embeddings(new_embeddings) + + # Copied from transformers.models.t5.modeling_t5.T5Model._tie_weights + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) + + # Copied from transformers.models.t5.modeling_t5.T5Model.get_encoder + def get_encoder(self): + return self.encoder + + # Copied from transformers.models.t5.modeling_t5.T5Model.get_decoder + def get_decoder(self): + return self.decoder + + # Copied from transformers.models.t5.modeling_t5.T5Model._prune_heads + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(UMT5_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + decoder_inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]: + r""" + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, UMT5Model + + >>> tokenizer = AutoTokenizer.from_pretrained("google/umt5-small") + >>> model = UMT5Model.from_pretrained("google/umt5-small") + + >>> input_ids = tokenizer( + ... "Studies have been shown that owning a dog is good for you", return_tensors="pt" + ... ).input_ids # Batch size 1 + >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1 + + >>> # preprocess: Prepend decoder_input_ids with start token which is pad token for UMT5Model. + >>> # This is not needed for torch's UMT5ForConditionalGeneration as it does this internally using labels arg. + >>> decoder_input_ids = model._shift_right(decoder_input_ids) + + >>> # forward pass + >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) + >>> last_hidden_states = outputs.last_hidden_state + ```""" + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + hidden_states = encoder_outputs[0] + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings("""UMT5 Model with a `language modeling` head on top.""", UMT5_START_DOCSTRING) +class UMT5ForConditionalGeneration(UMT5PreTrainedModel, GenerationMixin): + r""" + Examples: + + ```python + >>> from transformers import UMT5ForConditionalGeneration, AutoTokenizer + + >>> model = UMT5ForConditionalGeneration.from_pretrained("google/umt5-small") + >>> tokenizer = AutoTokenizer.from_pretrained("google/umt5-small") + >>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien." + >>> summary = "Weiter Verhandlung in Syrien." + >>> inputs = tokenizer(article, text_target=summary, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> loss = outputs.loss + ```""" + + model_type = "umt5" + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model_dim = config.d_model + + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = UMT5Stack(encoder_config, self.shared) + + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.is_encoder_decoder = False + decoder_config.num_layers = config.num_decoder_layers + self.decoder = UMT5Stack(decoder_config, self.shared) + + self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.get_input_embeddings + def get_input_embeddings(self): + return self.shared + + # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.set_input_embeddings + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + self.decoder.set_input_embeddings(new_embeddings) + + # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration._tie_weights + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) + + # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.set_output_embeddings + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.get_output_embeddings + def get_output_embeddings(self): + return self.lm_head + + # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.get_encoder + def get_encoder(self): + return self.encoder + + # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.get_decoder + def get_decoder(self): + return self.decoder + + @add_start_docstrings_to_model_forward(UMT5_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ..., + config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for + labels in `[0, ..., config.vocab_size]` + + Returns: + + Examples: + + ```python + >>> from transformers import AutoTokenizer, UMT5ForConditionalGeneration + + >>> tokenizer = AutoTokenizer.from_pretrained("google/umt5-small") + >>> model = UMT5ForConditionalGeneration.from_pretrained("google/umt5-small") + + >>> # training + >>> input_ids = tokenizer("The walks in park", return_tensors="pt").input_ids + >>> labels = tokenizer(" cute dog the ", return_tensors="pt").input_ids + >>> outputs = model(input_ids=input_ids, labels=labels) + >>> loss = outputs.loss + >>> logits = outputs.logits + + >>> # inference + >>> input_ids = tokenizer("Studies have shown that good for you", return_tensors="pt").input_ids + >>> outputs = model.generate(input_ids) + >>> tokenizer.decode(outputs[0], skip_special_tokens=True) + ```""" + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + # Convert encoder inputs in embeddings if needed + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + hidden_states = encoder_outputs[0] + + if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: + # get decoder inputs from shifting lm labels to the right + decoder_input_ids = self._shift_right(labels) + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + sequence_output = decoder_outputs[0] + + if self.config.tie_word_embeddings: + # Rescale output before projecting on vocab + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + sequence_output = sequence_output * (self.model_dim**-0.5) + + lm_logits = self.lm_head(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss(ignore_index=-100) + # move labels to correct device to enable PP + labels = labels.to(lm_logits.device) + loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) + + if not return_dict: + output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs + return ((loss,) + output) if loss is not None else output + + return Seq2SeqLMOutput( + loss=loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.prepare_decoder_input_ids_from_labels + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return self._shift_right(labels) + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + + +@add_start_docstrings( + "The bare UMT5 Model transformer outputting encoder's raw hidden-states without any specific head on top.", + UMT5_START_DOCSTRING, +) +class UMT5EncoderModel(UMT5PreTrainedModel): + r""" + Examples: + + ```python + >>> from transformers import UMT5EncoderModel, AutoTokenizer + + >>> model = UMT5EncoderModel.from_pretrained("google/umt5-small") + >>> tokenizer = AutoTokenizer.from_pretrained("google/umt5-small") + >>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien." + >>> input_ids = tokenizer(article, return_tensors="pt").input_ids + >>> outputs = model(input_ids) + >>> hidden_state = outputs.last_hidden_state + ```""" + + model_type = "umt5" + # config_class = UMT5Config + _tied_weights_keys = ["encoder.embed_tokens.weight"] + + def __init__(self, config): + super().__init__(config) + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = UMT5Stack(encoder_config, self.shared) + + # Initialize weights and apply final processing + self.post_init() + + # Copied from transformers.models.t5.modeling_t5.T5EncoderModel.get_input_embeddings + def get_input_embeddings(self): + return self.shared + + # Copied from transformers.models.t5.modeling_t5.T5EncoderModel.set_input_embeddings + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + + # Copied from transformers.models.t5.modeling_t5.T5EncoderModel._tie_weights + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + + # Copied from transformers.models.t5.modeling_t5.T5EncoderModel.get_encoder + def get_encoder(self): + return self.encoder + + # Copied from transformers.models.t5.modeling_t5.T5EncoderModel._prune_heads + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.block[layer].layer[0].SelfAttention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(UMT5_ENCODER_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC) + # Copied from transformers.models.t5.modeling_t5.T5EncoderModel.forward with T5->UMT5, google-t5/t5-small->google/umt5-small + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]: + r""" + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, UMT5EncoderModel + + >>> tokenizer = AutoTokenizer.from_pretrained("google/umt5-small") + >>> model = UMT5EncoderModel.from_pretrained("google/umt5-small") + >>> input_ids = tokenizer( + ... "Studies have been shown that owning a dog is good for you", return_tensors="pt" + ... ).input_ids # Batch size 1 + >>> outputs = model(input_ids=input_ids) + >>> last_hidden_states = outputs.last_hidden_state + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + return encoder_outputs + + +@add_start_docstrings( + """ + UMT5 model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE + tasks. + """, + UMT5_START_DOCSTRING, +) +class UMT5ForSequenceClassification(UMT5PreTrainedModel): + _keys_to_ignore_on_load_unexpected = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"] + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + + # Copied from transformers.models.t5.modeling_t5.T5ForSequenceClassification.__init__ with T5->UMT5 + def __init__(self, config: UMT5Config): + super().__init__(config) + self.transformer = UMT5Model(config) + self.classification_head = UMT5ClassificationHead(config) + + # Initialize weights and apply final processing + self.post_init() + + self.model_parallel = False + + @add_start_docstrings_to_model_forward(UMT5_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + Returns: + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if labels is not None: + use_cache = False + + if input_ids is None and inputs_embeds is not None: + raise NotImplementedError( + f"Passing input embeddings is currently not supported for {self.__class__.__name__}" + ) + + # Copied from models.bart.modeling_bart.BartModel.forward different to other models, T5 automatically creates + # decoder_input_ids from input_ids if no decoder_input_ids are provided + if decoder_input_ids is None and decoder_inputs_embeds is None: + if input_ids is None: + raise ValueError( + "If no `decoder_input_ids` or `decoder_inputs_embeds` are " + "passed, `input_ids` cannot be `None`. Please pass either " + "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`." + ) + decoder_input_ids = self._shift_right(input_ids) + + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + + eos_mask = input_ids.eq(self.config.eos_token_id).to(sequence_output.device) + + if len(torch.unique_consecutive(eos_mask.sum(1))) > 1: + raise ValueError("All examples must have the same number of tokens.") + batch_size, _, hidden_size = sequence_output.shape + sentence_representation = sequence_output[eos_mask, :].view(batch_size, -1, hidden_size)[:, -1, :] + logits = self.classification_head(sentence_representation) + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.config.num_labels == 1: + self.config.problem_type = "regression" + elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.config.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return Seq2SeqSequenceClassifierOutput( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + +@add_start_docstrings( + """ + UMT5 Encoder Model with a token classification head on top (a linear layer on top of the hidden-states output) + e.g. for Named-Entity-Recognition (NER) tasks. + """, + UMT5_START_DOCSTRING, +) +class UMT5ForTokenClassification(UMT5PreTrainedModel): + _keys_to_ignore_on_load_unexpected = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"] + _tied_weights_keys = ["transformer.encoder.embed_tokens.weight"] + + # Copied from transformers.models.t5.modeling_t5.T5ForTokenClassification.__init__ with T5->UMT5 + def __init__(self, config: UMT5Config): + super().__init__(config) + self.num_labels = config.num_labels + + self.transformer = UMT5EncoderModel(config) + self.dropout = nn.Dropout(config.classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(UMT5_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TokenClassifierOutput, config_class=_CONFIG_FOR_DOC) + # Copied from transformers.models.t5.modeling_t5.T5ForTokenClassification.forward with T5->UMT5 + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + Returns: + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + hidden_states = self.dropout(hidden_states) + logits = self.classifier(hidden_states) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits, outputs[2:-1]) + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + UMT5 Model with a span classification head on top for extractive question-answering tasks like SQuAD (linear layers + on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + UMT5_START_DOCSTRING, +) +class UMT5ForQuestionAnswering(UMT5PreTrainedModel): + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + + def __init__(self, config): + super().__init__(config) + self.model_dim = config.d_model + + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = UMT5Stack(encoder_config, self.shared) + + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.is_encoder_decoder = False + decoder_config.num_layers = config.num_decoder_layers + self.decoder = UMT5Stack(decoder_config, self.shared) + + self.num_labels = config.num_labels + self.qa_outputs = nn.Linear(config.d_model, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + # Copied from transformers.models.t5.modeling_t5.T5ForQuestionAnswering.get_input_embeddings + def get_input_embeddings(self): + return self.shared + + # Copied from transformers.models.t5.modeling_t5.T5ForQuestionAnswering.set_input_embeddings + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + self.decoder.set_input_embeddings(new_embeddings) + + # Copied from transformers.models.t5.modeling_t5.T5ForQuestionAnswering._tie_weights + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) + + # Copied from transformers.models.t5.modeling_t5.T5ForQuestionAnswering.get_encoder + def get_encoder(self): + return self.encoder + + # Copied from transformers.models.t5.modeling_t5.T5ForQuestionAnswering.get_decoder + def get_decoder(self): + return self.decoder + + @add_start_docstrings_to_model_forward(UMT5_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqQuestionAnsweringModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.FloatTensor], Seq2SeqQuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence + are not taken into account for computing the loss. + Returns: + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + use_cache = use_cache if use_cache is not None else self.config.use_cache + if start_positions is not None and end_positions is not None: + use_cache = False + + # Copied from models.bart.modeling_bart.BartModel.forward + # different to other models, T5 automatically creates decoder_input_ids from + # input_ids if no decoder_input_ids are provided + if decoder_input_ids is None and decoder_inputs_embeds is None: + if input_ids is None: + raise ValueError( + "If no `decoder_input_ids` or `decoder_inputs_embeds` are " + "passed, `input_ids` cannot be `None`. Please pass either " + "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`." + ) + decoder_input_ids = self._shift_right(input_ids) + + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + hidden_states = encoder_outputs[0] + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=None, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = decoder_outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1).to(start_logits.device) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1).to(end_logits.device) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + decoder_outputs[1:] + encoder_outputs + return ((total_loss,) + output) if total_loss is not None else output + + return Seq2SeqQuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +__all__ = [ + "UMT5EncoderModel", + "UMT5ForConditionalGeneration", + "UMT5ForQuestionAnswering", + "UMT5ForSequenceClassification", + "UMT5ForTokenClassification", + "UMT5Model", + "UMT5PreTrainedModel", +] diff --git a/docs/transformers/src/transformers/models/unispeech/__init__.py b/docs/transformers/src/transformers/models/unispeech/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a03daae25f4535c96c32e20dc4b13a817518bc06 --- /dev/null +++ b/docs/transformers/src/transformers/models/unispeech/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_unispeech import * + from .modeling_unispeech import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/unispeech/configuration_unispeech.py b/docs/transformers/src/transformers/models/unispeech/configuration_unispeech.py new file mode 100644 index 0000000000000000000000000000000000000000..ccfe52f79c1c5f68fc041fb966ea3c0ed6c5927f --- /dev/null +++ b/docs/transformers/src/transformers/models/unispeech/configuration_unispeech.py @@ -0,0 +1,309 @@ +# coding=utf-8 +# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""UniSpeech model configuration""" + +import functools +import operator + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class UniSpeechConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`UniSpeechModel`]. It is used to instantiate an + UniSpeech model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the UniSpeech + [microsoft/unispeech-large-1500h-cv](https://huggingface.co/microsoft/unispeech-large-1500h-cv) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32): + Vocabulary size of the UniSpeech model. Defines the number of different tokens that can be represented by + the `inputs_ids` passed when calling [`UniSpeechModel`]. Vocabulary size of the model. Defines the + different tokens that can be represented by the *inputs_ids* passed to the forward method of + [`UniSpeechModel`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + hidden_dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + activation_dropout (`float`, *optional*, defaults to 0.1): + The dropout ratio for activations inside the fully connected layer. + attention_dropout (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + feat_proj_dropout (`float`, *optional*, defaults to 0.0): + The dropout probability for output of the feature encoder. + feat_quantizer_dropout (`float`, *optional*, defaults to 0.0): + The dropout probability for the output of the feature encoder that's used by the quantizer. + final_dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for the final projection layer of [`UniSpeechForCTC`]. + layerdrop (`float`, *optional*, defaults to 0.1): + The LayerDrop probability. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) for more + details. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the layer normalization layers. + feat_extract_norm (`str`, *optional*, defaults to `"group"`): + The norm to be applied to 1D convolutional layers in feature encoder. One of `"group"` for group + normalization of only the first 1D convolutional layer or `"layer"` for layer normalization of all 1D + convolutional layers. + feat_extract_activation (`str, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the 1D convolutional layers of the feature + extractor. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` are supported. + conv_dim (`Tuple[int]` or `List[int]`, *optional*, defaults to `(512, 512, 512, 512, 512, 512, 512)`): + A tuple of integers defining the number of input and output channels of each 1D convolutional layer in the + feature encoder. The length of *conv_dim* defines the number of 1D convolutional layers. + conv_stride (`Tuple[int]` or `List[int]`, *optional*, defaults to `(5, 2, 2, 2, 2, 2, 2)`): + A tuple of integers defining the stride of each 1D convolutional layer in the feature encoder. The length + of *conv_stride* defines the number of convolutional layers and has to match the length of *conv_dim*. + conv_kernel (`Tuple[int]` or `List[int]`, *optional*, defaults to `(10, 3, 3, 3, 3, 2, 2)`): + A tuple of integers defining the kernel size of each 1D convolutional layer in the feature encoder. The + length of *conv_kernel* defines the number of convolutional layers and has to match the length of + *conv_dim*. + conv_bias (`bool`, *optional*, defaults to `False`): + Whether the 1D convolutional layers have a bias. + num_conv_pos_embeddings (`int`, *optional*, defaults to 128): + Number of convolutional positional embeddings. Defines the kernel size of 1D convolutional positional + embeddings layer. + num_conv_pos_embedding_groups (`int`, *optional*, defaults to 16): + Number of groups of 1D convolutional positional embeddings layer. + do_stable_layer_norm (`bool`, *optional*, defaults to `False`): + Whether to apply *stable* layer norm architecture of the Transformer encoder. `do_stable_layer_norm is + True` corresponds to applying layer norm before the attention layer, whereas `do_stable_layer_norm is + False` corresponds to applying layer norm after the attention layer. + apply_spec_augment (`bool`, *optional*, defaults to `True`): + Whether to apply *SpecAugment* data augmentation to the outputs of the feature encoder. For reference see + [SpecAugment: A Simple Data Augmentation Method for Automatic Speech + Recognition](https://arxiv.org/abs/1904.08779). + mask_time_prob (`float`, *optional*, defaults to 0.05): + Percentage (between 0 and 1) of all feature vectors along the time axis which will be masked. The masking + procecure generates ''mask_time_prob*len(time_axis)/mask_time_length'' independent masks over the axis. If + reasoning from the propability of each feature vector to be chosen as the start of the vector span to be + masked, *mask_time_prob* should be `prob_vector_start*mask_time_length`. Note that overlap may decrease the + actual percentage of masked vectors. This is only relevant if `apply_spec_augment is True`. + mask_time_length (`int`, *optional*, defaults to 10): + Length of vector span along the time axis. + mask_time_min_masks (`int`, *optional*, defaults to 2): + The minimum number of masks of length `mask_feature_length` generated along the time axis, each time step, + irrespectively of `mask_feature_prob`. Only relevant if ''mask_time_prob*len(time_axis)/mask_time_length < + mask_time_min_masks'' + mask_feature_prob (`float`, *optional*, defaults to 0.0): + Percentage (between 0 and 1) of all feature vectors along the feature axis which will be masked. The + masking procecure generates ''mask_feature_prob*len(feature_axis)/mask_time_length'' independent masks over + the axis. If reasoning from the propability of each feature vector to be chosen as the start of the vector + span to be masked, *mask_feature_prob* should be `prob_vector_start*mask_feature_length`. Note that overlap + may decrease the actual percentage of masked vectors. This is only relevant if `apply_spec_augment is + True`. + mask_feature_length (`int`, *optional*, defaults to 10): + Length of vector span along the feature axis. + mask_feature_min_masks (`int`, *optional*, defaults to 0): + The minimum number of masks of length `mask_feature_length` generated along the feature axis, each time + step, irrespectively of `mask_feature_prob`. Only relevant if + ''mask_feature_prob*len(feature_axis)/mask_feature_length < mask_feature_min_masks'' + num_codevectors_per_group (`int`, *optional*, defaults to 320): + Number of entries in each quantization codebook (group). + num_codevector_groups (`int`, *optional*, defaults to 2): + Number of codevector groups for product codevector quantization. + contrastive_logits_temperature (`float`, *optional*, defaults to 0.1): + The temperature *kappa* in the contrastive loss. + num_negatives (`int`, *optional*, defaults to 100): + Number of negative samples for the contrastive loss. + codevector_dim (`int`, *optional*, defaults to 256): + Dimensionality of the quantized feature vectors. + proj_codevector_dim (`int`, *optional*, defaults to 256): + Dimensionality of the final projection of both the quantized and the transformer features. + diversity_loss_weight (`int`, *optional*, defaults to 0.1): + The weight of the codebook diversity loss component. + ctc_loss_reduction (`str`, *optional*, defaults to `"mean"`): + Specifies the reduction to apply to the output of `torch.nn.CTCLoss`. Only relevant when training an + instance of [`UniSpeechForCTC`]. + ctc_zero_infinity (`bool`, *optional*, defaults to `False`): + Whether to zero infinite losses and the associated gradients of `torch.nn.CTCLoss`. Infinite losses mainly + occur when the inputs are too short to be aligned to the targets. Only relevant when training an instance + of [`UniSpeechForCTC`]. + use_weighted_layer_sum (`bool`, *optional*, defaults to `False`): + Whether to use a weighted average of layer outputs with learned weights. Only relevant when using an + instance of [`UniSpeechForSequenceClassification`]. + classifier_proj_size (`int`, *optional*, defaults to 256): + Dimensionality of the projection before token mean-pooling for classification. + num_ctc_classes (`int`, *optional*, defaults to 80): + Specifies the number of classes (phoneme tokens and blank token) for phoneme-level CTC loss. Only relevant + when using an instance of [`UniSpeechForPreTraining`]. + pad_token_id (`int`, *optional*, defaults to 0): + The id of the padding token. + bos_token_id (`int`, *optional*, defaults to 1): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 2): + The id of the "end-of-sequence" token. + replace_prob (`float`, *optional*, defaults to 0.5): + Propability that transformer feature is replaced by quantized feature for pretraining. + + Example: + + ```python + >>> from transformers import UniSpeechConfig, UniSpeechModel + + >>> # Initializing a UniSpeech facebook/unispeech-base-960h style configuration + >>> configuration = UniSpeechConfig() + + >>> # Initializing a model (with random weights) from the facebook/unispeech-base-960h style configuration + >>> model = UniSpeechModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "unispeech" + + def __init__( + self, + vocab_size=32, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout=0.1, + activation_dropout=0.1, + attention_dropout=0.1, + feat_proj_dropout=0.0, + feat_quantizer_dropout=0.0, + final_dropout=0.1, + layerdrop=0.1, + initializer_range=0.02, + layer_norm_eps=1e-5, + feat_extract_norm="group", + feat_extract_activation="gelu", + conv_dim=(512, 512, 512, 512, 512, 512, 512), + conv_stride=(5, 2, 2, 2, 2, 2, 2), + conv_kernel=(10, 3, 3, 3, 3, 2, 2), + conv_bias=False, + num_conv_pos_embeddings=128, + num_conv_pos_embedding_groups=16, + do_stable_layer_norm=False, + apply_spec_augment=True, + mask_time_prob=0.05, + mask_time_length=10, + mask_time_min_masks=2, + mask_feature_prob=0.0, + mask_feature_length=10, + mask_feature_min_masks=0, + num_codevectors_per_group=320, + num_codevector_groups=2, + contrastive_logits_temperature=0.1, + num_negatives=100, + codevector_dim=256, + proj_codevector_dim=256, + diversity_loss_weight=0.1, + ctc_loss_reduction="mean", + ctc_zero_infinity=False, + use_weighted_layer_sum=False, + classifier_proj_size=256, + num_ctc_classes=80, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + replace_prob=0.5, + **kwargs, + ): + super().__init__(**kwargs, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id) + self.hidden_size = hidden_size + self.feat_extract_norm = feat_extract_norm + self.feat_extract_activation = feat_extract_activation + self.conv_dim = list(conv_dim) + self.conv_stride = list(conv_stride) + self.conv_kernel = list(conv_kernel) + self.conv_bias = conv_bias + self.num_conv_pos_embeddings = num_conv_pos_embeddings + self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups + self.num_feat_extract_layers = len(self.conv_dim) + self.num_hidden_layers = num_hidden_layers + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.num_attention_heads = num_attention_heads + self.hidden_dropout = hidden_dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.feat_proj_dropout = feat_proj_dropout + self.final_dropout = final_dropout + self.layerdrop = layerdrop + self.layer_norm_eps = layer_norm_eps + self.initializer_range = initializer_range + self.num_ctc_classes = num_ctc_classes + self.vocab_size = vocab_size + self.do_stable_layer_norm = do_stable_layer_norm + self.use_weighted_layer_sum = use_weighted_layer_sum + self.classifier_proj_size = classifier_proj_size + + if ( + (len(self.conv_stride) != self.num_feat_extract_layers) + or (len(self.conv_kernel) != self.num_feat_extract_layers) + or (len(self.conv_dim) != self.num_feat_extract_layers) + ): + raise ValueError( + "Configuration for convolutional layers is incorrect. It is required that `len(config.conv_dim)` ==" + " `len(config.conv_stride)` == `len(config.conv_kernel)`, but is `len(config.conv_dim) =" + f" {len(self.conv_dim)}`, `len(config.conv_stride) = {len(self.conv_stride)}`," + f" `len(config.conv_kernel) = {len(self.conv_kernel)}`." + ) + + # fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779 + self.apply_spec_augment = apply_spec_augment + self.mask_time_prob = mask_time_prob + self.mask_time_length = mask_time_length + self.mask_time_min_masks = mask_time_min_masks + self.mask_feature_prob = mask_feature_prob + self.mask_feature_length = mask_feature_length + self.mask_feature_min_masks = mask_feature_min_masks + + # parameters for pretraining with codevector quantized representations + self.num_codevectors_per_group = num_codevectors_per_group + self.num_codevector_groups = num_codevector_groups + self.contrastive_logits_temperature = contrastive_logits_temperature + self.feat_quantizer_dropout = feat_quantizer_dropout + self.num_negatives = num_negatives + self.codevector_dim = codevector_dim + self.proj_codevector_dim = proj_codevector_dim + self.diversity_loss_weight = diversity_loss_weight + + # ctc loss + self.ctc_loss_reduction = ctc_loss_reduction + self.ctc_zero_infinity = ctc_zero_infinity + + # pretraining loss + self.replace_prob = replace_prob + + @property + def inputs_to_logits_ratio(self): + return functools.reduce(operator.mul, self.conv_stride, 1) + + +__all__ = ["UniSpeechConfig"] diff --git a/docs/transformers/src/transformers/models/unispeech/convert_unispeech_original_pytorch_checkpoint_to_pytorch.py b/docs/transformers/src/transformers/models/unispeech/convert_unispeech_original_pytorch_checkpoint_to_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..4eb8dfa7bbd293f41c2302cadbcbefe379da32df --- /dev/null +++ b/docs/transformers/src/transformers/models/unispeech/convert_unispeech_original_pytorch_checkpoint_to_pytorch.py @@ -0,0 +1,273 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert UniSpeech checkpoint.""" + +import argparse +import json +import os + +import fairseq +import torch +from fairseq.data import Dictionary + +from transformers import ( + UniSpeechConfig, + UniSpeechForCTC, + UniSpeechForPreTraining, + Wav2Vec2FeatureExtractor, + Wav2Vec2PhonemeCTCTokenizer, + Wav2Vec2Processor, + logging, +) + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + +MAPPING = { + "post_extract_proj": "feature_projection.projection", + "encoder.pos_conv.0": "encoder.pos_conv_embed.conv", + "self_attn.k_proj": "encoder.layers.*.attention.k_proj", + "self_attn.v_proj": "encoder.layers.*.attention.v_proj", + "self_attn.q_proj": "encoder.layers.*.attention.q_proj", + "self_attn.out_proj": "encoder.layers.*.attention.out_proj", + "self_attn_layer_norm": "encoder.layers.*.layer_norm", + "fc1": "encoder.layers.*.feed_forward.intermediate_dense", + "fc2": "encoder.layers.*.feed_forward.output_dense", + "final_layer_norm": "encoder.layers.*.final_layer_norm", + "encoder.layer_norm": "encoder.layer_norm", + "w2v_model.layer_norm": "feature_projection.layer_norm", + "quantizer.weight_proj": "quantizer.weight_proj", + "quantizer.vars": "quantizer.codevectors", + "project_q": "project_q", + "final_proj": "project_hid", + "w2v_encoder.proj": "ctc_proj", + "mask_emb": "masked_spec_embed", +} +TOP_LEVEL_KEYS = [ + "ctc_proj", + "quantizer.weight_proj", + "quantizer.codevectors", + "project_q", + "project_hid", +] + + +def set_recursively(hf_pointer, key, value, full_name, weight_type, is_finetuned): + for attribute in key.split("."): + if is_finetuned: + if attribute in ["quantizer", "project_q", "project_hid"]: + # those layers are only relevant for pretraining and should be dropped + return + + if attribute == "ctc_proj": + # we should rename `ctc_proj` to `lm_head` for fine-tuned phoneme models + attribute = "lm_head" + + hf_pointer = getattr(hf_pointer, attribute) + + if weight_type is not None: + hf_shape = getattr(hf_pointer, weight_type).shape + else: + hf_shape = hf_pointer.shape + + assert hf_shape == value.shape, ( + f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be" + f" {value.shape} for {full_name}" + ) + + if weight_type == "weight": + hf_pointer.weight.data = value + elif weight_type == "weight_g": + hf_pointer.weight_g.data = value + elif weight_type == "weight_v": + hf_pointer.weight_v.data = value + elif weight_type == "bias": + hf_pointer.bias.data = value + else: + hf_pointer.data = value + + logger.info(f"{key + '.' + weight_type if weight_type is not None else ''} was initialized from {full_name}.") + + +def recursively_load_weights(fairseq_model, hf_model, is_finetuned): + unused_weights = [] + fairseq_dict = fairseq_model.state_dict() + + feature_extractor = hf_model.unispeech.feature_extractor + + for name, value in fairseq_dict.items(): + is_used = False + if "conv_layers" in name: + load_conv_layer( + name, + value, + feature_extractor, + unused_weights, + hf_model.config.feat_extract_norm == "group", + ) + is_used = True + else: + for key, mapped_key in MAPPING.items(): + mapped_key = "unispeech." + mapped_key if mapped_key not in TOP_LEVEL_KEYS else mapped_key + if key in name or key.split("w2v_model.")[-1] == name.split(".")[0]: + is_used = True + if "*" in mapped_key: + layer_index = name.split(key)[0].split(".")[-2] + mapped_key = mapped_key.replace("*", layer_index) + if "weight_g" in name: + weight_type = "weight_g" + elif "weight_v" in name: + weight_type = "weight_v" + elif "bias" in name: + weight_type = "bias" + elif "weight" in name: + # TODO: don't match quantizer.weight_proj + weight_type = "weight" + else: + weight_type = None + set_recursively(hf_model, mapped_key, value, name, weight_type, is_finetuned) + continue + if not is_used: + unused_weights.append(name) + + logger.warning(f"Unused weights: {unused_weights}") + + +def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_group_norm): + name = full_name.split("conv_layers.")[-1] + items = name.split(".") + layer_id = int(items[0]) + type_id = int(items[1]) + + if type_id == 0: + if "bias" in name: + assert value.shape == feature_extractor.conv_layers[layer_id].conv.bias.data.shape, ( + f"{full_name} has size {value.shape}, but" + f" {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found." + ) + feature_extractor.conv_layers[layer_id].conv.bias.data = value + logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.") + elif "weight" in name: + assert value.shape == feature_extractor.conv_layers[layer_id].conv.weight.data.shape, ( + f"{full_name} has size {value.shape}, but" + f" {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found." + ) + feature_extractor.conv_layers[layer_id].conv.weight.data = value + logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.") + elif (type_id == 2 and not use_group_norm) or (type_id == 2 and layer_id == 0 and use_group_norm): + if "bias" in name: + assert value.shape == feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape, ( + f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.bias.data.shape} was" + " found." + ) + feature_extractor.conv_layers[layer_id].layer_norm.bias.data = value + logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.") + elif "weight" in name: + assert value.shape == feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape, ( + f"{full_name} has size {value.shape}, but" + f" {feature_extractor[layer_id].layer_norm.weight.data.shape} was found." + ) + feature_extractor.conv_layers[layer_id].layer_norm.weight.data = value + logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.") + else: + unused_weights.append(full_name) + + +@torch.no_grad() +def convert_unispeech_checkpoint( + checkpoint_path, pytorch_dump_folder_path, config_path=None, dict_path=None, is_finetuned=True +): + """ + Copy/paste/tweak model's weights to transformers design. + """ + if config_path is not None: + config = UniSpeechConfig.from_pretrained(config_path) + else: + config = UniSpeechConfig() + + if is_finetuned: + if dict_path: + target_dict = Dictionary.load_from_json(dict_path) + + # important change bos & pad token id since CTC symbol is and + # not as in fairseq + config.bos_token_id = target_dict.pad_index + config.pad_token_id = target_dict.bos_index + config.eos_token_id = target_dict.eos_index + config.vocab_size = len(target_dict.symbols) + vocab_path = os.path.join(pytorch_dump_folder_path, "vocab.json") + if not os.path.isdir(pytorch_dump_folder_path): + logger.error("--pytorch_dump_folder_path ({}) should be a directory".format(pytorch_dump_folder_path)) + return + os.makedirs(pytorch_dump_folder_path, exist_ok=True) + vocab_dict = target_dict.indices + + # fairseq has the and switched + vocab_dict[""] = 42 + vocab_dict[""] = 43 + with open(vocab_path, "w", encoding="utf-8") as vocab_handle: + json.dump(vocab_dict, vocab_handle) + tokenizer = Wav2Vec2PhonemeCTCTokenizer( + vocab_path, + unk_token=target_dict.unk_word, + pad_token=target_dict.pad_word, + bos_token=target_dict.bos_word, + eos_token=target_dict.eos_word, + word_delimiter_token="|", + do_lower_case=False, + ) + return_attention_mask = True if config.feat_extract_norm == "layer" else False + feature_extractor = Wav2Vec2FeatureExtractor( + feature_size=1, + sampling_rate=16000, + padding_value=0, + do_normalize=True, + return_attention_mask=return_attention_mask, + ) + processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer) + processor.save_pretrained(pytorch_dump_folder_path) + + hf_unispeech = UniSpeechForCTC(config) + else: + hf_unispeech = UniSpeechForPreTraining(config) + + if is_finetuned: + model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task( + [checkpoint_path], arg_overrides={"data": "/".join(dict_path.split("/")[:-1]), "w2v_path": checkpoint_path} + ) + else: + model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([checkpoint_path]) + + model = model[0].eval() + + recursively_load_weights(model, hf_unispeech, is_finetuned) + + hf_unispeech.save_pretrained(pytorch_dump_folder_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.") + parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to fairseq checkpoint") + parser.add_argument("--dict_path", default=None, type=str, help="Path to dict of fine-tuned model") + parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert") + parser.add_argument( + "--not_finetuned", action="store_true", help="Whether the model to convert is a fine-tuned model or not" + ) + args = parser.parse_args() + convert_unispeech_checkpoint( + args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path, args.dict_path, not args.not_finetuned + ) diff --git a/docs/transformers/src/transformers/models/unispeech/modeling_unispeech.py b/docs/transformers/src/transformers/models/unispeech/modeling_unispeech.py new file mode 100644 index 0000000000000000000000000000000000000000..74608797ab68238824760df0953eb09dbf52138f --- /dev/null +++ b/docs/transformers/src/transformers/models/unispeech/modeling_unispeech.py @@ -0,0 +1,1871 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/unispeech/modular_unispeech.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_unispeech.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +import math +import warnings +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +from torch.nn import CrossEntropyLoss + +from ...activations import ACT2FN +from ...integrations.deepspeed import is_deepspeed_zero3_enabled +from ...integrations.fsdp import is_fsdp_managed_module +from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_outputs import ( + BaseModelOutput, + CausalLMOutput, + ModelOutput, + SequenceClassifierOutput, + Wav2Vec2BaseModelOutput, +) +from ...modeling_utils import PreTrainedModel +from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from .configuration_unispeech import UniSpeechConfig + + +if is_flash_attn_available(): + from ...modeling_flash_attention_utils import _flash_attention_forward + + +logger = logging.get_logger(__name__) + +# Base docstring +_CHECKPOINT_FOR_DOC = "patrickvonplaten/unispeech-large-1500h-cv-timit" + +# General docstring +_CONFIG_FOR_DOC = "UniSpeechConfig" + + +@dataclass +class UniSpeechForPreTrainingOutput(ModelOutput): + """ + Output type of [`UniSpeechForPreTrainingOutput`], with potential hidden states and attentions. + + Args: + loss (*optional*, returned when model is in train mode, `torch.FloatTensor` of shape `(1,)`): + Total loss as the sum of the contrastive loss (L_m) and the diversity loss (L_d) as stated in the [official + paper](https://arxiv.org/pdf/2006.11477.pdf) . (classification) loss. + projected_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`): + Hidden-states of the model projected to *config.proj_codevector_dim* that can be used to predict the masked + projected quantized states. + projected_quantized_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`): + Quantized extracted feature vectors projected to *config.proj_codevector_dim* representing the positive + target vectors for contrastive loss. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + projected_states: Optional[torch.FloatTensor] = None + projected_quantized_states: Optional[torch.FloatTensor] = None + codevector_perplexity: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +class UniSpeechSamePadLayer(nn.Module): + def __init__(self, num_conv_pos_embeddings): + super().__init__() + self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0 + + def forward(self, hidden_states): + if self.num_pad_remove > 0: + hidden_states = hidden_states[:, :, : -self.num_pad_remove] + return hidden_states + + +class UniSpeechPositionalConvEmbedding(nn.Module): + def __init__(self, config): + super().__init__() + self.conv = nn.Conv1d( + config.hidden_size, + config.hidden_size, + kernel_size=config.num_conv_pos_embeddings, + padding=config.num_conv_pos_embeddings // 2, + groups=config.num_conv_pos_embedding_groups, + ) + + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + + if is_deepspeed_zero3_enabled(): + import deepspeed + + with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0): + self.conv = weight_norm(self.conv, name="weight", dim=2) + if hasattr(self.conv, "parametrizations"): + weight_g = self.conv.parametrizations.weight.original0 + weight_v = self.conv.parametrizations.weight.original1 + else: + weight_g = self.conv.weight_g + weight_v = self.conv.weight_v + deepspeed.zero.register_external_parameter(self, weight_v) + deepspeed.zero.register_external_parameter(self, weight_g) + else: + self.conv = weight_norm(self.conv, name="weight", dim=2) + + self.padding = UniSpeechSamePadLayer(config.num_conv_pos_embeddings) + self.activation = ACT2FN[config.feat_extract_activation] + + def forward(self, hidden_states): + hidden_states = hidden_states.transpose(1, 2) + + hidden_states = self.conv(hidden_states) + hidden_states = self.padding(hidden_states) + hidden_states = self.activation(hidden_states) + + hidden_states = hidden_states.transpose(1, 2) + return hidden_states + + +class UniSpeechNoLayerNormConvLayer(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 + self.out_conv_dim = config.conv_dim[layer_id] + + self.conv = nn.Conv1d( + self.in_conv_dim, + self.out_conv_dim, + kernel_size=config.conv_kernel[layer_id], + stride=config.conv_stride[layer_id], + bias=config.conv_bias, + ) + self.activation = ACT2FN[config.feat_extract_activation] + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + hidden_states = self.activation(hidden_states) + return hidden_states + + +class UniSpeechLayerNormConvLayer(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 + self.out_conv_dim = config.conv_dim[layer_id] + + self.conv = nn.Conv1d( + self.in_conv_dim, + self.out_conv_dim, + kernel_size=config.conv_kernel[layer_id], + stride=config.conv_stride[layer_id], + bias=config.conv_bias, + ) + self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True) + self.activation = ACT2FN[config.feat_extract_activation] + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + + hidden_states = hidden_states.transpose(-2, -1) + hidden_states = self.layer_norm(hidden_states) + hidden_states = hidden_states.transpose(-2, -1) + + hidden_states = self.activation(hidden_states) + return hidden_states + + +class UniSpeechGroupNormConvLayer(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 + self.out_conv_dim = config.conv_dim[layer_id] + + self.conv = nn.Conv1d( + self.in_conv_dim, + self.out_conv_dim, + kernel_size=config.conv_kernel[layer_id], + stride=config.conv_stride[layer_id], + bias=config.conv_bias, + ) + self.activation = ACT2FN[config.feat_extract_activation] + + self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True) + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + hidden_states = self.layer_norm(hidden_states) + hidden_states = self.activation(hidden_states) + return hidden_states + + +class UniSpeechFeatureEncoder(nn.Module): + """Construct the features from raw audio waveform""" + + def __init__(self, config): + super().__init__() + + if config.feat_extract_norm == "group": + conv_layers = [UniSpeechGroupNormConvLayer(config, layer_id=0)] + [ + UniSpeechNoLayerNormConvLayer(config, layer_id=i + 1) + for i in range(config.num_feat_extract_layers - 1) + ] + elif config.feat_extract_norm == "layer": + conv_layers = [ + UniSpeechLayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers) + ] + else: + raise ValueError( + f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']" + ) + self.conv_layers = nn.ModuleList(conv_layers) + self.gradient_checkpointing = False + self._requires_grad = True + + def _freeze_parameters(self): + for param in self.parameters(): + param.requires_grad = False + self._requires_grad = False + + def forward(self, input_values): + hidden_states = input_values[:, None] + + # make sure hidden_states require grad for gradient_checkpointing + if self._requires_grad and self.training: + hidden_states.requires_grad = True + + for conv_layer in self.conv_layers: + if self._requires_grad and self.gradient_checkpointing and self.training: + hidden_states = self._gradient_checkpointing_func( + conv_layer.__call__, + hidden_states, + ) + else: + hidden_states = conv_layer(hidden_states) + + return hidden_states + + +class UniSpeechFeatureProjection(nn.Module): + def __init__(self, config): + super().__init__() + self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps) + self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size) + self.dropout = nn.Dropout(config.feat_proj_dropout) + + def forward(self, hidden_states): + # non-projected hidden states are needed for quantization + norm_hidden_states = self.layer_norm(hidden_states) + hidden_states = self.projection(norm_hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states, norm_hidden_states + + +class UniSpeechAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + is_causal: bool = False, + config: Optional[UniSpeechConfig] = None, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + self.config = config + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + self.is_causal = is_causal + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.reshape(*proj_shape) + value_states = value_states.reshape(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value + + +class UniSpeechFlashAttention2(UniSpeechAttention): + """ + UniSpeech flash attention module. This module inherits from `UniSpeechAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() + + def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # UniSpeechFlashAttention2 attention does not support output_attentions + if output_attentions: + raise ValueError("UniSpeechFlashAttention2 attention does not support output_attentions") + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, q_len, _ = hidden_states.size() + + # get query proj + query_states = self._reshape(self.q_proj(hidden_states), -1, bsz) + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0].transpose(1, 2) + value_states = past_key_value[1].transpose(1, 2) + elif is_cross_attention: + # cross_attentions + key_states = self._reshape(self.k_proj(key_value_states), -1, bsz) + value_states = self._reshape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) + value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1) + value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1) + else: + # self_attention + key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) + value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2)) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=self.dropout if self.training else 0.0, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class UniSpeechSdpaAttention(UniSpeechAttention): + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + if output_attentions or layer_head_mask is not None: + # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "UniSpeechModel is using UniSpeechSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention" + ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states, + key_value_states=key_value_states, + past_key_value=past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + query_states = self._shape(query_states, tgt_len, bsz) + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. + is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False + + # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, + # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=self.dropout if self.training else 0.0, + is_causal=is_causal, + ) + + if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, None, past_key_value + + +class UniSpeechFeedForward(nn.Module): + def __init__(self, config): + super().__init__() + self.intermediate_dropout = nn.Dropout(config.activation_dropout) + + self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.output_dropout = nn.Dropout(config.hidden_dropout) + + def forward(self, hidden_states): + hidden_states = self.intermediate_dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + hidden_states = self.intermediate_dropout(hidden_states) + + hidden_states = self.output_dense(hidden_states) + hidden_states = self.output_dropout(hidden_states) + return hidden_states + + +UNISPEECH_ATTENTION_CLASSES = { + "eager": UniSpeechAttention, + "sdpa": UniSpeechSdpaAttention, + "flash_attention_2": UniSpeechFlashAttention2, +} + + +class UniSpeechEncoderLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.attention = UNISPEECH_ATTENTION_CLASSES[config._attn_implementation]( + embed_dim=config.hidden_size, + num_heads=config.num_attention_heads, + dropout=config.attention_dropout, + is_decoder=False, + ) + + self.dropout = nn.Dropout(config.hidden_dropout) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.feed_forward = UniSpeechFeedForward(config) + self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states, attention_mask=None, output_attentions=False): + attn_residual = hidden_states + hidden_states, attn_weights, _ = self.attention( + hidden_states, attention_mask=attention_mask, output_attentions=output_attentions + ) + hidden_states = self.dropout(hidden_states) + hidden_states = attn_residual + hidden_states + + hidden_states = self.layer_norm(hidden_states) + hidden_states = hidden_states + self.feed_forward(hidden_states) + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class UniSpeechEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.pos_conv_embed = UniSpeechPositionalConvEmbedding(config) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout) + self.layers = nn.ModuleList([UniSpeechEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + + def forward( + self, + hidden_states: torch.tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + if attention_mask is not None: + # make sure padded tokens output 0 + expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) + hidden_states[~expand_attention_mask] = 0 + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + else: + # extend attention_mask + attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) + attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min + attention_mask = attention_mask.expand( + attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] + ) + + position_embeddings = self.pos_conv_embed(hidden_states) + hidden_states = hidden_states + position_embeddings + hidden_states = self.layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states) + + synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) + + for layer in self.layers: + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = torch.rand([]) + + skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False + if not skip_the_layer or synced_gpus: + # under fsdp or deepspeed zero3 all gpus must run in sync + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + output_attentions, + ) + else: + layer_outputs = layer( + hidden_states, attention_mask=attention_mask, output_attentions=output_attentions + ) + hidden_states = layer_outputs[0] + + if skip_the_layer: + layer_outputs = (None, None) + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class UniSpeechAttnAdapterLayer(nn.Module): + def __init__(self, config): + """ + Implements adapter modules directly with 3D tensor weight as parameters and without using ModuleList to speed + up training throughput. + """ + super().__init__() + self.input_dim = config.adapter_attn_dim + self.hidden_dim = config.hidden_size + + self.norm = nn.LayerNorm(self.hidden_dim) + self.linear_1 = nn.Linear(self.hidden_dim, self.input_dim) + self.act_fn = nn.ReLU() + self.linear_2 = nn.Linear(self.input_dim, self.hidden_dim) + + def forward(self, hidden_states: torch.FloatTensor): + hidden_states = self.norm(hidden_states) + + hidden_states = self.linear_1(hidden_states) + hidden_states = self.act_fn(hidden_states) + hidden_states = self.linear_2(hidden_states) + + return hidden_states + + +class UniSpeechEncoderLayerStableLayerNorm(nn.Module): + def __init__(self, config): + super().__init__() + self.attention = UNISPEECH_ATTENTION_CLASSES[config._attn_implementation]( + embed_dim=config.hidden_size, + num_heads=config.num_attention_heads, + dropout=config.attention_dropout, + is_decoder=False, + ) + self.dropout = nn.Dropout(config.hidden_dropout) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.feed_forward = UniSpeechFeedForward(config) + self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + if getattr(config, "adapter_attn_dim", None) is not None: + self.adapter_layer = UniSpeechAttnAdapterLayer(config) + else: + self.adapter_layer = None + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ): + attn_residual = hidden_states + hidden_states = self.layer_norm(hidden_states) + hidden_states, attn_weights, _ = self.attention( + hidden_states, attention_mask=attention_mask, output_attentions=output_attentions + ) + hidden_states = self.dropout(hidden_states) + hidden_states = attn_residual + hidden_states + hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states)) + + if self.adapter_layer is not None: + hidden_states = hidden_states + self.adapter_layer(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class UniSpeechEncoderStableLayerNorm(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.pos_conv_embed = UniSpeechPositionalConvEmbedding(config) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout) + self.layers = nn.ModuleList( + [UniSpeechEncoderLayerStableLayerNorm(config) for _ in range(config.num_hidden_layers)] + ) + self.gradient_checkpointing = False + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + + def forward( + self, + hidden_states, + attention_mask=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + if attention_mask is not None: + # make sure padded tokens are not attended to + expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) + hidden_states = hidden_states * expand_attention_mask.to(dtype=hidden_states.dtype) + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + else: + # extend attention_mask + attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) + attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min + attention_mask = attention_mask.expand( + attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] + ) + + position_embeddings = self.pos_conv_embed(hidden_states) + hidden_states = hidden_states + position_embeddings + hidden_states = self.dropout(hidden_states) + + synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) + + for layer in self.layers: + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = torch.rand([]) + + skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False + if not skip_the_layer or synced_gpus: + # under fsdp or deepspeed zero3 all gpus must run in sync + # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + output_attentions, + ) + else: + layer_outputs = layer( + hidden_states, attention_mask=attention_mask, output_attentions=output_attentions + ) + hidden_states = layer_outputs[0] + + if skip_the_layer: + layer_outputs = (None, None) + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + hidden_states = self.layer_norm(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class UniSpeechGumbelVectorQuantizer(nn.Module): + """ + Vector quantization using gumbel softmax. See `[CATEGORICAL REPARAMETERIZATION WITH + GUMBEL-SOFTMAX](https://arxiv.org/pdf/1611.01144.pdf) for more information. + """ + + def __init__(self, config): + super().__init__() + self.num_groups = config.num_codevector_groups + self.num_vars = config.num_codevectors_per_group + + if config.codevector_dim % self.num_groups != 0: + raise ValueError( + f"`config.codevector_dim {config.codevector_dim} must be divisible " + f"by `config.num_codevector_groups` {self.num_groups} for concatenation" + ) + + # storage for codebook variables (codewords) + self.codevectors = nn.Parameter( + torch.FloatTensor(1, self.num_groups * self.num_vars, config.codevector_dim // self.num_groups) + ) + self.weight_proj = nn.Linear(config.conv_dim[-1], self.num_groups * self.num_vars) + + # can be decayed for training + self.temperature = 2 + + @staticmethod + def _compute_perplexity(probs): + marginal_probs = probs.mean(dim=0) + perplexity = torch.exp(-torch.sum(marginal_probs * torch.log(marginal_probs + 1e-7), dim=-1)).sum() + return perplexity + + def forward(self, hidden_states): + batch_size, sequence_length, hidden_size = hidden_states.shape + + # project to codevector dim + hidden_states = self.weight_proj(hidden_states) + hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1) + + if self.training: + # sample code vector probs via gumbel in differentiateable way + codevector_probs = nn.functional.gumbel_softmax( + hidden_states.float(), tau=self.temperature, hard=True + ).type_as(hidden_states) + + # compute perplexity + codevector_soft_dist = torch.softmax( + hidden_states.view(batch_size * sequence_length, self.num_groups, -1).float(), dim=-1 + ) + perplexity = self._compute_perplexity(codevector_soft_dist) + else: + # take argmax in non-differentiable way + # comptute hard codevector distribution (one hot) + codevector_idx = hidden_states.argmax(dim=-1) + codevector_probs = hidden_states.new_zeros(*hidden_states.shape).scatter_( + -1, codevector_idx.view(-1, 1), 1.0 + ) + codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1) + + perplexity = self._compute_perplexity(codevector_probs) + + codevector_probs = codevector_probs.view(batch_size * sequence_length, -1) + # use probs to retrieve codevectors + codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors + codevectors = codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1) + codevectors = codevectors.sum(-2).view(batch_size, sequence_length, -1) + + return codevectors, perplexity + + +class UniSpeechPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = UniSpeechConfig + base_model_prefix = "unispeech" + main_input_name = "input_values" + supports_gradient_checkpointing = True + _supports_flash_attn_2 = True + _supports_sdpa = True + + def _init_weights(self, module): + """Initialize the weights""" + # gumbel softmax requires special init + if isinstance(module, UniSpeechGumbelVectorQuantizer): + module.weight_proj.weight.data.normal_(mean=0.0, std=1) + module.weight_proj.bias.data.zero_() + nn.init.uniform_(module.codevectors) + elif isinstance(module, UniSpeechPositionalConvEmbedding): + nn.init.normal_( + module.conv.weight, + mean=0, + std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)), + ) + nn.init.constant_(module.conv.bias, 0) + elif isinstance(module, UniSpeechFeatureProjection): + k = math.sqrt(1 / module.projection.in_features) + nn.init.uniform_(module.projection.weight, a=-k, b=k) + nn.init.uniform_(module.projection.bias, a=-k, b=k) + elif isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, nn.Conv1d): + nn.init.kaiming_normal_(module.weight) + + if module.bias is not None: + k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) + nn.init.uniform_(module.bias, a=-k, b=k) + + def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): + """ + Computes the output length of the convolutional layers + """ + + def _conv_out_length(input_length, kernel_size, stride): + # 1D convolutional layer output length formula taken + # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html + return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1 + + for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): + input_lengths = _conv_out_length(input_lengths, kernel_size, stride) + + return input_lengths + + def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor): + # Effectively attention_mask.sum(-1), but not inplace to be able to run + # on inference mode. + non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1] + output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths).to(torch.long) + batch_size = attention_mask.shape[0] + + attention_mask = torch.zeros( + (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device + ) + # these two operations makes sure that all values before the output lengths idxs are attended to + attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1 + attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() + return attention_mask + + +def _compute_mask_indices( + shape: Tuple[int, int], + mask_prob: float, + mask_length: int, + attention_mask: Optional[torch.LongTensor] = None, + min_masks: int = 0, +) -> np.ndarray: + """ + Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for + ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on + CPU as part of the preprocessing during training. + + Args: + shape: The shape for which to compute masks. This should be of a tuple of size 2 where + the first element is the batch size and the second element is the length of the axis to span. + mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of + independently generated mask spans of length `mask_length` is computed by + `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the + actual percentage will be smaller. + mask_length: size of the mask + min_masks: minimum number of masked spans + attention_mask: A (right-padded) attention mask which independently shortens the feature axis of + each batch dimension. + """ + batch_size, sequence_length = shape + + if mask_length < 1: + raise ValueError("`mask_length` has to be bigger than 0.") + + if mask_length > sequence_length: + raise ValueError( + f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}" + f" and `sequence_length`: {sequence_length}`" + ) + + # epsilon is used for probabilistic rounding + epsilon = np.random.rand(1).item() + + def compute_num_masked_span(input_length): + """Given input length, compute how many spans should be masked""" + num_masked_span = int(mask_prob * input_length / mask_length + epsilon) + num_masked_span = max(num_masked_span, min_masks) + + # make sure num masked span <= sequence_length + if num_masked_span * mask_length > sequence_length: + num_masked_span = sequence_length // mask_length + + # make sure num_masked span is also <= input_length - (mask_length - 1) + if input_length - (mask_length - 1) < num_masked_span: + num_masked_span = max(input_length - (mask_length - 1), 0) + + return num_masked_span + + # compute number of masked spans in batch + input_lengths = ( + attention_mask.detach().sum(-1).tolist() + if attention_mask is not None + else [sequence_length for _ in range(batch_size)] + ) + + # SpecAugment mask to fill + spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool) + spec_aug_mask_idxs = [] + + max_num_masked_span = compute_num_masked_span(sequence_length) + + if max_num_masked_span == 0: + return spec_aug_mask + + for input_length in input_lengths: + # compute num of masked spans for this input + num_masked_span = compute_num_masked_span(input_length) + + # get random indices to mask + spec_aug_mask_idx = np.random.choice( + np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False + ) + + # pick first sampled index that will serve as a dummy index to pad vector + # to ensure same dimension for all batches due to probabilistic rounding + # Picking first sample just pads those vectors twice. + if len(spec_aug_mask_idx) == 0: + # this case can only happen if `input_length` is strictly smaller then + # `sequence_length` in which case the last token has to be a padding + # token which we can use as a dummy mask id + dummy_mask_idx = sequence_length - 1 + else: + dummy_mask_idx = spec_aug_mask_idx[0] + + spec_aug_mask_idx = np.concatenate( + [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx] + ) + spec_aug_mask_idxs.append(spec_aug_mask_idx) + + spec_aug_mask_idxs = np.array(spec_aug_mask_idxs) + + # expand masked indices to masked spans + spec_aug_mask_idxs = np.broadcast_to( + spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length) + ) + spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length) + + # add offset to the starting indexes so that indexes now create a span + offsets = np.arange(mask_length)[None, None, :] + offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape( + batch_size, max_num_masked_span * mask_length + ) + spec_aug_mask_idxs = spec_aug_mask_idxs + offsets + + # ensure that we cannot have indices larger than sequence_length + if spec_aug_mask_idxs.max() > sequence_length - 1: + spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1 + + # scatter indices to mask + np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1) + + return spec_aug_mask + + +_EXPECTED_OUTPUT_SHAPE = [1, 292, 1024] + + +UNISPEECH_START_DOCSTRING = r""" + UniSpeech was proposed in [UniSpeech: Unified Speech Representation Learning with Labeled and Unlabeled + Data](https://arxiv.org/abs/2101.07597) by Chengyi Wang, Yu Wu, Yao Qian, Kenichi Kumatani, Shujie Liu, Furu Wei, + Michael Zeng, Xuedong Huang. + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving etc.). + + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use + it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`UniSpeechConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +UNISPEECH_INPUTS_DOCSTRING = r""" + Args: + input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file + into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install + soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and + conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2Processor.__call__`] for details. + attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0, + 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + + + `attention_mask` should only be passed if the corresponding processor has `config.return_attention_mask == + True`. For all models whose processor has `config.return_attention_mask == False`, `attention_mask` should + **not** be passed to avoid degraded performance when doing batched inference. For such models + `input_values` should simply be padded with 0 and passed without `attention_mask`. Be aware that these + models also yield slightly different results depending on whether `input_values` is padded or not. + + + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +UniSpeechBaseModelOutput = Wav2Vec2BaseModelOutput + + +@add_start_docstrings( + "The bare UniSpeech Model transformer outputting raw hidden-states without any specific head on top.", + UNISPEECH_START_DOCSTRING, +) +class UniSpeechModel(UniSpeechPreTrainedModel): + def __init__(self, config: UniSpeechConfig): + super().__init__(config) + self.config = config + self.feature_extractor = UniSpeechFeatureEncoder(config) + self.feature_projection = UniSpeechFeatureProjection(config) + + if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0: + self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_()) + + if config.do_stable_layer_norm: + self.encoder = UniSpeechEncoderStableLayerNorm(config) + else: + self.encoder = UniSpeechEncoder(config) + + # Initialize weights and apply final processing + self.post_init() + + def _mask_hidden_states( + self, + hidden_states: torch.FloatTensor, + mask_time_indices: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + ): + """ + Masks extracted features along time axis and/or along feature axis according to + [SpecAugment](https://arxiv.org/abs/1904.08779). + """ + + # `config.apply_spec_augment` can set masking to False + if not getattr(self.config, "apply_spec_augment", True): + return hidden_states + + # generate indices & apply SpecAugment along time axis + batch_size, sequence_length, hidden_size = hidden_states.size() + + if mask_time_indices is not None: + # apply SpecAugment along time axis with given mask_time_indices + hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) + elif self.config.mask_time_prob > 0 and self.training: + mask_time_indices = _compute_mask_indices( + (batch_size, sequence_length), + mask_prob=self.config.mask_time_prob, + mask_length=self.config.mask_time_length, + attention_mask=attention_mask, + min_masks=self.config.mask_time_min_masks, + ) + mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool) + hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) + + if self.config.mask_feature_prob > 0 and self.training: + # generate indices & apply SpecAugment along feature axis + mask_feature_indices = _compute_mask_indices( + (batch_size, hidden_size), + mask_prob=self.config.mask_feature_prob, + mask_length=self.config.mask_feature_length, + min_masks=self.config.mask_feature_min_masks, + ) + mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool) + mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1) + hidden_states[mask_feature_indices] = 0 + + return hidden_states + + @add_start_docstrings_to_model_forward(UNISPEECH_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=UniSpeechBaseModelOutput, + config_class=_CONFIG_FOR_DOC, + modality="audio", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward( + self, + input_values: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + mask_time_indices: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, UniSpeechBaseModelOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + extract_features = self.feature_extractor(input_values) + extract_features = extract_features.transpose(1, 2) + + if attention_mask is not None: + # compute reduced attention_mask corresponding to feature vectors + attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask) + + hidden_states, extract_features = self.feature_projection(extract_features) + hidden_states = self._mask_hidden_states( + hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask + ) + + encoder_outputs = self.encoder( + hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = encoder_outputs[0] + + if not return_dict: + return (hidden_states, extract_features) + encoder_outputs[1:] + + return UniSpeechBaseModelOutput( + last_hidden_state=hidden_states, + extract_features=extract_features, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + """UniSpeech Model with a vector-quantization module and ctc loss for pre-training.""", UNISPEECH_START_DOCSTRING +) +class UniSpeechForPreTraining(UniSpeechPreTrainedModel): + def __init__(self, config: UniSpeechConfig): + super().__init__(config) + self.unispeech = UniSpeechModel(config) + self.dropout_features = nn.Dropout(config.feat_quantizer_dropout) + + self.quantizer = UniSpeechGumbelVectorQuantizer(config) + self.project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim) + self.project_hid = nn.Linear(config.proj_codevector_dim, config.hidden_size) + + self.ctc_proj = nn.Linear(config.hidden_size, config.num_ctc_classes) + self.dropout = nn.Dropout(config.final_dropout) + + # Initialize weights and apply final processing + self.post_init() + + def set_gumbel_temperature(self, temperature: int): + """ + Set the Gumbel softmax temperature to a given value. Only necessary for training + """ + self.quantizer.temperature = temperature + + def freeze_feature_extractor(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameters will + not be updated during training. + """ + warnings.warn( + "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. " + "Please use the equivalent `freeze_feature_encoder` method instead.", + FutureWarning, + ) + self.freeze_feature_encoder() + + def freeze_feature_encoder(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameter will + not be updated during training. + """ + self.unispeech.feature_extractor._freeze_parameters() + + @staticmethod + def compute_contrastive_logits( + target_features: torch.FloatTensor, + negative_features: torch.FloatTensor, + predicted_features: torch.FloatTensor, + temperature: int = 1, + ): + """ + Compute logits for contrastive loss based using cosine similarity as the distance measure between + `[positive_feature, negative_features]` and `[predicted_features]`. Additionally, temperature can be applied. + """ + target_features = torch.cat([target_features, negative_features], dim=0) + + logits = torch.cosine_similarity(predicted_features.float(), target_features.float(), dim=-1) + logits = logits.type_as(target_features) + + # apply temperature + logits = logits / temperature + return logits + + @add_start_docstrings_to_model_forward(UNISPEECH_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=UniSpeechForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_values: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, UniSpeechForPreTrainingOutput]: + r""" + mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict + masked extracted features in *config.proj_codevector_dim* space. + sampled_negative_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_negatives)`, *optional*): + Indices indicating which quantized target vectors are used as negative sampled vectors in contrastive loss. + Required input for pre-training. + + Returns: + + Example: + + ```python + >>> import torch + >>> from transformers import AutoFeatureExtractor, UniSpeechForPreTraining + + >>> feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/unispeech-large-1500h-cv") + >>> model = UniSpeechForPreTraining.from_pretrained("microsoft/unispeech-large-1500h-cv") + >>> # TODO: Add full pretraining example + ```""" + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.unispeech( + input_values, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + transformer_features = outputs[0] + + # quantize all (unmasked) extracted features and project to final vq dim + extract_features = self.dropout_features(outputs[1]) + quantized_features, codevector_perplexity = self.quantizer(extract_features) + + # project quantized features twice + quantized_features = self.project_q(quantized_features.to(self.project_q.weight.dtype)) + quantized_features = self.project_hid(quantized_features) + + prob_replace_matrix = torch.empty(transformer_features.size(0), transformer_features.size(1)).fill_( + self.config.replace_prob + ) + prob_replace_matrix = prob_replace_matrix.transpose(0, 1) + sampled_replace_matrix = torch.bernoulli(prob_replace_matrix).bool().to(transformer_features.device) + sampled_replace_matrix = sampled_replace_matrix.transpose(0, 1) + sampled_replace_matrix = sampled_replace_matrix.unsqueeze(-1) + logits = transformer_features.masked_fill(sampled_replace_matrix, 0.0) + ( + quantized_features.masked_fill(~sampled_replace_matrix, 0.0) + ) + + # project to ctc units + logits = self.dropout(logits) + logits = self.ctc_proj(logits) + + # TODO(PVP) - add negative sampling & loss computation + loss = None + if not return_dict: + if loss is not None: + return (loss, transformer_features, quantized_features, codevector_perplexity) + outputs[2:] + return (transformer_features, quantized_features, codevector_perplexity) + outputs[2:] + + return UniSpeechForPreTrainingOutput( + loss=loss, + projected_states=transformer_features, + projected_quantized_states=quantized_features, + codevector_perplexity=codevector_perplexity, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +_HIDDEN_STATES_START_POSITION = 2 + +# CTC docstring +_CTC_EXPECTED_OUTPUT = "'mister quilter is the apposl of the midle classes and weare glad to welcom his gosepl'" +_CTC_EXPECTED_LOSS = 17.17 + + +@add_start_docstrings( + """UniSpeech Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""", + UNISPEECH_START_DOCSTRING, + """ + target_lang (`str`, *optional*): + Language id of adapter weights. Adapter weights are stored in the format adapter..safetensors or + adapter..bin. Only relevant when using an instance of [`UniSpeechForCTC`] with adapters. Uses 'eng' + by default. + """, +) +class UniSpeechForCTC(UniSpeechPreTrainedModel): + def __init__(self, config, target_lang: Optional[str] = None): + super().__init__(config) + + self.unispeech = UniSpeechModel(config) + self.dropout = nn.Dropout(config.final_dropout) + + self.target_lang = target_lang + + if config.vocab_size is None: + raise ValueError( + f"You are trying to instantiate {self.__class__} with a configuration that " + "does not define the vocabulary size of the language model head. Please " + "instantiate the model as follows: `UniSpeechForCTC.from_pretrained(..., vocab_size=vocab_size)`. " + "or define `vocab_size` of your model's configuration." + ) + output_hidden_size = ( + config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size + ) + self.lm_head = nn.Linear(output_hidden_size, config.vocab_size) + + # Initialize weights and apply final processing + self.post_init() + + def tie_weights(self): + """ + This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when + passing `target_lang=...` to `from_pretrained(...)`. + + This method is **not** supposed to be called by the user and is prone to be changed in the future. + """ + + # Note that `tie_weights` is usually used to tie input and output embedding weights. The method is re-purposed to + # correctly load adapter layers for UniSpeech so that we do not have to introduce a new API to + # [`PreTrainedModel`]. While slightly hacky, UniSpeech never has to tie input and output embeddings, so that it is + # ok to repurpose this function here. + target_lang = self.target_lang + + if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None: + raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.") + elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None: + logger.info("By default `target_lang` is set to 'eng'.") + elif target_lang is not None: + self.load_adapter(target_lang, force_load=True) + + def freeze_feature_extractor(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameter will + not be updated during training. + """ + warnings.warn( + "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. " + "Please use the equivalent `freeze_feature_encoder` method instead.", + FutureWarning, + ) + self.freeze_feature_encoder() + + def freeze_feature_encoder(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameter will + not be updated during training. + """ + self.unispeech.feature_extractor._freeze_parameters() + + def freeze_base_model(self): + """ + Calling this function will disable the gradient computation for the base model so that its parameters will not + be updated during training. Only the classification head will be updated. + """ + for param in self.unispeech.parameters(): + param.requires_grad = False + + @add_start_docstrings_to_model_forward(UNISPEECH_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=CausalLMOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_CTC_EXPECTED_OUTPUT, + expected_loss=_CTC_EXPECTED_LOSS, + ) + def forward( + self, + input_values: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: Optional[torch.Tensor] = None, + ) -> Union[Tuple, CausalLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*): + Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to + the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. + All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., + config.vocab_size - 1]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None and labels.max() >= self.config.vocab_size: + raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}") + + outputs = self.unispeech( + input_values, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + hidden_states = self.dropout(hidden_states) + + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # retrieve loss input_lengths from attention_mask + attention_mask = ( + attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long) + ) + input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long) + + # assuming that padded tokens are filled with -100 + # when not being attended to + labels_mask = labels >= 0 + target_lengths = labels_mask.sum(-1) + flattened_targets = labels.masked_select(labels_mask) + + # ctc_loss doesn't support fp16 + log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1) + + with torch.backends.cudnn.flags(enabled=False): + loss = nn.functional.ctc_loss( + log_probs, + flattened_targets, + input_lengths, + target_lengths, + blank=self.config.pad_token_id, + reduction=self.config.ctc_loss_reduction, + zero_infinity=self.config.ctc_zero_infinity, + ) + + if not return_dict: + output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutput( + loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions + ) + + +@add_start_docstrings( + """ + UniSpeech Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like + SUPERB Keyword Spotting. + """, + UNISPEECH_START_DOCSTRING, +) +class UniSpeechForSequenceClassification(UniSpeechPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + if hasattr(config, "add_adapter") and config.add_adapter: + raise ValueError( + "Sequence classification does not support the use of UniSpeech adapters (config.add_adapter=True)" + ) + self.unispeech = UniSpeechModel(config) + num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings + if config.use_weighted_layer_sum: + self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers) + self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size) + self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def freeze_feature_extractor(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameters will + not be updated during training. + """ + warnings.warn( + "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. " + "Please use the equivalent `freeze_feature_encoder` method instead.", + FutureWarning, + ) + self.freeze_feature_encoder() + + def freeze_feature_encoder(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameter will + not be updated during training. + """ + self.unispeech.feature_extractor._freeze_parameters() + + def freeze_base_model(self): + """ + Calling this function will disable the gradient computation for the base model so that its parameters will not + be updated during training. Only the classification head will be updated. + """ + for param in self.unispeech.parameters(): + param.requires_grad = False + + @add_start_docstrings_to_model_forward(UNISPEECH_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=SequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + modality="audio", + ) + def forward( + self, + input_values: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: Optional[torch.Tensor] = None, + ) -> Union[Tuple, SequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states + + outputs = self.unispeech( + input_values, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if self.config.use_weighted_layer_sum: + hidden_states = outputs[_HIDDEN_STATES_START_POSITION] + hidden_states = torch.stack(hidden_states, dim=1) + norm_weights = nn.functional.softmax(self.layer_weights, dim=-1) + hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) + else: + hidden_states = outputs[0] + + hidden_states = self.projector(hidden_states) + if attention_mask is None: + pooled_output = hidden_states.mean(dim=1) + else: + padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask) + expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) + hidden_states[~expand_padding_mask] = 0.0 + pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1) + + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "UniSpeechForCTC", + "UniSpeechForPreTraining", + "UniSpeechForSequenceClassification", + "UniSpeechModel", + "UniSpeechPreTrainedModel", +] diff --git a/docs/transformers/src/transformers/models/unispeech/modular_unispeech.py b/docs/transformers/src/transformers/models/unispeech/modular_unispeech.py new file mode 100644 index 0000000000000000000000000000000000000000..1096bc559b44ddd603784f9f03e1cbef174ded27 --- /dev/null +++ b/docs/transformers/src/transformers/models/unispeech/modular_unispeech.py @@ -0,0 +1,563 @@ +import math +import warnings +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ...modeling_outputs import CausalLMOutput, ModelOutput, SequenceClassifierOutput, Wav2Vec2BaseModelOutput +from ...modeling_utils import PreTrainedModel +from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from ..wav2vec2.modeling_wav2vec2 import ( + Wav2Vec2Encoder, + Wav2Vec2EncoderStableLayerNorm, + Wav2Vec2FeatureEncoder, + Wav2Vec2FeatureProjection, + Wav2Vec2ForCTC, + Wav2Vec2ForSequenceClassification, + Wav2Vec2GumbelVectorQuantizer, + Wav2Vec2Model, + Wav2Vec2PositionalConvEmbedding, +) +from .configuration_unispeech import UniSpeechConfig + + +logger = logging.get_logger(__name__) + + +_HIDDEN_STATES_START_POSITION = 2 + +# General docstring +_CONFIG_FOR_DOC = "UniSpeechConfig" + +# Base docstring +_CHECKPOINT_FOR_DOC = "patrickvonplaten/unispeech-large-1500h-cv-timit" +_EXPECTED_OUTPUT_SHAPE = [1, 292, 1024] + +# CTC docstring +_CTC_EXPECTED_OUTPUT = "'mister quilter is the apposl of the midle classes and weare glad to welcom his gosepl'" +_CTC_EXPECTED_LOSS = 17.17 + + +@dataclass +class UniSpeechForPreTrainingOutput(ModelOutput): + """ + Output type of [`UniSpeechForPreTrainingOutput`], with potential hidden states and attentions. + + Args: + loss (*optional*, returned when model is in train mode, `torch.FloatTensor` of shape `(1,)`): + Total loss as the sum of the contrastive loss (L_m) and the diversity loss (L_d) as stated in the [official + paper](https://arxiv.org/pdf/2006.11477.pdf) . (classification) loss. + projected_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`): + Hidden-states of the model projected to *config.proj_codevector_dim* that can be used to predict the masked + projected quantized states. + projected_quantized_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`): + Quantized extracted feature vectors projected to *config.proj_codevector_dim* representing the positive + target vectors for contrastive loss. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + projected_states: Optional[torch.FloatTensor] = None + projected_quantized_states: Optional[torch.FloatTensor] = None + codevector_perplexity: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +class UniSpeechPositionalConvEmbedding(Wav2Vec2PositionalConvEmbedding): + pass + + +class UniSpeechFeatureEncoder(Wav2Vec2FeatureEncoder): + pass + + +class UniSpeechFeatureProjection(Wav2Vec2FeatureProjection): + pass + + +class UniSpeechEncoder(Wav2Vec2Encoder): + pass + + +class UniSpeechEncoderStableLayerNorm(Wav2Vec2EncoderStableLayerNorm): + pass + + +class UniSpeechGumbelVectorQuantizer(Wav2Vec2GumbelVectorQuantizer): + @staticmethod + def _compute_perplexity(probs): + marginal_probs = probs.mean(dim=0) + perplexity = torch.exp(-torch.sum(marginal_probs * torch.log(marginal_probs + 1e-7), dim=-1)).sum() + return perplexity + + def forward(self, hidden_states): + batch_size, sequence_length, hidden_size = hidden_states.shape + + # project to codevector dim + hidden_states = self.weight_proj(hidden_states) + hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1) + + if self.training: + # sample code vector probs via gumbel in differentiateable way + codevector_probs = nn.functional.gumbel_softmax( + hidden_states.float(), tau=self.temperature, hard=True + ).type_as(hidden_states) + + # compute perplexity + codevector_soft_dist = torch.softmax( + hidden_states.view(batch_size * sequence_length, self.num_groups, -1).float(), dim=-1 + ) + perplexity = self._compute_perplexity(codevector_soft_dist) + else: + # take argmax in non-differentiable way + # comptute hard codevector distribution (one hot) + codevector_idx = hidden_states.argmax(dim=-1) + codevector_probs = hidden_states.new_zeros(*hidden_states.shape).scatter_( + -1, codevector_idx.view(-1, 1), 1.0 + ) + codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1) + + perplexity = self._compute_perplexity(codevector_probs) + + codevector_probs = codevector_probs.view(batch_size * sequence_length, -1) + # use probs to retrieve codevectors + codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors + codevectors = codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1) + codevectors = codevectors.sum(-2).view(batch_size, sequence_length, -1) + + return codevectors, perplexity + + +class UniSpeechPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = UniSpeechConfig + base_model_prefix = "unispeech" + main_input_name = "input_values" + supports_gradient_checkpointing = True + _supports_flash_attn_2 = True + _supports_sdpa = True + + def _init_weights(self, module): + """Initialize the weights""" + # gumbel softmax requires special init + if isinstance(module, UniSpeechGumbelVectorQuantizer): + module.weight_proj.weight.data.normal_(mean=0.0, std=1) + module.weight_proj.bias.data.zero_() + nn.init.uniform_(module.codevectors) + elif isinstance(module, UniSpeechPositionalConvEmbedding): + nn.init.normal_( + module.conv.weight, + mean=0, + std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)), + ) + nn.init.constant_(module.conv.bias, 0) + elif isinstance(module, UniSpeechFeatureProjection): + k = math.sqrt(1 / module.projection.in_features) + nn.init.uniform_(module.projection.weight, a=-k, b=k) + nn.init.uniform_(module.projection.bias, a=-k, b=k) + elif isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, nn.Conv1d): + nn.init.kaiming_normal_(module.weight) + + if module.bias is not None: + k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) + nn.init.uniform_(module.bias, a=-k, b=k) + + def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): + """ + Computes the output length of the convolutional layers + """ + + def _conv_out_length(input_length, kernel_size, stride): + # 1D convolutional layer output length formula taken + # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html + return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1 + + for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): + input_lengths = _conv_out_length(input_lengths, kernel_size, stride) + + return input_lengths + + def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor): + # Effectively attention_mask.sum(-1), but not inplace to be able to run + # on inference mode. + non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1] + output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths).to(torch.long) + batch_size = attention_mask.shape[0] + + attention_mask = torch.zeros( + (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device + ) + # these two operations makes sure that all values before the output lengths idxs are attended to + attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1 + attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() + return attention_mask + + +UNISPEECH_START_DOCSTRING = r""" + UniSpeech was proposed in [UniSpeech: Unified Speech Representation Learning with Labeled and Unlabeled + Data](https://arxiv.org/abs/2101.07597) by Chengyi Wang, Yu Wu, Yao Qian, Kenichi Kumatani, Shujie Liu, Furu Wei, + Michael Zeng, Xuedong Huang. + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving etc.). + + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use + it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`UniSpeechConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +UNISPEECH_INPUTS_DOCSTRING = r""" + Args: + input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file + into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install + soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and + conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2Processor.__call__`] for details. + attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0, + 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + + + `attention_mask` should only be passed if the corresponding processor has `config.return_attention_mask == + True`. For all models whose processor has `config.return_attention_mask == False`, `attention_mask` should + **not** be passed to avoid degraded performance when doing batched inference. For such models + `input_values` should simply be padded with 0 and passed without `attention_mask`. Be aware that these + models also yield slightly different results depending on whether `input_values` is padded or not. + + + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +UniSpeechBaseModelOutput = Wav2Vec2BaseModelOutput + + +@add_start_docstrings( + "The bare UniSpeech Model transformer outputting raw hidden-states without any specific head on top.", + UNISPEECH_START_DOCSTRING, +) +class UniSpeechModel(UniSpeechPreTrainedModel, Wav2Vec2Model): + def __init__(self, config: UniSpeechConfig): + UniSpeechPreTrainedModel.__init__(config) + self.config = config + self.feature_extractor = UniSpeechFeatureEncoder(config) + self.feature_projection = UniSpeechFeatureProjection(config) + + if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0: + self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_()) + + if config.do_stable_layer_norm: + self.encoder = UniSpeechEncoderStableLayerNorm(config) + else: + self.encoder = UniSpeechEncoder(config) + + # Initialize weights and apply final processing + self.post_init() + + def freeze_feature_extractor(self): + raise AttributeError("Not needed for UniSpeech") + + def freeze_feature_encoder(self): + raise AttributeError("Not needed for UniSpeech") + + @add_start_docstrings_to_model_forward(UNISPEECH_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=UniSpeechBaseModelOutput, + config_class=_CONFIG_FOR_DOC, + modality="audio", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward( + self, + input_values: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + mask_time_indices: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, UniSpeechBaseModelOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + extract_features = self.feature_extractor(input_values) + extract_features = extract_features.transpose(1, 2) + + if attention_mask is not None: + # compute reduced attention_mask corresponding to feature vectors + attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask) + + hidden_states, extract_features = self.feature_projection(extract_features) + hidden_states = self._mask_hidden_states( + hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask + ) + + encoder_outputs = self.encoder( + hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = encoder_outputs[0] + + if not return_dict: + return (hidden_states, extract_features) + encoder_outputs[1:] + + return UniSpeechBaseModelOutput( + last_hidden_state=hidden_states, + extract_features=extract_features, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + """UniSpeech Model with a vector-quantization module and ctc loss for pre-training.""", UNISPEECH_START_DOCSTRING +) +class UniSpeechForPreTraining(UniSpeechPreTrainedModel): + def __init__(self, config: UniSpeechConfig): + super().__init__(config) + self.unispeech = UniSpeechModel(config) + self.dropout_features = nn.Dropout(config.feat_quantizer_dropout) + + self.quantizer = UniSpeechGumbelVectorQuantizer(config) + self.project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim) + self.project_hid = nn.Linear(config.proj_codevector_dim, config.hidden_size) + + self.ctc_proj = nn.Linear(config.hidden_size, config.num_ctc_classes) + self.dropout = nn.Dropout(config.final_dropout) + + # Initialize weights and apply final processing + self.post_init() + + def set_gumbel_temperature(self, temperature: int): + """ + Set the Gumbel softmax temperature to a given value. Only necessary for training + """ + self.quantizer.temperature = temperature + + def freeze_feature_extractor(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameters will + not be updated during training. + """ + warnings.warn( + "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. " + "Please use the equivalent `freeze_feature_encoder` method instead.", + FutureWarning, + ) + self.freeze_feature_encoder() + + def freeze_feature_encoder(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameter will + not be updated during training. + """ + self.unispeech.feature_extractor._freeze_parameters() + + @staticmethod + def compute_contrastive_logits( + target_features: torch.FloatTensor, + negative_features: torch.FloatTensor, + predicted_features: torch.FloatTensor, + temperature: int = 1, + ): + """ + Compute logits for contrastive loss based using cosine similarity as the distance measure between + `[positive_feature, negative_features]` and `[predicted_features]`. Additionally, temperature can be applied. + """ + target_features = torch.cat([target_features, negative_features], dim=0) + + logits = torch.cosine_similarity(predicted_features.float(), target_features.float(), dim=-1) + logits = logits.type_as(target_features) + + # apply temperature + logits = logits / temperature + return logits + + @add_start_docstrings_to_model_forward(UNISPEECH_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=UniSpeechForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_values: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, UniSpeechForPreTrainingOutput]: + r""" + mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict + masked extracted features in *config.proj_codevector_dim* space. + sampled_negative_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_negatives)`, *optional*): + Indices indicating which quantized target vectors are used as negative sampled vectors in contrastive loss. + Required input for pre-training. + + Returns: + + Example: + + ```python + >>> import torch + >>> from transformers import AutoFeatureExtractor, UniSpeechForPreTraining + + >>> feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/unispeech-large-1500h-cv") + >>> model = UniSpeechForPreTraining.from_pretrained("microsoft/unispeech-large-1500h-cv") + >>> # TODO: Add full pretraining example + ```""" + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.unispeech( + input_values, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + transformer_features = outputs[0] + + # quantize all (unmasked) extracted features and project to final vq dim + extract_features = self.dropout_features(outputs[1]) + quantized_features, codevector_perplexity = self.quantizer(extract_features) + + # project quantized features twice + quantized_features = self.project_q(quantized_features.to(self.project_q.weight.dtype)) + quantized_features = self.project_hid(quantized_features) + + prob_replace_matrix = torch.empty(transformer_features.size(0), transformer_features.size(1)).fill_( + self.config.replace_prob + ) + prob_replace_matrix = prob_replace_matrix.transpose(0, 1) + sampled_replace_matrix = torch.bernoulli(prob_replace_matrix).bool().to(transformer_features.device) + sampled_replace_matrix = sampled_replace_matrix.transpose(0, 1) + sampled_replace_matrix = sampled_replace_matrix.unsqueeze(-1) + logits = transformer_features.masked_fill(sampled_replace_matrix, 0.0) + ( + quantized_features.masked_fill(~sampled_replace_matrix, 0.0) + ) + + # project to ctc units + logits = self.dropout(logits) + logits = self.ctc_proj(logits) + + # TODO(PVP) - add negative sampling & loss computation + loss = None + if not return_dict: + if loss is not None: + return (loss, transformer_features, quantized_features, codevector_perplexity) + outputs[2:] + return (transformer_features, quantized_features, codevector_perplexity) + outputs[2:] + + return UniSpeechForPreTrainingOutput( + loss=loss, + projected_states=transformer_features, + projected_quantized_states=quantized_features, + codevector_perplexity=codevector_perplexity, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """UniSpeech Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""", + UNISPEECH_START_DOCSTRING, + """ + target_lang (`str`, *optional*): + Language id of adapter weights. Adapter weights are stored in the format adapter..safetensors or + adapter..bin. Only relevant when using an instance of [`UniSpeechForCTC`] with adapters. Uses 'eng' + by default. + """, +) +class UniSpeechForCTC(Wav2Vec2ForCTC): + @add_start_docstrings_to_model_forward(UNISPEECH_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=CausalLMOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_CTC_EXPECTED_OUTPUT, + expected_loss=_CTC_EXPECTED_LOSS, + ) + def forward(self, **super_kwargs): + super().forward(**super_kwargs) + + +@add_start_docstrings( + """ + UniSpeech Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like + SUPERB Keyword Spotting. + """, + UNISPEECH_START_DOCSTRING, +) +class UniSpeechForSequenceClassification(Wav2Vec2ForSequenceClassification): + @add_start_docstrings_to_model_forward(UNISPEECH_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=SequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + modality="audio", + ) + def forward(self, **super_kwargs): + super().forward(**super_kwargs) + + +__all__ = [ + "UniSpeechForCTC", + "UniSpeechForPreTraining", + "UniSpeechForSequenceClassification", + "UniSpeechModel", + "UniSpeechPreTrainedModel", +] diff --git a/docs/transformers/src/transformers/models/unispeech_sat/__init__.py b/docs/transformers/src/transformers/models/unispeech_sat/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cc3144ed17d14cd7db2b93b05d21d0a0e3ffe3d0 --- /dev/null +++ b/docs/transformers/src/transformers/models/unispeech_sat/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_unispeech_sat import * + from .modeling_unispeech_sat import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/unispeech_sat/configuration_unispeech_sat.py b/docs/transformers/src/transformers/models/unispeech_sat/configuration_unispeech_sat.py new file mode 100644 index 0000000000000000000000000000000000000000..a33403306e8aa72376bed4edc915001490182047 --- /dev/null +++ b/docs/transformers/src/transformers/models/unispeech_sat/configuration_unispeech_sat.py @@ -0,0 +1,327 @@ +# coding=utf-8 +# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""UniSpeechSat model configuration""" + +import functools +import operator + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class UniSpeechSatConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`UniSpeechSatModel`]. It is used to instantiate an + UniSpeechSat model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the UniSpeechSat + [microsoft/unispeech-sat-base-100h-libri-ft](https://huggingface.co/microsoft/unispeech-sat-base-100h-libri-ft) + architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32): + Vocabulary size of the UniSpeechSat model. Defines the number of different tokens that can be represented + by the `inputs_ids` passed when calling [`UniSpeechSatModel`]. Vocabulary size of the model. Defines the + different tokens that can be represented by the *inputs_ids* passed to the forward method of + [`UniSpeechSatModel`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + hidden_dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + activation_dropout (`float`, *optional*, defaults to 0.1): + The dropout ratio for activations inside the fully connected layer. + attention_dropout (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + feat_proj_dropout (`float`, *optional*, defaults to 0.0): + The dropout probability for output of the feature encoder. + feat_quantizer_dropout (`float`, *optional*, defaults to 0.0): + The dropout probability for the output of the feature encoder that's used by the quantizer. + final_dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for the final projection layer of [`UniSpeechSatForCTC`]. + layerdrop (`float`, *optional*, defaults to 0.1): + The LayerDrop probability. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) for more + details. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the layer normalization layers. + feat_extract_norm (`str`, *optional*, defaults to `"group"`): + The norm to be applied to 1D convolutional layers in feature encoder. One of `"group"` for group + normalization of only the first 1D convolutional layer or `"layer"` for layer normalization of all 1D + convolutional layers. + feat_extract_activation (`str, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the 1D convolutional layers of the feature + extractor. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` are supported. + conv_dim (`Tuple[int]` or `List[int]`, *optional*, defaults to `(512, 512, 512, 512, 512, 512, 512)`): + A tuple of integers defining the number of input and output channels of each 1D convolutional layer in the + feature encoder. The length of *conv_dim* defines the number of 1D convolutional layers. + conv_stride (`Tuple[int]` or `List[int]`, *optional*, defaults to `(5, 2, 2, 2, 2, 2, 2)`): + A tuple of integers defining the stride of each 1D convolutional layer in the feature encoder. The length + of *conv_stride* defines the number of convolutional layers and has to match the length of *conv_dim*. + conv_kernel (`Tuple[int]` or `List[int]`, *optional*, defaults to `(10, 3, 3, 3, 3, 2, 2)`): + A tuple of integers defining the kernel size of each 1D convolutional layer in the feature encoder. The + length of *conv_kernel* defines the number of convolutional layers and has to match the length of + *conv_dim*. + conv_bias (`bool`, *optional*, defaults to `False`): + Whether the 1D convolutional layers have a bias. + num_conv_pos_embeddings (`int`, *optional*, defaults to 128): + Number of convolutional positional embeddings. Defines the kernel size of 1D convolutional positional + embeddings layer. + num_conv_pos_embedding_groups (`int`, *optional*, defaults to 16): + Number of groups of 1D convolutional positional embeddings layer. + do_stable_layer_norm (`bool`, *optional*, defaults to `False`): + Whether to apply *stable* layer norm architecture of the Transformer encoder. `do_stable_layer_norm is + True` corresponds to applying layer norm before the attention layer, whereas `do_stable_layer_norm is + False` corresponds to applying layer norm after the attention layer. + apply_spec_augment (`bool`, *optional*, defaults to `True`): + Whether to apply *SpecAugment* data augmentation to the outputs of the feature encoder. For reference see + [SpecAugment: A Simple Data Augmentation Method for Automatic Speech + Recognition](https://arxiv.org/abs/1904.08779). + mask_time_prob (`float`, *optional*, defaults to 0.05): + Percentage (between 0 and 1) of all feature vectors along the time axis which will be masked. The masking + procecure generates ''mask_time_prob*len(time_axis)/mask_time_length'' independent masks over the axis. If + reasoning from the propability of each feature vector to be chosen as the start of the vector span to be + masked, *mask_time_prob* should be `prob_vector_start*mask_time_length`. Note that overlap may decrease the + actual percentage of masked vectors. This is only relevant if `apply_spec_augment is True`. + mask_time_length (`int`, *optional*, defaults to 10): + Length of vector span along the time axis. + mask_time_min_masks (`int`, *optional*, defaults to 2): + The minimum number of masks of length `mask_feature_length` generated along the time axis, each time step, + irrespectively of `mask_feature_prob`. Only relevant if ''mask_time_prob*len(time_axis)/mask_time_length < + mask_time_min_masks'' + mask_feature_prob (`float`, *optional*, defaults to 0.0): + Percentage (between 0 and 1) of all feature vectors along the feature axis which will be masked. The + masking procecure generates ''mask_feature_prob*len(feature_axis)/mask_time_length'' independent masks over + the axis. If reasoning from the propability of each feature vector to be chosen as the start of the vector + span to be masked, *mask_feature_prob* should be `prob_vector_start*mask_feature_length`. Note that overlap + may decrease the actual percentage of masked vectors. This is only relevant if `apply_spec_augment is + True`. + mask_feature_length (`int`, *optional*, defaults to 10): + Length of vector span along the feature axis. + mask_feature_min_masks (`int`, *optional*, defaults to 0): + The minimum number of masks of length `mask_feature_length` generated along the feature axis, each time + step, irrespectively of `mask_feature_prob`. Only relevant if + ''mask_feature_prob*len(feature_axis)/mask_feature_length < mask_feature_min_masks'' + num_codevectors_per_group (`int`, *optional*, defaults to 320): + Number of entries in each quantization codebook (group). + num_codevector_groups (`int`, *optional*, defaults to 2): + Number of codevector groups for product codevector quantization. + contrastive_logits_temperature (`float`, *optional*, defaults to 0.1): + The temperature *kappa* in the contrastive loss. + num_negatives (`int`, *optional*, defaults to 100): + Number of negative samples for the contrastive loss. + codevector_dim (`int`, *optional*, defaults to 256): + Dimensionality of the quantized feature vectors. + proj_codevector_dim (`int`, *optional*, defaults to 256): + Dimensionality of the final projection of both the quantized and the transformer features. + diversity_loss_weight (`int`, *optional*, defaults to 0.1): + The weight of the codebook diversity loss component. + ctc_loss_reduction (`str`, *optional*, defaults to `"mean"`): + Specifies the reduction to apply to the output of `torch.nn.CTCLoss`. Only relevant when training an + instance of [`UniSpeechSatForCTC`]. + ctc_zero_infinity (`bool`, *optional*, defaults to `False`): + Whether to zero infinite losses and the associated gradients of `torch.nn.CTCLoss`. Infinite losses mainly + occur when the inputs are too short to be aligned to the targets. Only relevant when training an instance + of [`UniSpeechSatForCTC`]. + use_weighted_layer_sum (`bool`, *optional*, defaults to `False`): + Whether to use a weighted average of layer outputs with learned weights. Only relevant when using an + instance of [`UniSpeechSatForSequenceClassification`]. + classifier_proj_size (`int`, *optional*, defaults to 256): + Dimensionality of the projection before token mean-pooling for classification. + tdnn_dim (`Tuple[int]` or `List[int]`, *optional*, defaults to `(512, 512, 512, 512, 1500)`): + A tuple of integers defining the number of output channels of each 1D convolutional layer in the *TDNN* + module of the *XVector* model. The length of *tdnn_dim* defines the number of *TDNN* layers. + tdnn_kernel (`Tuple[int]` or `List[int]`, *optional*, defaults to `(5, 3, 3, 1, 1)`): + A tuple of integers defining the kernel size of each 1D convolutional layer in the *TDNN* module of the + *XVector* model. The length of *tdnn_kernel* has to match the length of *tdnn_dim*. + tdnn_dilation (`Tuple[int]` or `List[int]`, *optional*, defaults to `(1, 2, 3, 1, 1)`): + A tuple of integers defining the dilation factor of each 1D convolutional layer in *TDNN* module of the + *XVector* model. The length of *tdnn_dilation* has to match the length of *tdnn_dim*. + xvector_output_dim (`int`, *optional*, defaults to 512): + Dimensionality of the *XVector* embedding vectors. + pad_token_id (`int`, *optional*, defaults to 0): + The id of the padding token. + bos_token_id (`int`, *optional*, defaults to 1): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 2): + The id of the "end-of-sequence" token. + num_clusters (`int`, *optional*, defaults to 504): + Number of clusters for weak labeling. Only relevant when using an instance of + [`UniSpeechSatForPreTraining`]. + + Example: + + ```python + >>> from transformers import UniSpeechSatModel, UniSpeechSatConfig + + >>> # Initializing a UniSpeechSat microsoft/unispeech-sat-base-100h-libri-ft style configuration + >>> configuration = UniSpeechSatConfig() + + >>> # Initializing a model from the microsoft/unispeech-sat-base-100h-libri-ft style configuration + >>> model = UniSpeechSatModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "unispeech-sat" + + def __init__( + self, + vocab_size=32, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout=0.1, + activation_dropout=0.1, + attention_dropout=0.1, + feat_proj_dropout=0.0, + feat_quantizer_dropout=0.0, + final_dropout=0.1, + layerdrop=0.1, + initializer_range=0.02, + layer_norm_eps=1e-5, + feat_extract_norm="group", + feat_extract_activation="gelu", + conv_dim=(512, 512, 512, 512, 512, 512, 512), + conv_stride=(5, 2, 2, 2, 2, 2, 2), + conv_kernel=(10, 3, 3, 3, 3, 2, 2), + conv_bias=False, + num_conv_pos_embeddings=128, + num_conv_pos_embedding_groups=16, + do_stable_layer_norm=False, + apply_spec_augment=True, + mask_time_prob=0.05, + mask_time_length=10, + mask_time_min_masks=2, + mask_feature_prob=0.0, + mask_feature_length=10, + mask_feature_min_masks=0, + num_codevectors_per_group=320, + num_codevector_groups=2, + contrastive_logits_temperature=0.1, + num_negatives=100, + codevector_dim=256, + proj_codevector_dim=256, + diversity_loss_weight=0.1, + ctc_loss_reduction="mean", + ctc_zero_infinity=False, + use_weighted_layer_sum=False, + classifier_proj_size=256, + tdnn_dim=(512, 512, 512, 512, 1500), + tdnn_kernel=(5, 3, 3, 1, 1), + tdnn_dilation=(1, 2, 3, 1, 1), + xvector_output_dim=512, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + num_clusters=504, + **kwargs, + ): + super().__init__(**kwargs, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id) + self.hidden_size = hidden_size + self.feat_extract_norm = feat_extract_norm + self.feat_extract_activation = feat_extract_activation + self.conv_dim = list(conv_dim) + self.conv_stride = list(conv_stride) + self.conv_kernel = list(conv_kernel) + self.conv_bias = conv_bias + self.num_conv_pos_embeddings = num_conv_pos_embeddings + self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups + self.num_feat_extract_layers = len(self.conv_dim) + self.num_hidden_layers = num_hidden_layers + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.num_attention_heads = num_attention_heads + self.hidden_dropout = hidden_dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.feat_proj_dropout = feat_proj_dropout + self.final_dropout = final_dropout + self.layerdrop = layerdrop + self.layer_norm_eps = layer_norm_eps + self.initializer_range = initializer_range + self.vocab_size = vocab_size + self.num_clusters = num_clusters + self.do_stable_layer_norm = do_stable_layer_norm + self.use_weighted_layer_sum = use_weighted_layer_sum + + if ( + (len(self.conv_stride) != self.num_feat_extract_layers) + or (len(self.conv_kernel) != self.num_feat_extract_layers) + or (len(self.conv_dim) != self.num_feat_extract_layers) + ): + raise ValueError( + "Configuration for convolutional layers is incorrect. It is required that `len(config.conv_dim)` ==" + " `len(config.conv_stride)` == `len(config.conv_kernel)`, but is `len(config.conv_dim) =" + f" {len(self.conv_dim)}`, `len(config.conv_stride) = {len(self.conv_stride)}`," + f" `len(config.conv_kernel) = {len(self.conv_kernel)}`." + ) + + # fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779 + self.apply_spec_augment = apply_spec_augment + self.mask_time_prob = mask_time_prob + self.mask_time_length = mask_time_length + self.mask_time_min_masks = mask_time_min_masks + self.mask_feature_prob = mask_feature_prob + self.mask_feature_length = mask_feature_length + self.mask_feature_min_masks = mask_feature_min_masks + + # parameters for pretraining with codevector quantized representations + self.num_codevectors_per_group = num_codevectors_per_group + self.num_codevector_groups = num_codevector_groups + self.contrastive_logits_temperature = contrastive_logits_temperature + self.feat_quantizer_dropout = feat_quantizer_dropout + self.num_negatives = num_negatives + self.codevector_dim = codevector_dim + self.proj_codevector_dim = proj_codevector_dim + self.diversity_loss_weight = diversity_loss_weight + + # ctc loss + self.ctc_loss_reduction = ctc_loss_reduction + self.ctc_zero_infinity = ctc_zero_infinity + + # SequenceClassification-specific parameter. Feel free to ignore for other classes. + self.classifier_proj_size = classifier_proj_size + + # XVector-specific parameters. Feel free to ignore for other classes. + self.tdnn_dim = list(tdnn_dim) + self.tdnn_kernel = list(tdnn_kernel) + self.tdnn_dilation = list(tdnn_dilation) + self.xvector_output_dim = xvector_output_dim + + @property + def inputs_to_logits_ratio(self): + return functools.reduce(operator.mul, self.conv_stride, 1) + + +__all__ = ["UniSpeechSatConfig"] diff --git a/docs/transformers/src/transformers/models/unispeech_sat/convert_unispeech_original_s3prl_checkpoint_to_pytorch.py b/docs/transformers/src/transformers/models/unispeech_sat/convert_unispeech_original_s3prl_checkpoint_to_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..0f1256e0ca3ee6f82839ba3c0c3f94d1b0c70256 --- /dev/null +++ b/docs/transformers/src/transformers/models/unispeech_sat/convert_unispeech_original_s3prl_checkpoint_to_pytorch.py @@ -0,0 +1,109 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert Hubert checkpoint.""" + +import argparse + +import torch + +from transformers import ( + UniSpeechSatConfig, + UniSpeechSatForAudioFrameClassification, + UniSpeechSatForSequenceClassification, + UniSpeechSatForXVector, + Wav2Vec2FeatureExtractor, + logging, +) + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + + +def convert_classification(base_model_name, hf_config, downstream_dict): + model = UniSpeechSatForSequenceClassification.from_pretrained(base_model_name, config=hf_config) + model.projector.weight.data = downstream_dict["projector.weight"] + model.projector.bias.data = downstream_dict["projector.bias"] + model.classifier.weight.data = downstream_dict["model.post_net.linear.weight"] + model.classifier.bias.data = downstream_dict["model.post_net.linear.bias"] + return model + + +def convert_diarization(base_model_name, hf_config, downstream_dict): + model = UniSpeechSatForAudioFrameClassification.from_pretrained(base_model_name, config=hf_config) + model.classifier.weight.data = downstream_dict["model.linear.weight"] + model.classifier.bias.data = downstream_dict["model.linear.bias"] + return model + + +def convert_xvector(base_model_name, hf_config, downstream_dict): + model = UniSpeechSatForXVector.from_pretrained(base_model_name, config=hf_config) + model.projector.weight.data = downstream_dict["connector.weight"] + model.projector.bias.data = downstream_dict["connector.bias"] + for i, kernel_size in enumerate(hf_config.tdnn_kernel): + model.tdnn[i].kernel.weight.data = downstream_dict[ + f"model.framelevel_feature_extractor.module.{i}.kernel.weight" + ] + model.tdnn[i].kernel.bias.data = downstream_dict[f"model.framelevel_feature_extractor.module.{i}.kernel.bias"] + + model.feature_extractor.weight.data = downstream_dict["model.utterancelevel_feature_extractor.linear1.weight"] + model.feature_extractor.bias.data = downstream_dict["model.utterancelevel_feature_extractor.linear1.bias"] + model.classifier.weight.data = downstream_dict["model.utterancelevel_feature_extractor.linear2.weight"] + model.classifier.bias.data = downstream_dict["model.utterancelevel_feature_extractor.linear2.bias"] + model.objective.weight.data = downstream_dict["objective.W"] + return model + + +@torch.no_grad() +def convert_s3prl_checkpoint(base_model_name, config_path, checkpoint_path, model_dump_path): + """ + Copy/paste/tweak model's weights to transformers design. + """ + checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=True) + + downstream_dict = checkpoint["Downstream"] + + hf_config = UniSpeechSatConfig.from_pretrained(config_path) + hf_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( + base_model_name, return_attention_mask=True, do_normalize=False + ) + + arch = hf_config.architectures[0] + if arch.endswith("ForSequenceClassification"): + hf_model = convert_classification(base_model_name, hf_config, downstream_dict) + elif arch.endswith("ForAudioFrameClassification"): + hf_model = convert_diarization(base_model_name, hf_config, downstream_dict) + elif arch.endswith("ForXVector"): + hf_model = convert_xvector(base_model_name, hf_config, downstream_dict) + else: + raise NotImplementedError(f"S3PRL weights conversion is not supported for {arch}") + + if hf_config.use_weighted_layer_sum: + hf_model.layer_weights.data = checkpoint["Featurizer"]["weights"] + + hf_feature_extractor.save_pretrained(model_dump_path) + hf_model.save_pretrained(model_dump_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--base_model_name", default=None, type=str, help="Name of the huggingface pretrained base model." + ) + parser.add_argument("--config_path", default=None, type=str, help="Path to the huggingface classifier config.") + parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to the s3prl checkpoint.") + parser.add_argument("--model_dump_path", default=None, type=str, help="Path to the final converted model.") + args = parser.parse_args() + convert_s3prl_checkpoint(args.base_model_name, args.config_path, args.checkpoint_path, args.model_dump_path) diff --git a/docs/transformers/src/transformers/models/unispeech_sat/convert_unispeech_sat_original_pytorch_checkpoint_to_pytorch.py b/docs/transformers/src/transformers/models/unispeech_sat/convert_unispeech_sat_original_pytorch_checkpoint_to_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..4a70d41dd2823460686c482518ca10f7f1cab2e7 --- /dev/null +++ b/docs/transformers/src/transformers/models/unispeech_sat/convert_unispeech_sat_original_pytorch_checkpoint_to_pytorch.py @@ -0,0 +1,224 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert UniSpeechSat checkpoint.""" + +import argparse + +import fairseq +import torch + +from transformers import UniSpeechSatConfig, UniSpeechSatForCTC, UniSpeechSatForPreTraining, logging + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + +MAPPING = { + "post_extract_proj": "feature_projection.projection", + "encoder.pos_conv.0": "encoder.pos_conv_embed.conv", + "self_attn.k_proj": "encoder.layers.*.attention.k_proj", + "self_attn.v_proj": "encoder.layers.*.attention.v_proj", + "self_attn.q_proj": "encoder.layers.*.attention.q_proj", + "self_attn.out_proj": "encoder.layers.*.attention.out_proj", + "self_attn_layer_norm": "encoder.layers.*.layer_norm", + "fc1": "encoder.layers.*.feed_forward.intermediate_dense", + "fc2": "encoder.layers.*.feed_forward.output_dense", + "final_layer_norm": "encoder.layers.*.final_layer_norm", + "encoder.layer_norm": "encoder.layer_norm", + "encoder.layer_norm_for_extract": "layer_norm_for_extract", + "w2v_model.layer_norm": "feature_projection.layer_norm", + "quantizer.weight_proj": "quantizer.weight_proj", + "quantizer.vars": "quantizer.codevectors", + "project_q": "project_q", + "final_proj": "project_hid", + "w2v_encoder.proj": "lm_head", + "label_embs_concat": "label_embeddings_concat", + "mask_emb": "masked_spec_embed", + "spk_proj": "speaker_proj", +} +TOP_LEVEL_KEYS = [ + "lm_head", + "quantizer.weight_proj", + "quantizer.codevectors", + "project_q", + "project_hid", + "label_embeddings_concat", + "speaker_proj", + "layer_norm_for_extract", +] + + +def set_recursively(hf_pointer, key, value, full_name, weight_type): + for attribute in key.split("."): + hf_pointer = getattr(hf_pointer, attribute) + + if weight_type is not None: + hf_shape = getattr(hf_pointer, weight_type).shape + else: + hf_shape = hf_pointer.shape + + if hf_shape != value.shape: + raise ValueError( + f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be" + f" {value.shape} for {full_name}" + ) + + if weight_type == "weight": + hf_pointer.weight.data = value + elif weight_type == "weight_g": + hf_pointer.weight_g.data = value + elif weight_type == "weight_v": + hf_pointer.weight_v.data = value + elif weight_type == "bias": + hf_pointer.bias.data = value + else: + hf_pointer.data = value + + logger.info(f"{key + '.' + weight_type if weight_type is not None else ''} was initialized from {full_name}.") + + +def recursively_load_weights(fairseq_model, hf_model): + unused_weights = [] + fairseq_dict = fairseq_model.state_dict() + + feature_extractor = hf_model.unispeech_sat.feature_extractor + + for name, value in fairseq_dict.items(): + is_used = False + if "conv_layers" in name: + load_conv_layer( + name, + value, + feature_extractor, + unused_weights, + hf_model.config.feat_extract_norm == "group", + ) + is_used = True + else: + for key, mapped_key in MAPPING.items(): + mapped_key = "unispeech_sat." + mapped_key if mapped_key not in TOP_LEVEL_KEYS else mapped_key + if key in name or key.split("w2v_model.")[-1] == name.split(".")[0]: + if "layer_norm_for_extract" in name and (".".join(name.split(".")[:-1]) != key): + # special case since naming is very similar + continue + is_used = True + if "*" in mapped_key: + layer_index = name.split(key)[0].split(".")[-2] + mapped_key = mapped_key.replace("*", layer_index) + if "weight_g" in name: + weight_type = "weight_g" + elif "weight_v" in name: + weight_type = "weight_v" + elif "bias" in name: + weight_type = "bias" + elif "weight" in name: + # TODO: don't match quantizer.weight_proj + weight_type = "weight" + else: + weight_type = None + set_recursively(hf_model, mapped_key, value, name, weight_type) + continue + if not is_used: + unused_weights.append(name) + + logger.warning(f"Unused weights: {unused_weights}") + + +def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_group_norm): + name = full_name.split("conv_layers.")[-1] + items = name.split(".") + layer_id = int(items[0]) + type_id = int(items[1]) + + if type_id == 0: + if "bias" in name: + if value.shape != feature_extractor.conv_layers[layer_id].conv.bias.data.shape: + raise ValueError( + f"{full_name} has size {value.shape}, but" + f" {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found." + ) + feature_extractor.conv_layers[layer_id].conv.bias.data = value + logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.") + elif "weight" in name: + if value.shape != feature_extractor.conv_layers[layer_id].conv.weight.data.shape: + raise ValueError( + f"{full_name} has size {value.shape}, but" + f" {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found." + ) + feature_extractor.conv_layers[layer_id].conv.weight.data = value + logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.") + elif (type_id == 2 and not use_group_norm) or (type_id == 2 and layer_id == 0 and use_group_norm): + if "bias" in name: + if value.shape != feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape: + raise ValueError( + f"{full_name} has size {value.shape}, but" + f" {feature_extractor[layer_id].layer_norm.bias.data.shape} was found." + ) + feature_extractor.conv_layers[layer_id].layer_norm.bias.data = value + logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.") + elif "weight" in name: + if value.shape != feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape: + raise ValueError( + f"{full_name} has size {value.shape}, but" + f" {feature_extractor[layer_id].layer_norm.weight.data.shape} was found." + ) + feature_extractor.conv_layers[layer_id].layer_norm.weight.data = value + logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.") + else: + unused_weights.append(full_name) + + +@torch.no_grad() +def convert_unispeech_sat_checkpoint( + checkpoint_path, pytorch_dump_folder_path, config_path=None, dict_path=None, is_finetuned=True +): + """ + Copy/paste/tweak model's weights to transformers design. + """ + if config_path is not None: + config = UniSpeechSatConfig.from_pretrained(config_path) + else: + config = UniSpeechSatConfig() + + dict_path = "" + + if is_finetuned: + hf_wav2vec = UniSpeechSatForCTC(config) + else: + hf_wav2vec = UniSpeechSatForPreTraining(config) + + model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task( + [checkpoint_path], arg_overrides={"data": "/".join(dict_path.split("/")[:-1])} + ) + model = model[0].eval() + + recursively_load_weights(model, hf_wav2vec) + + hf_wav2vec.save_pretrained(pytorch_dump_folder_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.") + parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to fairseq checkpoint") + parser.add_argument("--dict_path", default=None, type=str, help="Path to dict of fine-tuned model") + parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert") + parser.add_argument( + "--not_finetuned", action="store_true", help="Whether the model to convert is a fine-tuned model or not" + ) + args = parser.parse_args() + convert_unispeech_sat_checkpoint( + args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path, args.dict_path, not args.not_finetuned + ) diff --git a/docs/transformers/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/docs/transformers/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py new file mode 100644 index 0000000000000000000000000000000000000000..b1f8c4c3466ebcc37568d6c350d6a441568bfe26 --- /dev/null +++ b/docs/transformers/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py @@ -0,0 +1,2189 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/unispeech_sat/modular_unispeech_sat.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_unispeech_sat.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +import math +import warnings +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +from torch.nn import CrossEntropyLoss + +from ...activations import ACT2FN +from ...integrations.deepspeed import is_deepspeed_zero3_enabled +from ...integrations.fsdp import is_fsdp_managed_module +from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_outputs import ( + BaseModelOutput, + CausalLMOutput, + ModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, + Wav2Vec2BaseModelOutput, + XVectorOutput, +) +from ...modeling_utils import PreTrainedModel +from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_peft_available, + logging, + replace_return_docstrings, +) +from .configuration_unispeech_sat import UniSpeechSatConfig + + +if is_flash_attn_available(): + from ...modeling_flash_attention_utils import _flash_attention_forward + + +logger = logging.get_logger(__name__) + +# Base docstring +_CHECKPOINT_FOR_DOC = "microsoft/unispeech-sat-base-100h-libri-ft" + +# General docstring +_CONFIG_FOR_DOC = "UniSpeechSatConfig" + + +@dataclass +class UniSpeechSatForPreTrainingOutput(ModelOutput): + """ + Output type of [`UniSpeechSatForPreTrainingOutput`], with potential hidden states and attentions. + + Args: + loss (*optional*, returned when model is in train mode, `torch.FloatTensor` of shape `(1,)`): + Total loss as the sum of the contrastive loss (L_m) and the diversity loss (L_d) as stated in the [official + paper](https://arxiv.org/pdf/2006.11477.pdf) . (classification) loss. + projected_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`): + Hidden-states of the model projected to *config.proj_codevector_dim* that can be used to predict the masked + projected quantized states. + projected_quantized_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`): + Quantized extracted feature vectors projected to *config.proj_codevector_dim* representing the positive + target vectors for contrastive loss. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + projected_states: Optional[torch.FloatTensor] = None + projected_quantized_states: Optional[torch.FloatTensor] = None + codevector_perplexity: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +class UniSpeechSatSamePadLayer(nn.Module): + def __init__(self, num_conv_pos_embeddings): + super().__init__() + self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0 + + def forward(self, hidden_states): + if self.num_pad_remove > 0: + hidden_states = hidden_states[:, :, : -self.num_pad_remove] + return hidden_states + + +class UniSpeechSatPositionalConvEmbedding(nn.Module): + def __init__(self, config): + super().__init__() + self.conv = nn.Conv1d( + config.hidden_size, + config.hidden_size, + kernel_size=config.num_conv_pos_embeddings, + padding=config.num_conv_pos_embeddings // 2, + groups=config.num_conv_pos_embedding_groups, + ) + + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + + if is_deepspeed_zero3_enabled(): + import deepspeed + + with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0): + self.conv = weight_norm(self.conv, name="weight", dim=2) + if hasattr(self.conv, "parametrizations"): + weight_g = self.conv.parametrizations.weight.original0 + weight_v = self.conv.parametrizations.weight.original1 + else: + weight_g = self.conv.weight_g + weight_v = self.conv.weight_v + deepspeed.zero.register_external_parameter(self, weight_v) + deepspeed.zero.register_external_parameter(self, weight_g) + else: + self.conv = weight_norm(self.conv, name="weight", dim=2) + + self.padding = UniSpeechSatSamePadLayer(config.num_conv_pos_embeddings) + self.activation = ACT2FN[config.feat_extract_activation] + + def forward(self, hidden_states): + hidden_states = hidden_states.transpose(1, 2) + + hidden_states = self.conv(hidden_states) + hidden_states = self.padding(hidden_states) + hidden_states = self.activation(hidden_states) + + hidden_states = hidden_states.transpose(1, 2) + return hidden_states + + +class UniSpeechSatNoLayerNormConvLayer(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 + self.out_conv_dim = config.conv_dim[layer_id] + + self.conv = nn.Conv1d( + self.in_conv_dim, + self.out_conv_dim, + kernel_size=config.conv_kernel[layer_id], + stride=config.conv_stride[layer_id], + bias=config.conv_bias, + ) + self.activation = ACT2FN[config.feat_extract_activation] + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + hidden_states = self.activation(hidden_states) + return hidden_states + + +class UniSpeechSatLayerNormConvLayer(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 + self.out_conv_dim = config.conv_dim[layer_id] + + self.conv = nn.Conv1d( + self.in_conv_dim, + self.out_conv_dim, + kernel_size=config.conv_kernel[layer_id], + stride=config.conv_stride[layer_id], + bias=config.conv_bias, + ) + self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True) + self.activation = ACT2FN[config.feat_extract_activation] + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + + hidden_states = hidden_states.transpose(-2, -1) + hidden_states = self.layer_norm(hidden_states) + hidden_states = hidden_states.transpose(-2, -1) + + hidden_states = self.activation(hidden_states) + return hidden_states + + +class UniSpeechSatGroupNormConvLayer(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 + self.out_conv_dim = config.conv_dim[layer_id] + + self.conv = nn.Conv1d( + self.in_conv_dim, + self.out_conv_dim, + kernel_size=config.conv_kernel[layer_id], + stride=config.conv_stride[layer_id], + bias=config.conv_bias, + ) + self.activation = ACT2FN[config.feat_extract_activation] + + self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True) + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + hidden_states = self.layer_norm(hidden_states) + hidden_states = self.activation(hidden_states) + return hidden_states + + +class UniSpeechSatFeatureEncoder(nn.Module): + """Construct the features from raw audio waveform""" + + def __init__(self, config): + super().__init__() + + if config.feat_extract_norm == "group": + conv_layers = [UniSpeechSatGroupNormConvLayer(config, layer_id=0)] + [ + UniSpeechSatNoLayerNormConvLayer(config, layer_id=i + 1) + for i in range(config.num_feat_extract_layers - 1) + ] + elif config.feat_extract_norm == "layer": + conv_layers = [ + UniSpeechSatLayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers) + ] + else: + raise ValueError( + f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']" + ) + self.conv_layers = nn.ModuleList(conv_layers) + self.gradient_checkpointing = False + self._requires_grad = True + + def _freeze_parameters(self): + for param in self.parameters(): + param.requires_grad = False + self._requires_grad = False + + def forward(self, input_values): + hidden_states = input_values[:, None] + + # make sure hidden_states require grad for gradient_checkpointing + if self._requires_grad and self.training: + hidden_states.requires_grad = True + + for conv_layer in self.conv_layers: + if self._requires_grad and self.gradient_checkpointing and self.training: + hidden_states = self._gradient_checkpointing_func( + conv_layer.__call__, + hidden_states, + ) + else: + hidden_states = conv_layer(hidden_states) + + return hidden_states + + +class UniSpeechSatFeatureProjection(nn.Module): + def __init__(self, config): + super().__init__() + self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps) + self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size) + self.dropout = nn.Dropout(config.feat_proj_dropout) + + def forward(self, hidden_states): + # non-projected hidden states are needed for quantization + norm_hidden_states = self.layer_norm(hidden_states) + hidden_states = self.projection(norm_hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states, norm_hidden_states + + +class UniSpeechSatAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + is_causal: bool = False, + config: Optional[UniSpeechSatConfig] = None, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + self.config = config + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + self.is_causal = is_causal + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.reshape(*proj_shape) + value_states = value_states.reshape(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value + + +class UniSpeechSatFlashAttention2(UniSpeechSatAttention): + """ + UniSpeechSat flash attention module. This module inherits from `UniSpeechSatAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() + + def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # UniSpeechSatFlashAttention2 attention does not support output_attentions + if output_attentions: + raise ValueError("UniSpeechSatFlashAttention2 attention does not support output_attentions") + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, q_len, _ = hidden_states.size() + + # get query proj + query_states = self._reshape(self.q_proj(hidden_states), -1, bsz) + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0].transpose(1, 2) + value_states = past_key_value[1].transpose(1, 2) + elif is_cross_attention: + # cross_attentions + key_states = self._reshape(self.k_proj(key_value_states), -1, bsz) + value_states = self._reshape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) + value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1) + value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1) + else: + # self_attention + key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) + value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2)) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=self.dropout if self.training else 0.0, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class UniSpeechSatSdpaAttention(UniSpeechSatAttention): + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + if output_attentions or layer_head_mask is not None: + # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "UniSpeechSatModel is using UniSpeechSatSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention" + ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states, + key_value_states=key_value_states, + past_key_value=past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + query_states = self._shape(query_states, tgt_len, bsz) + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. + is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False + + # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, + # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=self.dropout if self.training else 0.0, + is_causal=is_causal, + ) + + if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, None, past_key_value + + +class UniSpeechSatFeedForward(nn.Module): + def __init__(self, config): + super().__init__() + self.intermediate_dropout = nn.Dropout(config.activation_dropout) + + self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.output_dropout = nn.Dropout(config.hidden_dropout) + + def forward(self, hidden_states): + hidden_states = self.intermediate_dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + hidden_states = self.intermediate_dropout(hidden_states) + + hidden_states = self.output_dense(hidden_states) + hidden_states = self.output_dropout(hidden_states) + return hidden_states + + +UNISPEECH_SAT_ATTENTION_CLASSES = { + "eager": UniSpeechSatAttention, + "sdpa": UniSpeechSatSdpaAttention, + "flash_attention_2": UniSpeechSatFlashAttention2, +} + + +class UniSpeechSatEncoderLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.attention = UNISPEECH_SAT_ATTENTION_CLASSES[config._attn_implementation]( + embed_dim=config.hidden_size, + num_heads=config.num_attention_heads, + dropout=config.attention_dropout, + is_decoder=False, + ) + + self.dropout = nn.Dropout(config.hidden_dropout) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.feed_forward = UniSpeechSatFeedForward(config) + self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states, attention_mask=None, output_attentions=False): + attn_residual = hidden_states + hidden_states, attn_weights, _ = self.attention( + hidden_states, attention_mask=attention_mask, output_attentions=output_attentions + ) + hidden_states = self.dropout(hidden_states) + hidden_states = attn_residual + hidden_states + + hidden_states = self.layer_norm(hidden_states) + hidden_states = hidden_states + self.feed_forward(hidden_states) + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class UniSpeechSatEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.pos_conv_embed = UniSpeechSatPositionalConvEmbedding(config) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout) + self.layers = nn.ModuleList([UniSpeechSatEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + + def forward( + self, + hidden_states: torch.tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + if attention_mask is not None: + # make sure padded tokens output 0 + expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) + hidden_states[~expand_attention_mask] = 0 + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + else: + # extend attention_mask + attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) + attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min + attention_mask = attention_mask.expand( + attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] + ) + + position_embeddings = self.pos_conv_embed(hidden_states) + hidden_states = hidden_states + position_embeddings + hidden_states = self.layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states) + + synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) + + for layer in self.layers: + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = torch.rand([]) + + skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False + if not skip_the_layer or synced_gpus: + # under fsdp or deepspeed zero3 all gpus must run in sync + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + output_attentions, + ) + else: + layer_outputs = layer( + hidden_states, attention_mask=attention_mask, output_attentions=output_attentions + ) + hidden_states = layer_outputs[0] + + if skip_the_layer: + layer_outputs = (None, None) + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class UniSpeechSatAttnAdapterLayer(nn.Module): + def __init__(self, config): + """ + Implements adapter modules directly with 3D tensor weight as parameters and without using ModuleList to speed + up training throughput. + """ + super().__init__() + self.input_dim = config.adapter_attn_dim + self.hidden_dim = config.hidden_size + + self.norm = nn.LayerNorm(self.hidden_dim) + self.linear_1 = nn.Linear(self.hidden_dim, self.input_dim) + self.act_fn = nn.ReLU() + self.linear_2 = nn.Linear(self.input_dim, self.hidden_dim) + + def forward(self, hidden_states: torch.FloatTensor): + hidden_states = self.norm(hidden_states) + + hidden_states = self.linear_1(hidden_states) + hidden_states = self.act_fn(hidden_states) + hidden_states = self.linear_2(hidden_states) + + return hidden_states + + +class UniSpeechSatEncoderLayerStableLayerNorm(nn.Module): + def __init__(self, config): + super().__init__() + self.attention = UNISPEECH_SAT_ATTENTION_CLASSES[config._attn_implementation]( + embed_dim=config.hidden_size, + num_heads=config.num_attention_heads, + dropout=config.attention_dropout, + is_decoder=False, + ) + self.dropout = nn.Dropout(config.hidden_dropout) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.feed_forward = UniSpeechSatFeedForward(config) + self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + if getattr(config, "adapter_attn_dim", None) is not None: + self.adapter_layer = UniSpeechSatAttnAdapterLayer(config) + else: + self.adapter_layer = None + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ): + attn_residual = hidden_states + hidden_states = self.layer_norm(hidden_states) + hidden_states, attn_weights, _ = self.attention( + hidden_states, attention_mask=attention_mask, output_attentions=output_attentions + ) + hidden_states = self.dropout(hidden_states) + hidden_states = attn_residual + hidden_states + hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states)) + + if self.adapter_layer is not None: + hidden_states = hidden_states + self.adapter_layer(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class UniSpeechSatEncoderStableLayerNorm(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.pos_conv_embed = UniSpeechSatPositionalConvEmbedding(config) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout) + self.layers = nn.ModuleList( + [UniSpeechSatEncoderLayerStableLayerNorm(config) for _ in range(config.num_hidden_layers)] + ) + self.gradient_checkpointing = False + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + + def forward( + self, + hidden_states, + attention_mask=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + if attention_mask is not None: + # make sure padded tokens are not attended to + expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) + hidden_states = hidden_states * expand_attention_mask.to(dtype=hidden_states.dtype) + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + else: + # extend attention_mask + attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) + attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min + attention_mask = attention_mask.expand( + attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] + ) + + position_embeddings = self.pos_conv_embed(hidden_states) + hidden_states = hidden_states + position_embeddings + hidden_states = self.dropout(hidden_states) + + synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) + + for layer in self.layers: + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = torch.rand([]) + + skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False + if not skip_the_layer or synced_gpus: + # under fsdp or deepspeed zero3 all gpus must run in sync + # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + output_attentions, + ) + else: + layer_outputs = layer( + hidden_states, attention_mask=attention_mask, output_attentions=output_attentions + ) + hidden_states = layer_outputs[0] + + if skip_the_layer: + layer_outputs = (None, None) + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + hidden_states = self.layer_norm(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class UniSpeechSatGumbelVectorQuantizer(nn.Module): + """ + Vector quantization using gumbel softmax. See `[CATEGORICAL REPARAMETERIZATION WITH + GUMBEL-SOFTMAX](https://arxiv.org/pdf/1611.01144.pdf) for more information. + """ + + def __init__(self, config): + super().__init__() + self.num_groups = config.num_codevector_groups + self.num_vars = config.num_codevectors_per_group + + if config.codevector_dim % self.num_groups != 0: + raise ValueError( + f"`config.codevector_dim {config.codevector_dim} must be divisible " + f"by `config.num_codevector_groups` {self.num_groups} for concatenation" + ) + + # storage for codebook variables (codewords) + self.codevectors = nn.Parameter( + torch.FloatTensor(1, self.num_groups * self.num_vars, config.codevector_dim // self.num_groups) + ) + self.weight_proj = nn.Linear(config.hidden_size, self.num_groups * self.num_vars) + + # can be decayed for training + self.temperature = 2 + + @staticmethod + def _compute_perplexity(probs, mask=None): + marginal_probs = probs.mean(dim=0) + perplexity = torch.exp(-torch.sum(marginal_probs * torch.log(marginal_probs + 1e-7), dim=-1)).sum() + return perplexity + + def forward(self, hidden_states): + batch_size, sequence_length, hidden_size = hidden_states.shape + + # project to codevector dim + hidden_states = self.weight_proj(hidden_states) + hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1) + + if self.training: + # sample code vector probs via gumbel in differentiateable way + codevector_probs = nn.functional.gumbel_softmax( + hidden_states.float(), tau=self.temperature, hard=True + ).type_as(hidden_states) + + # compute perplexity + codevector_soft_dist = torch.softmax( + hidden_states.view(batch_size * sequence_length, self.num_groups, -1).float(), dim=-1 + ) + perplexity = self._compute_perplexity(codevector_soft_dist) + else: + # take argmax in non-differentiable way + # comptute hard codevector distribution (one hot) + codevector_idx = hidden_states.argmax(dim=-1) + codevector_probs = hidden_states.new_zeros(*hidden_states.shape).scatter_( + -1, codevector_idx.view(-1, 1), 1.0 + ) + codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1) + + perplexity = self._compute_perplexity(codevector_probs) + + codevector_probs = codevector_probs.view(batch_size * sequence_length, -1) + # use probs to retrieve codevectors + codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors + codevectors = codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1) + codevectors = codevectors.sum(-2).view(batch_size, sequence_length, -1) + + return codevectors, perplexity + + +class UniSpeechSatPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = UniSpeechSatConfig + base_model_prefix = "unispeech_sat" + main_input_name = "input_values" + supports_gradient_checkpointing = True + _supports_flash_attn_2 = True + _supports_sdpa = True + + def _init_weights(self, module): + """Initialize the weights""" + # gumbel softmax requires special init + if isinstance(module, UniSpeechSatGumbelVectorQuantizer): + module.weight_proj.weight.data.normal_(mean=0.0, std=1) + module.weight_proj.bias.data.zero_() + nn.init.uniform_(module.codevectors) + elif isinstance(module, UniSpeechSatPositionalConvEmbedding): + nn.init.normal_( + module.conv.weight, + mean=0, + std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)), + ) + nn.init.constant_(module.conv.bias, 0) + elif isinstance(module, UniSpeechSatFeatureProjection): + k = math.sqrt(1 / module.projection.in_features) + nn.init.uniform_(module.projection.weight, a=-k, b=k) + nn.init.uniform_(module.projection.bias, a=-k, b=k) + elif isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, nn.Conv1d): + nn.init.kaiming_normal_(module.weight) + + if module.bias is not None: + k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) + nn.init.uniform_(module.bias, a=-k, b=k) + + def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): + """ + Computes the output length of the convolutional layers + """ + + def _conv_out_length(input_length, kernel_size, stride): + # 1D convolutional layer output length formula taken + # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html + return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1 + + for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): + input_lengths = _conv_out_length(input_lengths, kernel_size, stride) + + return input_lengths + + def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor): + # Effectively attention_mask.sum(-1), but not inplace to be able to run + # on inference mode. + non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1] + output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths).to(torch.long) + batch_size = attention_mask.shape[0] + + attention_mask = torch.zeros( + (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device + ) + # these two operations makes sure that all values before the output lengths idxs are attended to + attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1 + attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() + return attention_mask + + +def _compute_mask_indices( + shape: Tuple[int, int], + mask_prob: float, + mask_length: int, + attention_mask: Optional[torch.LongTensor] = None, + min_masks: int = 0, +) -> np.ndarray: + """ + Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for + ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on + CPU as part of the preprocessing during training. + + Args: + shape: The shape for which to compute masks. This should be of a tuple of size 2 where + the first element is the batch size and the second element is the length of the axis to span. + mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of + independently generated mask spans of length `mask_length` is computed by + `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the + actual percentage will be smaller. + mask_length: size of the mask + min_masks: minimum number of masked spans + attention_mask: A (right-padded) attention mask which independently shortens the feature axis of + each batch dimension. + """ + batch_size, sequence_length = shape + + if mask_length < 1: + raise ValueError("`mask_length` has to be bigger than 0.") + + if mask_length > sequence_length: + raise ValueError( + f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}" + f" and `sequence_length`: {sequence_length}`" + ) + + # epsilon is used for probabilistic rounding + epsilon = np.random.rand(1).item() + + def compute_num_masked_span(input_length): + """Given input length, compute how many spans should be masked""" + num_masked_span = int(mask_prob * input_length / mask_length + epsilon) + num_masked_span = max(num_masked_span, min_masks) + + # make sure num masked span <= sequence_length + if num_masked_span * mask_length > sequence_length: + num_masked_span = sequence_length // mask_length + + # make sure num_masked span is also <= input_length - (mask_length - 1) + if input_length - (mask_length - 1) < num_masked_span: + num_masked_span = max(input_length - (mask_length - 1), 0) + + return num_masked_span + + # compute number of masked spans in batch + input_lengths = ( + attention_mask.detach().sum(-1).tolist() + if attention_mask is not None + else [sequence_length for _ in range(batch_size)] + ) + + # SpecAugment mask to fill + spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool) + spec_aug_mask_idxs = [] + + max_num_masked_span = compute_num_masked_span(sequence_length) + + if max_num_masked_span == 0: + return spec_aug_mask + + for input_length in input_lengths: + # compute num of masked spans for this input + num_masked_span = compute_num_masked_span(input_length) + + # get random indices to mask + spec_aug_mask_idx = np.random.choice( + np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False + ) + + # pick first sampled index that will serve as a dummy index to pad vector + # to ensure same dimension for all batches due to probabilistic rounding + # Picking first sample just pads those vectors twice. + if len(spec_aug_mask_idx) == 0: + # this case can only happen if `input_length` is strictly smaller then + # `sequence_length` in which case the last token has to be a padding + # token which we can use as a dummy mask id + dummy_mask_idx = sequence_length - 1 + else: + dummy_mask_idx = spec_aug_mask_idx[0] + + spec_aug_mask_idx = np.concatenate( + [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx] + ) + spec_aug_mask_idxs.append(spec_aug_mask_idx) + + spec_aug_mask_idxs = np.array(spec_aug_mask_idxs) + + # expand masked indices to masked spans + spec_aug_mask_idxs = np.broadcast_to( + spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length) + ) + spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length) + + # add offset to the starting indexes so that indexes now create a span + offsets = np.arange(mask_length)[None, None, :] + offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape( + batch_size, max_num_masked_span * mask_length + ) + spec_aug_mask_idxs = spec_aug_mask_idxs + offsets + + # ensure that we cannot have indices larger than sequence_length + if spec_aug_mask_idxs.max() > sequence_length - 1: + spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1 + + # scatter indices to mask + np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1) + + return spec_aug_mask + + +_EXPECTED_OUTPUT_SHAPE = [1, 292, 768] + + +UNISPEECH_SAT_START_DOCSTRING = r""" + UniSpeechSat was proposed in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech + Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael + Auli. + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving etc.). + + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use + it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`UniSpeechSatConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +UNISPEECH_SAT_INPUTS_DOCSTRING = r""" + Args: + input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file + into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install + soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and + conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2Processor.__call__`] for details. + attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0, + 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + + + `attention_mask` should only be passed if the corresponding processor has `config.return_attention_mask == + True`. For all models whose processor has `config.return_attention_mask == False`, `attention_mask` should + **not** be passed to avoid degraded performance when doing batched inference. For such models + `input_values` should simply be padded with 0 and passed without `attention_mask`. Be aware that these + models also yield slightly different results depending on whether `input_values` is padded or not. + + + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +UniSpeechSatBaseModelOutput = Wav2Vec2BaseModelOutput + + +@add_start_docstrings( + "The bare UniSpeechSat Model transformer outputting raw hidden-states without any specific head on top.", + UNISPEECH_SAT_START_DOCSTRING, +) +class UniSpeechSatModel(UniSpeechSatPreTrainedModel): + def __init__(self, config: UniSpeechSatConfig): + super().__init__(config) + self.config = config + self.feature_extractor = UniSpeechSatFeatureEncoder(config) + self.feature_projection = UniSpeechSatFeatureProjection(config) + + self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_()) + + if config.do_stable_layer_norm: + self.encoder = UniSpeechSatEncoderStableLayerNorm(config) + else: + self.encoder = UniSpeechSatEncoder(config) + + # Initialize weights and apply final processing + self.post_init() + + def _mask_hidden_states( + self, + hidden_states: torch.FloatTensor, + mask_time_indices: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + ): + """ + Masks extracted features along time axis and/or along feature axis according to + [SpecAugment](https://arxiv.org/abs/1904.08779). + """ + + # `config.apply_spec_augment` can set masking to False + if not getattr(self.config, "apply_spec_augment", True): + return hidden_states + + # generate indices & apply SpecAugment along time axis + batch_size, sequence_length, hidden_size = hidden_states.size() + + if mask_time_indices is not None: + # apply SpecAugment along time axis with given mask_time_indices + hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) + elif self.config.mask_time_prob > 0 and self.training: + mask_time_indices = _compute_mask_indices( + (batch_size, sequence_length), + mask_prob=self.config.mask_time_prob, + mask_length=self.config.mask_time_length, + attention_mask=attention_mask, + min_masks=self.config.mask_time_min_masks, + ) + mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool) + hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) + + if self.config.mask_feature_prob > 0 and self.training: + # generate indices & apply SpecAugment along feature axis + mask_feature_indices = _compute_mask_indices( + (batch_size, hidden_size), + mask_prob=self.config.mask_feature_prob, + mask_length=self.config.mask_feature_length, + min_masks=self.config.mask_feature_min_masks, + ) + mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool) + mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1) + hidden_states[mask_feature_indices] = 0 + + return hidden_states + + @add_start_docstrings_to_model_forward(UNISPEECH_SAT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=UniSpeechSatBaseModelOutput, + config_class=_CONFIG_FOR_DOC, + modality="audio", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward( + self, + input_values: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + mask_time_indices: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, UniSpeechSatBaseModelOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + extract_features = self.feature_extractor(input_values) + extract_features = extract_features.transpose(1, 2) + + if attention_mask is not None: + # compute reduced attention_mask corresponding to feature vectors + attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask) + + hidden_states, extract_features = self.feature_projection(extract_features) + hidden_states = self._mask_hidden_states( + hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask + ) + + encoder_outputs = self.encoder( + hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = encoder_outputs[0] + + if not return_dict: + return (hidden_states, extract_features) + encoder_outputs[1:] + + return UniSpeechSatBaseModelOutput( + last_hidden_state=hidden_states, + extract_features=extract_features, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + """UniSpeechSat Model with a vector-quantization module and ctc loss for pre-training.""", + UNISPEECH_SAT_START_DOCSTRING, +) +class UniSpeechSatForPreTraining(UniSpeechSatPreTrainedModel): + def __init__(self, config: UniSpeechSatConfig): + super().__init__(config) + self.unispeech_sat = UniSpeechSatModel(config) + self.dropout_features = nn.Dropout(config.feat_quantizer_dropout) + + self.quantizer = UniSpeechSatGumbelVectorQuantizer(config) + self.project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim) + self.project_hid = nn.Linear(config.hidden_size, config.proj_codevector_dim) + + self.dropout = nn.Dropout(config.final_dropout) + + self.speaker_proj = nn.Linear(config.hidden_size, config.codevector_dim) + self.label_embeddings_concat = nn.Parameter(torch.FloatTensor(config.num_clusters, config.codevector_dim)) + self.label_embeddings_concat.data.zero_() + + self.layer_norm_for_extract = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + if self.config.do_stable_layer_norm: + self.layer_norm_for_extract.requires_grad = False + + # Initialize weights and apply final processing + self.post_init() + + def set_gumbel_temperature(self, temperature: int): + """ + Set the Gumbel softmax temperature to a given value. Only necessary for training + """ + self.quantizer.temperature = temperature + + def freeze_feature_extractor(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameters will + not be updated during training. + """ + warnings.warn( + "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. " + "Please use the equivalent `freeze_feature_encoder` method instead.", + FutureWarning, + ) + self.freeze_feature_encoder() + + def freeze_feature_encoder(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameter will + not be updated during training. + """ + self.unispeech_sat.feature_extractor._freeze_parameters() + + @staticmethod + def compute_contrastive_logits( + target_features: torch.FloatTensor, + negative_features: torch.FloatTensor, + predicted_features: torch.FloatTensor, + temperature: int = 1, + ): + """ + Compute logits for contrastive loss based using cosine similarity as the distance measure between + `[positive_feature, negative_features]` and `[predicted_features]`. Additionally, temperature can be applied. + """ + target_features = torch.cat([target_features, negative_features], dim=0) + + logits = torch.cosine_similarity(predicted_features.float(), target_features.float(), dim=-1) + logits = logits.type_as(target_features) + + # apply temperature + logits = logits / temperature + return logits + + @add_start_docstrings_to_model_forward(UNISPEECH_SAT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=UniSpeechSatForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_values: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, UniSpeechSatForPreTrainingOutput]: + r""" + Returns: + + Example: + + ```python + >>> import torch + >>> from transformers import AutoFeatureExtractor, UniSpeechSatForPreTraining + >>> from transformers.models.unispeech_sat.modeling_unispeech_sat import _compute_mask_indices + + >>> feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/unispeech-sat-base") + >>> model = UniSpeechSatForPreTraining.from_pretrained("microsoft/unispeech-sat-base") + >>> # TODO: Add full pretraining example + ```""" + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.unispeech_sat( + input_values, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + transformer_features = outputs[0] + + # quantize all (unmasked) extracted features and project to final vq dim + extract_features = self.dropout_features(outputs[1]) + + # TODO(PVP) - add pretraining logic and add to tests + logits = extract_features + loss = quantized_features = codevector_perplexity = None + + if not return_dict: + if loss is not None: + return (loss, logits, transformer_features, quantized_features, codevector_perplexity) + outputs[2:] + return (logits, transformer_features, quantized_features, codevector_perplexity) + outputs[2:] + + return UniSpeechSatForPreTrainingOutput( + loss=loss, + logits=logits, + projected_states=transformer_features, + projected_quantized_states=quantized_features, + codevector_perplexity=codevector_perplexity, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +_HIDDEN_STATES_START_POSITION = 2 + +# CTC docstring +_CTC_EXPECTED_OUTPUT = "'MISTER QUILDER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'" +_CTC_EXPECTED_LOSS = 39.88 + + +@add_start_docstrings( + """UniSpeechSat Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""", + UNISPEECH_SAT_START_DOCSTRING, + """ + target_lang (`str`, *optional*): + Language id of adapter weights. Adapter weights are stored in the format adapter..safetensors or + adapter..bin. Only relevant when using an instance of [`UniSpeechSatForCTC`] with adapters. Uses + 'eng' by default. + """, +) +class UniSpeechSatForCTC(UniSpeechSatPreTrainedModel): + def __init__(self, config, target_lang: Optional[str] = None): + super().__init__(config) + + self.unispeech_sat = UniSpeechSatModel(config) + self.dropout = nn.Dropout(config.final_dropout) + + self.target_lang = target_lang + + if config.vocab_size is None: + raise ValueError( + f"You are trying to instantiate {self.__class__} with a configuration that " + "does not define the vocabulary size of the language model head. Please " + "instantiate the model as follows: `UniSpeechSatForCTC.from_pretrained(..., vocab_size=vocab_size)`. " + "or define `vocab_size` of your model's configuration." + ) + output_hidden_size = ( + config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size + ) + self.lm_head = nn.Linear(output_hidden_size, config.vocab_size) + + # Initialize weights and apply final processing + self.post_init() + + def tie_weights(self): + """ + This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when + passing `target_lang=...` to `from_pretrained(...)`. + + This method is **not** supposed to be called by the user and is prone to be changed in the future. + """ + + # Note that `tie_weights` is usually used to tie input and output embedding weights. The method is re-purposed to + # correctly load adapter layers for UniSpeechSat so that we do not have to introduce a new API to + # [`PreTrainedModel`]. While slightly hacky, UniSpeechSat never has to tie input and output embeddings, so that it is + # ok to repurpose this function here. + target_lang = self.target_lang + + if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None: + raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.") + elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None: + logger.info("By default `target_lang` is set to 'eng'.") + elif target_lang is not None: + self.load_adapter(target_lang, force_load=True) + + def freeze_feature_extractor(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameter will + not be updated during training. + """ + warnings.warn( + "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. " + "Please use the equivalent `freeze_feature_encoder` method instead.", + FutureWarning, + ) + self.freeze_feature_encoder() + + def freeze_feature_encoder(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameter will + not be updated during training. + """ + self.unispeech_sat.feature_extractor._freeze_parameters() + + def freeze_base_model(self): + """ + Calling this function will disable the gradient computation for the base model so that its parameters will not + be updated during training. Only the classification head will be updated. + """ + for param in self.unispeech_sat.parameters(): + param.requires_grad = False + + @add_start_docstrings_to_model_forward(UNISPEECH_SAT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=CausalLMOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_CTC_EXPECTED_OUTPUT, + expected_loss=_CTC_EXPECTED_LOSS, + ) + def forward( + self, + input_values: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: Optional[torch.Tensor] = None, + ) -> Union[Tuple, CausalLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*): + Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to + the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. + All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., + config.vocab_size - 1]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None and labels.max() >= self.config.vocab_size: + raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}") + + outputs = self.unispeech_sat( + input_values, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + hidden_states = self.dropout(hidden_states) + + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # retrieve loss input_lengths from attention_mask + attention_mask = ( + attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long) + ) + input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long) + + # assuming that padded tokens are filled with -100 + # when not being attended to + labels_mask = labels >= 0 + target_lengths = labels_mask.sum(-1) + flattened_targets = labels.masked_select(labels_mask) + + # ctc_loss doesn't support fp16 + log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1) + + with torch.backends.cudnn.flags(enabled=False): + loss = nn.functional.ctc_loss( + log_probs, + flattened_targets, + input_lengths, + target_lengths, + blank=self.config.pad_token_id, + reduction=self.config.ctc_loss_reduction, + zero_infinity=self.config.ctc_zero_infinity, + ) + + if not return_dict: + output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutput( + loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions + ) + + +@add_start_docstrings( + """ + UniSpeechSat Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like + SUPERB Keyword Spotting. + """, + UNISPEECH_SAT_START_DOCSTRING, +) +class UniSpeechSatForSequenceClassification(UniSpeechSatPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + if hasattr(config, "add_adapter") and config.add_adapter: + raise ValueError( + "Sequence classification does not support the use of UniSpeechSat adapters (config.add_adapter=True)" + ) + self.unispeech_sat = UniSpeechSatModel(config) + num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings + if config.use_weighted_layer_sum: + self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers) + self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size) + self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def freeze_feature_extractor(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameters will + not be updated during training. + """ + warnings.warn( + "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. " + "Please use the equivalent `freeze_feature_encoder` method instead.", + FutureWarning, + ) + self.freeze_feature_encoder() + + def freeze_feature_encoder(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameter will + not be updated during training. + """ + self.unispeech_sat.feature_extractor._freeze_parameters() + + def freeze_base_model(self): + """ + Calling this function will disable the gradient computation for the base model so that its parameters will not + be updated during training. Only the classification head will be updated. + """ + for param in self.unispeech_sat.parameters(): + param.requires_grad = False + + @add_start_docstrings_to_model_forward(UNISPEECH_SAT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=SequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + modality="audio", + ) + def forward( + self, + input_values: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: Optional[torch.Tensor] = None, + ) -> Union[Tuple, SequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states + + outputs = self.unispeech_sat( + input_values, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if self.config.use_weighted_layer_sum: + hidden_states = outputs[_HIDDEN_STATES_START_POSITION] + hidden_states = torch.stack(hidden_states, dim=1) + norm_weights = nn.functional.softmax(self.layer_weights, dim=-1) + hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) + else: + hidden_states = outputs[0] + + hidden_states = self.projector(hidden_states) + if attention_mask is None: + pooled_output = hidden_states.mean(dim=1) + else: + padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask) + expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) + hidden_states[~expand_padding_mask] = 0.0 + pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1) + + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +# Frame class docstring +_FRAME_CLASS_CHECKPOINT = "microsoft/unispeech-sat-base-plus-sd" +_FRAME_EXPECTED_OUTPUT = [0, 0] + + +@add_start_docstrings( + """ + UniSpeechSat Model with a frame classification head on top for tasks like Speaker Diarization. + """, + UNISPEECH_SAT_START_DOCSTRING, +) +class UniSpeechSatForAudioFrameClassification(UniSpeechSatPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + if hasattr(config, "add_adapter") and config.add_adapter: + raise ValueError( + "Audio frame classification does not support the use of UniSpeechSat adapters (config.add_adapter=True)" + ) + self.unispeech_sat = UniSpeechSatModel(config) + num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings + if config.use_weighted_layer_sum: + self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + self.num_labels = config.num_labels + + self.init_weights() + + def freeze_feature_extractor(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameter will + not be updated during training. + """ + warnings.warn( + "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. " + "Please use the equivalent `freeze_feature_encoder` method instead.", + FutureWarning, + ) + self.freeze_feature_encoder() + + def freeze_feature_encoder(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameter will + not be updated during training. + """ + self.unispeech_sat.feature_extractor._freeze_parameters() + + def freeze_base_model(self): + """ + Calling this function will disable the gradient computation for the base model so that its parameters will not + be updated during training. Only the classification head will be updated. + """ + for param in self.unispeech_sat.parameters(): + param.requires_grad = False + + @add_start_docstrings_to_model_forward(UNISPEECH_SAT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_FRAME_CLASS_CHECKPOINT, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + modality="audio", + expected_output=_FRAME_EXPECTED_OUTPUT, + ) + def forward( + self, + input_values: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states + + outputs = self.unispeech_sat( + input_values, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if self.config.use_weighted_layer_sum: + hidden_states = outputs[_HIDDEN_STATES_START_POSITION] + hidden_states = torch.stack(hidden_states, dim=1) + norm_weights = nn.functional.softmax(self.layer_weights, dim=-1) + hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) + else: + hidden_states = outputs[0] + + logits = self.classifier(hidden_states) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), torch.argmax(labels.view(-1, self.num_labels), axis=1)) + + if not return_dict: + output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] + return output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class AMSoftmaxLoss(nn.Module): + def __init__(self, input_dim, num_labels, scale=30.0, margin=0.4): + super(AMSoftmaxLoss, self).__init__() + self.scale = scale + self.margin = margin + self.num_labels = num_labels + self.weight = nn.Parameter(torch.randn(input_dim, num_labels), requires_grad=True) + self.loss = nn.CrossEntropyLoss() + + def forward(self, hidden_states, labels): + labels = labels.flatten() + weight = nn.functional.normalize(self.weight, dim=0) + hidden_states = nn.functional.normalize(hidden_states, dim=1) + cos_theta = torch.mm(hidden_states, weight) + psi = cos_theta - self.margin + + onehot = nn.functional.one_hot(labels, self.num_labels) + logits = self.scale * torch.where(onehot.bool(), psi, cos_theta) + loss = self.loss(logits, labels) + + return loss + + +class TDNNLayer(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.in_conv_dim = config.tdnn_dim[layer_id - 1] if layer_id > 0 else config.tdnn_dim[layer_id] + self.out_conv_dim = config.tdnn_dim[layer_id] + self.kernel_size = config.tdnn_kernel[layer_id] + self.dilation = config.tdnn_dilation[layer_id] + + self.kernel = nn.Linear(self.in_conv_dim * self.kernel_size, self.out_conv_dim) + self.activation = nn.ReLU() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + if is_peft_available(): + from peft.tuners.lora import LoraLayer + + if is_peft_available(): + if isinstance(self.kernel, LoraLayer): + warnings.warn( + "Detected LoRA on TDNNLayer. LoRA weights won't be applied due to optimization. " + "You should exclude TDNNLayer from LoRA's target modules.", + ) + + # for backward compatibility, we keep nn.Linear but call F.conv1d for speed up + hidden_states = hidden_states.transpose(1, 2) + weight = self.kernel.weight.view(self.out_conv_dim, self.kernel_size, self.in_conv_dim).transpose(1, 2) + hidden_states = nn.functional.conv1d(hidden_states, weight, self.kernel.bias, dilation=self.dilation) + hidden_states = hidden_states.transpose(1, 2) + + hidden_states = self.activation(hidden_states) + return hidden_states + + +# Speaker Verification docstring +_XVECTOR_CHECKPOINT = "microsoft/unispeech-sat-base-plus-sv" +_XVECTOR_EXPECTED_OUTPUT = 0.97 + + +@add_start_docstrings( + """ + UniSpeechSat Model with an XVector feature extraction head on top for tasks like Speaker Verification. + """, + UNISPEECH_SAT_START_DOCSTRING, +) +class UniSpeechSatForXVector(UniSpeechSatPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.unispeech_sat = UniSpeechSatModel(config) + num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings + if config.use_weighted_layer_sum: + self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers) + self.projector = nn.Linear(config.hidden_size, config.tdnn_dim[0]) + + tdnn_layers = [TDNNLayer(config, i) for i in range(len(config.tdnn_dim))] + self.tdnn = nn.ModuleList(tdnn_layers) + + self.feature_extractor = nn.Linear(config.tdnn_dim[-1] * 2, config.xvector_output_dim) + self.classifier = nn.Linear(config.xvector_output_dim, config.xvector_output_dim) + + self.objective = AMSoftmaxLoss(config.xvector_output_dim, config.num_labels) + + self.init_weights() + + def freeze_feature_extractor(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameter will + not be updated during training. + """ + warnings.warn( + "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. " + "Please use the equivalent `freeze_feature_encoder` method instead.", + FutureWarning, + ) + self.freeze_feature_encoder() + + def freeze_feature_encoder(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameter will + not be updated during training. + """ + self.unispeech_sat.feature_extractor._freeze_parameters() + + def freeze_base_model(self): + """ + Calling this function will disable the gradient computation for the base model so that its parameters will not + be updated during training. Only the classification head will be updated. + """ + for param in self.unispeech_sat.parameters(): + param.requires_grad = False + + def _get_tdnn_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): + """ + Computes the output length of the TDNN layers + """ + + def _conv_out_length(input_length, kernel_size, stride): + # 1D convolutional layer output length formula taken + # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html + return (input_length - kernel_size) // stride + 1 + + for kernel_size in self.config.tdnn_kernel: + input_lengths = _conv_out_length(input_lengths, kernel_size, 1) + + return input_lengths + + @add_start_docstrings_to_model_forward(UNISPEECH_SAT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_XVECTOR_CHECKPOINT, + output_type=XVectorOutput, + config_class=_CONFIG_FOR_DOC, + modality="audio", + expected_output=_XVECTOR_EXPECTED_OUTPUT, + ) + def forward( + self, + input_values: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: Optional[torch.Tensor] = None, + ) -> Union[Tuple, XVectorOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states + + outputs = self.unispeech_sat( + input_values, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if self.config.use_weighted_layer_sum: + hidden_states = outputs[_HIDDEN_STATES_START_POSITION] + hidden_states = torch.stack(hidden_states, dim=1) + norm_weights = nn.functional.softmax(self.layer_weights, dim=-1) + hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) + else: + hidden_states = outputs[0] + + hidden_states = self.projector(hidden_states) + + for tdnn_layer in self.tdnn: + hidden_states = tdnn_layer(hidden_states) + + # Statistic Pooling + if attention_mask is None: + mean_features = hidden_states.mean(dim=1) + std_features = hidden_states.std(dim=1) + else: + feat_extract_output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(dim=1)) + tdnn_output_lengths = self._get_tdnn_output_lengths(feat_extract_output_lengths) + mean_features = [] + std_features = [] + for i, length in enumerate(tdnn_output_lengths): + mean_features.append(hidden_states[i, :length].mean(dim=0)) + std_features.append(hidden_states[i, :length].std(dim=0)) + mean_features = torch.stack(mean_features) + std_features = torch.stack(std_features) + statistic_pooling = torch.cat([mean_features, std_features], dim=-1) + + output_embeddings = self.feature_extractor(statistic_pooling) + logits = self.classifier(output_embeddings) + + loss = None + if labels is not None: + loss = self.objective(logits, labels) + + if not return_dict: + output = (logits, output_embeddings) + outputs[_HIDDEN_STATES_START_POSITION:] + return ((loss,) + output) if loss is not None else output + + return XVectorOutput( + loss=loss, + logits=logits, + embeddings=output_embeddings, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "UniSpeechSatForAudioFrameClassification", + "UniSpeechSatForCTC", + "UniSpeechSatForPreTraining", + "UniSpeechSatForSequenceClassification", + "UniSpeechSatForXVector", + "UniSpeechSatModel", + "UniSpeechSatPreTrainedModel", +] diff --git a/docs/transformers/src/transformers/models/unispeech_sat/modular_unispeech_sat.py b/docs/transformers/src/transformers/models/unispeech_sat/modular_unispeech_sat.py new file mode 100644 index 0000000000000000000000000000000000000000..44e566068ef5e4d8da7438971c36d8482a9fc4bb --- /dev/null +++ b/docs/transformers/src/transformers/models/unispeech_sat/modular_unispeech_sat.py @@ -0,0 +1,610 @@ +import math +import warnings +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ...modeling_outputs import ( + CausalLMOutput, + ModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, + Wav2Vec2BaseModelOutput, + XVectorOutput, +) +from ...modeling_utils import PreTrainedModel +from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from ..wav2vec2.modeling_wav2vec2 import ( + Wav2Vec2Encoder, + Wav2Vec2EncoderStableLayerNorm, + Wav2Vec2FeatureEncoder, + Wav2Vec2FeatureProjection, + Wav2Vec2ForAudioFrameClassification, + Wav2Vec2ForCTC, + Wav2Vec2ForSequenceClassification, + Wav2Vec2ForXVector, + Wav2Vec2GumbelVectorQuantizer, + Wav2Vec2Model, + Wav2Vec2PositionalConvEmbedding, +) +from .configuration_unispeech_sat import UniSpeechSatConfig + + +logger = logging.get_logger(__name__) + + +_HIDDEN_STATES_START_POSITION = 2 + +# General docstring +_CONFIG_FOR_DOC = "UniSpeechSatConfig" + +# Base docstring +_CHECKPOINT_FOR_DOC = "microsoft/unispeech-sat-base-100h-libri-ft" +_EXPECTED_OUTPUT_SHAPE = [1, 292, 768] + +# CTC docstring +_CTC_EXPECTED_OUTPUT = "'MISTER QUILDER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'" +_CTC_EXPECTED_LOSS = 39.88 + +# Frame class docstring +_FRAME_CLASS_CHECKPOINT = "microsoft/unispeech-sat-base-plus-sd" +_FRAME_EXPECTED_OUTPUT = [0, 0] + +# Speaker Verification docstring +_XVECTOR_CHECKPOINT = "microsoft/unispeech-sat-base-plus-sv" +_XVECTOR_EXPECTED_OUTPUT = 0.97 + + +@dataclass +class UniSpeechSatForPreTrainingOutput(ModelOutput): + """ + Output type of [`UniSpeechSatForPreTrainingOutput`], with potential hidden states and attentions. + + Args: + loss (*optional*, returned when model is in train mode, `torch.FloatTensor` of shape `(1,)`): + Total loss as the sum of the contrastive loss (L_m) and the diversity loss (L_d) as stated in the [official + paper](https://arxiv.org/pdf/2006.11477.pdf) . (classification) loss. + projected_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`): + Hidden-states of the model projected to *config.proj_codevector_dim* that can be used to predict the masked + projected quantized states. + projected_quantized_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`): + Quantized extracted feature vectors projected to *config.proj_codevector_dim* representing the positive + target vectors for contrastive loss. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + projected_states: Optional[torch.FloatTensor] = None + projected_quantized_states: Optional[torch.FloatTensor] = None + codevector_perplexity: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +class UniSpeechSatPositionalConvEmbedding(Wav2Vec2PositionalConvEmbedding): + pass + + +class UniSpeechSatFeatureEncoder(Wav2Vec2FeatureEncoder): + pass + + +class UniSpeechSatFeatureProjection(Wav2Vec2FeatureProjection): + pass + + +class UniSpeechSatEncoder(Wav2Vec2Encoder): + pass + + +class UniSpeechSatEncoderStableLayerNorm(Wav2Vec2EncoderStableLayerNorm): + pass + + +class UniSpeechSatGumbelVectorQuantizer(Wav2Vec2GumbelVectorQuantizer): + def __init__(self, config): + super().__init__() + self.weight_proj = nn.Linear(config.hidden_size, self.num_groups * self.num_vars) + + @staticmethod + def _compute_perplexity(probs, mask=None): + marginal_probs = probs.mean(dim=0) + perplexity = torch.exp(-torch.sum(marginal_probs * torch.log(marginal_probs + 1e-7), dim=-1)).sum() + return perplexity + + def forward(self, hidden_states): + batch_size, sequence_length, hidden_size = hidden_states.shape + + # project to codevector dim + hidden_states = self.weight_proj(hidden_states) + hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1) + + if self.training: + # sample code vector probs via gumbel in differentiateable way + codevector_probs = nn.functional.gumbel_softmax( + hidden_states.float(), tau=self.temperature, hard=True + ).type_as(hidden_states) + + # compute perplexity + codevector_soft_dist = torch.softmax( + hidden_states.view(batch_size * sequence_length, self.num_groups, -1).float(), dim=-1 + ) + perplexity = self._compute_perplexity(codevector_soft_dist) + else: + # take argmax in non-differentiable way + # comptute hard codevector distribution (one hot) + codevector_idx = hidden_states.argmax(dim=-1) + codevector_probs = hidden_states.new_zeros(*hidden_states.shape).scatter_( + -1, codevector_idx.view(-1, 1), 1.0 + ) + codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1) + + perplexity = self._compute_perplexity(codevector_probs) + + codevector_probs = codevector_probs.view(batch_size * sequence_length, -1) + # use probs to retrieve codevectors + codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors + codevectors = codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1) + codevectors = codevectors.sum(-2).view(batch_size, sequence_length, -1) + + return codevectors, perplexity + + +class UniSpeechSatPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = UniSpeechSatConfig + base_model_prefix = "unispeech_sat" + main_input_name = "input_values" + supports_gradient_checkpointing = True + _supports_flash_attn_2 = True + _supports_sdpa = True + + def _init_weights(self, module): + """Initialize the weights""" + # gumbel softmax requires special init + if isinstance(module, UniSpeechSatGumbelVectorQuantizer): + module.weight_proj.weight.data.normal_(mean=0.0, std=1) + module.weight_proj.bias.data.zero_() + nn.init.uniform_(module.codevectors) + elif isinstance(module, UniSpeechSatPositionalConvEmbedding): + nn.init.normal_( + module.conv.weight, + mean=0, + std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)), + ) + nn.init.constant_(module.conv.bias, 0) + elif isinstance(module, UniSpeechSatFeatureProjection): + k = math.sqrt(1 / module.projection.in_features) + nn.init.uniform_(module.projection.weight, a=-k, b=k) + nn.init.uniform_(module.projection.bias, a=-k, b=k) + elif isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, nn.Conv1d): + nn.init.kaiming_normal_(module.weight) + + if module.bias is not None: + k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) + nn.init.uniform_(module.bias, a=-k, b=k) + + def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): + """ + Computes the output length of the convolutional layers + """ + + def _conv_out_length(input_length, kernel_size, stride): + # 1D convolutional layer output length formula taken + # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html + return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1 + + for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): + input_lengths = _conv_out_length(input_lengths, kernel_size, stride) + + return input_lengths + + def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor): + # Effectively attention_mask.sum(-1), but not inplace to be able to run + # on inference mode. + non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1] + output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths).to(torch.long) + batch_size = attention_mask.shape[0] + + attention_mask = torch.zeros( + (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device + ) + # these two operations makes sure that all values before the output lengths idxs are attended to + attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1 + attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() + return attention_mask + + +UNISPEECH_SAT_START_DOCSTRING = r""" + UniSpeechSat was proposed in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech + Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael + Auli. + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving etc.). + + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use + it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`UniSpeechSatConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +UNISPEECH_SAT_INPUTS_DOCSTRING = r""" + Args: + input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file + into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install + soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and + conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2Processor.__call__`] for details. + attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0, + 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + + + `attention_mask` should only be passed if the corresponding processor has `config.return_attention_mask == + True`. For all models whose processor has `config.return_attention_mask == False`, `attention_mask` should + **not** be passed to avoid degraded performance when doing batched inference. For such models + `input_values` should simply be padded with 0 and passed without `attention_mask`. Be aware that these + models also yield slightly different results depending on whether `input_values` is padded or not. + + + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +UniSpeechSatBaseModelOutput = Wav2Vec2BaseModelOutput + + +@add_start_docstrings( + "The bare UniSpeechSat Model transformer outputting raw hidden-states without any specific head on top.", + UNISPEECH_SAT_START_DOCSTRING, +) +class UniSpeechSatModel(UniSpeechSatPreTrainedModel, Wav2Vec2Model): + def __init__(self, config: UniSpeechSatConfig): + UniSpeechSatPreTrainedModel.__init__(config) + self.config = config + self.feature_extractor = UniSpeechSatFeatureEncoder(config) + self.feature_projection = UniSpeechSatFeatureProjection(config) + + self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_()) + + if config.do_stable_layer_norm: + self.encoder = UniSpeechSatEncoderStableLayerNorm(config) + else: + self.encoder = UniSpeechSatEncoder(config) + + # Initialize weights and apply final processing + self.post_init() + + def freeze_feature_extractor(self): + raise AttributeError("Not needed for UniSpeechSat") + + def freeze_feature_encoder(self): + raise AttributeError("Not needed for UniSpeechSat") + + @add_start_docstrings_to_model_forward(UNISPEECH_SAT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=UniSpeechSatBaseModelOutput, + config_class=_CONFIG_FOR_DOC, + modality="audio", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward( + self, + input_values: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + mask_time_indices: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, UniSpeechSatBaseModelOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + extract_features = self.feature_extractor(input_values) + extract_features = extract_features.transpose(1, 2) + + if attention_mask is not None: + # compute reduced attention_mask corresponding to feature vectors + attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask) + + hidden_states, extract_features = self.feature_projection(extract_features) + hidden_states = self._mask_hidden_states( + hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask + ) + + encoder_outputs = self.encoder( + hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = encoder_outputs[0] + + if not return_dict: + return (hidden_states, extract_features) + encoder_outputs[1:] + + return UniSpeechSatBaseModelOutput( + last_hidden_state=hidden_states, + extract_features=extract_features, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + """UniSpeechSat Model with a vector-quantization module and ctc loss for pre-training.""", + UNISPEECH_SAT_START_DOCSTRING, +) +class UniSpeechSatForPreTraining(UniSpeechSatPreTrainedModel): + def __init__(self, config: UniSpeechSatConfig): + super().__init__(config) + self.unispeech_sat = UniSpeechSatModel(config) + self.dropout_features = nn.Dropout(config.feat_quantizer_dropout) + + self.quantizer = UniSpeechSatGumbelVectorQuantizer(config) + self.project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim) + self.project_hid = nn.Linear(config.hidden_size, config.proj_codevector_dim) + + self.dropout = nn.Dropout(config.final_dropout) + + self.speaker_proj = nn.Linear(config.hidden_size, config.codevector_dim) + self.label_embeddings_concat = nn.Parameter(torch.FloatTensor(config.num_clusters, config.codevector_dim)) + self.label_embeddings_concat.data.zero_() + + self.layer_norm_for_extract = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + if self.config.do_stable_layer_norm: + self.layer_norm_for_extract.requires_grad = False + + # Initialize weights and apply final processing + self.post_init() + + def set_gumbel_temperature(self, temperature: int): + """ + Set the Gumbel softmax temperature to a given value. Only necessary for training + """ + self.quantizer.temperature = temperature + + def freeze_feature_extractor(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameters will + not be updated during training. + """ + warnings.warn( + "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. " + "Please use the equivalent `freeze_feature_encoder` method instead.", + FutureWarning, + ) + self.freeze_feature_encoder() + + def freeze_feature_encoder(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameter will + not be updated during training. + """ + self.unispeech_sat.feature_extractor._freeze_parameters() + + @staticmethod + def compute_contrastive_logits( + target_features: torch.FloatTensor, + negative_features: torch.FloatTensor, + predicted_features: torch.FloatTensor, + temperature: int = 1, + ): + """ + Compute logits for contrastive loss based using cosine similarity as the distance measure between + `[positive_feature, negative_features]` and `[predicted_features]`. Additionally, temperature can be applied. + """ + target_features = torch.cat([target_features, negative_features], dim=0) + + logits = torch.cosine_similarity(predicted_features.float(), target_features.float(), dim=-1) + logits = logits.type_as(target_features) + + # apply temperature + logits = logits / temperature + return logits + + @add_start_docstrings_to_model_forward(UNISPEECH_SAT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=UniSpeechSatForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_values: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, UniSpeechSatForPreTrainingOutput]: + r""" + Returns: + + Example: + + ```python + >>> import torch + >>> from transformers import AutoFeatureExtractor, UniSpeechSatForPreTraining + >>> from transformers.models.unispeech_sat.modeling_unispeech_sat import _compute_mask_indices + + >>> feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/unispeech-sat-base") + >>> model = UniSpeechSatForPreTraining.from_pretrained("microsoft/unispeech-sat-base") + >>> # TODO: Add full pretraining example + ```""" + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.unispeech_sat( + input_values, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + transformer_features = outputs[0] + + # quantize all (unmasked) extracted features and project to final vq dim + extract_features = self.dropout_features(outputs[1]) + + # TODO(PVP) - add pretraining logic and add to tests + logits = extract_features + loss = quantized_features = codevector_perplexity = None + + if not return_dict: + if loss is not None: + return (loss, logits, transformer_features, quantized_features, codevector_perplexity) + outputs[2:] + return (logits, transformer_features, quantized_features, codevector_perplexity) + outputs[2:] + + return UniSpeechSatForPreTrainingOutput( + loss=loss, + logits=logits, + projected_states=transformer_features, + projected_quantized_states=quantized_features, + codevector_perplexity=codevector_perplexity, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """UniSpeechSat Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""", + UNISPEECH_SAT_START_DOCSTRING, + """ + target_lang (`str`, *optional*): + Language id of adapter weights. Adapter weights are stored in the format adapter..safetensors or + adapter..bin. Only relevant when using an instance of [`UniSpeechSatForCTC`] with adapters. Uses + 'eng' by default. + """, +) +class UniSpeechSatForCTC(Wav2Vec2ForCTC): + @add_start_docstrings_to_model_forward(UNISPEECH_SAT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=CausalLMOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_CTC_EXPECTED_OUTPUT, + expected_loss=_CTC_EXPECTED_LOSS, + ) + def forward(self, **super_kwargs): + return super().forward(**super_kwargs) + + +@add_start_docstrings( + """ + UniSpeechSat Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like + SUPERB Keyword Spotting. + """, + UNISPEECH_SAT_START_DOCSTRING, +) +class UniSpeechSatForSequenceClassification(Wav2Vec2ForSequenceClassification): + @add_start_docstrings_to_model_forward(UNISPEECH_SAT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=SequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + modality="audio", + ) + def forward(self, **super_kwargs): + super().forward(**super_kwargs) + + +@add_start_docstrings( + """ + UniSpeechSat Model with a frame classification head on top for tasks like Speaker Diarization. + """, + UNISPEECH_SAT_START_DOCSTRING, +) +class UniSpeechSatForAudioFrameClassification(Wav2Vec2ForAudioFrameClassification): + @add_start_docstrings_to_model_forward(UNISPEECH_SAT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_FRAME_CLASS_CHECKPOINT, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + modality="audio", + expected_output=_FRAME_EXPECTED_OUTPUT, + ) + def forward(self, **super_kwargs): + super().forward(**super_kwargs) + + +@add_start_docstrings( + """ + UniSpeechSat Model with an XVector feature extraction head on top for tasks like Speaker Verification. + """, + UNISPEECH_SAT_START_DOCSTRING, +) +class UniSpeechSatForXVector(Wav2Vec2ForXVector): + pass + + @add_start_docstrings_to_model_forward(UNISPEECH_SAT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_XVECTOR_CHECKPOINT, + output_type=XVectorOutput, + config_class=_CONFIG_FOR_DOC, + modality="audio", + expected_output=_XVECTOR_EXPECTED_OUTPUT, + ) + def forward(self, **super_kwargs): + super().forward(**super_kwargs) + + +__all__ = [ + "UniSpeechSatForAudioFrameClassification", + "UniSpeechSatForCTC", + "UniSpeechSatForPreTraining", + "UniSpeechSatForSequenceClassification", + "UniSpeechSatForXVector", + "UniSpeechSatModel", + "UniSpeechSatPreTrainedModel", +] diff --git a/docs/transformers/src/transformers/models/univnet/__init__.py b/docs/transformers/src/transformers/models/univnet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5d02d0dbe574ebac34022e5e8727c900e16f5977 --- /dev/null +++ b/docs/transformers/src/transformers/models/univnet/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_univnet import * + from .feature_extraction_univnet import * + from .modeling_univnet import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/univnet/configuration_univnet.py b/docs/transformers/src/transformers/models/univnet/configuration_univnet.py new file mode 100644 index 0000000000000000000000000000000000000000..0a3811ee3a261894bab37c059f0b8ac5a3948c34 --- /dev/null +++ b/docs/transformers/src/transformers/models/univnet/configuration_univnet.py @@ -0,0 +1,125 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""UnivNetModel model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class UnivNetConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`UnivNetModel`]. It is used to instantiate a + UnivNet vocoder model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the UnivNet + [dg845/univnet-dev](https://huggingface.co/dg845/univnet-dev) architecture, which corresponds to the 'c32' + architecture in [maum-ai/univnet](https://github.com/maum-ai/univnet/blob/master/config/default_c32.yaml). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + model_in_channels (`int`, *optional*, defaults to 64): + The number of input channels for the UnivNet residual network. This should correspond to + `noise_sequence.shape[1]` and the value used in the [`UnivNetFeatureExtractor`] class. + model_hidden_channels (`int`, *optional*, defaults to 32): + The number of hidden channels of each residual block in the UnivNet residual network. + num_mel_bins (`int`, *optional*, defaults to 100): + The number of frequency bins in the conditioning log-mel spectrogram. This should correspond to the value + used in the [`UnivNetFeatureExtractor`] class. + resblock_kernel_sizes (`Tuple[int]` or `List[int]`, *optional*, defaults to `[3, 3, 3]`): + A tuple of integers defining the kernel sizes of the 1D convolutional layers in the UnivNet residual + network. The length of `resblock_kernel_sizes` defines the number of resnet blocks and should match that of + `resblock_stride_sizes` and `resblock_dilation_sizes`. + resblock_stride_sizes (`Tuple[int]` or `List[int]`, *optional*, defaults to `[8, 8, 4]`): + A tuple of integers defining the stride sizes of the 1D convolutional layers in the UnivNet residual + network. The length of `resblock_stride_sizes` should match that of `resblock_kernel_sizes` and + `resblock_dilation_sizes`. + resblock_dilation_sizes (`Tuple[Tuple[int]]` or `List[List[int]]`, *optional*, defaults to `[[1, 3, 9, 27], [1, 3, 9, 27], [1, 3, 9, 27]]`): + A nested tuple of integers defining the dilation rates of the dilated 1D convolutional layers in the + UnivNet residual network. The length of `resblock_dilation_sizes` should match that of + `resblock_kernel_sizes` and `resblock_stride_sizes`. The length of each nested list in + `resblock_dilation_sizes` defines the number of convolutional layers per resnet block. + kernel_predictor_num_blocks (`int`, *optional*, defaults to 3): + The number of residual blocks in the kernel predictor network, which calculates the kernel and bias for + each location variable convolution layer in the UnivNet residual network. + kernel_predictor_hidden_channels (`int`, *optional*, defaults to 64): + The number of hidden channels for each residual block in the kernel predictor network. + kernel_predictor_conv_size (`int`, *optional*, defaults to 3): + The kernel size of each 1D convolutional layer in the kernel predictor network. + kernel_predictor_dropout (`float`, *optional*, defaults to 0.0): + The dropout probability for each residual block in the kernel predictor network. + initializer_range (`float`, *optional*, defaults to 0.01): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + leaky_relu_slope (`float`, *optional*, defaults to 0.2): + The angle of the negative slope used by the leaky ReLU activation. + + Example: + + ```python + >>> from transformers import UnivNetModel, UnivNetConfig + + >>> # Initializing a Tortoise TTS style configuration + >>> configuration = UnivNetConfig() + + >>> # Initializing a model (with random weights) from the Tortoise TTS style configuration + >>> model = UnivNetModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "univnet" + + def __init__( + self, + model_in_channels=64, + model_hidden_channels=32, + num_mel_bins=100, + resblock_kernel_sizes=[3, 3, 3], + resblock_stride_sizes=[8, 8, 4], + resblock_dilation_sizes=[[1, 3, 9, 27], [1, 3, 9, 27], [1, 3, 9, 27]], + kernel_predictor_num_blocks=3, + kernel_predictor_hidden_channels=64, + kernel_predictor_conv_size=3, + kernel_predictor_dropout=0.0, + initializer_range=0.01, + leaky_relu_slope=0.2, + **kwargs, + ): + if not (len(resblock_kernel_sizes) == len(resblock_stride_sizes) == len(resblock_dilation_sizes)): + raise ValueError( + "`resblock_kernel_sizes`, `resblock_stride_sizes`, and `resblock_dilation_sizes` must all have the" + " same length (which will be the number of resnet blocks in the model)." + ) + + self.model_in_channels = model_in_channels + self.model_hidden_channels = model_hidden_channels + self.num_mel_bins = num_mel_bins + self.resblock_kernel_sizes = resblock_kernel_sizes + self.resblock_stride_sizes = resblock_stride_sizes + self.resblock_dilation_sizes = resblock_dilation_sizes + self.kernel_predictor_num_blocks = kernel_predictor_num_blocks + self.kernel_predictor_hidden_channels = kernel_predictor_hidden_channels + self.kernel_predictor_conv_size = kernel_predictor_conv_size + self.kernel_predictor_dropout = kernel_predictor_dropout + self.initializer_range = initializer_range + self.leaky_relu_slope = leaky_relu_slope + super().__init__(**kwargs) + + +__all__ = ["UnivNetConfig"] diff --git a/docs/transformers/src/transformers/models/univnet/convert_univnet.py b/docs/transformers/src/transformers/models/univnet/convert_univnet.py new file mode 100644 index 0000000000000000000000000000000000000000..f790efab22f80db2545c8bc41a9df22f9efb22f1 --- /dev/null +++ b/docs/transformers/src/transformers/models/univnet/convert_univnet.py @@ -0,0 +1,162 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +import torch + +from transformers import UnivNetConfig, UnivNetModel, logging + + +logging.set_verbosity_info() +logger = logging.get_logger("transformers.models.univnet") + + +def get_kernel_predictor_key_mapping(config: UnivNetConfig, old_prefix: str = "", new_prefix: str = ""): + mapping = {} + # Initial conv layer + mapping[f"{old_prefix}.input_conv.0.weight_g"] = f"{new_prefix}.input_conv.weight_g" + mapping[f"{old_prefix}.input_conv.0.weight_v"] = f"{new_prefix}.input_conv.weight_v" + mapping[f"{old_prefix}.input_conv.0.bias"] = f"{new_prefix}.input_conv.bias" + + # Kernel predictor resnet blocks + for i in range(config.kernel_predictor_num_blocks): + mapping[f"{old_prefix}.residual_convs.{i}.1.weight_g"] = f"{new_prefix}.resblocks.{i}.conv1.weight_g" + mapping[f"{old_prefix}.residual_convs.{i}.1.weight_v"] = f"{new_prefix}.resblocks.{i}.conv1.weight_v" + mapping[f"{old_prefix}.residual_convs.{i}.1.bias"] = f"{new_prefix}.resblocks.{i}.conv1.bias" + + mapping[f"{old_prefix}.residual_convs.{i}.3.weight_g"] = f"{new_prefix}.resblocks.{i}.conv2.weight_g" + mapping[f"{old_prefix}.residual_convs.{i}.3.weight_v"] = f"{new_prefix}.resblocks.{i}.conv2.weight_v" + mapping[f"{old_prefix}.residual_convs.{i}.3.bias"] = f"{new_prefix}.resblocks.{i}.conv2.bias" + + # Kernel output conv + mapping[f"{old_prefix}.kernel_conv.weight_g"] = f"{new_prefix}.kernel_conv.weight_g" + mapping[f"{old_prefix}.kernel_conv.weight_v"] = f"{new_prefix}.kernel_conv.weight_v" + mapping[f"{old_prefix}.kernel_conv.bias"] = f"{new_prefix}.kernel_conv.bias" + + # Bias output conv + mapping[f"{old_prefix}.bias_conv.weight_g"] = f"{new_prefix}.bias_conv.weight_g" + mapping[f"{old_prefix}.bias_conv.weight_v"] = f"{new_prefix}.bias_conv.weight_v" + mapping[f"{old_prefix}.bias_conv.bias"] = f"{new_prefix}.bias_conv.bias" + + return mapping + + +def get_key_mapping(config: UnivNetConfig): + mapping = {} + + # NOTE: inital conv layer keys are the same + + # LVC Residual blocks + for i in range(len(config.resblock_stride_sizes)): + # LVCBlock initial convt layer + mapping[f"res_stack.{i}.convt_pre.1.weight_g"] = f"resblocks.{i}.convt_pre.weight_g" + mapping[f"res_stack.{i}.convt_pre.1.weight_v"] = f"resblocks.{i}.convt_pre.weight_v" + mapping[f"res_stack.{i}.convt_pre.1.bias"] = f"resblocks.{i}.convt_pre.bias" + + # Kernel predictor + kernel_predictor_mapping = get_kernel_predictor_key_mapping( + config, old_prefix=f"res_stack.{i}.kernel_predictor", new_prefix=f"resblocks.{i}.kernel_predictor" + ) + mapping.update(kernel_predictor_mapping) + + # LVC Residual blocks + for j in range(len(config.resblock_dilation_sizes[i])): + mapping[f"res_stack.{i}.conv_blocks.{j}.1.weight_g"] = f"resblocks.{i}.resblocks.{j}.conv.weight_g" + mapping[f"res_stack.{i}.conv_blocks.{j}.1.weight_v"] = f"resblocks.{i}.resblocks.{j}.conv.weight_v" + mapping[f"res_stack.{i}.conv_blocks.{j}.1.bias"] = f"resblocks.{i}.resblocks.{j}.conv.bias" + + # Output conv layer + mapping["conv_post.1.weight_g"] = "conv_post.weight_g" + mapping["conv_post.1.weight_v"] = "conv_post.weight_v" + mapping["conv_post.1.bias"] = "conv_post.bias" + + return mapping + + +def rename_state_dict(state_dict, keys_to_modify, keys_to_remove): + model_state_dict = {} + for key, value in state_dict.items(): + if key in keys_to_remove: + continue + + if key in keys_to_modify: + new_key = keys_to_modify[key] + model_state_dict[new_key] = value + else: + model_state_dict[key] = value + return model_state_dict + + +def convert_univnet_checkpoint( + checkpoint_path, + pytorch_dump_folder_path, + config_path=None, + repo_id=None, + safe_serialization=False, +): + model_state_dict_base = torch.load(checkpoint_path, map_location="cpu", weights_only=True) + # Get the generator's state dict + state_dict = model_state_dict_base["model_g"] + + if config_path is not None: + config = UnivNetConfig.from_pretrained(config_path) + else: + config = UnivNetConfig() + + keys_to_modify = get_key_mapping(config) + keys_to_remove = set() + hf_state_dict = rename_state_dict(state_dict, keys_to_modify, keys_to_remove) + + model = UnivNetModel(config) + # Apply weight norm since the original checkpoint has weight norm applied + model.apply_weight_norm() + model.load_state_dict(hf_state_dict) + # Remove weight norm in preparation for inference + model.remove_weight_norm() + + model.save_pretrained(pytorch_dump_folder_path, safe_serialization=safe_serialization) + + if repo_id: + print("Pushing to the hub...") + model.push_to_hub(repo_id) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--checkpoint_path", required=True, default=None, type=str, help="Path to original checkpoint") + parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert") + parser.add_argument( + "--pytorch_dump_folder_path", required=True, default=None, type=str, help="Path to the output PyTorch model." + ) + parser.add_argument( + "--push_to_hub", default=None, type=str, help="Where to upload the converted model on the 🤗 hub." + ) + parser.add_argument( + "--safe_serialization", action="store_true", help="Whether to save the model using `safetensors`." + ) + + args = parser.parse_args() + + convert_univnet_checkpoint( + args.checkpoint_path, + args.pytorch_dump_folder_path, + args.config_path, + args.push_to_hub, + args.safe_serialization, + ) + + +if __name__ == "__main__": + main() diff --git a/docs/transformers/src/transformers/models/univnet/feature_extraction_univnet.py b/docs/transformers/src/transformers/models/univnet/feature_extraction_univnet.py new file mode 100644 index 0000000000000000000000000000000000000000..5f43532d9df8089332460aa817b7458a0f1e3403 --- /dev/null +++ b/docs/transformers/src/transformers/models/univnet/feature_extraction_univnet.py @@ -0,0 +1,459 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Feature extractor class for UnivNetModel.""" + +from typing import Any, Dict, List, Optional, Union + +import numpy as np + +from ...audio_utils import mel_filter_bank, optimal_fft_length, spectrogram, window_function +from ...feature_extraction_sequence_utils import SequenceFeatureExtractor +from ...feature_extraction_utils import BatchFeature +from ...utils import PaddingStrategy, TensorType, logging + + +logger = logging.get_logger(__name__) + + +class UnivNetFeatureExtractor(SequenceFeatureExtractor): + r""" + Constructs a UnivNet feature extractor. + + This class extracts log-mel-filter bank features from raw speech using the short time Fourier Transform (STFT). The + STFT implementation follows that of TacoTron 2 and Hifi-GAN. + + This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains + most of the main methods. Users should refer to this superclass for more information regarding those methods. + + Args: + feature_size (`int`, *optional*, defaults to 1): + The feature dimension of the extracted features. + sampling_rate (`int`, *optional*, defaults to 24000): + The sampling rate at which the audio files should be digitalized expressed in hertz (Hz). + padding_value (`float`, *optional*, defaults to 0.0): + The value to pad with when applying the padding strategy defined by the `padding` argument to + [`UnivNetFeatureExtractor.__call__`]. Should correspond to audio silence. The `pad_end` argument to + `__call__` will also use this padding value. + do_normalize (`bool`, *optional*, defaults to `False`): + Whether to perform Tacotron 2 normalization on the input. Normalizing can help to significantly improve the + performance for some models. + num_mel_bins (`int`, *optional*, defaults to 100): + The number of mel-frequency bins in the extracted spectrogram features. This should match + `UnivNetModel.config.num_mel_bins`. + hop_length (`int`, *optional*, defaults to 256): + The direct number of samples between sliding windows. Otherwise referred to as "shift" in many papers. Note + that this is different from other audio feature extractors such as [`SpeechT5FeatureExtractor`] which take + the `hop_length` in ms. + win_length (`int`, *optional*, defaults to 1024): + The direct number of samples for each sliding window. Note that this is different from other audio feature + extractors such as [`SpeechT5FeatureExtractor`] which take the `win_length` in ms. + win_function (`str`, *optional*, defaults to `"hann_window"`): + Name for the window function used for windowing, must be accessible via `torch.{win_function}` + filter_length (`int`, *optional*, defaults to 1024): + The number of FFT components to use. If `None`, this is determined using + `transformers.audio_utils.optimal_fft_length`. + max_length_s (`int`, *optional*, defaults to 10): + The maximum input lenght of the model in seconds. This is used to pad the audio. + fmin (`float`, *optional*, defaults to 0.0): + Minimum mel frequency in Hz. + fmax (`float`, *optional*): + Maximum mel frequency in Hz. If not set, defaults to `sampling_rate / 2`. + mel_floor (`float`, *optional*, defaults to 1e-09): + Minimum value of mel frequency banks. Note that the way [`UnivNetFeatureExtractor`] uses `mel_floor` is + different than in [`transformers.audio_utils.spectrogram`]. + center (`bool`, *optional*, defaults to `False`): + Whether to pad the waveform so that frame `t` is centered around time `t * hop_length`. If `False`, frame + `t` will start at time `t * hop_length`. + compression_factor (`float`, *optional*, defaults to 1.0): + The multiplicative compression factor for dynamic range compression during spectral normalization. + compression_clip_val (`float`, *optional*, defaults to 1e-05): + The clip value applied to the waveform before applying dynamic range compression during spectral + normalization. + normalize_min (`float`, *optional*, defaults to -11.512925148010254): + The min value used for Tacotron 2-style linear normalization. The default is the original value from the + Tacotron 2 implementation. + normalize_max (`float`, *optional*, defaults to 2.3143386840820312): + The max value used for Tacotron 2-style linear normalization. The default is the original value from the + Tacotron 2 implementation. + model_in_channels (`int`, *optional*, defaults to 64): + The number of input channels to the [`UnivNetModel`] model. This should match + `UnivNetModel.config.model_in_channels`. + pad_end_length (`int`, *optional*, defaults to 10): + If padding the end of each waveform, the number of spectrogram frames worth of samples to append. The + number of appended samples will be `pad_end_length * hop_length`. + return_attention_mask (`bool`, *optional*, defaults to `True`): + Whether or not [`~UnivNetFeatureExtractor.__call__`] should return `attention_mask`. + """ + + model_input_names = ["input_features", "noise_sequence", "padding_mask"] + + def __init__( + self, + feature_size: int = 1, + sampling_rate: int = 24000, + padding_value: float = 0.0, + do_normalize: bool = False, + num_mel_bins: int = 100, + hop_length: int = 256, + win_length: int = 1024, + win_function: str = "hann_window", + filter_length: Optional[int] = 1024, + max_length_s: int = 10, + fmin: float = 0.0, + fmax: Optional[float] = None, + mel_floor: float = 1e-9, + center: bool = False, + compression_factor: float = 1.0, + compression_clip_val: float = 1e-5, + normalize_min: float = -11.512925148010254, + normalize_max: float = 2.3143386840820312, + model_in_channels: int = 64, + pad_end_length: int = 10, + return_attention_mask=True, + **kwargs, + ): + super().__init__( + feature_size=feature_size, + sampling_rate=sampling_rate, + padding_value=padding_value, + return_attention_mask=return_attention_mask, + **kwargs, + ) + + self.do_normalize = do_normalize + + self.num_mel_bins = num_mel_bins + self.hop_length = hop_length + self.win_length = win_length + self.win_function = win_function + self.filter_length = filter_length + self.fmin = fmin + if fmax is None: + # Follows the librosa.filters.mel implementation + fmax = float(sampling_rate) / 2 + self.fmax = fmax + self.mel_floor = mel_floor + + self.max_length_s = max_length_s + self.num_max_samples = max_length_s * sampling_rate + + if self.filter_length is None: + self.n_fft = optimal_fft_length(self.win_length) + else: + self.n_fft = self.filter_length + self.n_freqs = (self.n_fft // 2) + 1 + + self.window = window_function(window_length=self.win_length, name=self.win_function, periodic=True) + + self.mel_filters = mel_filter_bank( + num_frequency_bins=self.n_freqs, + num_mel_filters=self.num_mel_bins, + min_frequency=self.fmin, + max_frequency=self.fmax, + sampling_rate=self.sampling_rate, + norm="slaney", + mel_scale="slaney", + ) + + self.center = center + self.compression_factor = compression_factor + self.compression_clip_val = compression_clip_val + self.normalize_min = normalize_min + self.normalize_max = normalize_max + self.model_in_channels = model_in_channels + self.pad_end_length = pad_end_length + + def normalize(self, spectrogram): + return 2 * ((spectrogram - self.normalize_min) / (self.normalize_max - self.normalize_min)) - 1 + + def denormalize(self, spectrogram): + return self.normalize_min + (self.normalize_max - self.normalize_min) * ((spectrogram + 1) / 2) + + def mel_spectrogram(self, waveform: np.ndarray) -> np.ndarray: + """ + Calculates log MEL spectrograms from a batch of waveforms. Note that the input waveform(s) will be padded by + `int(self.n_fft - self.hop_length) / 2` on both sides using the `reflect` padding mode. + + Args: + waveform (`np.ndarray` of shape `(length,)`): + The input waveform. This must be a single real-valued, mono waveform. + + Returns: + `numpy.ndarray`: Array containing a log-mel spectrogram of shape `(num_frames, num_mel_bins)`. + """ + # Do custom padding based on the official MelGAN and Hifi-GAN implementations + # See https://github.com/maum-ai/univnet/blob/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/utils/stft.py#L84-L86 + waveform = np.pad( + waveform, + (int((self.n_fft - self.hop_length) / 2), int((self.n_fft - self.hop_length) / 2)), + mode="reflect", + ) + + # Get the complex spectrogram. + # Note: waveform must be unbatched currently due to the implementation of spectrogram(...). + complex_spectrogram = spectrogram( + waveform, + window=self.window, + frame_length=self.n_fft, + hop_length=self.hop_length, + fft_length=self.n_fft, + power=None, + center=self.center, + mel_filters=None, + mel_floor=None, + ) + + # Apply the MEL filter bank and MEL floor manually since UnivNet uses a slightly different implementation + amplitude_spectrogram = np.sqrt( + np.real(complex_spectrogram) ** 2 + np.imag(complex_spectrogram) ** 2 + self.mel_floor + ) + mel_spectrogram = np.matmul(self.mel_filters.T, amplitude_spectrogram) + + # Perform spectral normalization to get the log mel spectrogram. + log_mel_spectrogram = np.log( + np.clip(mel_spectrogram, a_min=self.compression_clip_val, a_max=None) * self.compression_factor + ) + + # Return spectrogram with num_mel_bins last + return log_mel_spectrogram.T + + def generate_noise( + self, + noise_length: int, + generator: Optional[np.random.Generator] = None, + ) -> np.ndarray: + """ + Generates a random noise sequence of standard Gaussian noise for use in the `noise_sequence` argument of + [`UnivNetModel.forward`]. + + Args: + spectrogram_length (`int`): + The length (dim 0) of the generated noise. + model_in_channels (`int`, *optional*, defaults to `None`): + The number of features (dim 1) of the generated noise. This should correspond to the + `model_in_channels` of the [`UnivNetGan`] model. If not set, this will default to + `self.config.model_in_channels`. + generator (`numpy.random.Generator`, *optional*, defaults to `None`) + An optional `numpy.random.Generator` random number generator to control noise generation. If not set, a + new generator with fresh entropy will be created. + + Returns: + `numpy.ndarray`: Array containing random standard Gaussian noise of shape `(noise_length, + model_in_channels)`. + """ + if generator is None: + generator = np.random.default_rng() + + noise_shape = (noise_length, self.model_in_channels) + noise = generator.standard_normal(noise_shape, dtype=np.float32) + + return noise + + def batch_decode(self, waveforms, waveform_lengths=None) -> List[np.ndarray]: + r""" + Removes padding from generated audio after running [`UnivNetModel.forward`]. This returns a ragged list of 1D + audio waveform arrays and not a single tensor/array because in general the waveforms will have different + lengths after removing padding. + + Args: + waveforms (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + The batched output waveforms from the [`UnivNetModel`]. + waveform_lengths (`torch.FloatTensor` of shape `(batch_size,)`, *optional*): + The batched lengths of each waveform before padding. + + Returns: + `List[np.ndarray]`: A ragged list of 1D waveform arrays with padding removed. + """ + # Collapse the batched waveform tensor to a list of 1D audio waveforms + waveforms = [waveform.detach().to(device="cpu", copy=True).numpy() for waveform in waveforms] + + if waveform_lengths is not None: + waveforms = [waveform[: waveform_lengths[i]] for i, waveform in enumerate(waveforms)] + + return waveforms + + def __call__( + self, + raw_speech: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]], + sampling_rate: Optional[int] = None, + padding: Union[bool, str, PaddingStrategy] = True, + max_length: Optional[int] = None, + truncation: bool = True, + pad_to_multiple_of: Optional[int] = None, + return_noise: bool = True, + generator: Optional[np.random.Generator] = None, + pad_end: bool = False, + pad_length: Optional[int] = None, + do_normalize: Optional[str] = None, + return_attention_mask: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + ) -> BatchFeature: + """ + Main method to featurize and prepare for the model one or several sequence(s). + + Args: + raw_speech (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`): + The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float + values, a list of numpy arrays or a list of list of float values. Must be mono channel audio, not + stereo, i.e. single float per timestep. + sampling_rate (`int`, *optional*): + The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass + `sampling_rate` at the forward call to prevent silent errors and allow automatic speech recognition + pipeline. + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`): + Select a strategy to pad the input `raw_speech` waveforms (according to the model's padding side and + padding index) among: + + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + + If `pad_end = True`, that padding will occur before the `padding` strategy is applied. + max_length (`int`, *optional*): + Maximum length of the returned list and optionally padding length (see above). + truncation (`bool`, *optional*, defaults to `True`): + Activates truncation to cut input sequences longer than `max_length` to `max_length`. + pad_to_multiple_of (`int`, *optional*): + If set will pad the sequence to a multiple of the provided value. + + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability + `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. + return_noise (`bool`, *optional*, defaults to `True`): + Whether to generate and return a noise waveform for use in [`UnivNetModel.forward`]. + generator (`numpy.random.Generator`, *optional*, defaults to `None`): + An optional `numpy.random.Generator` random number generator to use when generating noise. + pad_end (`bool`, *optional*, defaults to `False`): + Whether to pad the end of each waveform with silence. This can help reduce artifacts at the end of the + generated audio sample; see https://github.com/seungwonpark/melgan/issues/8 for more details. This + padding will be done before the padding strategy specified in `padding` is performed. + pad_length (`int`, *optional*, defaults to `None`): + If padding the end of each waveform, the length of the padding in spectrogram frames. If not set, this + will default to `self.config.pad_end_length`. + do_normalize (`bool`, *optional*): + Whether to perform Tacotron 2 normalization on the input. Normalizing can help to significantly improve + the performance for some models. If not set, this will default to `self.config.do_normalize`. + return_attention_mask (`bool`, *optional*): + Whether to return the attention mask. If left to the default, will return the attention mask according + to the specific feature_extractor's default. + + [What are attention masks?](../glossary#attention-mask) + + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors instead of list of python integers. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.np.array` objects. + - `'np'`: Return Numpy `np.ndarray` objects. + """ + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + + if sampling_rate is not None: + if sampling_rate != self.sampling_rate: + raise ValueError( + f"The model corresponding to this feature extractor: {self.__class__.__name__} was trained using a" + f" sampling rate of {self.sampling_rate}. Please make sure that the provided `raw_speech` input" + f" was sampled with {self.sampling_rate} and not {sampling_rate}." + ) + else: + logger.warning( + f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. " + "Failing to do so can result in silent errors that might be hard to debug." + ) + + is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1 + if is_batched_numpy and len(raw_speech.shape) > 2: + raise ValueError(f"Only mono-channel audio is supported for input to {self}") + is_batched = is_batched_numpy or ( + isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list))) + ) + + if is_batched: + raw_speech = [np.asarray(speech, dtype=np.float32) for speech in raw_speech] + elif not is_batched and not isinstance(raw_speech, np.ndarray): + raw_speech = np.asarray(raw_speech, dtype=np.float32) + elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(np.float64): + raw_speech = raw_speech.astype(np.float32) + + # always return batch + if not is_batched: + raw_speech = [np.asarray(raw_speech, dtype=np.float32)] + + # Pad end to reduce artifacts + if pad_end: + pad_length = pad_length if pad_length is not None else self.pad_end_length + raw_speech = [ + np.pad(waveform, (0, pad_length * self.hop_length), constant_values=self.padding_value) + for waveform in raw_speech + ] + + batched_speech = BatchFeature({"input_features": raw_speech}) + + padded_inputs = self.pad( + batched_speech, + padding=padding, + max_length=max_length if max_length is not None else self.num_max_samples, + truncation=truncation, + pad_to_multiple_of=pad_to_multiple_of, + return_attention_mask=return_attention_mask, + ) + + # make sure list is in array format + # input_features = padded_inputs.get("input_features").transpose(2, 0, 1) + input_features = padded_inputs.get("input_features") + + mel_spectrograms = [self.mel_spectrogram(waveform) for waveform in input_features] + + if isinstance(input_features[0], List): + batched_speech["input_features"] = [np.asarray(mel, dtype=np.float32) for mel in mel_spectrograms] + else: + batched_speech["input_features"] = [mel.astype(np.float32) for mel in mel_spectrograms] + + # convert attention_mask to correct format + attention_mask = padded_inputs.get("attention_mask") + if attention_mask is not None: + batched_speech["padding_mask"] = [np.asarray(array, dtype=np.int32) for array in attention_mask] + + if return_noise: + noise = [ + self.generate_noise(spectrogram.shape[0], generator) + for spectrogram in batched_speech["input_features"] + ] + batched_speech["noise_sequence"] = noise + + if do_normalize: + batched_speech["input_features"] = [ + self.normalize(spectrogram) for spectrogram in batched_speech["input_features"] + ] + + if return_tensors is not None: + batched_speech = batched_speech.convert_to_tensors(return_tensors) + + return batched_speech + + def to_dict(self) -> Dict[str, Any]: + output = super().to_dict() + + # Don't serialize these as they are derived from the other properties. + names = ["window", "mel_filters", "n_fft", "n_freqs", "num_max_samples"] + for name in names: + if name in output: + del output[name] + + return output + + +__all__ = ["UnivNetFeatureExtractor"] diff --git a/docs/transformers/src/transformers/models/univnet/modeling_univnet.py b/docs/transformers/src/transformers/models/univnet/modeling_univnet.py new file mode 100644 index 0000000000000000000000000000000000000000..d105711b5dc3eb8254b0a438af48e0ec4a070ff9 --- /dev/null +++ b/docs/transformers/src/transformers/models/univnet/modeling_univnet.py @@ -0,0 +1,655 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch UnivNetModel model.""" + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn + +from ...modeling_outputs import ModelOutput +from ...modeling_utils import PreTrainedModel +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from .configuration_univnet import UnivNetConfig + + +logger = logging.get_logger(__name__) + +# General docstring +_CONFIG_FOR_DOC = "UnivNetConfig" + +_CHECKPOINT_FOR_DOC = "dg845/univnet-dev" + + +@dataclass +class UnivNetModelOutput(ModelOutput): + """ + Output class for the [`UnivNetModel`], which includes the generated audio waveforms and the original unpadded + lengths of those waveforms (so that the padding can be removed by [`UnivNetModel.batch_decode`]). + + Args: + waveforms (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Batched 1D (mono-channel) output audio waveforms. + waveform_lengths (`torch.FloatTensor` of shape `(batch_size,)`): + The batched length in samples of each unpadded waveform in `waveforms`. + """ + + waveforms: Optional[torch.FloatTensor] = None + waveform_lengths: Optional[torch.FloatTensor] = None + + +class UnivNetKernelPredictorResidualBlock(nn.Module): + """ + Implementation of the residual block for the kernel predictor network inside each location variable convolution + block (LVCBlock). + + Parameters: + config: (`UnivNetConfig`): + Config for the `UnivNetModel` model. + """ + + def __init__( + self, + config: UnivNetConfig, + ): + super().__init__() + self.channels = config.model_in_channels + self.kernel_size = config.kernel_predictor_conv_size + self.dropout_prob = config.kernel_predictor_dropout + self.leaky_relu_slope = config.leaky_relu_slope + + padding = (self.kernel_size - 1) // 2 + + self.dropout = nn.Dropout(self.dropout_prob) + self.conv1 = nn.Conv1d(self.channels, self.channels, self.kernel_size, padding=padding, bias=True) + self.conv2 = nn.Conv1d(self.channels, self.channels, self.kernel_size, padding=padding, bias=True) + + def forward(self, hidden_states: torch.FloatTensor): + # hidden_states should have shape (batch_size, channels, seq_length) + residual = hidden_states + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv1(hidden_states) + hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope) + hidden_states = self.conv2(hidden_states) + hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope) + return hidden_states + residual + + def apply_weight_norm(self): + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + + weight_norm(self.conv1) + weight_norm(self.conv2) + + def remove_weight_norm(self): + nn.utils.remove_weight_norm(self.conv1) + nn.utils.remove_weight_norm(self.conv2) + + +class UnivNetKernelPredictor(nn.Module): + """ + Implementation of the kernel predictor network which supplies the kernel and bias for the location variable + convolutional layers (LVCs) in each UnivNet LVCBlock. + + Based on the KernelPredictor implementation in + [maum-ai/univnet](https://github.com/maum-ai/univnet/blob/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/model/lvcnet.py#L7). + + Parameters: + config: (`UnivNetConfig`): + Config for the `UnivNetModel` model. + conv_kernel_size (`int`, *optional*, defaults to 3): + The kernel size for the location variable convolutional layer kernels (convolutional weight tensor). + conv_layers (`int`, *optional*, defaults to 4): + The number of location variable convolutional layers to output kernels and biases for. + """ + + def __init__( + self, + config: UnivNetConfig, + conv_kernel_size: int = 3, + conv_layers: int = 4, + ): + super().__init__() + + self.conv_in_channels = config.model_hidden_channels + self.conv_out_channels = 2 * config.model_hidden_channels + self.conv_kernel_size = conv_kernel_size + self.conv_layers = conv_layers + + self.kernel_channels = ( + self.conv_in_channels * self.conv_out_channels * self.conv_kernel_size * self.conv_layers + ) + self.bias_channels = self.conv_out_channels * self.conv_layers + + self.resnet_in_channels = config.num_mel_bins + self.resnet_hidden_channels = config.kernel_predictor_hidden_channels + self.resnet_kernel_size = config.kernel_predictor_conv_size + self.num_blocks = config.kernel_predictor_num_blocks + + self.leaky_relu_slope = config.leaky_relu_slope + + padding = (self.resnet_kernel_size - 1) // 2 + + self.input_conv = nn.Conv1d(self.resnet_in_channels, self.resnet_hidden_channels, 5, padding=2, bias=True) + + self.resblocks = nn.ModuleList([UnivNetKernelPredictorResidualBlock(config) for _ in range(self.num_blocks)]) + + self.kernel_conv = nn.Conv1d( + self.resnet_hidden_channels, self.kernel_channels, self.resnet_kernel_size, padding=padding, bias=True + ) + self.bias_conv = nn.Conv1d( + self.resnet_hidden_channels, self.bias_channels, self.resnet_kernel_size, padding=padding, bias=True + ) + + def forward(self, spectrogram: torch.FloatTensor): + """ + Maps a conditioning log-mel spectrogram to a tensor of convolutional kernels and biases, for use in location + variable convolutional layers. Note that the input spectrogram should have shape (batch_size, input_channels, + seq_length). + + Args: + spectrogram (`torch.FloatTensor` of shape `(batch_size, input_channels, seq_length)`): + Tensor containing the log-mel spectrograms. + + Returns: + Tuple[`torch.FloatTensor, `torch.FloatTensor`]: tuple of tensors where the first element is the tensor of + location variable convolution kernels of shape `(batch_size, self.conv_layers, self.conv_in_channels, + self.conv_out_channels, self.conv_kernel_size, seq_length)` and the second element is the tensor of + location variable convolution biases of shape `(batch_size, self.conv_layers. self.conv_out_channels, + seq_length)`. + """ + batch_size, _, seq_length = spectrogram.shape + + hidden_states = self.input_conv(spectrogram) + hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope) + + for resblock in self.resblocks: + hidden_states = resblock(hidden_states) + + kernel_hidden_states = self.kernel_conv(hidden_states) + bias_hidden_states = self.bias_conv(hidden_states) + + # Reshape kernels and biases to appropriate shape + kernels = kernel_hidden_states.view( + batch_size, + self.conv_layers, + self.conv_in_channels, + self.conv_out_channels, + self.conv_kernel_size, + seq_length, + ).contiguous() + biases = bias_hidden_states.view( + batch_size, + self.conv_layers, + self.conv_out_channels, + seq_length, + ).contiguous() + + return kernels, biases + + def apply_weight_norm(self): + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + + weight_norm(self.input_conv) + for layer in self.resblocks: + layer.apply_weight_norm() + weight_norm(self.kernel_conv) + weight_norm(self.bias_conv) + + def remove_weight_norm(self): + nn.utils.remove_weight_norm(self.input_conv) + for layer in self.resblocks: + layer.remove_weight_norm() + nn.utils.remove_weight_norm(self.kernel_conv) + nn.utils.remove_weight_norm(self.bias_conv) + + +class UnivNetLvcResidualBlock(nn.Module): + """ + Implementation of the location variable convolution (LVC) residual block for the UnivNet residual network. + + Parameters: + config: (`UnivNetConfig`): + Config for the `UnivNetModel` model. + kernel_size (`int`): + The kernel size for the dilated 1D convolutional layer. + dilation (`int`): + The dilation for the dilated 1D convolutional layer. + """ + + def __init__( + self, + config: UnivNetConfig, + kernel_size: int, + dilation: int, + ): + super().__init__() + self.hidden_channels = config.model_hidden_channels + self.kernel_size = kernel_size + self.dilation = dilation + self.leaky_relu_slope = config.leaky_relu_slope + + padding = self.dilation * (self.kernel_size - 1) // 2 + + self.conv = nn.Conv1d( + self.hidden_channels, + self.hidden_channels, + self.kernel_size, + padding=padding, + dilation=self.dilation, + ) + + def forward(self, hidden_states, kernel, bias, hop_size=256): + residual = hidden_states + hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope) + hidden_states = self.conv(hidden_states) + hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope) + hidden_states = self.location_variable_convolution(hidden_states, kernel, bias, hop_size=hop_size) + # Gated activation unit + hidden_states = torch.sigmoid(hidden_states[:, : self.hidden_channels, :]) * torch.tanh( + hidden_states[:, self.hidden_channels :, :] + ) + # Skip connection + hidden_states = residual + hidden_states + + return hidden_states + + # Based on https://github.com/maum-ai/univnet/blob/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/model/lvcnet.py#L171 + def location_variable_convolution( + self, + hidden_states: torch.FloatTensor, + kernel: torch.FloatTensor, + bias: torch.FloatTensor, + dilation: int = 1, + hop_size: int = 256, + ): + """ + Performs location-variable convolution operation on the input sequence (hidden_states) using the local + convolution kernel. This was introduced in [LVCNet: Efficient Condition-Dependent Modeling Network for Waveform + Generation](https://arxiv.org/abs/2102.10815) by Zhen Zheng, Jianzong Wang, Ning Cheng, and Jing Xiao. + + Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100. + + Args: + hidden_states (`torch.FloatTensor` of shape `(batch_size, in_channels, in_length)`): + The input sequence of shape (batch, in_channels, in_length). + kernel (`torch.FloatTensor` of shape `(batch_size, in_channels, out_channels, kernel_size, kernel_length)`): + The local convolution kernel of shape (batch, in_channels, out_channels, kernel_size, kernel_length). + bias (`torch.FloatTensor` of shape `(batch_size, out_channels, kernel_length)`): + The bias for the local convolution of shape (batch, out_channels, kernel_length). + dilation (`int`, *optional*, defaults to 1): + The dilation of convolution. + hop_size (`int`, *optional*, defaults to 256): + The hop_size of the conditioning sequence. + Returns: + `torch.FloatTensor`: the output sequence after performing local convolution with shape (batch_size, + out_channels, in_length). + """ + batch, _, in_length = hidden_states.shape + batch, _, out_channels, kernel_size, kernel_length = kernel.shape + if in_length != (kernel_length * hop_size): + raise ValueError( + f"Dim 2 of `hidden_states` should be {kernel_length * hop_size}) but got {in_length}. Please check" + " `hidden_states` or `kernel` and `hop_size` to make sure they are correct." + ) + + padding = dilation * int((kernel_size - 1) / 2) + + # (batch, in_channels, in_length + 2*padding) + hidden_states = nn.functional.pad(hidden_states, (padding, padding), "constant", 0) + # (batch, in_channels, kernel_length, hop_size + 2*padding) + hidden_states = hidden_states.unfold(2, hop_size + 2 * padding, hop_size) + + if hop_size < dilation: + hidden_states = nn.functional.pad(hidden_states, (0, dilation), "constant", 0) + # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation) + hidden_states = hidden_states.unfold(3, dilation, dilation) + hidden_states = hidden_states[:, :, :, :, :hop_size] + # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation) + hidden_states = hidden_states.transpose(3, 4) + # (batch, in_channels, kernel_length, dilation, _, kernel_size) + hidden_states = hidden_states.unfold(4, kernel_size, 1) + + # Apply local convolution kernel to hidden_states. + output_hidden_states = torch.einsum("bildsk,biokl->bolsd", hidden_states, kernel) + + output_hidden_states = output_hidden_states.to(memory_format=torch.channels_last_3d) + bias = bias.unsqueeze(-1).unsqueeze(-1).to(memory_format=torch.channels_last_3d) + output_hidden_states = output_hidden_states + bias + output_hidden_states = output_hidden_states.contiguous().view(batch, out_channels, -1) + + return output_hidden_states + + def apply_weight_norm(self): + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + + weight_norm(self.conv) + + def remove_weight_norm(self): + nn.utils.remove_weight_norm(self.conv) + + +class UnivNetLvcBlock(nn.Module): + """ + Implementation of the location variable convolution (LVC) residual block of the UnivNet residual block. Includes a + `UnivNetKernelPredictor` inside to predict the kernels and biases of the LVC layers. + + Based on LVCBlock in + [maum-ai/univnet](https://github.com/maum-ai/univnet/blob/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/model/lvcnet.py#L98) + + Parameters: + config (`UnivNetConfig`): + Config for the `UnivNetModel` model. + layer_id (`int`): + An integer corresponding to the index of the current LVC resnet block layer. This should be between 0 and + `len(config.resblock_stride_sizes) - 1)` inclusive. + lvc_hop_size (`int`, *optional*, defaults to 256): + The hop size for the location variable convolutional layers. + """ + + def __init__( + self, + config: UnivNetConfig, + layer_id: int, + lvc_hop_size: int = 256, + ): + super().__init__() + self.hidden_channels = config.model_hidden_channels + self.kernel_size = config.resblock_kernel_sizes[layer_id] + self.stride = config.resblock_stride_sizes[layer_id] + self.dilations = config.resblock_dilation_sizes[layer_id] + self.cond_hop_length = lvc_hop_size + self.leaky_relu_slope = config.leaky_relu_slope + self.num_blocks = len(self.dilations) + + self.convt_pre = nn.ConvTranspose1d( + self.hidden_channels, + self.hidden_channels, + 2 * self.stride, + stride=self.stride, + padding=self.stride // 2 + self.stride % 2, + output_padding=self.stride % 2, + ) + + self.kernel_predictor = UnivNetKernelPredictor(config, self.kernel_size, self.num_blocks) + + self.resblocks = nn.ModuleList( + [UnivNetLvcResidualBlock(config, self.kernel_size, self.dilations[i]) for i in range(self.num_blocks)] + ) + + def forward(self, hidden_states: torch.FloatTensor, spectrogram: torch.FloatTensor): + # hidden_states: (batch_size, hidden_channels, seq_length) + # spectrogram: (batch_size, cond_channels, cond_length) + hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope) + hidden_states = self.convt_pre(hidden_states) + + kernels, biases = self.kernel_predictor(spectrogram) + + for i, resblock in enumerate(self.resblocks): + kernel = kernels[:, i, :, :, :, :] + bias = biases[:, i, :, :] + hidden_states = resblock(hidden_states, kernel, bias, hop_size=self.cond_hop_length) + + return hidden_states + + def apply_weight_norm(self): + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + + weight_norm(self.convt_pre) + self.kernel_predictor.apply_weight_norm() + for layer in self.resblocks: + layer.apply_weight_norm() + + def remove_weight_norm(self): + nn.utils.remove_weight_norm(self.convt_pre) + self.kernel_predictor.remove_weight_norm() + for layer in self.resblocks: + layer.remove_weight_norm() + + +UNIVNET_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`UnivNetConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +UNIVNET_INPUTS_DOCSTRING = r""" + Converts a noise waveform and a conditioning spectrogram to a speech waveform. Passing a batch of log-mel + spectrograms returns a batch of speech waveforms. Passing a single, un-batched log-mel spectrogram returns a + single, un-batched speech waveform. + + Args: + input_features (`torch.FloatTensor`): + Tensor containing the log-mel spectrograms. Can be batched and of shape `(batch_size, sequence_length, + config.num_mel_channels)`, or un-batched and of shape `(sequence_length, config.num_mel_channels)`. + noise_sequence (`torch.FloatTensor`, *optional*): + Tensor containing a noise sequence of standard Gaussian noise. Can be batched and of shape `(batch_size, + sequence_length, config.model_in_channels)`, or un-batched and of shape (sequence_length, + config.model_in_channels)`. If not supplied, will be randomly generated. + padding_mask (`torch.BoolTensor`, *optional*): + Mask indicating which parts of each sequence are padded. Mask values are selected in `[0, 1]`: + + - 1 for tokens that are **not masked** + - 0 for tokens that are **masked** + + The mask can be batched and of shape `(batch_size, sequence_length)` or un-batched and of shape + `(sequence_length,)`. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + return_dict: + Whether to return a [`~utils.ModelOutput`] subclass instead of a plain tuple. +""" + + +@add_start_docstrings( + """UnivNet GAN vocoder.""", + UNIVNET_START_DOCSTRING, +) +class UnivNetModel(PreTrainedModel): + config_class = UnivNetConfig + main_input_name = "input_features" + + def __init__(self, config: UnivNetConfig): + super().__init__(config) + + self.num_kernels = len(config.resblock_kernel_sizes) + self.leaky_relu_slope = config.leaky_relu_slope + + self.conv_pre = nn.Conv1d( + config.model_in_channels, + config.model_hidden_channels, + kernel_size=7, + stride=1, + padding=3, + padding_mode="reflect", + ) + + # Initialize location-variable convolution ResNet Blocks. + num_layers = len(config.resblock_stride_sizes) + hop_length = 1 + hop_lengths = [] + for stride in config.resblock_stride_sizes: + hop_length = hop_length * stride + hop_lengths.append(hop_length) + + self.resblocks = nn.ModuleList( + [ + UnivNetLvcBlock( + config, + layer_id=i, + lvc_hop_size=hop_lengths[i], + ) + for i in range(num_layers) + ] + ) + + self.conv_post = nn.Conv1d(config.model_hidden_channels, 1, 7, padding=3, padding_mode="reflect") + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(UNIVNET_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=UnivNetModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_features: torch.FloatTensor, + noise_sequence: Optional[torch.FloatTensor] = None, + padding_mask: Optional[torch.FloatTensor] = None, + generator: Optional[torch.Generator] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.FloatTensor], UnivNetModelOutput]: + r""" + Returns: + + Example: + + ```python + >>> from transformers import UnivNetFeatureExtractor, UnivNetModel + >>> from datasets import load_dataset, Audio + + >>> model = UnivNetModel.from_pretrained("dg845/univnet-dev") + >>> feature_extractor = UnivNetFeatureExtractor.from_pretrained("dg845/univnet-dev") + + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + >>> # Resample the audio to the feature extractor's sampling rate. + >>> ds = ds.cast_column("audio", Audio(sampling_rate=feature_extractor.sampling_rate)) + >>> inputs = feature_extractor( + ... ds[0]["audio"]["array"], sampling_rate=ds[0]["audio"]["sampling_rate"], return_tensors="pt" + ... ) + >>> audio = model(**inputs).waveforms + >>> list(audio.shape) + [1, 140288] + ``` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Resolve batch sizes for noise_sequence and spectrogram + spectrogram_batched = input_features.dim() == 3 + if not spectrogram_batched: + input_features = input_features.unsqueeze(0) + spectrogram_batch_size, spectrogram_length, _ = input_features.shape + + if noise_sequence is not None: + noise_sequence_batched = noise_sequence.dim() == 3 + if not noise_sequence_batched: + noise_sequence = noise_sequence.unsqueeze(0) + else: + # Randomly generate noise_sequence + noise_sequence_shape = (spectrogram_batch_size, spectrogram_length, self.config.model_in_channels) + noise_sequence = torch.randn( + noise_sequence_shape, generator=generator, dtype=input_features.dtype, device=input_features.device + ) + noise_sequence_batch_size = noise_sequence.shape[0] + + if spectrogram_batch_size > 1 and noise_sequence_batch_size == 1: + # Repeat noise_sequence spectrogram_batch_size times + noise_sequence = noise_sequence.repeat(spectrogram_batch_size, 1, 1) + elif noise_sequence_batch_size > 1 and spectrogram_batch_size == 1: + # Repeat spectrogram noise_sequence_batch_size times + input_features = input_features.repeat(noise_sequence_batch_size, 1, 1) + + if noise_sequence_batch_size != spectrogram_batch_size: + raise ValueError( + f"The batch size of `noise_sequence` is {noise_sequence_batch_size} and the batch size of" + f" `input_features` is {spectrogram_batch_size}, but the two are expected to be equal." + ) + + if padding_mask is not None: + if padding_mask.dim() == 1: + padding_mask = padding_mask.unsqueeze(0) + padding_mask_batch_size = padding_mask.shape[0] + if padding_mask_batch_size != spectrogram_batch_size: + raise ValueError( + f"The batch size of `padding_mask` is {padding_mask_batch_size} and the batch size of" + f" `input_features` is {spectrogram_batch_size}, but the two are expected to be equal." + ) + + # Change shapes to have channels before sequence lengths + hidden_states = noise_sequence.transpose(2, 1) + input_features = input_features.transpose(2, 1) + + hidden_states = self.conv_pre(hidden_states) + + for resblock in self.resblocks: + hidden_states = resblock(hidden_states, input_features) + + hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope) + hidden_states = self.conv_post(hidden_states) + hidden_states = torch.tanh(hidden_states) + + # Remove sequence length dimension since this collapses to 1 + # NOTE: keep waveforms batched even if there's only one + waveform = hidden_states.squeeze(1) + + # Get sequence lengths for UnivNetFeatureExtractor.batch_decode. + waveform_lengths = None + if padding_mask is not None: + # Padding is always contiguous and added on the right + waveform_lengths = torch.sum(padding_mask, dim=1) + + if not return_dict: + outputs = (waveform, waveform_lengths) + return outputs + + return UnivNetModelOutput( + waveforms=waveform, + waveform_lengths=waveform_lengths, + ) + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, (nn.Linear, nn.Conv1d, nn.ConvTranspose1d)): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + + def apply_weight_norm(self): + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + + weight_norm(self.conv_pre) + for layer in self.resblocks: + layer.apply_weight_norm() + weight_norm(self.conv_post) + + def remove_weight_norm(self): + nn.utils.remove_weight_norm(self.conv_pre) + for layer in self.resblocks: + layer.remove_weight_norm() + nn.utils.remove_weight_norm(self.conv_post) + + +__all__ = ["UnivNetModel"] diff --git a/docs/transformers/src/transformers/models/upernet/__init__.py b/docs/transformers/src/transformers/models/upernet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dee32c4758f0e4c54406473ec7146caa5514c413 --- /dev/null +++ b/docs/transformers/src/transformers/models/upernet/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_upernet import * + from .modeling_upernet import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/upernet/configuration_upernet.py b/docs/transformers/src/transformers/models/upernet/configuration_upernet.py new file mode 100644 index 0000000000000000000000000000000000000000..c235c17d9cb62ac16d4e27bffff65fc35da3921f --- /dev/null +++ b/docs/transformers/src/transformers/models/upernet/configuration_upernet.py @@ -0,0 +1,140 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""UperNet model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ...utils.backbone_utils import verify_backbone_config_arguments +from ..auto.configuration_auto import CONFIG_MAPPING + + +logger = logging.get_logger(__name__) + + +class UperNetConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of an [`UperNetForSemanticSegmentation`]. It is used to + instantiate an UperNet model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the UperNet + [openmmlab/upernet-convnext-tiny](https://huggingface.co/openmmlab/upernet-convnext-tiny) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + backbone_config (`PretrainedConfig` or `dict`, *optional*, defaults to `ResNetConfig()`): + The configuration of the backbone model. + backbone (`str`, *optional*): + Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this + will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone` + is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights. + use_pretrained_backbone (`bool`, *optional*, `False`): + Whether to use pretrained weights for the backbone. + use_timm_backbone (`bool`, *optional*, `False`): + Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers + library. + backbone_kwargs (`dict`, *optional*): + Keyword arguments to be passed to AutoBackbone when loading from a checkpoint + e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set. + hidden_size (`int`, *optional*, defaults to 512): + The number of hidden units in the convolutional layers. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + pool_scales (`Tuple[int]`, *optional*, defaults to `[1, 2, 3, 6]`): + Pooling scales used in Pooling Pyramid Module applied on the last feature map. + use_auxiliary_head (`bool`, *optional*, defaults to `True`): + Whether to use an auxiliary head during training. + auxiliary_loss_weight (`float`, *optional*, defaults to 0.4): + Weight of the cross-entropy loss of the auxiliary head. + auxiliary_channels (`int`, *optional*, defaults to 256): + Number of channels to use in the auxiliary head. + auxiliary_num_convs (`int`, *optional*, defaults to 1): + Number of convolutional layers to use in the auxiliary head. + auxiliary_concat_input (`bool`, *optional*, defaults to `False`): + Whether to concatenate the output of the auxiliary head with the input before the classification layer. + loss_ignore_index (`int`, *optional*, defaults to 255): + The index that is ignored by the loss function. + + Examples: + + ```python + >>> from transformers import UperNetConfig, UperNetForSemanticSegmentation + + >>> # Initializing a configuration + >>> configuration = UperNetConfig() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = UperNetForSemanticSegmentation(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "upernet" + + def __init__( + self, + backbone_config=None, + backbone=None, + use_pretrained_backbone=False, + use_timm_backbone=False, + backbone_kwargs=None, + hidden_size=512, + initializer_range=0.02, + pool_scales=[1, 2, 3, 6], + use_auxiliary_head=True, + auxiliary_loss_weight=0.4, + auxiliary_in_channels=384, + auxiliary_channels=256, + auxiliary_num_convs=1, + auxiliary_concat_input=False, + loss_ignore_index=255, + **kwargs, + ): + super().__init__(**kwargs) + if backbone_config is None and backbone is None: + logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.") + backbone_config = CONFIG_MAPPING["resnet"](out_features=["stage1", "stage2", "stage3", "stage4"]) + elif isinstance(backbone_config, dict): + backbone_model_type = backbone_config.get("model_type") + config_class = CONFIG_MAPPING[backbone_model_type] + backbone_config = config_class.from_dict(backbone_config) + + verify_backbone_config_arguments( + use_timm_backbone=use_timm_backbone, + use_pretrained_backbone=use_pretrained_backbone, + backbone=backbone, + backbone_config=backbone_config, + backbone_kwargs=backbone_kwargs, + ) + + self.backbone_config = backbone_config + self.backbone = backbone + self.use_pretrained_backbone = use_pretrained_backbone + self.use_timm_backbone = use_timm_backbone + self.backbone_kwargs = backbone_kwargs + self.hidden_size = hidden_size + self.initializer_range = initializer_range + self.pool_scales = pool_scales + self.use_auxiliary_head = use_auxiliary_head + self.auxiliary_loss_weight = auxiliary_loss_weight + self.auxiliary_in_channels = auxiliary_in_channels + self.auxiliary_channels = auxiliary_channels + self.auxiliary_num_convs = auxiliary_num_convs + self.auxiliary_concat_input = auxiliary_concat_input + self.loss_ignore_index = loss_ignore_index + + +__all__ = ["UperNetConfig"] diff --git a/docs/transformers/src/transformers/models/upernet/convert_convnext_upernet_to_pytorch.py b/docs/transformers/src/transformers/models/upernet/convert_convnext_upernet_to_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..eeb3ab5fc9938171099a47feef23c4694d8b5169 --- /dev/null +++ b/docs/transformers/src/transformers/models/upernet/convert_convnext_upernet_to_pytorch.py @@ -0,0 +1,214 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert ConvNext + UperNet checkpoints from mmsegmentation.""" + +import argparse +import json + +import requests +import torch +from huggingface_hub import hf_hub_download +from PIL import Image + +from transformers import ConvNextConfig, SegformerImageProcessor, UperNetConfig, UperNetForSemanticSegmentation + + +def get_upernet_config(model_name): + auxiliary_in_channels = 384 + if "tiny" in model_name: + depths = [3, 3, 9, 3] + hidden_sizes = [96, 192, 384, 768] + if "small" in model_name: + depths = [3, 3, 27, 3] + hidden_sizes = [96, 192, 384, 768] + if "base" in model_name: + depths = [3, 3, 27, 3] + hidden_sizes = [128, 256, 512, 1024] + auxiliary_in_channels = 512 + if "large" in model_name: + depths = [3, 3, 27, 3] + hidden_sizes = [192, 384, 768, 1536] + auxiliary_in_channels = 768 + if "xlarge" in model_name: + depths = [3, 3, 27, 3] + hidden_sizes = [256, 512, 1024, 2048] + auxiliary_in_channels = 1024 + + # set label information + num_labels = 150 + repo_id = "huggingface/label-files" + filename = "ade20k-id2label.json" + id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r")) + id2label = {int(k): v for k, v in id2label.items()} + label2id = {v: k for k, v in id2label.items()} + + backbone_config = ConvNextConfig( + depths=depths, hidden_sizes=hidden_sizes, out_features=["stage1", "stage2", "stage3", "stage4"] + ) + config = UperNetConfig( + backbone_config=backbone_config, + auxiliary_in_channels=auxiliary_in_channels, + num_labels=num_labels, + id2label=id2label, + label2id=label2id, + ) + + return config + + +# here we list all keys to be renamed (original name on the left, our name on the right) +def create_rename_keys(config): + rename_keys = [] + + # fmt: off + # stem + rename_keys.append(("backbone.downsample_layers.0.0.weight", "backbone.embeddings.patch_embeddings.weight")) + rename_keys.append(("backbone.downsample_layers.0.0.bias", "backbone.embeddings.patch_embeddings.bias")) + rename_keys.append(("backbone.downsample_layers.0.1.weight", "backbone.embeddings.layernorm.weight")) + rename_keys.append(("backbone.downsample_layers.0.1.bias", "backbone.embeddings.layernorm.bias")) + # stages + for i in range(len(config.backbone_config.depths)): + for j in range(config.backbone_config.depths[i]): + rename_keys.append((f"backbone.stages.{i}.{j}.gamma", f"backbone.encoder.stages.{i}.layers.{j}.layer_scale_parameter")) + rename_keys.append((f"backbone.stages.{i}.{j}.depthwise_conv.weight", f"backbone.encoder.stages.{i}.layers.{j}.dwconv.weight")) + rename_keys.append((f"backbone.stages.{i}.{j}.depthwise_conv.bias", f"backbone.encoder.stages.{i}.layers.{j}.dwconv.bias")) + rename_keys.append((f"backbone.stages.{i}.{j}.norm.weight", f"backbone.encoder.stages.{i}.layers.{j}.layernorm.weight")) + rename_keys.append((f"backbone.stages.{i}.{j}.norm.bias", f"backbone.encoder.stages.{i}.layers.{j}.layernorm.bias")) + rename_keys.append((f"backbone.stages.{i}.{j}.pointwise_conv1.weight", f"backbone.encoder.stages.{i}.layers.{j}.pwconv1.weight")) + rename_keys.append((f"backbone.stages.{i}.{j}.pointwise_conv1.bias", f"backbone.encoder.stages.{i}.layers.{j}.pwconv1.bias")) + rename_keys.append((f"backbone.stages.{i}.{j}.pointwise_conv2.weight", f"backbone.encoder.stages.{i}.layers.{j}.pwconv2.weight")) + rename_keys.append((f"backbone.stages.{i}.{j}.pointwise_conv2.bias", f"backbone.encoder.stages.{i}.layers.{j}.pwconv2.bias")) + if i > 0: + rename_keys.append((f"backbone.downsample_layers.{i}.0.weight", f"backbone.encoder.stages.{i}.downsampling_layer.0.weight")) + rename_keys.append((f"backbone.downsample_layers.{i}.0.bias", f"backbone.encoder.stages.{i}.downsampling_layer.0.bias")) + rename_keys.append((f"backbone.downsample_layers.{i}.1.weight", f"backbone.encoder.stages.{i}.downsampling_layer.1.weight")) + rename_keys.append((f"backbone.downsample_layers.{i}.1.bias", f"backbone.encoder.stages.{i}.downsampling_layer.1.bias")) + + rename_keys.append((f"backbone.norm{i}.weight", f"backbone.hidden_states_norms.stage{i+1}.weight")) + rename_keys.append((f"backbone.norm{i}.bias", f"backbone.hidden_states_norms.stage{i+1}.bias")) + + # decode head + rename_keys.extend( + [ + ("decode_head.conv_seg.weight", "decode_head.classifier.weight"), + ("decode_head.conv_seg.bias", "decode_head.classifier.bias"), + ("auxiliary_head.conv_seg.weight", "auxiliary_head.classifier.weight"), + ("auxiliary_head.conv_seg.bias", "auxiliary_head.classifier.bias"), + ] + ) + # fmt: on + + return rename_keys + + +def rename_key(dct, old, new): + val = dct.pop(old) + dct[new] = val + + +def convert_upernet_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub): + model_name_to_url = { + "upernet-convnext-tiny": "https://download.openmmlab.com/mmsegmentation/v0.5/convnext/upernet_convnext_tiny_fp16_512x512_160k_ade20k/upernet_convnext_tiny_fp16_512x512_160k_ade20k_20220227_124553-cad485de.pth", + "upernet-convnext-small": "https://download.openmmlab.com/mmsegmentation/v0.5/convnext/upernet_convnext_small_fp16_512x512_160k_ade20k/upernet_convnext_small_fp16_512x512_160k_ade20k_20220227_131208-1b1e394f.pth", + "upernet-convnext-base": "https://download.openmmlab.com/mmsegmentation/v0.5/convnext/upernet_convnext_base_fp16_512x512_160k_ade20k/upernet_convnext_base_fp16_512x512_160k_ade20k_20220227_181227-02a24fc6.pth", + "upernet-convnext-large": "https://download.openmmlab.com/mmsegmentation/v0.5/convnext/upernet_convnext_large_fp16_640x640_160k_ade20k/upernet_convnext_large_fp16_640x640_160k_ade20k_20220226_040532-e57aa54d.pth", + "upernet-convnext-xlarge": "https://download.openmmlab.com/mmsegmentation/v0.5/convnext/upernet_convnext_xlarge_fp16_640x640_160k_ade20k/upernet_convnext_xlarge_fp16_640x640_160k_ade20k_20220226_080344-95fc38c2.pth", + } + checkpoint_url = model_name_to_url[model_name] + state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu")["state_dict"] + + config = get_upernet_config(model_name) + model = UperNetForSemanticSegmentation(config) + model.eval() + + # replace "bn" => "batch_norm" + for key in state_dict.copy().keys(): + val = state_dict.pop(key) + if "bn" in key: + key = key.replace("bn", "batch_norm") + state_dict[key] = val + + # rename keys + rename_keys = create_rename_keys(config) + for src, dest in rename_keys: + rename_key(state_dict, src, dest) + + model.load_state_dict(state_dict) + + # verify on image + url = "https://huggingface.co/datasets/hf-internal-testing/fixtures_ade20k/resolve/main/ADE_val_00000001.jpg" + image = Image.open(requests.get(url, stream=True).raw).convert("RGB") + + processor = SegformerImageProcessor() + pixel_values = processor(image, return_tensors="pt").pixel_values + + with torch.no_grad(): + outputs = model(pixel_values) + + if model_name == "upernet-convnext-tiny": + expected_slice = torch.tensor( + [[-8.8110, -8.8110, -8.6521], [-8.8110, -8.8110, -8.6521], [-8.7746, -8.7746, -8.6130]] + ) + elif model_name == "upernet-convnext-small": + expected_slice = torch.tensor( + [[-8.8236, -8.8236, -8.6771], [-8.8236, -8.8236, -8.6771], [-8.7638, -8.7638, -8.6240]] + ) + elif model_name == "upernet-convnext-base": + expected_slice = torch.tensor( + [[-8.8558, -8.8558, -8.6905], [-8.8558, -8.8558, -8.6905], [-8.7669, -8.7669, -8.6021]] + ) + elif model_name == "upernet-convnext-large": + expected_slice = torch.tensor( + [[-8.6660, -8.6660, -8.6210], [-8.6660, -8.6660, -8.6210], [-8.6310, -8.6310, -8.5964]] + ) + elif model_name == "upernet-convnext-xlarge": + expected_slice = torch.tensor( + [[-8.4980, -8.4980, -8.3977], [-8.4980, -8.4980, -8.3977], [-8.4379, -8.4379, -8.3412]] + ) + print("Logits:", outputs.logits[0, 0, :3, :3]) + assert torch.allclose(outputs.logits[0, 0, :3, :3], expected_slice, atol=1e-4) + print("Looks ok!") + + if pytorch_dump_folder_path is not None: + print(f"Saving model {model_name} to {pytorch_dump_folder_path}") + model.save_pretrained(pytorch_dump_folder_path) + print(f"Saving processor to {pytorch_dump_folder_path}") + processor.save_pretrained(pytorch_dump_folder_path) + + if push_to_hub: + print(f"Pushing model and processor for {model_name} to hub") + model.push_to_hub(f"openmmlab/{model_name}") + processor.push_to_hub(f"openmmlab/{model_name}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--model_name", + default="upernet-convnext-tiny", + type=str, + choices=[f"upernet-convnext-{size}" for size in ["tiny", "small", "base", "large", "xlarge"]], + help="Name of the ConvNext UperNet model you'd like to convert.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory." + ) + parser.add_argument( + "--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub." + ) + + args = parser.parse_args() + convert_upernet_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub) diff --git a/docs/transformers/src/transformers/models/upernet/convert_swin_upernet_to_pytorch.py b/docs/transformers/src/transformers/models/upernet/convert_swin_upernet_to_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..9580af7c46a50c26c25fe5a9f2728188fbd0193e --- /dev/null +++ b/docs/transformers/src/transformers/models/upernet/convert_swin_upernet_to_pytorch.py @@ -0,0 +1,297 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert Swin Transformer + UperNet checkpoints from mmsegmentation. + +URL: https://github.com/open-mmlab/mmsegmentation/tree/master/configs/swin +""" + +import argparse +import json + +import requests +import torch +from huggingface_hub import hf_hub_download +from PIL import Image + +from transformers import SegformerImageProcessor, SwinConfig, UperNetConfig, UperNetForSemanticSegmentation + + +def get_upernet_config(model_name): + auxiliary_in_channels = 384 + window_size = 7 + if "tiny" in model_name: + embed_dim = 96 + depths = (2, 2, 6, 2) + num_heads = (3, 6, 12, 24) + elif "small" in model_name: + embed_dim = 96 + depths = (2, 2, 18, 2) + num_heads = (3, 6, 12, 24) + elif "base" in model_name: + embed_dim = 128 + depths = (2, 2, 18, 2) + num_heads = (4, 8, 16, 32) + window_size = 12 + auxiliary_in_channels = 512 + elif "large" in model_name: + embed_dim = 192 + depths = (2, 2, 18, 2) + num_heads = (6, 12, 24, 48) + window_size = 12 + auxiliary_in_channels = 768 + + # set label information + num_labels = 150 + repo_id = "huggingface/label-files" + filename = "ade20k-id2label.json" + id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r")) + id2label = {int(k): v for k, v in id2label.items()} + label2id = {v: k for k, v in id2label.items()} + + backbone_config = SwinConfig( + embed_dim=embed_dim, + depths=depths, + num_heads=num_heads, + window_size=window_size, + out_features=["stage1", "stage2", "stage3", "stage4"], + ) + config = UperNetConfig( + backbone_config=backbone_config, + auxiliary_in_channels=auxiliary_in_channels, + num_labels=num_labels, + id2label=id2label, + label2id=label2id, + ) + + return config + + +# here we list all keys to be renamed (original name on the left, our name on the right) +def create_rename_keys(config): + rename_keys = [] + + # fmt: off + # stem + rename_keys.append(("backbone.patch_embed.projection.weight", "backbone.embeddings.patch_embeddings.projection.weight")) + rename_keys.append(("backbone.patch_embed.projection.bias", "backbone.embeddings.patch_embeddings.projection.bias")) + rename_keys.append(("backbone.patch_embed.norm.weight", "backbone.embeddings.norm.weight")) + rename_keys.append(("backbone.patch_embed.norm.bias", "backbone.embeddings.norm.bias")) + # stages + for i in range(len(config.backbone_config.depths)): + for j in range(config.backbone_config.depths[i]): + rename_keys.append((f"backbone.stages.{i}.blocks.{j}.norm1.weight", f"backbone.encoder.layers.{i}.blocks.{j}.layernorm_before.weight")) + rename_keys.append((f"backbone.stages.{i}.blocks.{j}.norm1.bias", f"backbone.encoder.layers.{i}.blocks.{j}.layernorm_before.bias")) + rename_keys.append((f"backbone.stages.{i}.blocks.{j}.attn.w_msa.relative_position_bias_table", f"backbone.encoder.layers.{i}.blocks.{j}.attention.self.relative_position_bias_table")) + rename_keys.append((f"backbone.stages.{i}.blocks.{j}.attn.w_msa.relative_position_index", f"backbone.encoder.layers.{i}.blocks.{j}.attention.self.relative_position_index")) + rename_keys.append((f"backbone.stages.{i}.blocks.{j}.attn.w_msa.proj.weight", f"backbone.encoder.layers.{i}.blocks.{j}.attention.output.dense.weight")) + rename_keys.append((f"backbone.stages.{i}.blocks.{j}.attn.w_msa.proj.bias", f"backbone.encoder.layers.{i}.blocks.{j}.attention.output.dense.bias")) + rename_keys.append((f"backbone.stages.{i}.blocks.{j}.norm2.weight", f"backbone.encoder.layers.{i}.blocks.{j}.layernorm_after.weight")) + rename_keys.append((f"backbone.stages.{i}.blocks.{j}.norm2.bias", f"backbone.encoder.layers.{i}.blocks.{j}.layernorm_after.bias")) + rename_keys.append((f"backbone.stages.{i}.blocks.{j}.ffn.layers.0.0.weight", f"backbone.encoder.layers.{i}.blocks.{j}.intermediate.dense.weight")) + rename_keys.append((f"backbone.stages.{i}.blocks.{j}.ffn.layers.0.0.bias", f"backbone.encoder.layers.{i}.blocks.{j}.intermediate.dense.bias")) + rename_keys.append((f"backbone.stages.{i}.blocks.{j}.ffn.layers.1.weight", f"backbone.encoder.layers.{i}.blocks.{j}.output.dense.weight")) + rename_keys.append((f"backbone.stages.{i}.blocks.{j}.ffn.layers.1.bias", f"backbone.encoder.layers.{i}.blocks.{j}.output.dense.bias")) + + if i < 3: + rename_keys.append((f"backbone.stages.{i}.downsample.reduction.weight", f"backbone.encoder.layers.{i}.downsample.reduction.weight")) + rename_keys.append((f"backbone.stages.{i}.downsample.norm.weight", f"backbone.encoder.layers.{i}.downsample.norm.weight")) + rename_keys.append((f"backbone.stages.{i}.downsample.norm.bias", f"backbone.encoder.layers.{i}.downsample.norm.bias")) + rename_keys.append((f"backbone.norm{i}.weight", f"backbone.hidden_states_norms.stage{i+1}.weight")) + rename_keys.append((f"backbone.norm{i}.bias", f"backbone.hidden_states_norms.stage{i+1}.bias")) + + # decode head + rename_keys.extend( + [ + ("decode_head.conv_seg.weight", "decode_head.classifier.weight"), + ("decode_head.conv_seg.bias", "decode_head.classifier.bias"), + ("auxiliary_head.conv_seg.weight", "auxiliary_head.classifier.weight"), + ("auxiliary_head.conv_seg.bias", "auxiliary_head.classifier.bias"), + ] + ) + # fmt: on + + return rename_keys + + +def rename_key(dct, old, new): + val = dct.pop(old) + dct[new] = val + + +# we split up the matrix of each encoder layer into queries, keys and values +def read_in_q_k_v(state_dict, backbone_config): + num_features = [int(backbone_config.embed_dim * 2**i) for i in range(len(backbone_config.depths))] + for i in range(len(backbone_config.depths)): + dim = num_features[i] + for j in range(backbone_config.depths[i]): + # fmt: off + # read in weights + bias of input projection layer (in original implementation, this is a single matrix + bias) + in_proj_weight = state_dict.pop(f"backbone.stages.{i}.blocks.{j}.attn.w_msa.qkv.weight") + in_proj_bias = state_dict.pop(f"backbone.stages.{i}.blocks.{j}.attn.w_msa.qkv.bias") + # next, add query, keys and values (in that order) to the state dict + state_dict[f"backbone.encoder.layers.{i}.blocks.{j}.attention.self.query.weight"] = in_proj_weight[:dim, :] + state_dict[f"backbone.encoder.layers.{i}.blocks.{j}.attention.self.query.bias"] = in_proj_bias[: dim] + state_dict[f"backbone.encoder.layers.{i}.blocks.{j}.attention.self.key.weight"] = in_proj_weight[ + dim : dim * 2, : + ] + state_dict[f"backbone.encoder.layers.{i}.blocks.{j}.attention.self.key.bias"] = in_proj_bias[ + dim : dim * 2 + ] + state_dict[f"backbone.encoder.layers.{i}.blocks.{j}.attention.self.value.weight"] = in_proj_weight[ + -dim :, : + ] + state_dict[f"backbone.encoder.layers.{i}.blocks.{j}.attention.self.value.bias"] = in_proj_bias[-dim :] + # fmt: on + + +def correct_unfold_reduction_order(x): + out_channel, in_channel = x.shape + x = x.reshape(out_channel, 4, in_channel // 4) + x = x[:, [0, 2, 1, 3], :].transpose(1, 2).reshape(out_channel, in_channel) + return x + + +def reverse_correct_unfold_reduction_order(x): + out_channel, in_channel = x.shape + x = x.reshape(out_channel, in_channel // 4, 4) + x = x[:, :, [0, 2, 1, 3]].transpose(1, 2).reshape(out_channel, in_channel) + + return x + + +def correct_unfold_norm_order(x): + in_channel = x.shape[0] + x = x.reshape(4, in_channel // 4) + x = x[[0, 2, 1, 3], :].transpose(0, 1).reshape(in_channel) + return x + + +# there was an incompatibility with this version, due to a new implementation of their downsampling operation using nn.Unfold. +# was resolved as seen here: +# https://github.com/open-mmlab/mmdetection/blob/31c84958f54287a8be2b99cbf87a6dcf12e57753/mmdet/models/utils/ckpt_convert.py#L96. +def reverse_correct_unfold_norm_order(x): + in_channel = x.shape[0] + x = x.reshape(in_channel // 4, 4) + x = x[:, [0, 2, 1, 3]].transpose(0, 1).reshape(in_channel) + return x + + +def convert_upernet_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub): + model_name_to_url = { + "upernet-swin-tiny": "https://download.openmmlab.com/mmsegmentation/v0.5/swin/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K_20210531_112542-e380ad3e.pth", + "upernet-swin-small": "https://download.openmmlab.com/mmsegmentation/v0.5/swin/upernet_swin_small_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K/upernet_swin_small_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K_20210526_192015-ee2fff1c.pth", + "upernet-swin-base": "https://download.openmmlab.com/mmsegmentation/v0.5/swin/upernet_swin_base_patch4_window12_512x512_160k_ade20k_pretrain_384x384_22K/upernet_swin_base_patch4_window12_512x512_160k_ade20k_pretrain_384x384_22K_20210531_125459-429057bf.pth", + "upernet-swin-large": "https://download.openmmlab.com/mmsegmentation/v0.5/swin/upernet_swin_large_patch4_window12_512x512_pretrain_384x384_22K_160k_ade20k/upernet_swin_large_patch4_window12_512x512_pretrain_384x384_22K_160k_ade20k_20220318_091743-9ba68901.pth", + } + checkpoint_url = model_name_to_url[model_name] + state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu", file_name=model_name)[ + "state_dict" + ] + + for name, param in state_dict.items(): + print(name, param.shape) + + config = get_upernet_config(model_name) + model = UperNetForSemanticSegmentation(config) + model.eval() + + # replace "bn" => "batch_norm" + for key in state_dict.copy().keys(): + val = state_dict.pop(key) + if "bn" in key: + key = key.replace("bn", "batch_norm") + state_dict[key] = val + + # rename keys + rename_keys = create_rename_keys(config) + for src, dest in rename_keys: + rename_key(state_dict, src, dest) + read_in_q_k_v(state_dict, config.backbone_config) + + # fix downsample parameters + for key, value in state_dict.items(): + if "downsample" in key: + if "reduction" in key: + state_dict[key] = reverse_correct_unfold_reduction_order(value) + if "norm" in key: + state_dict[key] = reverse_correct_unfold_norm_order(value) + + model.load_state_dict(state_dict) + + # verify on image + url = "https://huggingface.co/datasets/hf-internal-testing/fixtures_ade20k/resolve/main/ADE_val_00000001.jpg" + image = Image.open(requests.get(url, stream=True).raw).convert("RGB") + + processor = SegformerImageProcessor() + pixel_values = processor(image, return_tensors="pt").pixel_values + + with torch.no_grad(): + outputs = model(pixel_values) + logits = outputs.logits + + print(logits.shape) + print("First values of logits:", logits[0, 0, :3, :3]) + # assert values + if model_name == "upernet-swin-tiny": + expected_slice = torch.tensor( + [[-7.5958, -7.5958, -7.4302], [-7.5958, -7.5958, -7.4302], [-7.4797, -7.4797, -7.3068]] + ) + elif model_name == "upernet-swin-small": + expected_slice = torch.tensor( + [[-7.1921, -7.1921, -6.9532], [-7.1921, -7.1921, -6.9532], [-7.0908, -7.0908, -6.8534]] + ) + elif model_name == "upernet-swin-base": + expected_slice = torch.tensor( + [[-6.5851, -6.5851, -6.4330], [-6.5851, -6.5851, -6.4330], [-6.4763, -6.4763, -6.3254]] + ) + elif model_name == "upernet-swin-large": + expected_slice = torch.tensor( + [[-7.5297, -7.5297, -7.3802], [-7.5297, -7.5297, -7.3802], [-7.4044, -7.4044, -7.2586]] + ) + print("Logits:", outputs.logits[0, 0, :3, :3]) + assert torch.allclose(outputs.logits[0, 0, :3, :3], expected_slice, atol=1e-4) + print("Looks ok!") + + if pytorch_dump_folder_path is not None: + print(f"Saving model {model_name} to {pytorch_dump_folder_path}") + model.save_pretrained(pytorch_dump_folder_path) + print(f"Saving processor to {pytorch_dump_folder_path}") + processor.save_pretrained(pytorch_dump_folder_path) + + if push_to_hub: + print(f"Pushing model and processor for {model_name} to hub") + model.push_to_hub(f"openmmlab/{model_name}") + processor.push_to_hub(f"openmmlab/{model_name}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--model_name", + default="upernet-swin-tiny", + type=str, + choices=[f"upernet-swin-{size}" for size in ["tiny", "small", "base", "large"]], + help="Name of the Swin + UperNet model you'd like to convert.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory." + ) + parser.add_argument( + "--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub." + ) + + args = parser.parse_args() + convert_upernet_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub) diff --git a/docs/transformers/src/transformers/models/upernet/modeling_upernet.py b/docs/transformers/src/transformers/models/upernet/modeling_upernet.py new file mode 100644 index 0000000000000000000000000000000000000000..0196169c37763bb261d12692b8abf3dba9f3913d --- /dev/null +++ b/docs/transformers/src/transformers/models/upernet/modeling_upernet.py @@ -0,0 +1,420 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch UperNet model. Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.""" + +from typing import List, Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss + +from ...modeling_outputs import SemanticSegmenterOutput +from ...modeling_utils import PreTrainedModel +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings +from ...utils.backbone_utils import load_backbone +from .configuration_upernet import UperNetConfig + + +# General docstring +_CONFIG_FOR_DOC = "UperNetConfig" + + +class UperNetConvModule(nn.Module): + """ + A convolutional block that bundles conv/norm/activation layers. This block simplifies the usage of convolution + layers, which are commonly used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU). + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]], + padding: Union[int, Tuple[int, int], str] = 0, + bias: bool = False, + dilation: Union[int, Tuple[int, int]] = 1, + ) -> None: + super().__init__() + self.conv = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + padding=padding, + bias=bias, + dilation=dilation, + ) + self.batch_norm = nn.BatchNorm2d(out_channels) + self.activation = nn.ReLU() + + def forward(self, input: torch.Tensor) -> torch.Tensor: + output = self.conv(input) + output = self.batch_norm(output) + output = self.activation(output) + + return output + + +class UperNetPyramidPoolingBlock(nn.Module): + def __init__(self, pool_scale: int, in_channels: int, channels: int) -> None: + super().__init__() + self.layers = [ + nn.AdaptiveAvgPool2d(pool_scale), + UperNetConvModule(in_channels, channels, kernel_size=1), + ] + for i, layer in enumerate(self.layers): + self.add_module(str(i), layer) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + hidden_state = input + for layer in self.layers: + hidden_state = layer(hidden_state) + return hidden_state + + +class UperNetPyramidPoolingModule(nn.Module): + """ + Pyramid Pooling Module (PPM) used in PSPNet. + + Args: + pool_scales (`Tuple[int]`): + Pooling scales used in Pooling Pyramid Module. + in_channels (`int`): + Input channels. + channels (`int`): + Channels after modules, before conv_seg. + align_corners (`bool`): + align_corners argument of F.interpolate. + """ + + def __init__(self, pool_scales: Tuple[int, ...], in_channels: int, channels: int, align_corners: bool) -> None: + super().__init__() + self.pool_scales = pool_scales + self.align_corners = align_corners + self.in_channels = in_channels + self.channels = channels + self.blocks = [] + for i, pool_scale in enumerate(pool_scales): + block = UperNetPyramidPoolingBlock(pool_scale=pool_scale, in_channels=in_channels, channels=channels) + self.blocks.append(block) + self.add_module(str(i), block) + + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: + ppm_outs = [] + for ppm in self.blocks: + ppm_out = ppm(x) + upsampled_ppm_out = nn.functional.interpolate( + ppm_out, size=x.size()[2:], mode="bilinear", align_corners=self.align_corners + ) + ppm_outs.append(upsampled_ppm_out) + return ppm_outs + + +class UperNetHead(nn.Module): + """ + Unified Perceptual Parsing for Scene Understanding. This head is the implementation of + [UPerNet](https://arxiv.org/abs/1807.10221). + """ + + def __init__(self, config, in_channels): + super().__init__() + + self.config = config + self.pool_scales = config.pool_scales # e.g. (1, 2, 3, 6) + self.in_channels = in_channels + self.channels = config.hidden_size + self.align_corners = False + self.classifier = nn.Conv2d(self.channels, config.num_labels, kernel_size=1) + + # PSP Module + self.psp_modules = UperNetPyramidPoolingModule( + self.pool_scales, + self.in_channels[-1], + self.channels, + align_corners=self.align_corners, + ) + self.bottleneck = UperNetConvModule( + self.in_channels[-1] + len(self.pool_scales) * self.channels, + self.channels, + kernel_size=3, + padding=1, + ) + # FPN Module + self.lateral_convs = nn.ModuleList() + self.fpn_convs = nn.ModuleList() + for in_channels in self.in_channels[:-1]: # skip the top layer + l_conv = UperNetConvModule(in_channels, self.channels, kernel_size=1) + fpn_conv = UperNetConvModule(self.channels, self.channels, kernel_size=3, padding=1) + self.lateral_convs.append(l_conv) + self.fpn_convs.append(fpn_conv) + + self.fpn_bottleneck = UperNetConvModule( + len(self.in_channels) * self.channels, + self.channels, + kernel_size=3, + padding=1, + ) + + def psp_forward(self, inputs): + x = inputs[-1] + psp_outs = [x] + psp_outs.extend(self.psp_modules(x)) + psp_outs = torch.cat(psp_outs, dim=1) + output = self.bottleneck(psp_outs) + + return output + + def forward(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor: + # build laterals + laterals = [lateral_conv(encoder_hidden_states[i]) for i, lateral_conv in enumerate(self.lateral_convs)] + + laterals.append(self.psp_forward(encoder_hidden_states)) + + # build top-down path + used_backbone_levels = len(laterals) + for i in range(used_backbone_levels - 1, 0, -1): + prev_shape = laterals[i - 1].shape[2:] + laterals[i - 1] = laterals[i - 1] + nn.functional.interpolate( + laterals[i], size=prev_shape, mode="bilinear", align_corners=self.align_corners + ) + + # build outputs + fpn_outs = [self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels - 1)] + # append psp feature + fpn_outs.append(laterals[-1]) + + for i in range(used_backbone_levels - 1, 0, -1): + fpn_outs[i] = nn.functional.interpolate( + fpn_outs[i], size=fpn_outs[0].shape[2:], mode="bilinear", align_corners=self.align_corners + ) + fpn_outs = torch.cat(fpn_outs, dim=1) + output = self.fpn_bottleneck(fpn_outs) + output = self.classifier(output) + + return output + + +class UperNetFCNHead(nn.Module): + """ + Fully Convolution Networks for Semantic Segmentation. This head is the implementation of + [FCNNet](https://arxiv.org/abs/1411.4038>). + + Args: + config: + Configuration. + in_channels (int): + Number of input channels. + kernel_size (int): + The kernel size for convs in the head. Default: 3. + dilation (int): + The dilation rate for convs in the head. Default: 1. + """ + + def __init__( + self, config, in_index: int = 2, kernel_size: int = 3, dilation: Union[int, Tuple[int, int]] = 1 + ) -> None: + super().__init__() + + self.config = config + self.in_channels = config.auxiliary_in_channels + self.channels = config.auxiliary_channels + self.num_convs = config.auxiliary_num_convs + self.concat_input = config.auxiliary_concat_input + self.in_index = in_index + + conv_padding = (kernel_size // 2) * dilation + convs = [] + convs.append( + UperNetConvModule( + self.in_channels, self.channels, kernel_size=kernel_size, padding=conv_padding, dilation=dilation + ) + ) + for i in range(self.num_convs - 1): + convs.append( + UperNetConvModule( + self.channels, self.channels, kernel_size=kernel_size, padding=conv_padding, dilation=dilation + ) + ) + if self.num_convs == 0: + self.convs = nn.Identity() + else: + self.convs = nn.Sequential(*convs) + if self.concat_input: + self.conv_cat = UperNetConvModule( + self.in_channels + self.channels, self.channels, kernel_size=kernel_size, padding=kernel_size // 2 + ) + + self.classifier = nn.Conv2d(self.channels, config.num_labels, kernel_size=1) + + def forward(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor: + # just take the relevant feature maps + hidden_states = encoder_hidden_states[self.in_index] + output = self.convs(hidden_states) + if self.concat_input: + output = self.conv_cat(torch.cat([hidden_states, output], dim=1)) + output = self.classifier(output) + return output + + +class UperNetPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = UperNetConfig + main_input_name = "pixel_values" + _no_split_modules = [] + + def _init_weights(self, module): + if isinstance(module, nn.Conv2d): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.BatchNorm2d): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + + +UPERNET_START_DOCSTRING = r""" + Parameters: + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use + it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + config ([`UperNetConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +UPERNET_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`SegformerImageProcessor.__call__`] for details. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers in case the backbone has them. See + `attentions` under returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers of the backbone. See `hidden_states` under + returned tensors for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + """UperNet framework leveraging any vision backbone e.g. for ADE20k, CityScapes.""", + UPERNET_START_DOCSTRING, +) +class UperNetForSemanticSegmentation(UperNetPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.backbone = load_backbone(config) + + # Semantic segmentation head(s) + self.decode_head = UperNetHead(config, in_channels=self.backbone.channels) + self.auxiliary_head = UperNetFCNHead(config) if config.use_auxiliary_head else None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(UPERNET_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=SemanticSegmenterOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + labels: Optional[torch.Tensor] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, SemanticSegmenterOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*): + Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy). + + Returns: + + Examples: + ```python + >>> from transformers import AutoImageProcessor, UperNetForSemanticSegmentation + >>> from PIL import Image + >>> from huggingface_hub import hf_hub_download + + >>> image_processor = AutoImageProcessor.from_pretrained("openmmlab/upernet-convnext-tiny") + >>> model = UperNetForSemanticSegmentation.from_pretrained("openmmlab/upernet-convnext-tiny") + + >>> filepath = hf_hub_download( + ... repo_id="hf-internal-testing/fixtures_ade20k", filename="ADE_val_00000001.jpg", repo_type="dataset" + ... ) + >>> image = Image.open(filepath).convert("RGB") + + >>> inputs = image_processor(images=image, return_tensors="pt") + + >>> outputs = model(**inputs) + + >>> logits = outputs.logits # shape (batch_size, num_labels, height, width) + >>> list(logits.shape) + [1, 150, 512, 512] + ```""" + if labels is not None and self.config.num_labels == 1: + raise ValueError("The number of labels should be greater than one") + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + outputs = self.backbone.forward_with_filtered_kwargs( + pixel_values, output_hidden_states=output_hidden_states, output_attentions=output_attentions + ) + features = outputs.feature_maps + + logits = self.decode_head(features) + logits = nn.functional.interpolate(logits, size=pixel_values.shape[2:], mode="bilinear", align_corners=False) + + auxiliary_logits = None + if self.auxiliary_head is not None: + auxiliary_logits = self.auxiliary_head(features) + auxiliary_logits = nn.functional.interpolate( + auxiliary_logits, size=pixel_values.shape[2:], mode="bilinear", align_corners=False + ) + + loss = None + if labels is not None: + # compute weighted loss + loss_fct = CrossEntropyLoss(ignore_index=self.config.loss_ignore_index) + loss = loss_fct(logits, labels) + if auxiliary_logits is not None: + auxiliary_loss = loss_fct(auxiliary_logits, labels) + loss += self.config.auxiliary_loss_weight * auxiliary_loss + + if not return_dict: + if output_hidden_states: + output = (logits,) + outputs[1:] + else: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SemanticSegmenterOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = ["UperNetForSemanticSegmentation", "UperNetPreTrainedModel"] diff --git a/docs/transformers/src/transformers/models/video_llava/__init__.py b/docs/transformers/src/transformers/models/video_llava/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bb5f5971a167997ffc89de76d5cc5ba3d8216b4b --- /dev/null +++ b/docs/transformers/src/transformers/models/video_llava/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_video_llava import * + from .image_processing_video_llava import * + from .modeling_video_llava import * + from .processing_video_llava import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/src/transformers/models/video_llava/configuration_video_llava.py b/docs/transformers/src/transformers/models/video_llava/configuration_video_llava.py new file mode 100644 index 0000000000000000000000000000000000000000..402d99467316030b71b8d30150c3a8cc875e0852 --- /dev/null +++ b/docs/transformers/src/transformers/models/video_llava/configuration_video_llava.py @@ -0,0 +1,143 @@ +# coding=utf-8 +# Copyright 2024 Microsoft Research & University of Wisconsin-Madison and the HuggingFace Inc. team. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""VideoLlava model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ..auto import CONFIG_MAPPING, AutoConfig + + +logger = logging.get_logger(__name__) + + +class VideoLlavaConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`VideoLlavaForConditionalGeneration`]. It is used to instantiate an + VideoLlava model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the like LanguageBind/Video-LLaVA-7B-hf. + + e.g. [LanguageBind/Video-LLaVA-7B-hf](https://huggingface.co/LanguageBind/Video-LLaVA-7B-hf) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vision_config (`VideoLlavaVisionConfig`, *optional*): + Custom vision config or dict. Defaults to `CLIPVisionConfig` if not indicated. + text_config (`Union[AutoConfig, dict]`, *optional*): + The config object of the text backbone. Can be any of `LlamaConfig` or `MistralConfig`. + Defaults to `LlamaConfig` if not indicated. + image_token_index (`int`, *optional*, defaults to 32000): + The image token index to encode the image prompt. + video_token_index (`int`, *optional*, defaults to 32001): + The video token index to encode the image prompt. + projector_hidden_act (`str`, *optional*, defaults to `"gelu"`): + The activation function used by the multimodal projector. + vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`): + The feature selection strategy used to select the vision feature from the CLIP backbone. + Can be either "full" to select all features or "default" to select features without `CLS`. + vision_feature_layer (`Union[int, List[int]]`, *optional*, defaults to -2): + The index of the layer to select the vision feature. If multiple indices are provided, + the vision feature of the corresponding indices will be concatenated to form the + vision features. + image_seq_length (`int`, *optional*, defaults to 256): + Sequence length of one image embedding. + video_seq_length (`int`, *optional*, defaults to 2056): + Sequence length of one video embedding. + multimodal_projector_bias (`bool`, *optional*, defaults to `True`): + Whether to use bias in the multimodal projector. + + Example: + + ```python + >>> from transformers import VideoLlavaForConditionalGeneration, VideoLlavaConfig, CLIPVisionConfig, LlamaConfig + + >>> # Initializing a CLIP-vision config + >>> vision_config = CLIPVisionConfig() + + >>> # Initializing a Llama config + >>> text_config = LlamaConfig() + + >>> # Initializing a VideoLlava video_llava-1.5-7b style configuration + >>> configuration = VideoLlavaConfig(vision_config, text_config) + + >>> # Initializing a model from the video_llava-1.5-7b style configuration + >>> model = VideoLlavaForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "video_llava" + attribute_map = { + "image_token_id": "image_token_index", + "video_token_id": "video_token_index", + } + sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig} + + def __init__( + self, + vision_config=None, + text_config=None, + image_token_index=32000, + video_token_index=32001, + projector_hidden_act="gelu", + vision_feature_select_strategy="default", + vision_feature_layer=-2, + image_seq_length=256, + video_seq_length=2056, + multimodal_projector_bias=True, + **kwargs, + ): + self.image_token_index = image_token_index + self.video_token_index = video_token_index + self.projector_hidden_act = projector_hidden_act + self.vision_feature_select_strategy = vision_feature_select_strategy + self.vision_feature_layer = vision_feature_layer + self.image_seq_length = image_seq_length + self.video_seq_length = video_seq_length + self.multimodal_projector_bias = multimodal_projector_bias + + self.vision_config = vision_config + + if isinstance(self.vision_config, dict): + if "model_type" not in vision_config: + vision_config["model_type"] = "clip_vision_model" + logger.warning("Key=`model_type` not found in vision config, setting it to `clip_vision_model`") + self.vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config) + elif vision_config is None: + self.vision_config = CONFIG_MAPPING["clip_vision_model"]( + intermediate_size=4096, + hidden_size=1024, + patch_size=14, + image_size=224, + num_hidden_layers=24, + num_attention_heads=16, + vocab_size=32000, + projection_dim=768, + ) + + if isinstance(text_config, dict): + if "model_type" not in text_config: + text_config["model_type"] = "llama" + logger.warning("Key=`model_type` not found in text config, setting it to `llama`") + text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) + elif text_config is None: + text_config = CONFIG_MAPPING["llama"]() + + self.text_config = text_config + super().__init__(**kwargs) + + +__all__ = ["VideoLlavaConfig"] diff --git a/docs/transformers/src/transformers/models/video_llava/convert_video_llava_weights_to_hf.py b/docs/transformers/src/transformers/models/video_llava/convert_video_llava_weights_to_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..fff886f8a833564fd1389c60ef92477565785852 --- /dev/null +++ b/docs/transformers/src/transformers/models/video_llava/convert_video_llava_weights_to_hf.py @@ -0,0 +1,159 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse + +import torch +from huggingface_hub import hf_hub_download + +from transformers import ( + AddedToken, + AutoConfig, + AutoTokenizer, + VideoLlavaConfig, + VideoLlavaForConditionalGeneration, + VideoLlavaImageProcessor, + VideoLlavaProcessor, +) + + +EPILOG_TXT = """Example: + python transformers/src/transformers/models/video_llava/convert_video_llava_weights_to_hf.py --text_model_id lmsys/vicuna-7b-v1.5 --vision_model_id openai/clip-vit-large-patch14 --output_hub_path org/video_llava-7b --old_state_dict_id LanguageBind/Video-LLaVA-7B + +Example for creating the old state dict file with Python: + + import torch + from video_llava.model.language_model.video_llava import VideoLlavaForCausalLM + + # load model + kwargs = {"device_map": "auto", "torch_dtype": torch.float16} + model = VideoLlavaForCausalLM.from_pretrained("LanguageBind/Video-LLaVA-7B-hf", low_cpu_mem_usage=True, **kwargs) + + # load vision tower + model.get_vision_tower().load_model() + + # Save state dict + torch.save(model.state_dict(), "tmp/hf_models/video_llava-7b/model_state_dict.bin") +""" + +KEYS_TO_MODIFY_MAPPING = { + "model.video_tower.video_tower": "video_tower", + "model.image_tower.image_tower": "image_tower", + "model.mm_projector": "multi_modal_projector", + "model": "language_model.model", + "lm_head": "language_model.lm_head", + "video_tower": "video_tower.vision_model", + "image_tower": "image_tower.vision_model", + "multi_modal_projector.0": "multi_modal_projector.linear_1", + "multi_modal_projector.2": "multi_modal_projector.linear_2", +} + + +def convert_state_dict_to_hf(state_dict): + new_state_dict = {} + for key, value in state_dict.items(): + if key.endswith(".inv_freq"): + continue + for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items(): + if key_to_modify in key: + key = key.replace(key_to_modify, new_key) + + new_state_dict[key] = value + return new_state_dict + + +def convert_video_llava_llama_to_hf(text_model_id, vision_model_id, output_hub_path, old_state_dict_id): + torch.set_default_dtype(torch.float16) + text_config = AutoConfig.from_pretrained(text_model_id) + + tokenizer = AutoTokenizer.from_pretrained(text_model_id) + tokenizer.add_tokens(AddedToken("", special=True, normalized=False), special_tokens=True) + tokenizer.add_tokens(AddedToken("